Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
644df335
Commit
644df335
authored
Jan 30, 2023
by
rocking
Browse files
Merge branch 'develop' into gemm_layernorm_instance
parents
d99640ab
7494c1c6
Changes
254
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1897 additions
and
218 deletions
+1897
-218
include/ck/tensor_operation/gpu/device/device_reduce.hpp
include/ck/tensor_operation/gpu/device/device_reduce.hpp
+28
-8
include/ck/tensor_operation/gpu/device/device_softmax.hpp
include/ck/tensor_operation/gpu/device/device_softmax.hpp
+4
-6
include/ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp
..._operation/gpu/device/impl/device_elementwise_2d_impl.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp
...sor_operation/gpu/device/impl/device_elementwise_impl.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp
...gpu/device/impl/device_elementwise_normalization_impl.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
...r_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
+669
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+0
-1
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp
...ion/gpu/device/impl/device_multiple_reduce_multiblock.hpp
+8
-8
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp
...ion/gpu/device/impl/device_multiple_reduce_threadwise.hpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
...r_operation/gpu/device/impl/device_normalization_impl.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
...or_operation/gpu/device/impl/device_reduce_multiblock.hpp
+16
-8
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
...or_operation/gpu/device/impl/device_reduce_threadwise.hpp
+15
-6
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
.../tensor_operation/gpu/device/impl/device_softmax_impl.hpp
+12
-11
include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp
...evice/impl/device_sparse_embeddings_forward_layernorm.hpp
+44
-61
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+70
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+44
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp
.../tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp
+157
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
+744
-0
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
...gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
+61
-84
No files found.
include/ck/tensor_operation/gpu/device/device_reduce.hpp
View file @
644df335
...
...
@@ -13,10 +13,16 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
Rank
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
OutputIndex
>
struct
DeviceReduce
:
public
BaseOperator
{
static
constexpr
index_t
NumOutDim
=
(
Rank
-
NumReduceDim
==
0
)
?
1
:
Rank
-
NumReduceDim
;
...
...
@@ -27,8 +33,8 @@ struct DeviceReduce : public BaseOperator
const
std
::
array
<
index_t
,
NumOutDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumOutDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
...
...
@@ -39,12 +45,26 @@ struct DeviceReduce : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
Rank
,
template
<
typename
InDataType
,
typename
AccDataType
,
typename
OutDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
typename
ReduceOperation
,
typename
InElementwiseOperation
,
typename
AccElementwiseOperation
>
using
DeviceReducePtr
=
std
::
unique_ptr
<
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>>
;
typename
AccElementwiseOperation
,
bool
PropagateNan
,
bool
OutputIndex
>
using
DeviceReducePtr
=
std
::
unique_ptr
<
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_softmax.hpp
View file @
644df335
...
...
@@ -27,10 +27,8 @@ struct DeviceSoftmax : public BaseOperator
// @param[in] inLengths Input tensor extent(s) from high to low dimension
// @param[in] inStrides Input tensor stride(s) from high to low dimension
// @param[in] reduceDims The dimension(s) the normalization operation is applied
// @param[in] alpha Typeless pointer in host memory storing the alpha scaling
// value as type AccDataType
// @param[in] beta Typeless pointer in host memory storing the beta scaling
// value as type AccDataType
// @param[in] alpha double type value
// @param[in] beta double type value
// @param[in] in_dev Typeless const pointer in device memory storing the input
// tensor
// @param out_dev Typeless pointer in device memory storing the output tensor
...
...
@@ -43,8 +41,8 @@ struct DeviceSoftmax : public BaseOperator
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
...
...
include/ck/tensor_operation/gpu/device/device_elementwise_2d.hpp
→
include/ck/tensor_operation/gpu/device/
impl/
device_elementwise_2d
_impl
.hpp
View file @
644df335
...
...
@@ -8,7 +8,7 @@
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise
_base
.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
...
@@ -26,10 +26,10 @@ template <typename InDataTypeTuple,
index_t
NPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
DeviceElementwise
:
public
DeviceElementwise
Base
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim_m
+
NumDim_n
>
struct
DeviceElementwise
2dImpl
:
public
DeviceElementwise
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim_m
+
NumDim_n
>
{
static
constexpr
index_t
NumDim
=
NumDim_m
+
NumDim_n
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_elementwise.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_elementwise
_impl
.hpp
View file @
644df335
...
...
@@ -8,7 +8,7 @@
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise
_base
.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
...
@@ -25,8 +25,8 @@ template <typename InDataTypeTuple,
index_t
MPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
DeviceElementwise
:
public
DeviceElementwise
Base
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim
>
struct
DeviceElementwise
Impl
:
public
DeviceElementwise
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim
>
{
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
static
constexpr
int
NumOutput
=
OutDataTypeTuple
::
Size
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp
View file @
644df335
...
...
@@ -270,18 +270,18 @@ struct DeviceElementwiseNormalizationImpl
const
std
::
vector
<
index_t
>
reduceDims
,
XElementwiseOperation
x_elementwise_op
,
YElementwiseOperation
y_elementwise_op
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
:
epsilon_
(
epsilon
),
p_gamma_
(
p_gamma
),
:
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
x_elementwise_op_
(
x_elementwise_op
),
y_elementwise_op_
(
y_elementwise_op
)
{
epsilon_
=
static_cast
<
AccDataType
>
(
epsilon
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
for
(
int
i
=
0
;
i
<
NumInput
;
i
++
)
...
...
@@ -543,7 +543,7 @@ struct DeviceElementwiseNormalizationImpl
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
0 → 100644
View file @
644df335
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
644df335
...
...
@@ -431,9 +431,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
...
...
@@ -471,6 +468,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
arg
.
block_2_etile_map_
);
};
const
auto
K
=
arg
.
a_grid_desc_m_k_
.
GetLength
(
I1
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
644df335
...
...
@@ -486,7 +486,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp
View file @
644df335
...
...
@@ -73,8 +73,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
static_for
<
0
,
NumReduction
,
1
>
{}([
&
](
auto
I
)
{
using
OutDataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
flag
=
flag
&&
ck
::
reduce
::
InMemoryDataOperatonSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
;
flag
&&
ck
::
reduce
::
InMemoryDataOperat
i
onSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
;
});
return
flag
;
...
...
@@ -270,8 +270,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const
std
::
array
<
index_t
,
NumOutputDim
>&
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>&
outStridesArray
,
const
std
::
array
<
int
,
NumReduceDim
>&
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>&
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>&
betas
,
const
std
::
array
<
double
,
NumReduction
>&
alphas
,
const
std
::
array
<
double
,
NumReduction
>&
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>&
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
@@ -286,8 +286,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
for
(
size_t
i
=
0
;
i
<
NumReduction
;
i
++
)
{
alpha_values_
(
i
)
=
*
static_cast
<
const
AccDataType
*
>
(
alphas
[
i
]);
beta_values_
(
i
)
=
*
static_cast
<
const
AccDataType
*
>
(
betas
[
i
]);
alpha_values_
(
i
)
=
static_cast
<
AccDataType
>
(
alphas
[
i
]);
beta_values_
(
i
)
=
static_cast
<
AccDataType
>
(
betas
[
i
]);
};
in_dev_
=
static_cast
<
const
InDataType
*>
(
in_dev
);
...
...
@@ -547,8 +547,8 @@ struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce<Rank,
const
std
::
array
<
index_t
,
NumOutputDim
>
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>
outStridesArray
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>
betas
,
const
std
::
array
<
double
,
NumReduction
>
alphas
,
const
std
::
array
<
double
,
NumReduction
>
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp
View file @
644df335
...
...
@@ -195,8 +195,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const
std
::
array
<
index_t
,
NumOutputDim
>&
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>&
outStridesArray
,
const
std
::
array
<
int
,
NumReduceDim
>&
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>&
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>&
betas
,
const
std
::
array
<
double
,
NumReduction
>&
alphas
,
const
std
::
array
<
double
,
NumReduction
>&
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>&
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
@@ -211,8 +211,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
for
(
size_t
i
=
0
;
i
<
NumReduction
;
i
++
)
{
alpha_values_
(
i
)
=
*
static_cast
<
const
AccDataType
*
>
(
alphas
[
i
]);
beta_values_
(
i
)
=
*
static_cast
<
const
AccDataType
*
>
(
betas
[
i
]);
alpha_values_
(
i
)
=
static_cast
<
AccDataType
>
(
alphas
[
i
]);
beta_values_
(
i
)
=
static_cast
<
AccDataType
>
(
betas
[
i
]);
};
in_dev_
=
static_cast
<
const
InDataType
*>
(
in_dev
);
...
...
@@ -374,8 +374,8 @@ struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce<Rank,
const
std
::
array
<
index_t
,
NumOutputDim
>
outLengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumOutputDim
>
,
NumReduction
>
outStridesArray
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
const
std
::
array
<
const
void
*
,
NumReduction
>
alphas
,
const
std
::
array
<
const
void
*
,
NumReduction
>
betas
,
const
std
::
array
<
double
,
NumReduction
>
alphas
,
const
std
::
array
<
double
,
NumReduction
>
betas
,
const
void
*
in_dev
,
const
std
::
array
<
void
*
,
NumReduction
>
out_dev_buffers
,
const
InElementwiseOperationTuple
in_elementwise_op_tuple
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
View file @
644df335
...
...
@@ -221,18 +221,19 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccElementwiseOperation
acc_elementwise_op
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
:
epsilon_
(
epsilon
),
p_x_
(
p_x
),
:
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
acc_elementwise_op_
(
acc_elementwise_op
)
{
epsilon_
=
static_cast
<
AccDataType
>
(
epsilon
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
...
...
@@ -421,7 +422,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataTyp
e
epsilon
,
doubl
e
epsilon
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
View file @
644df335
...
...
@@ -40,8 +40,16 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceReduceMultiBlock
:
public
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>
struct
DeviceReduceMultiBlock
:
public
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
...
...
@@ -67,8 +75,8 @@ struct DeviceReduceMultiBlock
static
constexpr
bool
use_multiblock
=
(
OutMemoryDataOperation
==
InMemoryDataOperationEnum
::
AtomicAdd
);
static_assert
(
ck
::
reduce
::
InMemoryDataOperatonSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
,
static_assert
(
ck
::
reduce
::
InMemoryDataOperat
i
onSupportedOnDataType
<
OutMemoryDataOperation
,
OutDataType
>::
value
,
"The OutDataType must support the specified OutMemoryDataOperation!"
);
static_assert
(
!
use_multiblock
||
(
use_multiblock
&&
!
OutputIndex
),
...
...
@@ -209,8 +217,8 @@ struct DeviceReduceMultiBlock
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
InDataType
*
in_dev
,
const
IndexDataType
*
in_index_dev
,
OutDataType
*
out_dev
,
...
...
@@ -494,8 +502,8 @@ struct DeviceReduceMultiBlock
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
View file @
644df335
...
...
@@ -35,8 +35,17 @@ template <typename InDataType,
index_t
InSrcVectorDim
,
index_t
InSrcVectorSize
,
index_t
OutDstVectorSize
>
struct
DeviceReduceThreadWise
:
public
DeviceReduce
<
Rank
,
NumReduceDim
,
InElementwiseOperation
,
AccElementwiseOperation
>
struct
DeviceReduceThreadWise
:
public
DeviceReduce
<
InDataType
,
AccDataType
,
OutDataType
,
Rank
,
NumReduceDim
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
PropagateNan
,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
...
...
@@ -156,8 +165,8 @@ struct DeviceReduceThreadWise
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
IndexDataType
*
out_index_dev
,
...
...
@@ -332,8 +341,8 @@ struct DeviceReduceThreadWise
const
std
::
array
<
index_t
,
NumDstDim
>
outLengths
,
const
std
::
array
<
index_t
,
NumDstDim
>
outStrides
,
const
std
::
array
<
int
,
NumReduceDim
>
reduceDims
,
float
alpha
,
float
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
const
void
*
in_index_dev
,
void
*
out_dev
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp
View file @
644df335
...
...
@@ -156,19 +156,20 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
Argument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataTyp
e
alpha
,
AccDataTyp
e
beta
,
doubl
e
alpha
,
doubl
e
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
AccElementwiseOp
acc_elementwise_op
)
:
alpha_
{
alpha
},
beta_
{
beta
},
in_dev_
{
in_dev
},
:
in_dev_
{
in_dev
},
out_dev_
{
out_dev
},
in_elementwise_op_
{
in_elementwise_op
},
acc_elementwise_op_
{
acc_elementwise_op
}
{
alpha_
=
static_cast
<
AccDataType
>
(
alpha
);
beta_
=
static_cast
<
AccDataType
>
(
beta
);
if
(
Rank
!=
inLengths
.
size
()
||
Rank
!=
inStrides
.
size
()
||
NumReduceDim
!=
reduceDims
.
size
())
{
...
...
@@ -336,8 +337,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
static
auto
MakeArgument
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
AccDataTyp
e
alpha
,
const
AccDataTyp
e
beta
,
doubl
e
alpha
,
doubl
e
beta
,
const
InDataType
*
in_dev
,
OutDataType
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
...
...
@@ -375,8 +376,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inLengths
,
const
std
::
vector
<
index_t
>
inStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
void
*
alpha
,
const
void
*
beta
,
double
alpha
,
double
beta
,
const
void
*
in_dev
,
void
*
out_dev
,
InElementwiseOp
in_elementwise_op
,
...
...
@@ -385,8 +386,8 @@ struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
return
std
::
make_unique
<
Argument
>
(
inLengths
,
inStrides
,
reduceDims
,
*
static_cast
<
const
AccDataType
*>
(
alpha
)
,
*
static_cast
<
const
AccDataType
*>
(
beta
)
,
alpha
,
beta
,
static_cast
<
const
InDataType
*>
(
in_dev
),
static_cast
<
OutDataType
*>
(
out_dev
),
in_elementwise_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding
3
_forward_layernorm.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding
s
_forward_layernorm.hpp
View file @
644df335
...
...
@@ -12,7 +12,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
3
_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
s
_forward_layernorm.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -24,16 +24,17 @@ template <typename EmbType,
typename
BetaDataType
,
typename
AccDataType
,
typename
OutType
,
typename
EmbElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
DimClusterSize
,
ck
::
index_t
RowClusterSize
,
ck
::
index_t
DimPerBlock
,
ck
::
index_t
RowPerBlock
,
ck
::
index_t
DimThreadSize
,
ck
::
index_t
RowVectorSize
>
struct
DeviceSparseEmbedding3ForwardLayernorm
:
public
BaseOperator
ck
::
index_t
RowVectorSize
,
ck
::
index_t
NumEmbeddings
>
struct
DeviceSparseEmbeddingsForwardLayernorm
:
public
BaseOperator
{
static
auto
MakeOutputDescriptor
(
const
index_t
index_length
,
const
index_t
rows
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
index_length
,
rows
));
...
...
@@ -42,96 +43,79 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
struct
Argument
:
public
BaseArgument
{
Argument
(
OutType
*
p_out
,
const
EmbType
*
p_emb_a
,
const
EmbType
*
p_emb_b
,
const
EmbType
*
p_emb_c
,
const
IndexType
*
p_index_a
,
const
IndexType
*
p_index_b
,
const
IndexType
*
p_index_c
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>&
p_embs
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>&
p_indexs
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
ck
::
index_t
NumRows
,
const
ck
::
index_t
EmbeddingDim
,
const
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
)
const
AccDataType
epsilon
,
const
EmbElementwiseOperation
emb_elementwise_op
)
:
p_out_
(
p_out
),
p_emb_a_
(
p_emb_a
),
p_emb_b_
(
p_emb_b
),
p_emb_c_
(
p_emb_c
),
p_index_a_
(
p_index_a
),
p_index_b_
(
p_index_b
),
p_index_c_
(
p_index_c
),
p_embs_
(
p_embs
),
p_indexs_
(
p_indexs
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
NumRows_
(
NumRows
),
EmbeddingDim_
(
EmbeddingDim
),
IndexLength_
(
IndexLength
),
epsilon_
(
epsilon
)
epsilon_
(
epsilon
),
emb_elementwise_op_
(
emb_elementwise_op
)
{
grid_size_
=
(
IndexLength
+
DimClusterSize
-
1
)
/
DimClusterSize
;
}
OutType
*
p_out_
;
const
EmbType
*
p_emb_a_
;
const
EmbType
*
p_emb_b_
;
const
EmbType
*
p_emb_c_
;
const
IndexType
*
p_index_a_
;
const
IndexType
*
p_index_b_
;
const
IndexType
*
p_index_c_
;
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>
p_embs_
;
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>
p_indexs_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
ck
::
index_t
NumRows_
;
ck
::
index_t
EmbeddingDim_
;
ck
::
index_t
IndexLength_
;
AccDataType
epsilon_
;
EmbElementwiseOperation
emb_elementwise_op_
;
size_t
grid_size_
;
};
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_out
,
const
void
*
p_emb_a
,
const
void
*
p_emb_b
,
const
void
*
p_emb_c
,
const
void
*
p_index_a
,
const
void
*
p_index_b
,
const
void
*
p_index_c
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
ck
::
index_t
NumRows
,
ck
::
index_t
EmbeddingDim
,
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
)
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_out
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>&
p_embs
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>&
p_indexs
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
ck
::
index_t
EmbeddingDim
,
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
,
const
EmbElementwiseOperation
emb_elementwise_op
)
{
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
OutType
*>
(
p_out
),
reinterpret_cast
<
const
EmbType
*>
(
p_emb_a
),
reinterpret_cast
<
const
EmbType
*>
(
p_emb_b
),
reinterpret_cast
<
const
EmbType
*>
(
p_emb_c
),
reinterpret_cast
<
const
IndexType
*>
(
p_index_a
),
reinterpret_cast
<
const
IndexType
*>
(
p_index_b
),
reinterpret_cast
<
const
IndexType
*>
(
p_index_c
),
p_embs
,
p_indexs
,
reinterpret_cast
<
const
GammaDataType
*>
(
p_gamma
),
reinterpret_cast
<
const
BetaDataType
*>
(
p_beta
),
NumRows
,
EmbeddingDim
,
IndexLength
,
epsilon
);
epsilon
,
emb_elementwise_op
);
}
using
GridwiseSparseEmbedding
=
GridwiseSparseEmbedding
3
ForwardLayernorm
<
EmbType
,
GridwiseSparseEmbedding
s
ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
decltype
(
MakeOutputDescriptor
(
1
,
1
)),
EmbElementwiseOperation
,
BlockSize
,
DimClusterSize
,
RowClusterSize
,
DimPerBlock
,
RowPerBlock
,
DimThreadSize
,
RowVectorSize
>
;
RowVectorSize
,
NumEmbeddings
>
;
struct
Invoker
:
public
BaseInvoker
{
...
...
@@ -139,14 +123,16 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
{
auto
out_desc
=
MakeOutputDescriptor
(
arg
.
IndexLength_
,
arg
.
EmbeddingDim_
);
const
auto
kernel_main
=
kernel_sparse_embedding
3
_forward_layernorm
<
GridwiseSparseEmbedding
,
kernel_sparse_embedding
s
_forward_layernorm
<
GridwiseSparseEmbedding
,
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
decltype
(
out_desc
)
>
;
decltype
(
out_desc
),
EmbElementwiseOperation
,
NumEmbeddings
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_main
,
...
...
@@ -154,16 +140,13 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
dim3
(
BlockSize
),
0
,
arg
.
p_out_
,
arg
.
p_emb_a_
,
arg
.
p_emb_b_
,
arg
.
p_emb_c_
,
arg
.
p_index_a_
,
arg
.
p_index_b_
,
arg
.
p_index_c_
,
arg
.
p_embs_
,
arg
.
p_indexs_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
out_desc
,
arg
.
epsilon_
);
arg
.
epsilon_
,
arg
.
emb_elementwise_op_
);
return
(
avg_time
);
}
...
...
@@ -177,7 +160,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
static
bool
IsSupportedArgument
(
const
Argument
*
p_arg
)
{
return
(
RowPerBlock
==
p_arg
->
EmbeddingDim_
)
&&
(
p_arg
->
NumRows_
%
DimPerBlock
==
0
)
;
return
(
RowPerBlock
==
p_arg
->
EmbeddingDim_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -195,7 +178,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceSparseEmbedding
3
ForwardLayernorm_"
<<
BlockSize
<<
"_"
<<
str
<<
"DeviceSparseEmbedding
s
ForwardLayernorm_"
<<
BlockSize
<<
"_"
<<
DimClusterSize
<<
"x"
<<
RowClusterSize
<<
"_"
<<
DimPerBlock
<<
"x"
<<
RowPerBlock
<<
"_"
<<
DimThreadSize
<<
"x"
<<
RowVectorSize
;
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
644df335
...
...
@@ -172,6 +172,42 @@ struct AddAdd
}
};
// C = A * B
// E = (C + D0) x D1
struct
AddMultiply
{
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
half_t
,
half_t
,
half_t
>
(
half_t
&
e
,
const
half_t
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
half_t
y
=
(
c
+
d0
)
*
d1
;
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
,
half_t
,
half_t
>
(
half_t
&
e
,
const
float
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
half_t
y
=
(
type_convert
<
half_t
>
(
c
)
+
d0
)
*
d1
;
e
=
y
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
,
half_t
,
half_t
>
(
float
&
e
,
const
float
&
c
,
const
half_t
&
d0
,
const
half_t
&
d1
)
const
{
const
float
y
=
(
c
+
d0
)
*
d1
;
e
=
y
;
}
};
// C = A * B
// E = FastGelu(C + D0 + D1)
struct
AddAddFastGelu
...
...
@@ -278,6 +314,40 @@ struct Normalize
double
epsilon_
;
};
// used by BatchNorm inference
// y = gamma * (x-mean) / sqrt(epsilon+variance) + beta
// The data type of mean and variance is used as AccDataType
struct
NormalizeInInfer
{
NormalizeInInfer
(
double
epsilon
=
1e-4
)
:
epsilon_
(
epsilon
)
{}
template
<
typename
T1
,
typename
T2
,
typename
T3
,
typename
T4
>
__host__
__device__
constexpr
void
operator
()(
T1
&
y
,
const
T1
&
x
,
const
T2
&
mean
,
const
T2
&
variance
,
const
T3
&
gamma
,
const
T4
&
beta
)
const
{
static_assert
(
std
::
is_same
<
T2
,
float
>::
value
||
std
::
is_same
<
T2
,
double
>::
value
,
"Data type is not supported by this operation!"
);
using
ck
::
type_convert
;
using
ck
::
math
::
sqrt
;
T2
tmp_x
,
tmp_y
;
tmp_x
=
type_convert
<
T2
>
(
x
);
tmp_y
=
((
tmp_x
-
mean
)
/
sqrt
(
variance
+
type_convert
<
T2
>
(
epsilon_
)))
*
type_convert
<
T2
>
(
gamma
)
+
type_convert
<
T2
>
(
beta
);
y
=
type_convert
<
T1
>
(
tmp_y
);
};
double
epsilon_
;
};
template
<
typename
Y
,
typename
X
>
struct
UnaryTypeConvert
;
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
644df335
...
...
@@ -154,6 +154,50 @@ struct BlockToCTileMap_M00_N0_M01Adapt
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
return
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp
0 → 100644
View file @
644df335
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
namespace
ck
{
template
<
typename
TileLoadThreadGroup
,
index_t
NumGemmKPrefetchStage
>
struct
GridwiseGemmLoadWave
;
// 1-stage prefetch
template
<
typename
TileLoadThreadGroup
>
struct
GridwiseGemmLoadWave
<
TileLoadThreadGroup
,
1
>
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
// TODO: improve applicability
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
>
static
__device__
void
RunLoadWavePipeline
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
index_t
num_loop
)
{
// global read 0
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// move to 1
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write 0
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
// sync for Load threads()
block_sync_lds
();
// global read i + 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// move to i + 2
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// sync with math threads()
block_sync_lds
();
// LDS write i+1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
// GEMM num_loop - 1
}
}
};
template
<
typename
TileMathThreadGroup
,
index_t
NumGemmKPrefetchStage
>
struct
GridwiseGemmMathWave
;
// 1- stage prefetch
template
<
typename
TileMathThreadGroup
>
struct
GridwiseGemmMathWave
<
TileMathThreadGroup
,
1
>
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
RunMathWavePipeline
(
ABlockBuffer
&
a_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BlockwiseGemm
&
block_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
// GEMM i
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
// GEMM num_loop - 1
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
0 → 100644
View file @
644df335
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
3
_forward_layernorm.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
s
_forward_layernorm.hpp
View file @
644df335
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment