Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
003ec407
Commit
003ec407
authored
Nov 28, 2022
by
rocking
Browse files
Add welford count
parent
b7f500f0
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
197 additions
and
169 deletions
+197
-169
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+106
-114
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+91
-55
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
003ec407
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
View file @
003ec407
...
@@ -47,8 +47,7 @@ template <typename ABDataType,
...
@@ -47,8 +47,7 @@ template <typename ABDataType,
typename
BGridDesc_N_K
,
typename
BGridDesc_N_K
,
typename
DsGridDesc_M_N
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
EGridDesc_M_N
,
typename
MeanGridDesc_M_N
,
typename
MeanVarCountGridDesc_M_N
,
typename
VarGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -80,7 +79,6 @@ template <typename ABDataType,
...
@@ -80,7 +79,6 @@ template <typename ABDataType,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
PostShuffleThreadClusterSize_M_N
,
typename
PostShuffleThreadClusterSize_M_N
,
index_t
PostShuffleScalarPerVector
,
index_t
PostShuffleScalarPerVector
,
index_t
MeanVarTransferScalarPerVector
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
>
struct
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
struct
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
{
{
...
@@ -242,10 +240,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -242,10 +240,10 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
// TODO - MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
// TODO - MakeMeanVar
Count
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
template
<
typename
GridDescriptor_M_N
>
template
<
typename
GridDescriptor_M_N
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
const
GridDescriptor_M_N
&
grid_desc_m_n
)
MakeMeanVar
Count
GridDescriptor_MBlock_MPerBlock_NBlock
(
const
GridDescriptor_M_N
&
grid_desc_m_n
)
{
{
const
auto
M
=
grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
grid_desc_m_n
.
GetLength
(
I0
);
const
auto
NBlock
=
grid_desc_m_n
.
GetLength
(
I1
);
const
auto
NBlock
=
grid_desc_m_n
.
GetLength
(
I1
);
...
@@ -276,8 +274,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -276,8 +274,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
MeanGridDesc_M_N
&
mean_grid_desc_m_n
,
const
MeanVarCountGridDesc_M_N
&
mean_var_count_grid_desc_m_n
,
const
VarGridDesc_M_N
&
var_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
const
Block2ETileMap
&
block_2_etile_map
)
{
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
...
@@ -290,9 +287,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -290,9 +287,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// check consistency of desc
// check consistency of desc
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
M
==
mean_grid_desc_m_n
.
GetLength
(
I0
)
&&
M
==
var_grid_desc_m_n
.
GetLength
(
I0
)
&&
M
==
mean_var_count_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
/
NPerBlock
==
mean_grid_desc_m_n
.
GetLength
(
I1
)
&&
N
/
NPerBlock
==
mean_var_count_grid_desc_m_n
.
GetLength
(
I1
)))
N
/
NPerBlock
==
var_grid_desc_m_n
.
GetLength
(
I1
)))
{
{
return
false
;
return
false
;
}
}
...
@@ -356,10 +352,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -356,10 +352,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
MeanGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
using
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
MeanGridDesc_M_N
{}))
>
;
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
MeanVarCountGridDesc_M_N
{}))
>
;
using
VarGridDescriptor_MBlock_MPerBlock_NBlock
=
remove_cvref_t
<
decltype
(
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
VarGridDesc_M_N
{}))
>
;
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
...
@@ -372,13 +366,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -372,13 +366,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
typename
Block2ETileMap
>
__device__
static
void
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
EDataType
*
__restrict__
p_e_grid
,
MeanDataType
*
__restrict__
p_mean_grid
,
MeanDataType
*
__restrict__
p_welford_mean_grid
,
VarDataType
*
__restrict__
p_var_grid
,
VarDataType
*
__restrict__
p_welford_var_grid
,
int32_t
*
__restrict__
p_welford_count
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
...
@@ -389,8 +383,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -389,8 +383,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
MeanGridDescriptor_MBlock_MPerBlock_NBlock
&
mean_grid_desc_mblock_mperblock_nblock
,
const
Mean
VarCount
GridDescriptor_MBlock_MPerBlock_NBlock
&
const
VarGridDescriptor_MBlock_MPerBlock_NBlock
&
var
_grid_desc_mblock_mperblock_nblock
,
mean_var_count
_grid_desc_mblock_mperblock_nblock
,
const
Block2ETileMap
&
block_2_etile_map
)
const
Block2ETileMap
&
block_2_etile_map
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -411,10 +405,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -411,10 +405,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
mean_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
mean_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_mean_grid
,
mean_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
p_welford_mean_grid
,
mean_var_count_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
auto
var_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
var_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_var_grid
,
var_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
p_welford_var_grid
,
mean_var_count_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
auto
welford_count_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_welford_count
,
mean_var_count_grid_desc_mblock_mperblock_nblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -871,9 +871,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -871,9 +871,14 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
decltype
(
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
>
(
decltype
(
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
()));
thread_welford_dst_desc_m
.
GetElementSpaceSize
()));
using
welford_count_vgpr_type
=
decltype
(
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
()));
Array
<
ThreadwiseWelford
,
num_shuffleM
>
threadwise_welfords
;
Array
<
ThreadwiseWelford
,
num_shuffleM
>
threadwise_welfords
;
Array
<
mean_var_vgpr_type
,
num_shuffleM
>
mean_thread_bufs
;
Array
<
mean_var_vgpr_type
,
num_shuffleM
>
mean_thread_bufs
;
Array
<
mean_var_vgpr_type
,
num_shuffleM
>
var_thread_bufs
;
Array
<
mean_var_vgpr_type
,
num_shuffleM
>
var_thread_bufs
;
Array
<
welford_count_vgpr_type
,
num_shuffleM
>
welford_count_thread_bufs
;
static_for
<
0
,
num_shuffleM
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_shuffleM
,
1
>
{}([
&
](
auto
i
)
{
// TODO - padding
// TODO - padding
...
@@ -884,9 +889,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -884,9 +889,13 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
var_thread_bufs
(
i
)
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
>
(
var_thread_bufs
(
i
)
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
());
thread_welford_dst_desc_m
.
GetElementSpaceSize
());
welford_count_thread_bufs
(
i
)
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
int32_t
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
());
static_for
<
0
,
PostShuffleThreadSliceSize_M
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
PostShuffleThreadSliceSize_M
,
1
>
{}([
&
](
auto
j
)
{
mean_thread_bufs
(
i
)(
j
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
mean_thread_bufs
(
i
)(
j
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_bufs
(
i
)(
j
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
var_thread_bufs
(
i
)(
j
)
=
type_convert
<
AccDataType
>
(
0.0
f
);
welford_count_thread_bufs
(
i
)(
j
)
=
0
;
});
});
});
});
...
@@ -984,11 +993,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -984,11 +993,12 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
static_for
<
0
,
num_shuffleM
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_shuffleM
,
1
>
{}([
&
](
auto
i
)
{
auto
&
mean_thread_buf
=
mean_thread_bufs
(
i
);
auto
&
mean_thread_buf
=
mean_thread_bufs
(
i
);
auto
&
var_thread_buf
=
var_thread_bufs
(
i
);
auto
&
var_thread_buf
=
var_thread_bufs
(
i
);
int
count
=
threadwise_welfords
(
i
).
cur_count_
;
auto
&
count
_thread_buf
=
welford_count_thread_bufs
(
i
)
;
static_for
<
0
,
PostShuffleThreadSliceSize_M
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
PostShuffleThreadSliceSize_M
,
1
>
{}([
&
](
auto
j
)
{
block_sync_lds
();
block_sync_lds
();
BlockwiseWelford
::
Run
(
mean_thread_buf
(
j
),
var_thread_buf
(
j
),
count
);
BlockwiseWelford
::
Run
(
mean_thread_buf
(
j
),
var_thread_buf
(
j
),
count_thread_buf
(
j
));
});
});
constexpr
auto
thread_welford_desc_I_m_I
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
thread_welford_desc_I_m_I
=
make_naive_tensor_descriptor_packed
(
...
@@ -997,20 +1007,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -997,20 +1007,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
constexpr
int
shuffleMPerBlock
=
constexpr
int
shuffleMPerBlock
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I1
);
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I1
);
static_assert
(
PostShuffleThreadSliceSize_M
%
MeanVarTransferScalarPerVector
==
0
);
auto
mean_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
mean_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
AccDataType
,
MeanDataType
,
MeanDataType
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
mean_grid_desc_mblock_mperblock_nblock
),
decltype
(
mean_
var_count_
grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
1
,
MeanVarTransferScalarPerVector
,
1
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
false
>
{
mean_grid_desc_mblock_mperblock_nblock
,
false
>
{
mean_
var_count_
grid_desc_mblock_mperblock_nblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
make_multi_index
(
block_work_idx
[
I0
],
// mblock
shuffleMPerBlock
*
i
+
shuffleMPerBlock
*
i
+
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
...
@@ -1021,32 +1030,59 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -1021,32 +1030,59 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
AccDataType
,
AccDataType
,
VarDataType
,
VarDataType
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
var
_grid_desc_mblock_mperblock_nblock
),
decltype
(
mean_var_count
_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
1
,
MeanVarTransferScalarPerVector
,
1
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
1
,
false
>
{
var
_grid_desc_mblock_mperblock_nblock
,
false
>
{
mean_var_count
_grid_desc_mblock_mperblock_nblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
make_multi_index
(
block_work_idx
[
I0
],
// mblock
shuffleMPerBlock
*
i
+
shuffleMPerBlock
*
i
+
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
block_work_idx
[
I1
]),
// nblock
block_work_idx
[
I1
]),
// nblock
tensor_operation
::
element_wise
::
PassThrough
{}};
tensor_operation
::
element_wise
::
PassThrough
{}};
mean_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
auto
count_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
int32_t
,
int32_t
,
decltype
(
thread_welford_desc_I_m_I
),
decltype
(
mean_var_count_grid_desc_mblock_mperblock_nblock
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
PostShuffleThreadSliceSize_M
,
1
>
,
Sequence
<
0
,
1
,
2
>
,
1
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
false
>
{
mean_var_count_grid_desc_mblock_mperblock_nblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
shuffleMPerBlock
*
i
+
post_shuffle_thread_data_idx_begin
[
I0
],
// mperblock
block_work_idx
[
I1
]),
// nblock
tensor_operation
::
element_wise
::
PassThrough
{}};
mean_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
mean_thread_buf
,
mean_thread_buf
,
mean_grid_desc_mblock_mperblock_nblock
,
mean_
var_count_
grid_desc_mblock_mperblock_nblock
,
mean_grid_buf
);
mean_grid_buf
);
var_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
var_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
var_thread_buf
,
var_thread_buf
,
var
_grid_desc_mblock_mperblock_nblock
,
mean_var_count
_grid_desc_mblock_mperblock_nblock
,
var_grid_buf
);
var_grid_buf
);
count_thread_copy_vgpr_to_global
.
Run
(
thread_welford_desc_I_m_I
,
make_tuple
(
I0
,
I0
,
I0
),
count_thread_buf
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
welford_count_grid_buf
);
});
});
}
// shuffle C + Ds + welford + write out
}
// shuffle C + Ds + welford + write out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment