Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8749678a
Commit
8749678a
authored
Nov 22, 2022
by
rocking
Browse files
Rename F and G to mean and var
parent
9a25afe4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
118 additions
and
115 deletions
+118
-115
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+69
-67
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+49
-48
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
8749678a
...
@@ -24,8 +24,8 @@ template <typename GridwiseGemm,
...
@@ -24,8 +24,8 @@ template <typename GridwiseGemm,
typename
ABDataType
,
typename
ABDataType
,
typename
DsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
EDataType
,
typename
F
DataType
,
typename
Mean
DataType
,
typename
G
DataType
,
typename
Var
DataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
CDEElementwiseOperation
,
...
@@ -33,8 +33,8 @@ template <typename GridwiseGemm,
...
@@ -33,8 +33,8 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
F
GridDescriptor_MBlock_MPerBlock_NBlock
,
typename
Mean
GridDescriptor_MBlock_MPerBlock_NBlock
,
typename
G
GridDescriptor_MBlock_MPerBlock_NBlock
,
typename
Var
GridDescriptor_MBlock_MPerBlock_NBlock
,
typename
Block2ETileMap
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
...
@@ -46,8 +46,8 @@ __global__ void
...
@@ -46,8 +46,8 @@ __global__ void
const
ABDataType
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
EDataType
*
__restrict__
p_e_grid
,
F
DataType
*
__restrict__
p_
f
_grid
,
Mean
DataType
*
__restrict__
p_
mean
_grid
,
G
DataType
*
__restrict__
p_
g
_grid
,
Var
DataType
*
__restrict__
p_
var
_grid
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
...
@@ -57,8 +57,8 @@ __global__ void
...
@@ -57,8 +57,8 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
F
GridDescriptor_MBlock_MPerBlock_NBlock
f
_grid_desc_mblock_mperblock_nblock
,
const
Mean
GridDescriptor_MBlock_MPerBlock_NBlock
mean
_grid_desc_mblock_mperblock_nblock
,
const
G
GridDescriptor_MBlock_MPerBlock_NBlock
g
_grid_desc_mblock_mperblock_nblock
,
const
Var
GridDescriptor_MBlock_MPerBlock_NBlock
var
_grid_desc_mblock_mperblock_nblock
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
@@ -68,8 +68,8 @@ __global__ void
...
@@ -68,8 +68,8 @@ __global__ void
p_b_grid
,
p_b_grid
,
p_ds_grid
,
p_ds_grid
,
p_e_grid
,
p_e_grid
,
p_
f
_grid
,
p_
mean
_grid
,
p_
g
_grid
,
p_
var
_grid
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -78,16 +78,16 @@ __global__ void
...
@@ -78,16 +78,16 @@ __global__ void
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
f
_grid_desc_mblock_mperblock_nblock
,
mean
_grid_desc_mblock_mperblock_nblock
,
g
_grid_desc_mblock_mperblock_nblock
,
var
_grid_desc_mblock_mperblock_nblock
,
block_2_etile_map
);
block_2_etile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
p_e_grid
;
ignore
=
p_
f
_grid
;
ignore
=
p_
mean
_grid
;
ignore
=
p_
g
_grid
;
ignore
=
p_
var
_grid
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
cde_element_op
;
...
@@ -95,8 +95,8 @@ __global__ void
...
@@ -95,8 +95,8 @@ __global__ void
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
f
_grid_desc_mblock_mperblock_nblock
;
ignore
=
mean
_grid_desc_mblock_mperblock_nblock
;
ignore
=
g
_grid_desc_mblock_mperblock_nblock
;
ignore
=
var
_grid_desc_mblock_mperblock_nblock
;
ignore
=
block_2_etile_map
;
ignore
=
block_2_etile_map
;
#endif
#endif
}
}
...
@@ -185,9 +185,9 @@ template <typename ALayout,
...
@@ -185,9 +185,9 @@ template <typename ALayout,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
:
public
BaseOperator
struct
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
:
public
BaseOperator
{
{
using
DeviceOp
=
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
;
using
F
DataType
=
CShuffleDataType
;
using
Mean
DataType
=
CShuffleDataType
;
using
G
DataType
=
CShuffleDataType
;
using
Var
DataType
=
CShuffleDataType
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
@@ -264,13 +264,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -264,13 +264,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
EGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
F
GridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
Mean
GridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
G
GridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
Var
GridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
HGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
HGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
...
@@ -279,8 +279,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -279,8 +279,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
CShuffleDataType
,
CShuffleDataType
,
DsDataType
,
DsDataType
,
EDataType
,
EDataType
,
F
DataType
,
Mean
DataType
,
G
DataType
,
Var
DataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
...
@@ -289,8 +289,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -289,8 +289,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
EGridDesc_M_N
,
F
GridDesc_M_N
,
Mean
GridDesc_M_N
,
G
GridDesc_M_N
,
Var
GridDesc_M_N
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
MPerBlock
,
MPerBlock
,
...
@@ -328,7 +328,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -328,7 +328,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
using
GridwiseWelfordLayernorm
=
using
GridwiseWelfordLayernorm
=
GridwiseWelfordSecondHalfLayernorm2d
<
EDataType
,
HDataType
,
F
DataType
,
G
DataType
>
;
GridwiseWelfordSecondHalfLayernorm2d
<
EDataType
,
HDataType
,
Mean
DataType
,
Var
DataType
>
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
@@ -354,15 +354,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -354,15 +354,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_
f
_grid_
{
nullptr
},
p_
mean
_grid_
{
nullptr
},
p_
g
_grid_
{
nullptr
},
p_
var
_grid_
{
nullptr
},
p_h_grid_
{
static_cast
<
HDataType
*>
(
p_h_grid
)},
p_h_grid_
{
static_cast
<
HDataType
*>
(
p_h_grid
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
MRaw
,
KRaw
,
StrideA
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_m_n_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideE
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideE
)},
f
_grid_desc_m_n_
{},
mean
_grid_desc_m_n_
{},
g
_grid_desc_m_n_
{},
var
_grid_desc_m_n_
{},
h_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
...
@@ -371,14 +371,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -371,14 +371,14 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
h_element_op_
{
h_element_op
},
h_element_op_
{
h_element_op
},
blkGroupSize_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)}
blkGroupSize_
{
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)}
{
{
f
_grid_desc_m_n_
=
mean
_grid_desc_m_n_
=
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
blkGroupSize_
,
blkGroupSize_
);
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
blkGroupSize_
,
blkGroupSize_
);
g
_grid_desc_m_n_
=
var
_grid_desc_m_n_
=
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
blkGroupSize_
,
blkGroupSize_
);
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
blkGroupSize_
,
blkGroupSize_
);
int
welford_size
=
MRaw
*
blkGroupSize_
;
int
welford_size
=
MRaw
*
blkGroupSize_
;
hip_check_error
(
hipMalloc
(
&
p_
f
_grid_
,
sizeof
(
F
DataType
)
*
welford_size
));
hip_check_error
(
hipMalloc
(
&
p_
mean
_grid_
,
sizeof
(
Mean
DataType
)
*
welford_size
));
hip_check_error
(
hipMalloc
(
&
p_
g
_grid_
,
sizeof
(
G
DataType
)
*
welford_size
));
hip_check_error
(
hipMalloc
(
&
p_
var
_grid_
,
sizeof
(
Var
DataType
)
*
welford_size
));
// populate pointer, desc for Ds
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -398,8 +398,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -398,8 +398,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
e_grid_desc_m_n_
,
f
_grid_desc_m_n_
,
mean
_grid_desc_m_n_
,
g
_grid_desc_m_n_
,
var
_grid_desc_m_n_
,
block_2_etile_map_
))
block_2_etile_map_
))
{
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
...
@@ -410,11 +410,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -410,11 +410,13 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
e_grid_desc_m_n_
);
f_grid_desc_mblock_mperblock_nblock_
=
mean_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemm
::
MakeFGGridDescriptor_MBlock_MPerBlock_NBlock
(
f_grid_desc_m_n_
);
GridwiseGemm
::
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
mean_grid_desc_m_n_
);
g_grid_desc_mblock_mperblock_nblock_
=
var_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemm
::
MakeFGGridDescriptor_MBlock_MPerBlock_NBlock
(
g_grid_desc_m_n_
);
GridwiseGemm
::
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
var_grid_desc_m_n_
);
}
}
// TODO - H
// TODO - H
...
@@ -436,8 +438,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -436,8 +438,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
EDataType
*
p_e_grid_
;
F
DataType
*
p_
f
_grid_
;
// mean
Mean
DataType
*
p_
mean
_grid_
;
// mean
G
DataType
*
p_
g
_grid_
;
// variance * count
Var
DataType
*
p_
var
_grid_
;
// variance * count
HDataType
*
p_h_grid_
;
HDataType
*
p_h_grid_
;
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
...
@@ -445,8 +447,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -445,8 +447,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
BGridDesc_N_K
b_grid_desc_n_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
F
GridDesc_M_N
f
_grid_desc_m_n_
;
Mean
GridDesc_M_N
mean
_grid_desc_m_n_
;
G
GridDesc_M_N
g
_grid_desc_m_n_
;
Var
GridDesc_M_N
var
_grid_desc_m_n_
;
HGridDesc_M_N
h_grid_desc_m_n_
;
HGridDesc_M_N
h_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
...
@@ -456,10 +458,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -456,10 +458,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
F
GridDescriptor_MBlock_MPerBlock_NBlock
typename
GridwiseGemm
::
Mean
GridDescriptor_MBlock_MPerBlock_NBlock
f
_grid_desc_mblock_mperblock_nblock_
;
mean
_grid_desc_mblock_mperblock_nblock_
;
typename
GridwiseGemm
::
G
GridDescriptor_MBlock_MPerBlock_NBlock
typename
GridwiseGemm
::
Var
GridDescriptor_MBlock_MPerBlock_NBlock
g
_grid_desc_mblock_mperblock_nblock_
;
var
_grid_desc_mblock_mperblock_nblock_
;
// block-to-e-tile map
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
Block2ETileMap
block_2_etile_map_
;
...
@@ -486,8 +488,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -486,8 +488,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
f
_grid_desc_m_n_
,
arg
.
mean
_grid_desc_m_n_
,
arg
.
g
_grid_desc_m_n_
,
arg
.
var
_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
arg
.
block_2_etile_map_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
...
@@ -508,8 +510,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -508,8 +510,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
EDataType
,
F
DataType
,
Mean
DataType
,
G
DataType
,
Var
DataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
...
@@ -517,8 +519,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -517,8 +519,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
typename
GridwiseGemm
::
DefaultBGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
DefaultBGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
F
GridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemm
::
Mean
GridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemm
::
G
GridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemm
::
Var
GridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
has_main_loop
>
;
...
@@ -526,8 +528,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -526,8 +528,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
kernel_welford_layernorm2d_second_half
<
GridwiseWelfordLayernorm
,
kernel_welford_layernorm2d_second_half
<
GridwiseWelfordLayernorm
,
EDataType
,
EDataType
,
HDataType
,
HDataType
,
F
DataType
,
Mean
DataType
,
G
DataType
>
;
Var
DataType
>
;
avg_time
+=
avg_time
+=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
...
@@ -539,8 +541,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -539,8 +541,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
arg
.
p_
f
_grid_
,
arg
.
p_
mean
_grid_
,
arg
.
p_
g
_grid_
,
arg
.
p_
var
_grid_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
cde_element_op_
,
...
@@ -548,8 +550,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -548,8 +550,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
f
_grid_desc_mblock_mperblock_nblock_
,
arg
.
mean
_grid_desc_mblock_mperblock_nblock_
,
arg
.
g
_grid_desc_mblock_mperblock_nblock_
,
arg
.
var
_grid_desc_mblock_mperblock_nblock_
,
arg
.
block_2_etile_map_
);
arg
.
block_2_etile_map_
);
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
@@ -558,8 +560,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -558,8 +560,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
arg
.
p_
f
_grid_
,
arg
.
p_
mean
_grid_
,
arg
.
p_
g
_grid_
,
arg
.
p_
var
_grid_
,
arg
.
p_h_grid_
);
arg
.
p_h_grid_
);
return
avg_time
;
return
avg_time
;
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
View file @
8749678a
...
@@ -37,8 +37,8 @@ template <typename ABDataType,
...
@@ -37,8 +37,8 @@ template <typename ABDataType,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
EDataType
,
typename
F
DataType
,
typename
Mean
DataType
,
typename
G
DataType
,
typename
Var
DataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
CDEElementwiseOperation
,
...
@@ -47,8 +47,8 @@ template <typename ABDataType,
...
@@ -47,8 +47,8 @@ template <typename ABDataType,
typename
BGridDesc_N_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
F
GridDesc_M_N
,
typename
Mean
GridDesc_M_N
,
typename
G
GridDesc_M_N
,
typename
Var
GridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -242,10 +242,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -242,10 +242,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
// TODO - Make
FG
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
// TODO - Make
MeanVar
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
template
<
typename
GridDescriptor_M_N
>
template
<
typename
GridDescriptor_M_N
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
Make
FG
GridDescriptor_MBlock_MPerBlock_NBlock
(
const
GridDescriptor_M_N
&
grid_desc_m_n
)
Make
MeanVar
GridDescriptor_MBlock_MPerBlock_NBlock
(
const
GridDescriptor_M_N
&
grid_desc_m_n
)
{
{
const
auto
M
=
grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
grid_desc_m_n
.
GetLength
(
I0
);
const
auto
NBlock
=
grid_desc_m_n
.
GetLength
(
I1
);
const
auto
NBlock
=
grid_desc_m_n
.
GetLength
(
I1
);
...
@@ -271,13 +271,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -271,13 +271,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2ETileMap
>
template
<
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
__host__
__device__
static
constexpr
bool
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
FGridDesc_M_N
&
f_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
GGridDesc_M_N
&
g_grid_desc_m_n
,
const
MeanGridDesc_M_N
&
mean_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
const
VarGridDesc_M_N
&
var_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
{
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
@@ -289,9 +290,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -289,9 +290,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// check consistency of desc
// check consistency of desc
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
M
==
f
_grid_desc_m_n
.
GetLength
(
I0
)
&&
M
==
g
_grid_desc_m_n
.
GetLength
(
I0
)
&&
M
==
mean
_grid_desc_m_n
.
GetLength
(
I0
)
&&
M
==
var
_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
/
NPerBlock
==
f
_grid_desc_m_n
.
GetLength
(
I1
)
&&
N
/
NPerBlock
==
mean
_grid_desc_m_n
.
GetLength
(
I1
)
&&
N
/
NPerBlock
==
g
_grid_desc_m_n
.
GetLength
(
I1
)))
N
/
NPerBlock
==
var
_grid_desc_m_n
.
GetLength
(
I1
)))
{
{
return
false
;
return
false
;
}
}
...
@@ -353,12 +354,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -353,12 +354,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
DefaultBGridDesc_BK0_N_BK1
=
using
DefaultBGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
F
GridDescriptor_MBlock_MPerBlock_NBlock
=
using
Mean
GridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
remove_cvref_t
<
decltype
(
MakeFG
GridDescriptor_MBlock_MPerBlock_NBlock
(
F
GridDesc_M_N
{}))
>
;
MakeMeanVar
GridDescriptor_MBlock_MPerBlock_NBlock
(
Mean
GridDesc_M_N
{}))
>
;
using
G
GridDescriptor_MBlock_MPerBlock_NBlock
=
using
Var
GridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
remove_cvref_t
<
decltype
(
MakeFG
GridDescriptor_MBlock_MPerBlock_NBlock
(
G
GridDesc_M_N
{}))
>
;
MakeMeanVar
GridDescriptor_MBlock_MPerBlock_NBlock
(
Var
GridDesc_M_N
{}))
>
;
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
...
@@ -376,8 +377,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -376,8 +377,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
const
ABDataType
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
EDataType
*
__restrict__
p_e_grid
,
F
DataType
*
__restrict__
p_
f
_grid
,
Mean
DataType
*
__restrict__
p_
mean
_grid
,
G
DataType
*
__restrict__
p_
g
_grid
,
Var
DataType
*
__restrict__
p_
var
_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
...
@@ -388,8 +389,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -388,8 +389,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
F
GridDescriptor_MBlock_MPerBlock_NBlock
&
f
_grid_desc_mblock_mperblock_nblock
,
const
Mean
GridDescriptor_MBlock_MPerBlock_NBlock
&
mean
_grid_desc_mblock_mperblock_nblock
,
const
G
GridDescriptor_MBlock_MPerBlock_NBlock
&
g
_grid_desc_mblock_mperblock_nblock
,
const
Var
GridDescriptor_MBlock_MPerBlock_NBlock
&
var
_grid_desc_mblock_mperblock_nblock
,
const
Block2ETileMap
&
block_2_etile_map
)
const
Block2ETileMap
&
block_2_etile_map
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -409,11 +410,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -409,11 +410,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
auto
e_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
e_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
f
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
mean
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
f
_grid
,
f
_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
p_
mean
_grid
,
mean
_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
auto
g
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
var
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
g
_grid
,
g
_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
p_
var
_grid
,
var
_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -989,11 +990,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -989,11 +990,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I1
);
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I1
);
static_assert
(
mreduce_per_thread
%
FGTransferScalarPerVector
==
0
);
static_assert
(
mreduce_per_thread
%
FGTransferScalarPerVector
==
0
);
auto
f
_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
mean
_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
AccDataType
,
F
DataType
,
Mean
DataType
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
f
_grid_desc_mblock_mperblock_nblock
),
decltype
(
mean
_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
mreduce_per_thread
,
1
>
,
Sequence
<
1
,
mreduce_per_thread
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
...
@@ -1001,18 +1002,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -1001,18 +1002,18 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
FGTransferScalarPerVector
,
FGTransferScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
false
>
{
f
_grid_desc_mblock_mperblock_nblock
,
false
>
{
mean
_grid_desc_mblock_mperblock_nblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
make_multi_index
(
block_work_idx
[
I0
],
// mblock
shuffleMPerBlock
*
i
+
shuffleMPerBlock
*
i
+
c_reduce_thread_data_idx_begin
[
I0
],
// mperblock
c_reduce_thread_data_idx_begin
[
I0
],
// mperblock
block_work_idx
[
I1
]),
// nblock
block_work_idx
[
I1
]),
// nblock
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
g
_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
var
_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
AccDataType
,
G
DataType
,
Var
DataType
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
g
_grid_desc_mblock_mperblock_nblock
),
decltype
(
var
_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
mreduce_per_thread
,
1
>
,
Sequence
<
1
,
mreduce_per_thread
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
...
@@ -1020,24 +1021,24 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -1020,24 +1021,24 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
FGTransferScalarPerVector
,
FGTransferScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
false
>
{
g
_grid_desc_mblock_mperblock_nblock
,
false
>
{
var
_grid_desc_mblock_mperblock_nblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
make_multi_index
(
block_work_idx
[
I0
],
// mblock
shuffleMPerBlock
*
i
+
shuffleMPerBlock
*
i
+
c_reduce_thread_data_idx_begin
[
I0
],
// mperblock
c_reduce_thread_data_idx_begin
[
I0
],
// mperblock
block_work_idx
[
I1
]),
// nblock
block_work_idx
[
I1
]),
// nblock
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
f
_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
mean
_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
mean_thread_buf
,
mean_thread_buf
,
f
_grid_desc_mblock_mperblock_nblock
,
mean
_grid_desc_mblock_mperblock_nblock
,
f
_grid_buf
);
mean
_grid_buf
);
g
_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
var
_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
var_thread_buf
,
var_thread_buf
,
g
_grid_desc_mblock_mperblock_nblock
,
var
_grid_desc_mblock_mperblock_nblock
,
g
_grid_buf
);
var
_grid_buf
);
});
});
}
// shuffle C + Ds + welford + write out
}
// shuffle C + Ds + welford + write out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment