Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
78ff5f81
"vscode:/vscode.git/clone" did not exist on "5c0736c9e059c6de77629bdba5c4ee6c63387663"
Commit
78ff5f81
authored
Dec 02, 2022
by
rocking
Browse files
Implement layernorm
parent
a4e34d88
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
273 additions
and
144 deletions
+273
-144
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+38
-39
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
...mm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
+233
-103
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
78ff5f81
...
@@ -111,8 +111,9 @@ template <typename GridwiseWelfordLayernorm,
...
@@ -111,8 +111,9 @@ template <typename GridwiseWelfordLayernorm,
typename
BetaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
EHGridDesc_M_N
,
typename
EHGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_NBlock
,
typename
GammaBetaGridDesc_N
>
typename
GammaBetaGridDesc_N
,
typename
HElementwiseOperation
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -127,13 +128,13 @@ __global__ void
...
@@ -127,13 +128,13 @@ __global__ void
HDataType
*
__restrict__
p_h_grid
,
HDataType
*
__restrict__
p_h_grid
,
const
EHGridDesc_M_N
e_grid_desc_m_n
,
const
EHGridDesc_M_N
e_grid_desc_m_n
,
const
EHGridDesc_M_N
h_grid_desc_m_n
,
const
EHGridDesc_M_N
h_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_N
mean_var_count_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_N
Block
mean_var_count_grid_desc_m_n
block
,
const
GammaBetaGridDesc_N
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
beta_grid_desc_n
,
const
GammaBetaGridDesc_N
beta_grid_desc_n
,
index_t
blkgroup_size
,
index_t
numMeanVarCountBlockTileIteration_N
,
index_t
num
_mean_var_count_k_b
lock
_t
ile
_i
teration
,
index_t
num
NormB
lock
T
ile
I
teration
_N
,
index_t
num_xy_k_block_tile_iterati
on
,
ComputeDataType
epsil
on
,
ComputeDataType
epsilon
)
HElementwiseOperation
h_element_op
)
{
{
GridwiseWelfordLayernorm
::
Run
(
p_e_grid
,
GridwiseWelfordLayernorm
::
Run
(
p_e_grid
,
p_in_welford_mean_grid
,
p_in_welford_mean_grid
,
...
@@ -144,13 +145,13 @@ __global__ void
...
@@ -144,13 +145,13 @@ __global__ void
p_h_grid
,
p_h_grid
,
e_grid_desc_m_n
,
e_grid_desc_m_n
,
h_grid_desc_m_n
,
h_grid_desc_m_n
,
mean_var_count_grid_desc_m_n
,
mean_var_count_grid_desc_m_n
block
,
gamma_grid_desc_n
,
gamma_grid_desc_n
,
beta_grid_desc_n
,
beta_grid_desc_n
,
blkgroup_size
,
numMeanVarCountBlockTileIteration_N
,
num
_mean_var_count_k_b
lock
_t
ile
_i
teration
,
num
NormB
lock
T
ile
I
teration
_N
,
num_xy_k_block_tile_iterati
on
,
epsil
on
,
epsilon
);
h_element_op
);
}
}
}
// namespace ck
}
// namespace ck
...
@@ -371,12 +372,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -371,12 +372,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}
}
};
};
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
MeanVarCountGridDesc_M_N
=
decltype
(
MakeMeanVarCountGridDescriptor_M_NBlock
(
1
,
1
));
using
MeanVarCountGridDesc_M_N
Block
=
decltype
(
MakeMeanVarCountGridDescriptor_M_NBlock
(
1
,
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
...
@@ -394,7 +395,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -394,7 +395,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
DsGridDesc_M_N
,
EHGridDesc_M_N
,
EHGridDesc_M_N
,
MeanVarCountGridDesc_M_N
,
MeanVarCountGridDesc_M_N
Block
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
MPerBlock
,
MPerBlock
,
...
@@ -439,8 +440,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -439,8 +440,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType
,
BetaDataType
,
AccDataType
,
AccDataType
,
EHGridDesc_M_N
,
EHGridDesc_M_N
,
MeanVarCountGridDesc_M_N
,
MeanVarCountGridDesc_M_N
Block
,
GammaBetaGridDesc_N
,
GammaBetaGridDesc_N
,
HElementwiseOperation
,
BlockSize
,
BlockSize
,
LayernormThreadClusterSize_M_N
::
At
(
I0
),
LayernormThreadClusterSize_M_N
::
At
(
I0
),
LayernormThreadClusterSize_M_N
::
At
(
I1
),
LayernormThreadClusterSize_M_N
::
At
(
I1
),
...
@@ -488,7 +490,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -488,7 +490,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_m_n_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
mean_var_count_grid_desc_m_n_
{},
mean_var_count_grid_desc_m_n
block
_
{},
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
...
@@ -504,7 +506,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -504,7 +506,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
gemm_nblock_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)},
epsilon_
{
epsilon
}
epsilon_
{
epsilon
}
{
{
mean_var_count_grid_desc_m_n_
=
mean_var_count_grid_desc_m_n
block
_
=
DeviceOp
::
MakeMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
DeviceOp
::
MakeMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
hip_check_error
(
hipMalloc
(
&
p_e_grid_
,
sizeof
(
EDataType
)
*
MRaw
*
NRaw
));
hip_check_error
(
hipMalloc
(
&
p_e_grid_
,
sizeof
(
EDataType
)
*
MRaw
*
NRaw
));
...
@@ -546,7 +548,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -546,7 +548,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
mean_var_count_grid_desc_mblock_mperblock_nblock_
=
mean_var_count_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemmWelford
::
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
GridwiseGemmWelford
::
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
mean_var_count_grid_desc_m_n_
);
mean_var_count_grid_desc_m_n
block
_
);
}
}
}
}
...
@@ -578,7 +580,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -578,7 +580,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K
b_grid_desc_n_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EHGridDesc_M_N
e_grid_desc_m_n_
;
EHGridDesc_M_N
e_grid_desc_m_n_
;
MeanVarCountGridDesc_M_N
mean_var_count_grid_desc_m_n_
;
MeanVarCountGridDesc_M_N
Block
mean_var_count_grid_desc_m_n
block
_
;
GammaBetaGridDesc_N
gamma_grid_desc_n_
;
GammaBetaGridDesc_N
gamma_grid_desc_n_
;
GammaBetaGridDesc_N
beta_grid_desc_n_
;
GammaBetaGridDesc_N
beta_grid_desc_n_
;
EHGridDesc_M_N
h_grid_desc_m_n_
;
EHGridDesc_M_N
h_grid_desc_m_n_
;
...
@@ -666,8 +668,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -666,8 +668,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BetaDataType
,
BetaDataType
,
AccDataType
,
AccDataType
,
EHGridDesc_M_N
,
EHGridDesc_M_N
,
MeanVarCountGridDesc_M_N
,
MeanVarCountGridDesc_M_NBlock
,
GammaBetaGridDesc_N
>
;
GammaBetaGridDesc_N
,
HElementwiseOperation
>
;
avg_time
+=
avg_time
+=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
...
@@ -692,17 +695,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -692,17 +695,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
mean_var_count_grid_desc_mblock_mperblock_nblock_
,
arg
.
mean_var_count_grid_desc_mblock_mperblock_nblock_
,
arg
.
block_2_etile_map_
);
arg
.
block_2_etile_map_
);
grid_size
=
math
::
integer_least_multiple
(
M
,
LayernormBlockTileSize_M_N
::
At
(
0
))
/
grid_size
=
math
::
integer_divide_ceil
(
M
,
LayernormBlockTileSize_M_N
::
At
(
0
));
LayernormBlockTileSize_M_N
::
At
(
0
);
index_t
numMeanVarCountBlockTileIteration_N
=
index_t
numMeanVarCountBlockTileIteration_N
=
math
::
integer_divide_ceil
(
math
::
integer_least_multiple
(
arg
.
gemm_nblock_
,
arg
.
gemm_nblock_
,
LayernormThreadClusterSize_M_N
::
At
(
I1
));
LayernormThreadClusterSize_M_N
::
At
(
I1
))
/
LayernormThreadClusterSize_M_N
::
At
(
I1
);
index_t
numEBlockTileIteration_N
=
index_t
numNormBlockTileIteration_N
=
math
::
integer_least_multiple
(
N
,
LayernormBlockTileSize_M_N
::
At
(
I1
))
/
math
::
integer_divide_ceil
(
N
,
LayernormBlockTileSize_M_N
::
At
(
I1
));
LayernormBlockTileSize_M_N
::
At
(
I1
);
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_welford_layernorm
,
kernel_welford_layernorm
,
...
@@ -718,13 +717,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -718,13 +717,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
p_h_grid_
,
arg
.
p_h_grid_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
h_grid_desc_m_n_
,
arg
.
h_grid_desc_m_n_
,
arg
.
mean_var_count_grid_desc_m_n_
,
arg
.
mean_var_count_grid_desc_m_n
block
_
,
arg
.
gamma_grid_desc_n_
,
arg
.
gamma_grid_desc_n_
,
arg
.
beta_grid_desc_n_
,
arg
.
beta_grid_desc_n_
,
arg
.
gemm_nblock_
,
numMeanVarCountBlockTileIteration_N
,
numMeanVarCountBlockTileIteration_N
,
numEBlockTileIteration_N
,
numNormBlockTileIteration_N
,
arg
.
epsilon_
);
arg
.
epsilon_
,
arg
.
h_element_op_
);
return
avg_time
;
return
avg_time
;
};
};
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
View file @
78ff5f81
...
@@ -47,7 +47,7 @@ template <typename ABDataType,
...
@@ -47,7 +47,7 @@ template <typename ABDataType,
typename
BGridDesc_N_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_N
Block
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -349,7 +349,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -349,7 +349,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
using
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
MeanVarCountGridDesc_M_N
{}))
>
;
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
MeanVarCountGridDesc_M_N
Block
{}))
>
;
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
View file @
78ff5f81
...
@@ -27,8 +27,9 @@ template <typename EDataType,
...
@@ -27,8 +27,9 @@ template <typename EDataType,
typename
BetaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
EHGridDesc_M_N
,
typename
EHGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_N
Block
,
typename
GammaBetaGridDesc_N
,
typename
GammaBetaGridDesc_N
,
typename
HElementwiseOperation
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
NThreadClusterSize
,
index_t
NThreadClusterSize
,
...
@@ -42,32 +43,34 @@ template <typename EDataType,
...
@@ -42,32 +43,34 @@ template <typename EDataType,
index_t
MeanVarSrcDstVectorSize
>
index_t
MeanVarSrcDstVectorSize
>
struct
GridwiseWelfordSecondHalfLayernorm2d
struct
GridwiseWelfordSecondHalfLayernorm2d
{
{
static_assert
((
ESrcHDstVectorDim
==
0
&&
MThreadSliceSize
%
ESrcVectorSize
==
0
)
||
// TODO - Support ESrcHDstVectorDim == 0
(
ESrcHDstVectorDim
==
1
&&
NThreadSliceSize
%
ESrcVectorSize
==
0
),
static_assert
(
ESrcHDstVectorDim
==
1
&&
NThreadSliceSize
%
ESrcVectorSize
==
0
&&
NThreadSliceSize
%
GammaSrcVectorSize
==
0
&&
NThreadSliceSize
%
BetaSrcVectorSize
==
0
,
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
((
ESrcHDstVectorDim
==
0
&&
MThreadSliceSize
%
HDstVectorSize
==
0
)
||
static_assert
(
ESrcHDstVectorDim
==
1
&&
NThreadSliceSize
%
HDstVectorSize
==
0
,
(
ESrcHDstVectorDim
==
1
&&
NThreadSliceSize
%
HDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
ESrcHDstVectorDim
==
0
);
using
ThreadClusterLengths_M_N
=
Sequence
<
MThreadClusterSize
,
NThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
Sequence
<
0
,
1
>
;
using
ThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
using
ThreadClusterLengths_M_N
=
Sequence
<
MThreadClusterSize
,
NThreadClusterSize
>
;
static
constexpr
auto
thread_cluster_desc_m_n
=
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_N
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths_M_N
{},
ThreadClusterArrangeOrder
{});
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
using
ThreadBufferLengths_M_N
=
Sequence
<
MThreadSliceSize
,
NThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
NThreadSliceSize
>
{}));
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
static
constexpr
auto
thread_buffer_desc_m_1
=
static
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
using
ThreadBufferLengths_N
=
Sequence
<
NThreadSliceSize
>
;
static
constexpr
auto
thread_buffer_desc_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NThreadSliceSize
>
{}));
using
ThreadReduceSrcDesc_M_1
=
decltype
(
thread_buffer_desc_m_1
);
using
ThreadReduceSrcDesc_M_1
=
decltype
(
thread_buffer_desc_m_1
);
using
ThreadReduceDstDesc_M
=
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
...
@@ -80,19 +83,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -80,19 +83,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
ThreadClusterLengths_M_N
,
ThreadClusterLengths_M_N
,
ThreadClusterArrangeOrder
>
;
ThreadClusterArrangeOrder
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
N_BlockTileSize
=
NThreadClusterSize
*
NThreadSliceSize
;
static
constexpr
index_t
N_BlockTileSize
=
NThreadClusterSize
*
NThreadSliceSize
;
static
constexpr
index_t
N_BlockTileStepSize
=
NThreadClusterSize
*
ESrcVectorSize
;
static
constexpr
auto
EThreadBufferNumber
=
Number
<
NThreadSliceSize
/
ESrcVectorSize
>
{};
static
constexpr
auto
GammaThreadBufferNumber
=
Number
<
NThreadSliceSize
/
ESrcVectorSize
>
{};
static
constexpr
auto
BetaThreadBufferNumber
=
Number
<
NThreadSliceSize
/
ESrcVectorSize
>
{};
static
constexpr
auto
HThreadBufferNumber
=
Number
<
NThreadSliceSize
/
ESrcVectorSize
>
{};
__device__
static
void
Run
(
const
EDataType
*
__restrict__
p_e_grid
,
__device__
static
void
Run
(
const
EDataType
*
__restrict__
p_e_grid
,
const
MeanDataType
*
__restrict__
p_in_welford_mean_grid
,
const
MeanDataType
*
__restrict__
p_in_welford_mean_grid
,
...
@@ -103,47 +98,88 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -103,47 +98,88 @@ struct GridwiseWelfordSecondHalfLayernorm2d
HDataType
*
__restrict__
p_h_grid
,
HDataType
*
__restrict__
p_h_grid
,
const
EHGridDesc_M_N
&
e_grid_desc_m_n
,
const
EHGridDesc_M_N
&
e_grid_desc_m_n
,
const
EHGridDesc_M_N
&
h_grid_desc_m_n
,
const
EHGridDesc_M_N
&
h_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_N
&
mean_var_count_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_N
Block
&
mean_var_count_grid_desc_m_n
,
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
index_t
gemm_nblock_
,
index_t
numMeanVarCountBlockTileIteration_N
,
index_t
numMeanVarCountBlockTileIteration_N
,
index_t
numEBlockTileIteration_N
,
index_t
numNormBlockTileIteration_N
,
ComputeDataType
epsilon
)
ComputeDataType
epsilon
,
HElementwiseOperation
h_element_op
)
{
{
ignore
=
p_e_grid
;
ignore
=
p_in_welford_mean_grid
;
ignore
=
p_in_welford_var_grid
;
ignore
=
p_in_welford_count_grid
;
ignore
=
p_gamma_grid
;
ignore
=
p_beta_grid
;
ignore
=
p_h_grid
;
ignore
=
e_grid_desc_m_n
;
ignore
=
h_grid_desc_m_n
;
ignore
=
mean_var_count_grid_desc_m_n
;
ignore
=
gamma_grid_desc_n
;
ignore
=
beta_grid_desc_n
;
ignore
=
gemm_nblock_
;
ignore
=
numMeanVarCountBlockTileIteration_N
;
ignore
=
numEBlockTileIteration_N
;
ignore
=
epsilon
;
// Thread/Block id
// Thread/Block id
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
auto
thread_cluster_idx
=
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
thread_cluster_desc
_m_n
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
// step1: Merge mean and variance
// Global Memory
auto
threadwise_mean_load_m_k
=
const
auto
e_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_grid
,
e_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_mean_grid
,
mean_var_count_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_var_grid
,
mean_var_count_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_count_grid
,
mean_var_count_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_grid
,
gamma_grid_desc_n
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_grid
,
beta_grid_desc_n
.
GetElementSpaceSize
());
auto
h_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_h_grid
,
h_grid_desc_m_n
.
GetElementSpaceSize
());
// VGPR
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
in_welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
in_welford_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
in_welford_count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
welford_mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
welford_var_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
welford_count_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
ESrcVectorSize
,
true
>
e_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
GammaSrcVectorSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
BetaSrcVectorSize
,
true
>
beta_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
HDstVectorSize
,
true
>
h_thread_buf
;
// IO
auto
threadwise_mean_load_m_nblock
=
ThreadwiseTensorSliceTransfer_v2
<
MeanDataType
,
ThreadwiseTensorSliceTransfer_v2
<
MeanDataType
,
ComputeDataType
,
ComputeDataType
,
MeanVarCountGridDesc_M_N
,
MeanVarCountGridDesc_M_N
Block
,
decltype
(
thread_buffer_desc_m_1
),
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
ThreadBufferDimAccessOrder
,
1
,
1
,
1
,
1
,
1
,
1
,
...
@@ -153,13 +189,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -153,13 +189,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
thread_n_cluster_id
));
auto
threadwise_var_load_m_k
=
auto
threadwise_var_load_m_
nbloc
k
=
ThreadwiseTensorSliceTransfer_v2
<
VarDataType
,
ThreadwiseTensorSliceTransfer_v2
<
VarDataType
,
ComputeDataType
,
ComputeDataType
,
MeanVarCountGridDesc_M_N
,
MeanVarCountGridDesc_M_N
Block
,
decltype
(
thread_buffer_desc_m_1
),
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
ThreadBufferDimAccessOrder
,
1
,
1
,
1
,
1
,
1
,
1
,
...
@@ -169,13 +205,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -169,13 +205,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
thread_n_cluster_id
));
auto
threadwise_count_load_m_k
=
auto
threadwise_count_load_m_
nbloc
k
=
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
ThreadwiseTensorSliceTransfer_v2
<
int32_t
,
int32_t
,
int32_t
,
MeanVarCountGridDesc_M_N
,
MeanVarCountGridDesc_M_N
Block
,
decltype
(
thread_buffer_desc_m_1
),
decltype
(
thread_buffer_desc_m_1
),
ThreadBufferLengths_M_1
,
ThreadBufferLengths_M_1
,
Sequence
<
0
,
1
>
,
ThreadBufferDimAccessOrder
,
1
,
1
,
1
,
1
,
1
,
1
,
...
@@ -185,29 +221,68 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -185,29 +221,68 @@ struct GridwiseWelfordSecondHalfLayernorm2d
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
thread_n_cluster_id
));
const
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
threadwise_e_load_m_n
=
p_in_welford_mean_grid
,
mean_var_count_grid_desc_m_n
.
GetElementSpaceSize
());
ThreadwiseTensorSliceTransfer_v2
<
EDataType
,
ComputeDataType
,
const
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
decltype
(
e_grid_desc_m_n
),
p_in_welford_var_grid
,
mean_var_count_grid_desc_m_n
.
GetElementSpaceSize
());
decltype
(
thread_buffer_desc_m_n
),
ThreadBufferLengths_M_N
,
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
ThreadBufferDimAccessOrder
,
p_in_welford_count_grid
,
mean_var_count_grid_desc_m_n
.
GetElementSpaceSize
());
ESrcHDstVectorDim
,
ESrcVectorSize
,
1
,
true
>
(
e_grid_desc_m_n
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
*
NThreadSliceSize
));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
auto
threadwise_gamma_load_m_n
=
in_welford_mean_thread_buf
;
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
ComputeDataType
,
in_welford_var_thread_buf
;
decltype
(
gamma_grid_desc_n
),
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
decltype
(
thread_buffer_desc_n
),
in_welford_count_thread_buf
;
ThreadBufferLengths_N
,
Sequence
<
0
>
,
// DimAccessOrder,
0
,
// SrcVectorDim,
GammaSrcVectorSize
,
1
,
true
>
(
gamma_grid_desc_n
,
make_multi_index
(
thread_n_cluster_id
*
NThreadSliceSize
));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
auto
threadwise_beta_load_m_n
=
welford_mean_thread_buf
;
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
,
true
>
ComputeDataType
,
welford_var_thread_buf
;
decltype
(
beta_grid_desc_n
),
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
MThreadSliceSize
,
true
>
decltype
(
thread_buffer_desc_n
),
welford_count_thread_buf
;
ThreadBufferLengths_N
,
Sequence
<
0
>
,
// DimAccessOrder,
0
,
// SrcVectorDim,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_n
,
make_multi_index
(
thread_n_cluster_id
*
NThreadSliceSize
));
auto
threadwise_h_store_m_n
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
HDataType
,
decltype
(
thread_buffer_desc_m_n
),
decltype
(
h_grid_desc_m_n
),
HElementwiseOperation
,
ThreadBufferLengths_M_N
,
ThreadBufferDimAccessOrder
,
ESrcHDstVectorDim
,
HDstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
h_grid_desc_m_n
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
*
NThreadSliceSize
),
h_element_op
);
// step1: Merge mean and variance
constexpr
auto
mean_var_count_thread_copy_step_m_n
=
constexpr
auto
mean_var_count_thread_copy_step_m_n
=
make_multi_index
(
0
,
NThreadClusterSize
);
make_multi_index
(
0
,
NThreadClusterSize
);
...
@@ -220,23 +295,23 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -220,23 +295,23 @@ struct GridwiseWelfordSecondHalfLayernorm2d
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
numMeanVarCountBlockTileIteration_N
;
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
numMeanVarCountBlockTileIteration_N
;
++
reducedTiles
)
++
reducedTiles
)
{
{
threadwise_mean_load_m_k
.
Run
(
mean_var_count_grid_desc_m_n
,
threadwise_mean_load_m_
nbloc
k
.
Run
(
mean_var_count_grid_desc_m_n
,
welford_mean_global_val_buf
,
welford_mean_global_val_buf
,
thread_buffer_desc_m_1
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
in_welford_mean_thread_buf
);
in_welford_mean_thread_buf
);
threadwise_var_load_m_k
.
Run
(
mean_var_count_grid_desc_m_n
,
threadwise_var_load_m_
nbloc
k
.
Run
(
mean_var_count_grid_desc_m_n
,
welford_var_global_val_buf
,
welford_var_global_val_buf
,
thread_buffer_desc_m_1
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
in_welford_var_thread_buf
);
in_welford_var_thread_buf
);
threadwise_count_load_m_k
.
Run
(
mean_var_count_grid_desc_m_n
,
threadwise_count_load_m_
nbloc
k
.
Run
(
mean_var_count_grid_desc_m_n
,
welford_count_global_val_buf
,
welford_count_global_val_buf
,
thread_buffer_desc_m_1
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
in_welford_count_thread_buf
);
in_welford_count_thread_buf
);
ThreadwiseWelford
::
Run
(
in_welford_mean_thread_buf
,
ThreadwiseWelford
::
Run
(
in_welford_mean_thread_buf
,
in_welford_var_thread_buf
,
in_welford_var_thread_buf
,
...
@@ -245,12 +320,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -245,12 +320,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d
welford_var_thread_buf
,
welford_var_thread_buf
,
welford_count_thread_buf
);
welford_count_thread_buf
);
threadwise_mean_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_n
,
threadwise_mean_load_m_
nbloc
k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_n
,
mean_var_count_thread_copy_step_m_n
);
mean_var_count_thread_copy_step_m_n
);
threadwise_var_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_n
,
threadwise_var_load_m_
nbloc
k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_n
,
mean_var_count_thread_copy_step_m_n
);
mean_var_count_thread_copy_step_m_n
);
threadwise_count_load_m_k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_n
,
threadwise_count_load_m_
nbloc
k
.
MoveSrcSliceWindow
(
mean_var_count_grid_desc_m_n
,
mean_var_count_thread_copy_step_m_n
);
mean_var_count_thread_copy_step_m_n
);
}
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
@@ -262,9 +337,64 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -262,9 +337,64 @@ struct GridwiseWelfordSecondHalfLayernorm2d
});
});
// step2: normalization
// step2: normalization
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num
E
BlockTileIteration_N
;
++
reducedTiles
)
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num
Norm
BlockTileIteration_N
;
++
reducedTiles
)
{
{
// TODO
// h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
threadwise_e_load_m_n
.
Run
(
e_grid_desc_m_n
,
e_global_val_buf
,
thread_buffer_desc_m_n
,
make_tuple
(
I0
,
I0
),
e_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
m
)
{
auto
divisor
=
1
/
__builtin_amdgcn_sqrtf
(
welford_var_thread_buf
(
m
)
+
epsilon
);
static_for
<
0
,
NThreadSliceSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
m_n
=
thread_buffer_desc_m_n
.
CalculateOffset
(
make_tuple
(
m
,
n
));
h_thread_buf
(
Number
<
m_n
>
{})
=
(
e_thread_buf
(
Number
<
m_n
>
{})
-
welford_mean_thread_buf
(
m
))
*
divisor
;
});
});
threadwise_gamma_load_m_n
.
Run
(
gamma_grid_desc_n
,
gamma_global_val_buf
,
thread_buffer_desc_n
,
make_tuple
(
I0
),
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
NThreadSliceSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
m_n
=
thread_buffer_desc_m_n
.
CalculateOffset
(
make_tuple
(
m
,
n
));
h_thread_buf
(
Number
<
m_n
>
{})
=
h_thread_buf
(
Number
<
m_n
>
{})
*
gamma_thread_buf
(
n
);
});
});
threadwise_beta_load_m_n
.
Run
(
beta_grid_desc_n
,
beta_global_val_buf
,
thread_buffer_desc_n
,
make_tuple
(
I0
),
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
NThreadSliceSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
m_n
=
thread_buffer_desc_m_n
.
CalculateOffset
(
make_tuple
(
m
,
n
));
h_thread_buf
(
Number
<
m_n
>
{})
=
h_thread_buf
(
Number
<
m_n
>
{})
+
beta_thread_buf
(
n
);
});
});
threadwise_h_store_m_n
.
Run
(
thread_buffer_desc_m_n
,
make_tuple
(
I0
,
I0
),
h_thread_buf
,
h_grid_desc_m_n
,
h_global_val_buf
);
threadwise_e_load_m_n
.
MoveSrcSliceWindow
(
e_grid_desc_m_n
,
make_multi_index
(
0
,
N_BlockTileSize
));
threadwise_gamma_load_m_n
.
MoveSrcSliceWindow
(
gamma_grid_desc_n
,
make_multi_index
(
N_BlockTileSize
));
threadwise_beta_load_m_n
.
MoveSrcSliceWindow
(
beta_grid_desc_n
,
make_multi_index
(
N_BlockTileSize
));
threadwise_h_store_m_n
.
MoveDstSliceWindow
(
h_grid_desc_m_n
,
make_multi_index
(
0
,
N_BlockTileSize
));
}
}
}
// run
}
// run
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment