Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
98def248
Commit
98def248
authored
Apr 23, 2024
by
Adam Osewski
Browse files
Rework RunWrite.
parent
bbd26e10
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
463 additions
and
478 deletions
+463
-478
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+22
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+441
-468
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
98def248
...
...
@@ -157,10 +157,12 @@ __global__ void
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
// if (changed group_id || next [M,N] tile)
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
{
GridwiseGemm
::
StorePartials
(
p_workspace
,
results_buffer
);
}
// With cshuffle at store partials all workgroups have to store
// their partials to workspace gmem.
// TODO: The reduction workgroup don't have to store it's own results to GMEM!
// Would be enough to keep it in registers and during AccumulatePartials
// do CShuffle in flight with loading partials products of other peer workgroups.
GridwiseGemm
::
StorePartials
(
p_workspace
,
static_cast
<
void
*>
(
p_shared
),
results_buffer
);
work_scheduler
.
FlagFinished
();
...
...
@@ -171,10 +173,20 @@ __global__ void
index_t
neighbour_count
=
work_scheduler
.
WaitForNeighbours
(
k_batch
,
b2c_tile_map
.
GetTileKIdx
());
constexpr
auto
workspace_thread_desc_m0m1_n0n1n2
=
GridwiseGemm
::
MakeReductionThreadDesc_M0M1_N0N1N2
();
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
typename
GridwiseGemm
::
CShuffleDataT
,
workspace_thread_desc_m0m1_n0n1n2
.
GetElementSpaceSize
(),
true
>
acc_buff
{};
acc_buff
.
Clear
();
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
if
(
neighbour_count
>
0
)
GridwiseGemm
::
AccumulatePartials
(
p_workspace
,
results
_buff
er
,
neighbour_count
+
1
);
GridwiseGemm
::
AccumulatePartials
(
p_workspace
,
acc
_buff
,
neighbour_count
+
1
);
// Signal waiting blocks that they can start use their workspace.
work_scheduler
.
Reset
(
neighbour_count
);
...
...
@@ -195,17 +207,17 @@ __global__ void
GridwiseGemm
::
template
RunWrite
(
p_ds_grid
,
p_e_grid
,
static_cast
<
void
*
>(
p_shared
)
,
acc_buff
,
M
,
N
,
stride_ds
,
stride_e
,
cde_element_op
,
b2c_tile_map
,
results_buffer
);
b2c_tile_map
);
}
else
if
(
work_scheduler
.
HasTile
())
{
// TODO Move this just before StorePartials!
work_scheduler
.
WaitForReduction
();
}
}
while
(
work_scheduler
.
HasTile
());
...
...
@@ -757,7 +769,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<<
", grid_size: "
<<
grid_size
<<
", flag_count: "
<<
flag_count
<<
", p_flags: "
<<
p_flags
<<
", workspace_ptr: "
<<
dev_gemm_workspace
<<
", acc_workspace_size_bytes: "
<<
acc_workspace_size_bytes
<<
std
::
endl
;
<<
", kbatch: "
<<
arg
.
K_BATCH
<<
std
::
endl
;
}
auto
preprocess
=
[
&
]()
{
...
...
@@ -995,7 +1007,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
// the amount of workspace bytes needed, may be less due to the number of available CUs in
// stream used to launch kernel.
size_t
size_bytes
=
Block2ETileMapKSplit
::
GetAccWorkspaceSize
(
sizeof
(
Acc
DataType
),
grid_size
)
+
Block2ETileMapKSplit
::
GetAccWorkspaceSize
(
sizeof
(
CShuffle
DataType
),
grid_size
)
+
flag_count
*
sizeof
(
uint32_t
);
return
size_bytes
;
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
98def248
...
...
@@ -10,12 +10,13 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v
7
.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v
6r1
.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
...
...
@@ -80,6 +81,13 @@ template <typename ADataType,
PipelineVersion
PipelineVer
>
class
GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
template
<
index_t
...
Ids
>
__device__
static
bool
is_thread_local_1d_id_idx
()
{
const
auto
tid
=
get_thread_local_1d_id
();
return
((
tid
==
Ids
)
||
...);
}
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
GemmSpecialization
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
...
...
@@ -106,7 +114,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
public:
using
AccType
=
AccDataType
;
using
AccType
=
AccDataType
;
using
CShuffleDataT
=
CShuffleDataType
;
__host__
__device__
static
auto
CalculateMPadded
(
index_t
M
)
{
...
...
@@ -327,10 +336,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
c_block_size
*
sizeof
(
CShuffleDataType
));
}
// E desc for destination in blockwise copy
// M0 - MBlock
// M1 - MPerBlock
// N0 - NBlock
// N1 - NVecPerThread
// N2 - NVecSize
template
<
typename
EGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_M
Block_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
MakeEGridDescriptor_M
0M1_N0N1N2
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
...
...
@@ -345,18 +358,49 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction_M0_N0N1
();
constexpr
auto
workspace_thread_desc_m0m1_n0n1n2
=
MakeReductionThreadDesc_M0M1_N0N1N2
();
// # of threads in NDim * vector load size * # repeats per thread
constexpr
auto
NPerBlockPadded
=
cluster_length_reduce
.
At
(
I2
)
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I3
)
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I4
);
constexpr
auto
NPerBlockPad
=
NPerBlockPadded
-
Number
<
NPerBlock
>
{};
const
auto
e_grid_desc_m0m1_n0n1pad
=
transform_tensor_descriptor
(
e_grid_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_pass_through_transform
(
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
)),
make_pass_through_transform
(
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I1
)),
make_pass_through_transform
(
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
)),
make_right_pad_transform
(
Number
<
NPerBlock
>
{},
NPerBlockPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
e_grid_desc_m0m1_n0n1n2
=
transform_tensor_descriptor
(
e_grid_desc_m0m1_n0n1pad
,
make_tuple
(
make_pass_through_transform
(
e_grid_desc_m0m1_n0n1pad
.
GetLength
(
I0
)),
make_pass_through_transform
(
e_grid_desc_m0m1_n0n1pad
.
GetLength
(
I1
)),
make_pass_through_transform
(
e_grid_desc_m0m1_n0n1pad
.
GetLength
(
I2
)),
make_unmerge_transform
(
make_tuple
(
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I3
)
*
cluster_length_reduce
.
At
(
I2
),
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I4
)))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
>
{}));
return
e_grid_desc_m0m1_n0n1n2
;
}
// Ds desc for source in blockwise copy
template
<
typename
DsGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeDsGridDescriptor_M
Block_MPerBlock_NBlock_NPerBlock
(
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
)
MakeDsGridDescriptor_M
0M1_N0N1N2
(
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
i
]);
},
[
&
](
auto
i
)
{
return
MakeEGridDescriptor_M0M1_N0N1N2
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumDTensor
>
{});
}
...
...
@@ -600,20 +644,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
__host__
__device__
static
auto
MakeWorkspaceGridDesc_GridSize_
I1_
MPerBlock_NPerBlock
(
index_t
grid_size
)
MakeWorkspaceGridDesc_GridSize_MPerBlock_
I1_
NPerBlock
(
index_t
grid_size
)
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
grid_size
,
I1
.
value
,
MPerBlock
,
NPerBlock
),
make_tuple
(
MPerBlock
*
NPerBlock
,
MPerBlock
*
NPerBlock
,
NPerBlock
,
I1
.
value
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
grid_size
,
I1
.
value
,
MPerBlock
,
NPerBlock
),
make_tuple
(
MPerBlock
*
NPerBlock
,
MPerBlock
*
NPerBlock
,
I1
.
value
,
MPerBlock
));
}
return
make_naive_tensor_descriptor
(
make_tuple
(
grid_size
,
MPerBlock
,
I1
.
value
,
NPerBlock
),
make_tuple
(
MPerBlock
*
NPerBlock
,
NPerBlock
,
NPerBlock
,
I1
.
value
));
}
__device__
__host__
static
constexpr
auto
GetMPerBlock
()
{
return
MPerBlock
;
}
...
...
@@ -850,21 +885,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
c_thread_buf
);
}
// TODO Need to do CShuffle already here:
template
<
typename
CThreadBuf
>
__device__
static
void
StorePartials
(
void
*
__restrict__
p_workspace
,
void
*
__restrict__
p_shared
,
const
CThreadBuf
&
c_thread_buf
)
{
// M0 = grid_size
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const
auto
workspace_grid_desc_m0_n0_m1_n1
=
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock
(
get_grid_size
());
const
auto
w_grid_m0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I0
);
const
auto
w_grid_n0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I1
);
using
BlockwiseGemmT
=
remove_cvref_t
<
decltype
(
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
...
...
@@ -880,161 +905,10 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
KPack
,
LoopSched
>
())
>
;
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
BlockwiseGemmT
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
// M0 = grid_size -> MRepeats (MXdlPerWave)
// N0 = 1 -> NRepeats (NXdlPerWave)
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1
,
make_tuple
(
make_pass_through_transform
(
w_grid_m0
),
make_pass_through_transform
(
w_grid_n0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
,
6
,
7
,
8
>
{},
Sequence
<
3
,
5
,
9
>
{}));
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
make_merge_transform
(
make_tuple
(
w_grid_m0
,
M0
)),
// MRepeats (grid)
make_merge_transform
(
make_tuple
(
w_grid_n0
,
N0
)),
// NRepeats (grid)
make_pass_through_transform
(
M1
),
// MWave
make_pass_through_transform
(
N1
),
// NWave
make_pass_through_transform
(
M2
),
// mfma_instr.num_groups_per_blk
make_pass_through_transform
(
M3
),
// mfma_instr.num_input_blks
make_pass_through_transform
(
M4
),
// mfma_instr.group_size
make_pass_through_transform
(
N2
)),
// mfma_instr.num_threads_per_blk
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{},
Sequence
<
8
>
{},
Sequence
<
9
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
auto
p_workspace_grid
=
reinterpret_cast
<
AccDataType
*>
(
p_workspace
);
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_workspace_grid
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmT
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
const
auto
c_thread_mtx_on_block
=
BlockwiseGemmT
::
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
auto
c_thread_copy_vgpr_to_gmem
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
AccDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLengths
()),
// SliceLengths
// N -> then M dims
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// DimAccessOrder
7
,
// DstVectorDim,
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
true
>
{
// DstResetCoordinateAfterRun
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
((
static_cast
<
index_t
>
(
blockIdx
.
x
))
*
MXdlPerWave
,
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
c_thread_copy_vgpr_to_gmem
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
w_grid_buf
);
}
template
<
typename
CThreadBuf
>
__device__
static
void
AccumulatePartials
(
void
*
__restrict__
p_workspace
,
CThreadBuf
&
c_thread_buf
,
uint32_t
reduce_count
)
{
using
BlockwiseGemmT
=
remove_cvref_t
<
decltype
(
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ComputeType
,
ComputeType
,
AccDataType
,
decltype
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()),
decltype
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()),
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
LoopSched
>
())
>
;
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmT
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
(),
true
>
acc_buf
{};
// M0 = grid_size
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const
auto
workspace_grid_desc_m0_n0_m1_n1
=
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock
(
get_grid_size
());
const
auto
w_grid_m0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I0
);
const
auto
w_grid_n0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I1
);
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"MXdlPerWave % CShuffleMXdlPerWavePerShuffle != 0 or "
"NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0,"
);
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
BlockwiseGemmT
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
...
...
@@ -1048,213 +922,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
// M0 = grid_size -> MRepeats (MXdlPerWave)
// N0 = 1 -> NRepeats (NXdlPerWave)
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1
,
make_tuple
(
make_pass_through_transform
(
w_grid_m0
),
make_pass_through_transform
(
w_grid_n0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
,
6
,
7
,
8
>
{},
Sequence
<
3
,
5
,
9
>
{}));
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
make_merge_transform
(
make_tuple
(
w_grid_m0
,
M0
)),
// MRepeats (grid)
make_merge_transform
(
make_tuple
(
w_grid_n0
,
N0
)),
// NRepeats (grid)
make_pass_through_transform
(
M1
),
// MWave
make_pass_through_transform
(
N1
),
// NWave
make_pass_through_transform
(
M2
),
// mfma_instr.num_groups_per_blk
make_pass_through_transform
(
M3
),
// mfma_instr.num_input_blks
make_pass_through_transform
(
M4
),
// mfma_instr.group_size
make_pass_through_transform
(
N2
)),
// mfma_instr.num_threads_per_blk
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{},
Sequence
<
8
>
{},
Sequence
<
9
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
const
auto
c_thread_mtx_on_block
=
BlockwiseGemmT
::
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
auto
p_workspace_grid
=
reinterpret_cast
<
AccDataType
*>
(
p_workspace
);
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_workspace_grid
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
auto
acc_load
=
ThreadwiseTensorSliceTransfer_v2
<
AccDataType
,
// SrcData,
AccDataType
,
// DstData,
decltype
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
// SrcDesc,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
// DstDesc,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLengths
()),
// SliceLengths,
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// DimAccessOrder,
7
,
// SrcVectorDim,
1
,
// SrcScalarPerVector,
1
,
// SrcScalarStrideInVector,
false
// SrcResetCoordinateAfterRun,
>
{
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
make_multi_index
((
static_cast
<
index_t
>
(
blockIdx
.
x
)
+
1
)
*
MXdlPerWave
,
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
])};
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
AccDataType
>
;
constexpr
auto
partial_acc_load_step
=
make_multi_index
(
MXdlPerWave
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
for
(
uint32_t
i_t
=
1
;
i_t
<
reduce_count
;
++
i_t
)
{
acc_buf
.
Clear
();
acc_load
.
Run
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
w_grid_buf
,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
acc_buf
);
static_for
<
0
,
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
(),
1
>
{}(
[
&
](
auto
i_vec
)
{
Accumulation
::
Calculate
(
c_thread_buf
(
i_vec
),
acc_buf
[
i_vec
]);
});
acc_load
.
MoveSrcSliceWindow
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
partial_acc_load_step
);
}
}
template
<
typename
Block2ETileMap
,
typename
CThreadBuf
>
__device__
static
void
RunWrite
(
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
const
index_t
M
,
const
index_t
N
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
Block2ETileMap
&
block_2_etile_map
,
const
CThreadBuf
&
c_thread_buf
)
{
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
DsGridDesc_M_N
ds_grid_desc_m_n
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
ds_grid_desc_m_n
(
j
)
=
MakeEGridDescriptor_M_N
<
DLayout
>
(
M
,
N
,
StrideDs
[
j
]);
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
ds_grid_desc_mblock_mperblock_nblock_nperblock
(
j
)
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
j
]);
});
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideE
);
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
auto
e_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// shuffle C and write out
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// divide block work by [M, N, K]
const
auto
block_work_idx
=
block_2_etile_map
.
GetBottomIndex
();
using
BlockwiseGemmT
=
remove_cvref_t
<
decltype
(
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ComputeType
,
ComputeType
,
AccDataType
,
decltype
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()),
decltype
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()),
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
,
LoopSched
>
())
>
;
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmT
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
BlockwiseGemmT
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
...
...
@@ -1281,6 +951,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmT
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
const
auto
c_thread_mtx_on_block
=
BlockwiseGemmT
::
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
...
...
@@ -1338,32 +1011,44 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_desc_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of reference to C/Ds tensor descriptors
const
auto
c_ds_buf_refs
=
concat_tuple_of_reference
(
tie
(
c_shuffle_block_buf
),
generate_tie
(
[
&
](
auto
i
)
->
const
auto
&
// return type should be reference
{
return
ds_grid_buf
[
i
];
},
Number
<
NumDTensor
>
{}));
// tuple of starting index of C/Ds blockwise copy
const
auto
idx_c_ds_block_begin
=
container_concat
(
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
);
},
Number
<
NumDTensor
>
{}));
// space filling curve for threadwise C in VGPR before shuffle
// M0 = grid_size
// M1 = MPerBlock
// N0 = 1
// N1 = NPerBlock
const
auto
workspace_grid_desc_m0_m1_n0_n1
=
MakeWorkspaceGridDesc_GridSize_MPerBlock_I1_NPerBlock
(
get_grid_size
());
auto
p_workspace_grid
=
reinterpret_cast
<
AccDataType
*>
(
p_workspace
);
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_workspace_grid
,
workspace_grid_desc_m0_m1_n0_n1
.
GetElementSpaceSize
());
// shuffle: blockwise copy C from LDS to workspace
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
// ElementwiseOperation,
InMemoryDataOperationEnum
::
Set
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
CShuffleDataType
,
// typename SrcData,
CShuffleDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
workspace_grid_desc_m0_m1_n0_n1
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
workspace_grid_desc_m0_m1_n0_n1
,
make_multi_index
(
static_cast
<
index_t
>
(
blockIdx
.
x
),
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
...
...
@@ -1376,8 +1061,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
M4
,
1
>>
{};
// space filling curve for shuffled blockwise
C/D/E
constexpr
auto
sfc_
cde_block
=
// space filling curve for shuffled blockwise
W in global mem
constexpr
auto
sfc_
w_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
...
...
@@ -1386,39 +1071,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
// blockwise copy C/D/E between LDS and global
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
Tuple
<
EDataType
>
,
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CDEElementwiseOperation
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make
// Sequence support
// arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumDTensor
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
)),
cde_element_op
};
static_assert
(
num_access
==
sfc_w_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
...
...
@@ -1435,28 +1088,348 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
block_sync_lds
();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_buf
));
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
workspace_grid_desc_m0_m1_n0_n1
,
w_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
constexpr
auto
w_global_step
=
sfc_w_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
workspace_grid_desc_m0_m1_n0_n1
,
w_global_step
);
}
});
}
__device__
static
constexpr
auto
GetClusterLengthReduction_M0_N0N1
()
{
return
Sequence
<
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
I1
),
I1
.
value
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
I3
)
>
{};
}
__device__
static
constexpr
auto
MakeReductionThreadDesc_M0M1_N0N1N2
()
{
constexpr
auto
cluster_lengths
=
GetClusterLengthReduction_M0_N0N1
();
constexpr
auto
N1_elems
=
math
::
integer_divide_ceil
(
Number
<
NPerBlock
>
{},
cluster_lengths
.
At
(
I2
));
static_assert
(
N1_elems
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
,
"Invalid ReductionThreadDesc M0M1_N0N1N2! N1_elems have to be a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock!"
);
constexpr
auto
N2
=
Number
<
CDEShuffleBlockTransferScalarPerVector_NPerBlock
>
{};
constexpr
auto
N1
=
math
::
integer_divide_ceil
(
N1_elems
,
N2
);
constexpr
auto
M1
=
math
::
integer_divide_ceil
(
Number
<
MPerBlock
>
{},
cluster_lengths
.
At
(
I0
));
static_assert
(
Number
<
M1
>
{}
*
cluster_lengths
.
At
(
I0
)
>=
Number
<
MPerBlock
>
{},
"Invalid ReductionThreadDesc M0M1_N0N1N2! M1 * cluster_length[0] have to be grater "
"or equal to MPerBlock."
);
static_assert
(
Number
<
N1
>
{}
*
Number
<
N2
>
{}
*
cluster_lengths
.
At
(
I2
)
>=
Number
<
NPerBlock
>
{},
"Invalid ReductionThreadDesc M0M1_N0N1N2! N1 * N2 * cluster_length[2] have "
"to be grater or equal to NPerBlock."
);
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
M1
>
{},
I1
,
N1
,
N2
));
}
template
<
typename
AccumulationBuffer
>
__device__
static
void
AccumulatePartials
(
void
*
__restrict__
p_workspace
,
AccumulationBuffer
&
acc_buff
,
uint32_t
reduce_count
)
{
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction_M0_N0N1
();
constexpr
auto
reduce_cluster_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
const
auto
reduce_thread_cluster_idx
=
reduce_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
const
auto
thread_n0_cluster_id
=
reduce_thread_cluster_idx
[
I1
];
// Should be I0
const
auto
thread_n1_cluster_id
=
reduce_thread_cluster_idx
[
I2
];
constexpr
auto
workspace_thread_desc_m0m1_n0n1n2
=
MakeReductionThreadDesc_M0M1_N0N1N2
();
const
auto
workspace_grid_desc_m0m1_n0n1
=
MakeWorkspaceGridDesc_GridSize_MPerBlock_I1_NPerBlock
(
get_grid_size
());
// # of threads in NDim * vector load size * # repeats per thread
constexpr
auto
NPerBlockPadded
=
cluster_length_reduce
.
At
(
I2
)
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I3
)
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I4
);
constexpr
auto
NPerBlockPad
=
NPerBlockPadded
-
Number
<
NPerBlock
>
{};
const
auto
workspace_grid_desc_m0m1_n0n1pad
=
transform_tensor_descriptor
(
workspace_grid_desc_m0m1_n0n1
,
make_tuple
(
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1
.
GetLength
(
I0
)),
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1
.
GetLength
(
I1
)),
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1
.
GetLength
(
I2
)),
make_right_pad_transform
(
Number
<
NPerBlock
>
{},
NPerBlockPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
workspace_grid_desc_m0m1_n0n1n2
=
transform_tensor_descriptor
(
workspace_grid_desc_m0m1_n0n1pad
,
make_tuple
(
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1pad
.
GetLength
(
I0
)),
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1pad
.
GetLength
(
I1
)),
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1pad
.
GetLength
(
I2
)),
make_unmerge_transform
(
make_tuple
(
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I3
),
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I4
)
*
cluster_length_reduce
.
At
(
I2
)))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
>
{}));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
CShuffleDataType
,
workspace_thread_desc_m0m1_n0n1n2
.
GetElementSpaceSize
(),
true
>
partial_acc_buf
{};
auto
p_workspace_grid
=
reinterpret_cast
<
CShuffleDataType
*>
(
p_workspace
);
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
p_workspace_grid
,
workspace_grid_desc_m0m1_n0n1n2
.
GetElementSpaceSize
());
auto
acc_load
=
ThreadwiseTensorSliceTransfer_v2
<
CShuffleDataType
,
// SrcData,
CShuffleDataType
,
// DstData,
decltype
(
workspace_grid_desc_m0m1_n0n1n2
),
// SrcDesc,
decltype
(
workspace_thread_desc_m0m1_n0n1n2
),
// DstDesc,
decltype
(
workspace_thread_desc_m0m1_n0n1n2
.
GetLengths
()),
// SliceLengths,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
// DimAccessOrder,
4
,
// SrcVectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
// SrcScalarPerVector,
1
,
// SrcScalarStrideInVector,
false
// SrcResetCoordinateAfterRun,
>
{
workspace_grid_desc_m0m1_n0n1n2
,
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
// make_multi_index((static_cast<index_t>(blockIdx.x) + 1),
// We want to have a thread raked access pattern
make_multi_index
(
static_cast
<
index_t
>
(
blockIdx
.
x
),
thread_m_cluster_id
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I1
),
I0
,
thread_n0_cluster_id
,
thread_n1_cluster_id
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I4
))};
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
CShuffleDataType
>
;
constexpr
auto
partial_acc_load_step
=
make_multi_index
(
I1
,
I0
,
I0
,
I0
,
I0
);
// TODO: We do not need to read this workgroup partial results since they're
// already in c_thread_buff
for
(
uint32_t
i_t
=
0
;
i_t
<
reduce_count
;
++
i_t
)
{
partial_acc_buf
.
Clear
();
acc_load
.
Run
(
workspace_grid_desc_m0m1_n0n1n2
,
w_grid_buf
,
workspace_thread_desc_m0m1_n0n1n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
partial_acc_buf
);
// move on Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_desc_refs
,
i
+
I1
,
cde_lds_and_global_step
);
static_for
<
0
,
workspace_thread_desc_m0m1_n0n1n2
.
GetElementSpaceSize
(),
1
>
{}(
[
&
](
auto
i_vec
)
{
Accumulation
::
Calculate
(
acc_buff
(
i_vec
),
partial_acc_buf
[
i_vec
]);
});
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
cde_lds_and_global_step
);
}
acc_load
.
MoveSrcSliceWindow
(
workspace_grid_desc_m0m1_n0n1n2
,
partial_acc_load_step
);
}
}
template
<
typename
Block2ETileMap
,
typename
AccumulationBuffer
>
__device__
static
void
RunWrite
(
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
/* void* __restrict__ p_shared, */
const
AccumulationBuffer
&
acc_buff
,
const
index_t
M
,
const
index_t
N
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
Block2ETileMap
&
block_2_etile_map
)
{
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M0M1_N0N1N2
(
DsGridDesc_M_N
{}))
>
;
constexpr
index_t
ScalarPerVector
=
CDEShuffleBlockTransferScalarPerVector_NPerBlock
;
DsGridDesc_M_N
ds_grid_desc_m_n
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_m0m1_n0n1n2
;
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideE
);
const
auto
e_grid_desc_m0m1_n0n1n2
=
MakeEGridDescriptor_M0M1_N0N1N2
(
e_grid_desc_m_n
);
auto
e_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_grid
,
e_grid_desc_m0m1_n0n1n2
.
GetElementSpaceSize
());
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
ds_grid_desc_m_n
(
j
)
=
MakeEGridDescriptor_M_N
<
DLayout
>
(
M
,
N
,
StrideDs
[
j
]);
});
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
ds_grid_desc_m0m1_n0n1n2
(
j
)
=
MakeEGridDescriptor_M0M1_N0N1N2
(
ds_grid_desc_m_n
[
j
]);
});
// TODO: on MI300 we could use NonTemporal load, MI200 streaming mode?
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_m0m1_n0n1n2
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor
>
{});
constexpr
auto
ds_thread_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
return
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
DDataType
,
ScalarPerVector
,
true
>
{};
},
Number
<
NumDTensor
>
{});
auto
aux_vgpr_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EDataType
,
ScalarPerVector
,
true
>
{};
constexpr
auto
d_vgpr_buf_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
Number
<
ScalarPerVector
>
{}));
// divide block work by [M, N, K]
const
auto
block_work_idx
=
block_2_etile_map
.
GetBottomIndex
();
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction_M0_N0N1
();
constexpr
auto
reduce_cluster_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
const
auto
reduce_thread_cluster_idx
=
reduce_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
const
auto
thread_n0_cluster_id
=
reduce_thread_cluster_idx
[
I1
];
// Should be I0
const
auto
thread_n1_cluster_id
=
reduce_thread_cluster_idx
[
I2
];
constexpr
auto
workspace_thread_desc_m0m1_n0n1n2
=
MakeReductionThreadDesc_M0M1_N0N1N2
();
auto
ds_grid_load
=
generate_tuple
(
[
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
using
SliceLengths
=
Sequence
<
I1
,
I1
,
I1
,
I1
,
ScalarPerVector
>
;
return
ThreadwiseTensorSliceTransfer_v2
<
DDataType
,
DDataType
,
decltype
(
ds_grid_desc_m0m1_n0n1n2
(
i
)),
decltype
(
d_vgpr_buf_desc
),
SliceLengths
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
ScalarPerVector
,
1
,
false
>
{
ds_grid_desc_m_n
(
i
),
make_multi_index
(
block_work_idx
[
I0
],
thread_m_cluster_id
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I1
),
block_work_idx
[
I1
],
thread_n0_cluster_id
,
thread_n1_cluster_id
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I4
))};
},
Number
<
NumDTensor
>
{});
auto
e_grid_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
EDataType
,
EDataType
,
decltype
(
workspace_thread_desc_m0m1_n0n1n2
),
decltype
(
e_grid_desc_m0m1_n0n1n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
I1
,
I1
,
I1
,
I1
,
ScalarPerVector
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
ScalarPerVector
,
EGlobalMemoryDataOperation
,
1
,
false
>
{
e_grid_desc_m0m1_n0n1n2
,
make_multi_index
(
block_work_idx
[
I0
],
thread_m_cluster_id
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I1
),
block_work_idx
[
I1
],
thread_n0_cluster_id
,
thread_n1_cluster_id
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I4
)),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
constexpr
auto
MIter
=
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I1
);
constexpr
auto
NIter
=
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I3
);
constexpr
auto
n1_step
=
cluster_length_reduce
.
At
(
I2
);
constexpr
auto
d_grid_M1_fwd_step
=
make_multi_index
(
I0
,
I1
,
I0
,
I0
,
I0
);
constexpr
auto
d_grid_N1_fwd_step
=
make_multi_index
(
I0
,
I0
,
I0
,
n1_step
,
I0
);
constexpr
auto
d_grid_N1_bwd_step
=
make_multi_index
(
I0
,
I0
,
I0
,
-
1
*
n1_step
*
(
NIter
-
1
),
I0
);
constexpr
auto
thr_buf_N1_offset
=
Number
<
ScalarPerVector
>
{};
constexpr
auto
thr_buf_M1_offset
=
NIter
*
thr_buf_N1_offset
;
static_for
<
0
,
MIter
,
1
>
{}([
&
](
auto
m_idx
)
{
static_for
<
0
,
NIter
,
1
>
{}([
&
](
auto
n_idx
)
{
// load multiple Ds:
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
d_idx
)
{
ds_grid_load
(
d_idx
).
Run
(
ds_grid_desc_m0m1_n0n1n2
(
d_idx
),
ds_grid_buf
(
d_idx
),
d_vgpr_buf_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
ds_thread_buf
(
d_idx
));
});
constexpr
auto
acc_buf_offset
=
m_idx
*
thr_buf_M1_offset
+
n_idx
*
thr_buf_N1_offset
;
// apply pointwise function
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
I
)
{
// get reference to src data
const
auto
src_data_ds_refs
=
generate_tie
(
// return type should be lvalue
[
&
](
auto
iSrc
)
->
const
auto
&
{
return
ds_thread_buf
[
iSrc
][
I
];
},
Number
<
NumDTensor
>
{});
const
auto
src_data_refs
=
concat_tuple_of_reference
(
tie
(
acc_buff
[
acc_buf_offset
+
I
]),
src_data_ds_refs
);
// apply pointwise function
// pointwise function signature:
// element_op_(dst_data_refs[I0],
// dst_data_refs[I1],
// ...,
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
cde_element_op
,
tie
(
aux_vgpr_buf
(
I
)),
src_data_refs
);
});
e_grid_store
.
Run
(
workspace_thread_desc_m0m1_n0n1n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
aux_vgpr_buf
,
e_grid_desc_m0m1_n0n1n2
,
e_grid_buf
);
if
constexpr
(
n_idx
!=
(
NIter
-
1
))
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
d_idx
)
{
ds_grid_load
(
d_idx
).
MoveSrcSliceWindow
(
ds_grid_desc_m0m1_n0n1n2
(
d_idx
),
d_grid_N1_fwd_step
);
});
e_grid_store
.
MoveDstSliceWindow
(
e_grid_desc_m0m1_n0n1n2
,
d_grid_N1_fwd_step
);
}
else
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
d_idx
)
{
ds_grid_load
(
d_idx
).
MoveSrcSliceWindow
(
ds_grid_desc_m0m1_n0n1n2
(
d_idx
),
d_grid_N1_bwd_step
);
});
e_grid_store
.
MoveDstSliceWindow
(
e_grid_desc_m0m1_n0n1n2
,
d_grid_N1_bwd_step
);
}
});
// NIter
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
d_idx
)
{
ds_grid_load
(
d_idx
).
MoveSrcSliceWindow
(
ds_grid_desc_m0m1_n0n1n2
(
d_idx
),
d_grid_M1_fwd_step
);
});
e_grid_store
.
MoveDstSliceWindow
(
e_grid_desc_m0m1_n0n1n2
,
d_grid_M1_fwd_step
);
});
// MIter
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment