Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ddc76a8b
Commit
ddc76a8b
authored
May 31, 2022
by
Anthony Chang
Browse files
AccElemOp for gemm outputs prior to feeding to layernorm
parent
12db5b6d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
11 deletions
+55
-11
example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp
..._gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp
+20
-7
include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp
...eration/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp
+13
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+8
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
...ference_tensor_operation/cpu/reference_gemm_layernorm.hpp
+14
-2
No files found.
example/21_gemm_layernorm/gemm_xdl_layernorm_single_kernel_fp16.cpp
View file @
ddc76a8b
...
@@ -36,20 +36,29 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
...
@@ -36,20 +36,29 @@ using ALayout = ck::tensor_layout::gemm::RowMajor;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
CLayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
struct
Relu
{
template
<
typename
OutT
,
typename
InT
>
__host__
__device__
void
operator
()(
OutT
&
y
,
const
InT
&
x
)
const
{
y
=
x
>
0
?
x
:
0
;
}
};
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
C
ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Acc
ElementOp
=
Relu
;
//
using
Acc
ElementOp =
ck::tensor_operation::element_wise::PassThrough
;
using
C
ElementOp
=
Relu
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmLayerNorm_Xdl_CShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmLayerNorm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B|
C|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| ALayout| BLayout| CLayout| AData| BData| CData| C0Data| GemmAcc| CShuffle| ReduceAcc| A| B|
Acc| C|
GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
//######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise|
Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise|
Elementwise|
Elementwise|
Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation|
| Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | Operation| Operation|
Operation|
Operation|
| Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | |
|
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
C0DataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
C
ElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
<
Row
,
Col
,
Row
,
ADataType
,
BDataType
,
CDataType
,
C0DataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
Acc
ElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
S
<
64
,
4
>
,
4
,
1
>
;
// clang-format on
// clang-format on
using
ReferenceInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemmLayernorm
<
ADataType
,
using
ReferenceInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemmLayernorm
<
ADataType
,
...
@@ -59,6 +68,7 @@ using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADa
...
@@ -59,6 +68,7 @@ using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm<ADa
AccDataType
,
AccDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
AccElementOp
,
CElementOp
>
;
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
...
@@ -176,6 +186,7 @@ int main(int argc, char* argv[])
...
@@ -176,6 +186,7 @@ int main(int argc, char* argv[])
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
acc_element_op
=
AccElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
// do GEMM
...
@@ -195,6 +206,7 @@ int main(int argc, char* argv[])
...
@@ -195,6 +206,7 @@ int main(int argc, char* argv[])
StrideC
,
StrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
c_element_op
);
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
...
@@ -235,6 +247,7 @@ int main(int argc, char* argv[])
...
@@ -235,6 +247,7 @@ int main(int argc, char* argv[])
c_m_n_host_result
,
c_m_n_host_result
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
c_element_op
);
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_layernorm_cshuffle.hpp
View file @
ddc76a8b
...
@@ -33,6 +33,7 @@ template <typename ALayout,
...
@@ -33,6 +33,7 @@ template <typename ALayout,
typename
ReduceAccDataType
,
typename
ReduceAccDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
...
@@ -380,6 +381,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -380,6 +381,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
ReduceAccDataType
,
ReduceAccDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
...
@@ -439,6 +441,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -439,6 +441,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
...
@@ -455,6 +458,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -455,6 +458,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
block_2_ctile_map_
{},
block_2_ctile_map_
{},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
c_element_op_
{
c_element_op
}
c_element_op_
{
c_element_op
}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
if
(
GridwiseGemm
::
CheckValidity
(
...
@@ -489,6 +493,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -489,6 +493,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
AccElementwiseOperation
acc_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
};
};
...
@@ -538,6 +543,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -538,6 +543,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
C0DataType
,
C0DataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
...
@@ -560,6 +566,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -560,6 +566,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
arg
.
p_c0_beta_
,
arg
.
p_c0_beta_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
...
@@ -576,6 +583,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -576,6 +583,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
C0DataType
,
C0DataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
...
@@ -597,6 +605,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -597,6 +605,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
arg
.
p_c0_beta_
,
arg
.
p_c0_beta_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
...
@@ -648,6 +657,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -648,6 +657,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
...
@@ -664,6 +674,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -664,6 +674,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
StrideC
,
StrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
c_element_op
};
c_element_op
};
}
}
...
@@ -683,6 +694,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -683,6 +694,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
index_t
/* KBatch */
=
1
)
index_t
/* KBatch */
=
1
)
{
{
...
@@ -700,6 +712,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -700,6 +712,7 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
StrideC
,
StrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
c_element_op
);
c_element_op
);
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
ddc76a8b
...
@@ -21,6 +21,7 @@ template <typename GridwiseGemm,
...
@@ -21,6 +21,7 @@ template <typename GridwiseGemm,
typename
FloatC0
,
typename
FloatC0
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
...
@@ -41,6 +42,7 @@ __global__ void
...
@@ -41,6 +42,7 @@ __global__ void
const
FloatC0
*
__restrict__
p_c0_beta_grid
,
// 1xN
const
FloatC0
*
__restrict__
p_c0_beta_grid
,
// 1xN
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
...
@@ -62,6 +64,7 @@ __global__ void
...
@@ -62,6 +64,7 @@ __global__ void
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
...
@@ -79,6 +82,7 @@ __global__ void
...
@@ -79,6 +82,7 @@ __global__ void
ignore
=
p_c0_beta_grid
;
ignore
=
p_c0_beta_grid
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
...
@@ -99,6 +103,7 @@ template <typename FloatAB,
...
@@ -99,6 +103,7 @@ template <typename FloatAB,
typename
FloatReduceAcc
,
// Data type after shuffle
typename
FloatReduceAcc
,
// Data type after shuffle
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
...
@@ -377,6 +382,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -377,6 +382,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
AccElementwiseOperation
&
acc_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
...
@@ -630,7 +636,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -630,7 +636,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
FloatCShuffle
,
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
e
lement
_
wise
::
PassThrough
,
AccE
lementwise
Operation
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
...
@@ -654,7 +660,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -654,7 +660,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}
};
acc_element_op
};
// shuffle: blockwise copy C from LDS to global
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_layernorm.hpp
View file @
ddc76a8b
...
@@ -16,6 +16,7 @@ template <typename ADataType,
...
@@ -16,6 +16,7 @@ template <typename ADataType,
typename
AccDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
struct
ReferenceGemmLayernorm
:
public
device
::
BaseOperator
struct
ReferenceGemmLayernorm
:
public
device
::
BaseOperator
{
{
...
@@ -25,7 +26,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
...
@@ -25,7 +26,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
AccDataType
,
AccDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
C
ElementwiseOperation
>
;
Acc
ElementwiseOperation
>
;
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
// D = Layernorm(acc + broadcast(bias)) * broadcast(gamma) + broadcast(beta)
template
<
typename
InDataType
,
typename
OutDataType
,
typename
ComputeDataType
>
template
<
typename
InDataType
,
typename
OutDataType
,
typename
ComputeDataType
>
...
@@ -95,6 +96,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
...
@@ -95,6 +96,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
Tensor
<
CDataType
>&
c_m_n
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
const
CDataType
epsilon
=
1e-5
)
const
CDataType
epsilon
=
1e-5
)
:
a_m_k_
{
a_m_k
},
:
a_m_k_
{
a_m_k
},
...
@@ -105,6 +107,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
...
@@ -105,6 +107,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
c_m_n_
{
c_m_n
},
c_m_n_
{
c_m_n
},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
epsilon_
{
epsilon
}
epsilon_
{
epsilon
}
{
{
...
@@ -119,7 +122,9 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
...
@@ -119,7 +122,9 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
AccElementwiseOperation
acc_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
const
CDataType
epsilon_
;
const
CDataType
epsilon_
;
};
};
...
@@ -140,11 +145,16 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
...
@@ -140,11 +145,16 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
acc_m_n
,
acc_m_n
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
arg
.
ac
c_element_op_
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
RunLayernorm
(
arg
.
c_m_n_
,
acc_m_n
,
arg
.
c0_n_bias_
,
arg
.
c0_n_gamma_
,
arg
.
c0_n_beta_
);
RunLayernorm
(
arg
.
c_m_n_
,
acc_m_n
,
arg
.
c0_n_bias_
,
arg
.
c0_n_gamma_
,
arg
.
c0_n_beta_
);
arg
.
c_m_n_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
){
arg
.
c_element_op_
(
self
(
idx
[
0
],
idx
[
1
]),
self
(
idx
[
0
],
idx
[
1
]));
});
return
0
;
return
0
;
}
}
...
@@ -171,6 +181,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
...
@@ -171,6 +181,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
Tensor
<
CDataType
>&
c_m_n
,
Tensor
<
CDataType
>&
c_m_n
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
const
CDataType
epsilon
=
1e-5
)
const
CDataType
epsilon
=
1e-5
)
{
{
...
@@ -182,6 +193,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
...
@@ -182,6 +193,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
c_m_n
,
c_m_n
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
acc_element_op
,
c_element_op
,
c_element_op
,
epsilon
};
epsilon
};
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment