Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ad2f82ac
Commit
ad2f82ac
authored
Nov 30, 2022
by
rocking
Browse files
Sync code, prepare to test on MI200
parent
1d7290fb
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
253 additions
and
31 deletions
+253
-31
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+21
-12
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+13
-10
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
...mm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
+219
-9
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
ad2f82ac
...
@@ -284,7 +284,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -284,7 +284,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}
}
template
<
typename
LayOut
>
template
<
typename
LayOut
>
static
auto
MakeGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
Stride
)
static
auto
Make
E
GridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
Stride
)
{
{
const
auto
grid_desc_mraw_nraw
=
[
&
]()
{
const
auto
grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayOut
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayOut
>::
value
)
...
@@ -308,11 +308,19 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -308,11 +308,19 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeGridDescriptor_M_N
<
DLayout
>
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
return
DeviceOp
::
Make
E
GridDescriptor_M_N
<
DLayout
>
(
MRaws
[
i
],
NRaws
[
i
],
DsStride
[
i
]);
},
},
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
static
auto
MakeMeanVarCountGridDescriptor_M_NBlock
(
index_t
M
,
index_t
NBlock
)
{
const
auto
grid_desc_m_n
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
M
,
NBlock
));
// TODO - padding according to MNperBlock of Gemm and Layernorm
return
grid_desc_m_n
;
}
static
auto
MakeDescriptor_M
(
index_t
MRaw
)
static
auto
MakeDescriptor_M
(
index_t
MRaw
)
{
{
const
auto
grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
...
@@ -366,9 +374,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -366,9 +374,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
MeanVarCountGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
MeanVarCountGridDesc_M_N
=
decltype
(
Make
MeanVarCount
GridDescriptor_M_N
Block
(
1
,
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
using
GammaBetaGridDesc_N
=
decltype
(
MakeDescriptor_N
(
1
));
using
EHGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
EHGridDesc_M_N
=
decltype
(
Make
E
GridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
...
@@ -479,11 +487,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -479,11 +487,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
MRaw
,
KRaw
,
StrideA
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_m_n_
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
e_grid_desc_m_n_
{
DeviceOp
::
Make
E
GridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideH
)},
mean_var_count_grid_desc_m_n_
{},
mean_var_count_grid_desc_m_n_
{},
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
h_grid_desc_m_n_
{
DeviceOp
::
Make
E
GridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
GridwiseGemmWelford
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
GridwiseGemmWelford
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
b_grid_desc_bk0_n_bk1_
{
...
@@ -497,7 +505,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -497,7 +505,10 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
epsilon_
{
epsilon
}
epsilon_
{
epsilon
}
{
{
mean_var_count_grid_desc_m_n_
=
mean_var_count_grid_desc_m_n_
=
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
gemm_nblock_
,
gemm_nblock_
);
DeviceOp
::
MakeMeanVarCountGridDescriptor_M_NBlock
(
MRaw
,
gemm_nblock_
);
int
s
=
mean_var_count_grid_desc_m_n_
.
GetElementSpaceSize
();
printf
(
"mean_var_count_grid_desc_m_n.GetElementSpaceSize() = %d
\n
"
,
s
);
hip_check_error
(
hipMalloc
(
&
p_e_grid_
,
sizeof
(
EDataType
)
*
MRaw
*
NRaw
));
hip_check_error
(
hipMalloc
(
&
p_e_grid_
,
sizeof
(
EDataType
)
*
MRaw
*
NRaw
));
...
@@ -518,7 +529,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -518,7 +529,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// D desc
// D desc
ds_grid_desc_m_n_
(
i
)
=
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeGridDescriptor_M_N
<
DLayout
>
(
MRaw
,
NRaw
,
StrideDs
[
i
]);
DeviceOp
::
Make
E
GridDescriptor_M_N
<
DLayout
>
(
MRaw
,
NRaw
,
StrideDs
[
i
]);
});
});
// populate desc for Ds/E/F/G
// populate desc for Ds/E/F/G
...
@@ -526,7 +537,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -526,7 +537,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
b_grid_desc_n_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
e_grid_desc_m_n_
,
mean_var_count_grid_desc_m_n_
,
block_2_etile_map_
))
block_2_etile_map_
))
{
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
...
@@ -612,7 +622,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -612,7 +622,6 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
mean_var_count_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
arg
.
block_2_etile_map_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmWelford has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemmWelford has invalid setting"
);
...
@@ -694,7 +703,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -694,7 +703,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
LayernormThreadClusterSize_M_N
::
At
(
I1
))
/
LayernormThreadClusterSize_M_N
::
At
(
I1
))
/
LayernormThreadClusterSize_M_N
::
At
(
I1
);
LayernormThreadClusterSize_M_N
::
At
(
I1
);
index_t
num
X
BlockTileIteration_N
=
index_t
num
E
BlockTileIteration_N
=
math
::
integer_least_multiple
(
N
,
LayernormBlockTileSize_M_N
::
At
(
I1
))
/
math
::
integer_least_multiple
(
N
,
LayernormBlockTileSize_M_N
::
At
(
I1
))
/
LayernormBlockTileSize_M_N
::
At
(
I1
);
LayernormBlockTileSize_M_N
::
At
(
I1
);
...
@@ -717,7 +726,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -717,7 +726,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
beta_grid_desc_n_
,
arg
.
beta_grid_desc_n_
,
arg
.
gemm_nblock_
,
arg
.
gemm_nblock_
,
numMeanVarCountBlockTileIteration_N
,
numMeanVarCountBlockTileIteration_N
,
num
X
BlockTileIteration_N
,
num
E
BlockTileIteration_N
,
arg
.
epsilon_
);
arg
.
epsilon_
);
return
avg_time
;
return
avg_time
;
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
View file @
ad2f82ac
...
@@ -269,13 +269,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -269,13 +269,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2ETileMap
>
template
<
typename
Block2ETileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
CheckValidity
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
const
MeanVarCountGridDesc_M_N
&
mean_var_count_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
{
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
@@ -286,9 +284,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -286,9 +284,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
// check consistency of desc
// check consistency of desc
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)))
M
==
mean_var_count_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
/
NPerBlock
==
mean_var_count_grid_desc_m_n
.
GetLength
(
I1
)))
{
{
return
false
;
return
false
;
}
}
...
@@ -997,6 +993,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -997,6 +993,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
static_for
<
0
,
PostShuffleThreadSliceSize_M
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
PostShuffleThreadSliceSize_M
,
1
>
{}([
&
](
auto
j
)
{
block_sync_lds
();
block_sync_lds
();
count_thread_buf
=
threadwise_welfords
(
i
).
cur_count_
;
BlockwiseWelford
::
Run
(
BlockwiseWelford
::
Run
(
mean_thread_buf
(
j
),
var_thread_buf
(
j
),
count_thread_buf
(
j
));
mean_thread_buf
(
j
),
var_thread_buf
(
j
),
count_thread_buf
(
j
));
});
});
...
@@ -1083,6 +1080,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -1083,6 +1080,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
count_thread_buf
,
count_thread_buf
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
welford_count_grid_buf
);
welford_count_grid_buf
);
float
mean
=
static_cast
<
float
>
(
mean_thread_buf
(
I0
));
float
var
=
static_cast
<
float
>
(
var_thread_buf
(
I0
));
int
count
=
count_thread_buf
(
I0
);
if
(
i
==
0
&&
get_thread_global_1d_id
()
==
0
)
printf
(
"1st kernel mean = %f, var = %f, count = %d
\n
"
,
mean
,
var
,
count
);
});
});
}
// shuffle C + Ds + welford + write out
}
// shuffle C + Ds + welford + write out
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
View file @
ad2f82ac
...
@@ -34,15 +34,23 @@ template <typename EDataType,
...
@@ -34,15 +34,23 @@ template <typename EDataType,
index_t
NThreadClusterSize
,
index_t
NThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
MThreadSliceSize
,
index_t
NThreadSliceSize
,
index_t
NThreadSliceSize
,
index_t
ESrc
Y
DstVectorDim
,
index_t
ESrc
H
DstVectorDim
,
index_t
ESrcVectorSize
,
index_t
ESrcVectorSize
,
index_t
Y
DstVectorSize
,
index_t
H
DstVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
index_t
MeanVarSrcDstVectorSize
>
struct
GridwiseWelfordSecondHalfLayernorm2d
struct
GridwiseWelfordSecondHalfLayernorm2d
{
{
static
constexpr
bool
reorder_thread_cluster
=
(
ESrcYDstVectorDim
==
0
);
static_assert
((
ESrcHDstVectorDim
==
0
&&
MThreadSliceSize
%
ESrcVectorSize
==
0
)
||
(
ESrcHDstVectorDim
==
1
&&
NThreadSliceSize
%
ESrcVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static_assert
((
ESrcHDstVectorDim
==
0
&&
MThreadSliceSize
%
HDstVectorSize
==
0
)
||
(
ESrcHDstVectorDim
==
1
&&
NThreadSliceSize
%
HDstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
ESrcHDstVectorDim
==
0
);
using
ThreadClusterLengths_M_N
=
Sequence
<
MThreadClusterSize
,
NThreadClusterSize
>
;
using
ThreadClusterLengths_M_N
=
Sequence
<
MThreadClusterSize
,
NThreadClusterSize
>
;
...
@@ -73,8 +81,14 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -73,8 +81,14 @@ struct GridwiseWelfordSecondHalfLayernorm2d
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
N_BlockTileSize
=
NThreadClusterSize
*
NThreadSliceSize
;
static
constexpr
index_t
N_BlockTileSize
=
NThreadClusterSize
*
NThreadSliceSize
;
static
constexpr
index_t
N_BlockTileStepSize
=
NThreadClusterSize
*
ESrcVectorSize
;
static
constexpr
auto
EThreadBufferNumber
=
Number
<
NThreadSliceSize
/
ESrcVectorSize
>
{};
static
constexpr
auto
GammaThreadBufferNumber
=
Number
<
NThreadSliceSize
/
ESrcVectorSize
>
{};
static
constexpr
auto
BetaThreadBufferNumber
=
Number
<
NThreadSliceSize
/
ESrcVectorSize
>
{};
static
constexpr
auto
HThreadBufferNumber
=
Number
<
NThreadSliceSize
/
ESrcVectorSize
>
{};
__device__
static
void
Run
(
const
EDataType
*
__restrict__
p_e_grid
,
__device__
static
void
Run
(
const
EDataType
*
__restrict__
p_e_grid
,
const
MeanDataType
*
__restrict__
p_in_welford_mean_grid
,
const
MeanDataType
*
__restrict__
p_in_welford_mean_grid
,
...
@@ -89,8 +103,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -89,8 +103,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
index_t
gemm_nblock_
,
index_t
gemm_nblock_
,
index_t
num
_m
ean
_var_count_k_b
lock
_t
ile
_i
teration
,
index_t
num
M
ean
VarCountB
lock
T
ile
I
teration
_N
,
index_t
num
_xy_k_b
lock
_t
ile
_i
teration
,
index_t
num
EB
lock
T
ile
I
teration
_N
,
ComputeDataType
epsilon
)
ComputeDataType
epsilon
)
{
{
ignore
=
p_e_grid
;
ignore
=
p_e_grid
;
...
@@ -106,10 +120,206 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -106,10 +120,206 @@ struct GridwiseWelfordSecondHalfLayernorm2d
ignore
=
gamma_grid_desc_n
;
ignore
=
gamma_grid_desc_n
;
ignore
=
beta_grid_desc_n
;
ignore
=
beta_grid_desc_n
;
ignore
=
gemm_nblock_
;
ignore
=
gemm_nblock_
;
ignore
=
num
_m
ean
_var_count_k_b
lock
_t
ile
_i
teration
;
ignore
=
num
M
ean
VarCountB
lock
T
ile
I
teration
_N
;
ignore
=
num
_xy_k_b
lock
_t
ile
_i
teration
;
ignore
=
num
EB
lock
T
ile
I
teration
_N
;
ignore
=
epsilon
;
ignore
=
epsilon
;
// float mean = static_cast<float>(p_in_welford_mean_grid[0]);
// float var = static_cast<float>(p_in_welford_var_grid[0]);
// int count = p_in_welford_count_grid[0];
// if(get_thread_global_1d_id() == 0)
// printf("kernel mean = %f, var = %f, count = %d\n", mean, var, count);
float
mean
=
static_cast
<
float
>
(
p_in_welford_mean_grid
[
0
]);
if
(
get_thread_global_1d_id
()
==
0
)
printf
(
"mean = %f
\n
"
,
mean
);
int
s
=
static_cast
<
int
>
(
mean_var_count_grid_desc_m_n
.
GetElementSpaceSize
());
if
(
get_thread_global_1d_id
()
==
0
)
printf
(
"mean_var_count_grid_desc_m_n.GetElementSpaceSize() = %d
\n
"
,
s
);
// using ThreadBufferLengths_1_1 = Sequence<1, 1>;
// constexpr auto thread_buffer_desc_1_1 =
// make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// constexpr auto grid_desc_1_1 =
// make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// const auto mean_grid = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, 1, true> mean_thread;
// float mean1 = (mean_grid.template Get<MeanDataType>(0, true));
// if(get_thread_global_1d_id() == 0)
// printf("global mean = %f\n", mean1);
// auto threadwise_mean_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<MeanDataType,
// ComputeDataType,
// decltype(mean_var_count_grid_desc_m_n),
// decltype(thread_buffer_desc_1_1),
// ThreadBufferLengths_1_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(mean_var_count_grid_desc_m_n,
// make_multi_index(0, 0));
// threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
// mean_grid,
// thread_buffer_desc_1_1,
// make_tuple(I0, I0),
// mean_thread);
// if(get_thread_global_1d_id() == 0)
// printf("threadwise mean = %f\n", mean_thread(Number<0>{}));
// // Thread/Block id
// const index_t thread_local_id = get_thread_local_1d_id();
// const index_t block_global_id = get_block_1d_id();
// const auto thread_cluster_idx =
// thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
// const auto thread_m_cluster_id = thread_cluster_idx[I0];
// const auto thread_n_cluster_id = thread_cluster_idx[I1];
// // step1: Merge mean and variance
// using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
// constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
// make_tuple(Number<MThreadSliceSize>{}, Number<1>{}));
// auto threadwise_mean_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<MeanDataType,
// ComputeDataType,
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(mean_var_count_grid_desc_m_n,
// make_multi_index(0, 0));
// auto threadwise_var_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<VarDataType,
// ComputeDataType,
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(
// mean_var_count_grid_desc_m_n,
// make_multi_index(block_global_id * M_BlockTileSize +
// thread_m_cluster_id * MThreadSliceSize,
// thread_n_cluster_id));
// auto threadwise_count_load_m_k =
// ThreadwiseTensorSliceTransfer_v2<int32_t,
// int32_t,
// MeanVarCountGridDesc_M_N,
// decltype(thread_buffer_desc_m_1),
// ThreadBufferLengths_M_1,
// Sequence<0, 1>,
// 1,
// 1,
// 1,
// true>(
// mean_var_count_grid_desc_m_n,
// make_multi_index(block_global_id * M_BlockTileSize +
// thread_m_cluster_id * MThreadSliceSize,
// thread_n_cluster_id));
// const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_mean_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_var_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
// p_in_welford_count_grid, mean_var_count_grid_desc_m_n.GetElementSpaceSize());
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// in_welford_mean_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// in_welford_var_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
// in_welford_count_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// welford_mean_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, MThreadSliceSize, true>
// welford_var_thread_buf;
// StaticBuffer<AddressSpaceEnum::Vgpr, int32_t, MThreadSliceSize, true>
// welford_count_thread_buf;
// constexpr auto mean_var_count_thread_copy_step_m_n =
// make_multi_index(0, NThreadClusterSize);
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// welford_mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
// welford_var_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
// welford_count_thread_buf(I) = 0;
// });
// for(index_t reducedTiles = 0; reducedTiles < numMeanVarCountBlockTileIteration_N;
// ++reducedTiles)
// {
// threadwise_mean_load_m_k.Run(mean_var_count_grid_desc_m_n,
// welford_mean_global_val_buf,
// thread_buffer_desc_m_1,
// make_tuple(I0, I0),
// in_welford_mean_thread_buf);
// // threadwise_var_load_m_k.Run(mean_var_count_grid_desc_m_n,
// // welford_var_global_val_buf,
// // thread_buffer_desc_m_1,
// // make_tuple(I0, I0),
// // in_welford_var_thread_buf);
// // threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_n,
// // welford_count_global_val_buf,
// // thread_buffer_desc_m_1,
// // make_tuple(I0, I0),
// // in_welford_count_thread_buf);
// // ThreadwiseWelford::Run(in_welford_mean_thread_buf,
// // in_welford_var_thread_buf,
// // in_welford_count_thread_buf,
// // welford_mean_thread_buf,
// // welford_var_thread_buf,
// // welford_count_thread_buf);
// // threadwise_mean_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// // threadwise_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// // threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_n,
// // mean_var_count_thread_copy_step_m_n);
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// if(get_thread_global_1d_id() == 0)
// printf("mean = %f, var = %f, count = %d\n",
// in_welford_mean_thread_buf(I),
// in_welford_var_thread_buf(I),
// in_welford_count_thread_buf(I));
// });
// }
// static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
// if constexpr(I > 0)
// block_sync_lds();
// if(get_thread_global_1d_id() == 0)
// printf("count = %d\n", welford_count_thread_buf(I));
// BlockwiseWelford::Run(
// welford_mean_thread_buf(I), welford_var_thread_buf(I),
// welford_count_thread_buf(I));
// });
}
// run
}
// run
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment