Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f278b2a5
Commit
f278b2a5
authored
Dec 26, 2022
by
rocking
Browse files
Add EMeanVarDataType parameter.
parent
d3f2dbbd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
86 additions
and
92 deletions
+86
-92
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp
...yernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp
+9
-8
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
...ce/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+62
-65
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+8
-10
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
...mm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
+7
-9
No files found.
example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp
View file @
f278b2a5
...
...
@@ -39,6 +39,7 @@ using CShuffleDataType = F32;
using
D0DataType
=
F16
;
using
D1DataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<
D0DataType
,
D1DataType
>
;
using
EMeanVarDataType
=
F16
;
using
GammaDataType
=
F16
;
using
BetaDataType
=
F16
;
using
HDataType
=
F16
;
...
...
@@ -60,11 +61,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
HLayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
GammaDataType
,
BetaDataType
,
HDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
HElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
32
,
8
>
,
8
,
S
<
8
,
32
>
,
8
>
;
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData|
EMeanVarData|
GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm|
//######| | | | | Type| Type| Type| DataType| Type|
Type|
Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize|
//######| | | | | | | | | |
|
| | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M|
//######| | | | | | | | | |
|
| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
HLayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EMeanVarDataType
,
GammaDataType
,
BetaDataType
,
HDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
HElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
32
,
8
>
,
8
,
S
<
8
,
32
>
,
8
>
;
// clang-format on
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
...
...
@@ -86,7 +87,7 @@ auto f_host_tensor_descriptor2d =
}
};
void
host_gemm_layernorm
(
Tensor
<
H
DataType
>&
e_m_n
,
void
host_gemm_layernorm
(
Tensor
<
EMeanVar
DataType
>&
e_m_n
,
Tensor
<
HDataType
>&
h_m_n
,
const
Tensor
<
ADataType
>&
a_m_k
,
const
Tensor
<
BDataType
>&
b_k_n
,
...
...
@@ -109,7 +110,7 @@ void host_gemm_layernorm(Tensor<HDataType>& e_m_n,
BElementOp
,
PassThrough
>
;
using
ReferenceLayernorm
=
ck
::
tensor_operation
::
host
::
ReferenceLayernorm
<
H
DataType
,
using
ReferenceLayernorm
=
ck
::
tensor_operation
::
host
::
ReferenceLayernorm
<
EMeanVar
DataType
,
GammaDataType
,
BetaDataType
,
HDataType
,
...
...
@@ -229,7 +230,7 @@ int main()
if
(
do_verification
)
{
Tensor
<
H
DataType
>
e_m_n_host
(
HostTensorDescriptor
{
M
,
N
});
Tensor
<
EMeanVar
DataType
>
e_m_n_host
(
HostTensorDescriptor
{
M
,
N
});
Tensor
<
HDataType
>
h_m_n_host
(
HostTensorDescriptor
{
M
,
N
});
host_gemm_layernorm
(
e_m_n_host
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
f278b2a5
...
...
@@ -23,9 +23,7 @@ namespace ck {
template
<
typename
GridwiseGemmWelford
,
typename
ABDataType
,
typename
DsPointer
,
typename
EDataType
,
typename
MeanDataType
,
typename
VarDataType
,
typename
EMeanVarDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
...
...
@@ -45,9 +43,9 @@ __global__ void
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
MeanDataType
*
__restrict__
p_welford_mean_grid
,
VarDataType
*
__restrict__
p_welford_var_grid
,
E
MeanVar
DataType
*
__restrict__
p_e_grid
,
E
Mean
Var
DataType
*
__restrict__
p_welford_mean_grid
,
EMean
VarDataType
*
__restrict__
p_welford_var_grid
,
int32_t
*
__restrict__
p_welford_count_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
...
...
@@ -111,10 +109,8 @@ __global__ void
}
template
<
typename
GridwiseWelfordLayernorm
,
typename
EDataType
,
typename
E
MeanVar
DataType
,
typename
HDataType
,
typename
MeanDataType
,
typename
VarDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
...
...
@@ -128,9 +124,9 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_welford_layernorm2d_second_half
(
const
EDataType
*
__restrict__
p_e_grid
,
const
MeanDataType
*
__restrict__
p_in_welford_mean_grid
,
const
VarDataType
*
__restrict__
p_in_welford_var_grid
,
const
E
MeanVar
DataType
*
__restrict__
p_e_grid
,
const
E
Mean
Var
DataType
*
__restrict__
p_in_welford_mean_grid
,
const
EMean
VarDataType
*
__restrict__
p_in_welford_var_grid
,
const
int32_t
*
__restrict__
p_in_welford_count_grid
,
const
GammaDataType
*
__restrict__
p_gamma_grid
,
const
BetaDataType
*
__restrict__
p_beta_grid
,
...
...
@@ -192,6 +188,7 @@ template <typename ALayout,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EMeanVarDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
HDataType
,
...
...
@@ -249,16 +246,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
CDEElementwiseOperation
,
HElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
;
using
ELayout
=
HLayout
;
// EDataType, MeanDataType and VarDataType must be the same.
// eg. M, N, K = [1, 1, 1],
// in case of layernorm, divisor = 1 / sqrt(var + 1e-5) = 316.227783
// if (x - mean) != 0, (x - mean) * divisor * gamma might be too large
// However, (x - mean) * divisor * gamma should be 0 in this case
using
EDataType
=
HDataType
;
using
MeanDataType
=
HDataTyp
e
;
using
VarDataType
=
HDataType
;
using
DeviceOp
=
DeviceGemmMultipleDLayernorm_Xdl_CShuffl
e
;
using
ELayout
=
HLayout
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
LayernormHDstVectorSize
=
PostShuffleScalarPerVector
;
...
...
@@ -392,9 +387,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
MeanDataType
,
VarDataType
,
EMeanVarDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
...
...
@@ -442,10 +435,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
using
Block2ETileMap
=
typename
GridwiseGemmWelford
::
DefaultBlock2ETileMap
;
using
GridwiseWelfordLayernorm
=
GridwiseWelfordSecondHalfLayernorm2d
<
EDataType
,
GridwiseWelfordSecondHalfLayernorm2d
<
E
MeanVar
DataType
,
HDataType
,
MeanDataType
,
VarDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
...
...
@@ -488,7 +479,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_
e_grid_
{
static_cast
<
EDataType
*>
(
p_h_grid
)
},
p_
workspace_e_grid_
{
nullptr
},
p_workspace_mean_
{
nullptr
},
p_workspace_var_
{
nullptr
},
p_workspace_count_
{
nullptr
},
...
...
@@ -611,7 +602,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemmWelford
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p
_e_grid_
;
void
*
p_workspace
_e_grid_
;
void
*
p_workspace_mean_
;
void
*
p_workspace_var_
;
void
*
p_workspace_count_
;
...
...
@@ -694,9 +685,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
GridwiseGemmWelford
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemmWelford
::
DsGridPointer
,
EDataType
,
MeanDataType
,
VarDataType
,
EMeanVarDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
...
...
@@ -713,10 +702,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
const
auto
kernel_welford_layernorm
=
kernel_welford_layernorm2d_second_half
<
GridwiseWelfordLayernorm
,
EDataType
,
E
MeanVar
DataType
,
HDataType
,
MeanDataType
,
VarDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
...
...
@@ -735,9 +722,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p
_e_grid_
,
static_cast
<
MeanDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
VarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
EMeanVarDataType
*>
(
arg
.
p_workspace
_e_grid_
)
,
static_cast
<
E
Mean
Var
DataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
EMean
VarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
...
...
@@ -760,29 +747,29 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
index_t
numMeanVarCountBlockTileIteration_N
=
math
::
integer_divide_ceil
(
arg
.
gemm_nblock_
,
LayernormThreadClusterSize_M_N
::
At
(
I1
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_welford_layernorm
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p
_e_grid_
,
static_cast
<
const
MeanDataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
const
VarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
const
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
p_gamma_grid_
,
arg
.
p_beta_grid_
,
arg
.
p_h_grid_
,
arg
.
layernorm_e_grid_desc_m_n_
,
arg
.
h_grid_desc_m_n_
,
arg
.
layernorm_mean_var_grid_desc_m_nblock_
,
arg
.
layernorm_count_grid_desc_m_nblock_
,
arg
.
gamma_grid_desc_n_
,
arg
.
beta_grid_desc_n_
,
numMeanVarCountBlockTileIteration_N
,
NBlockClusterLength
,
arg
.
epsilon_
,
arg
.
h_element_op_
);
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_welford_layernorm
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
static_cast
<
EMeanVarDataType
*>
(
arg
.
p_workspace
_e_grid_
)
,
static_cast
<
const
E
Mean
Var
DataType
*>
(
arg
.
p_workspace_mean_
),
static_cast
<
const
EMean
VarDataType
*>
(
arg
.
p_workspace_var_
),
static_cast
<
const
int32_t
*>
(
arg
.
p_workspace_count_
),
arg
.
p_gamma_grid_
,
arg
.
p_beta_grid_
,
arg
.
p_h_grid_
,
arg
.
layernorm_e_grid_desc_m_n_
,
arg
.
h_grid_desc_m_n_
,
arg
.
layernorm_mean_var_grid_desc_m_nblock_
,
arg
.
layernorm_count_grid_desc_m_nblock_
,
arg
.
gamma_grid_desc_n_
,
arg
.
beta_grid_desc_n_
,
numMeanVarCountBlockTileIteration_N
,
NBlockClusterLength
,
arg
.
epsilon_
,
arg
.
h_element_op_
);
return
avg_time
;
};
...
...
@@ -814,14 +801,17 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
int
gemm_welford_size
=
pArg_
->
MRaw_
*
pArg_
->
gemm_nblock_
;
// workspace for welford intermediate mean
workspace_size
+=
gemm_welford_size
*
sizeof
(
MeanDataType
)
+
64
;
workspace_size
+=
gemm_welford_size
*
sizeof
(
E
Mean
Var
DataType
)
+
64
;
// workspace for welford intermediate mean
workspace_size
+=
gemm_welford_size
*
sizeof
(
VarDataType
)
+
64
;
workspace_size
+=
gemm_welford_size
*
sizeof
(
EMean
VarDataType
)
+
64
;
// workspace for welford intermediate count
workspace_size
+=
pArg_
->
gemm_nblock_
*
sizeof
(
int32_t
)
+
64
;
if
constexpr
(
!
is_same_v
<
EMeanVarDataType
,
HDataType
>
)
workspace_size
+=
pArg_
->
MRaw_
*
pArg_
->
NRaw_
*
sizeof
(
EMeanVarDataType
);
return
(
workspace_size
);
};
...
...
@@ -836,20 +826,27 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
// setup buffer used for intermediate welford mean
pArg_
->
p_workspace_mean_
=
static_cast
<
char
*>
(
pArg_
->
p_workspace_
);
index_t
mean_space_sz
=
gemm_welford_size
*
sizeof
(
MeanDataType
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
index_t
mean_space_sz
=
gemm_welford_size
*
sizeof
(
EMeanVarDataType
);
mean_space_sz
=
math
::
integer_least_multiple
(
mean_space_sz
,
64
);
// setup buffer used for intermediate welford varirance
pArg_
->
p_workspace_var_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_mean_
)
+
mean_space_sz
;
index_t
variance_space_sz
=
gemm_welford_size
*
sizeof
(
VarDataType
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
index_t
variance_space_sz
=
gemm_welford_size
*
sizeof
(
EMeanVarDataType
);
variance_space_sz
=
math
::
integer_least_multiple
(
variance_space_sz
,
64
);
// setup buffer used for intermediate welford count
pArg_
->
p_workspace_count_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_var_
)
+
variance_space_sz
;
index_t
count_space_sz
=
gemm_welford_size
*
sizeof
(
int32_t
);
count_space_sz
=
math
::
integer_least_multiple
(
count_space_sz
,
64
);
if
constexpr
(
!
is_same_v
<
EMeanVarDataType
,
HDataType
>
)
pArg_
->
p_workspace_e_grid_
=
reinterpret_cast
<
char
*>
(
pArg_
->
p_workspace_count_
)
+
count_space_sz
;
else
pArg_
->
p_workspace_e_grid_
=
static_cast
<
void
*>
(
pArg_
->
p_h_grid_
);
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
View file @
f278b2a5
...
...
@@ -36,9 +36,7 @@ template <typename ABDataType,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
MeanDataType
,
typename
VarDataType
,
typename
EMeanVarDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
...
...
@@ -329,7 +327,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
if
(
!
(
a_grid_desc_m_k
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
b_grid_desc_n_k
.
GetElementSpaceSize
()
*
sizeof
(
ABDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
))
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
E
MeanVar
DataType
)
<=
TwoGB
))
{
return
false
;
}
...
...
@@ -370,9 +368,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
MeanDataType
*
__restrict__
p_welford_mean_grid
,
VarDataType
*
__restrict__
p_welford_var_grid
,
E
MeanVar
DataType
*
__restrict__
p_e_grid
,
E
Mean
Var
DataType
*
__restrict__
p_welford_mean_grid
,
EMean
VarDataType
*
__restrict__
p_welford_var_grid
,
int32_t
*
__restrict__
p_welford_count
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
...
...
@@ -825,7 +823,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto
e_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
EDataType
,
E
MeanVar
DataType
,
decltype
(
post_shuffle_thread_desc_I1_mperblock_I1_nperblock
),
decltype
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -1042,7 +1040,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto
mean_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
MeanDataType
,
E
Mean
Var
DataType
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
mean_var_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
...
...
@@ -1062,7 +1060,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto
var_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
VarDataType
,
EMean
VarDataType
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
mean_var_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
View file @
f278b2a5
...
...
@@ -19,10 +19,8 @@
namespace
ck
{
template
<
typename
EDataType
,
template
<
typename
E
MeanVar
DataType
,
typename
HDataType
,
typename
MeanDataType
,
typename
VarDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
...
...
@@ -87,9 +85,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
N_BlockTileSize
=
NThreadClusterSize
*
NThreadSliceSize
;
__device__
static
void
Run
(
const
EDataType
*
__restrict__
p_e_grid
,
const
MeanDataType
*
__restrict__
p_in_welford_mean_grid
,
const
VarDataType
*
__restrict__
p_in_welford_var_grid
,
__device__
static
void
Run
(
const
E
MeanVar
DataType
*
__restrict__
p_e_grid
,
const
E
Mean
Var
DataType
*
__restrict__
p_in_welford_mean_grid
,
const
EMean
VarDataType
*
__restrict__
p_in_welford_var_grid
,
const
int32_t
*
__restrict__
p_in_welford_count_grid
,
const
GammaDataType
*
__restrict__
p_gamma_grid
,
const
BetaDataType
*
__restrict__
p_beta_grid
,
...
...
@@ -176,7 +174,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
// IO
auto
threadwise_mean_load_m_nblock
=
ThreadwiseTensorSliceTransfer_v2
<
MeanDataType
,
ThreadwiseTensorSliceTransfer_v2
<
E
Mean
Var
DataType
,
ComputeDataType
,
MeanVarGridDesc_M_NBlock
,
decltype
(
thread_buffer_desc_m_1
),
...
...
@@ -192,7 +190,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_n_cluster_id
));
auto
threadwise_var_load_m_nblock
=
ThreadwiseTensorSliceTransfer_v2
<
VarDataType
,
ThreadwiseTensorSliceTransfer_v2
<
EMean
VarDataType
,
ComputeDataType
,
MeanVarGridDesc_M_NBlock
,
decltype
(
thread_buffer_desc_m_1
),
...
...
@@ -224,7 +222,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_n_cluster_id
));
auto
threadwise_e_load_m_n
=
ThreadwiseTensorSliceTransfer_v2
<
EDataType
,
ThreadwiseTensorSliceTransfer_v2
<
E
MeanVar
DataType
,
ComputeDataType
,
decltype
(
e_grid_desc_m_n
),
decltype
(
thread_buffer_desc_m_n
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment