Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
333176c5
Commit
333176c5
authored
May 21, 2024
by
Adam Osewski
Browse files
Draft changes to run gridwise gemm through multiple SplitK tiles
parent
be48abdb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
320 additions
and
258 deletions
+320
-258
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_single_kernel_fp16.cpp
...grouped_gemm_multiple_d_splitk_xdl_single_kernel_fp16.cpp
+23
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+90
-60
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+205
-188
No files found.
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_single_kernel_fp16.cpp
View file @
333176c5
...
...
@@ -37,10 +37,11 @@ using BDataType = F16;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F
32
;
using
EDataType
=
F
16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
// using BLayout = Row;
using
DsLayout
=
ck
::
Tuple
<>
;
using
ELayout
=
Row
;
...
...
@@ -56,7 +57,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultip
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
;
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v1
>
;
// < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// clang-format on
struct
ProblemSize
final
...
...
@@ -76,8 +79,9 @@ struct ExecutionConfig final
{
bool
do_verification
=
true
;
int
init_method
=
1
;
int
k_batch
=
128
;
bool
time_kernel
=
false
;
// int k_batch = 128;
int
k_batch
=
1
;
bool
time_kernel
=
false
;
};
bool
run_grouped_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
...
...
@@ -158,9 +162,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
case
3
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
1
}(
a_tensors
[
i
]);
ck
::
utils
::
FillConstant
<
BDataType
>
{
1
}(
b_tensors
[
i
]);
break
;
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
// a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
// b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
ck
::
utils
::
FillMonotonicSeq
<
ADataType
>
{
0
,
1
}(
a_tensors
[
i
]);
ck
::
utils
::
FillMonotonicSeq
<
BDataType
>
{
1
,
1
}(
b_tensors
[
i
]);
}
}
...
...
@@ -309,17 +319,20 @@ int main(int argc, char* argv[])
if
(
argc
<
11
)
{
std
::
vector
<
ck
::
index_t
>
Ms
{
64
,
127
,
255
,
129
,
260
,
190
,
77
};
// std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
std
::
vector
<
ck
::
index_t
>
Ms
{
64
};
problem_size
.
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
problem_size
.
Ms
.
push_back
(
Ms
[
i
]);
problem_size
.
Ns
.
push_back
(
252
);
// problem_size.Ns.push_back(252);
problem_size
.
Ns
.
push_back
(
256
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
// problem_size.stride_Bs.push_back(problem_size.Ns[i]);
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
333176c5
...
...
@@ -131,30 +131,40 @@ __global__ void
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
results_buffer
.
Clear
();
//
results_buffer.Clear();
b2c_tile_map
.
CalculateBottomIndex
(
work_scheduler
.
tile_id_
-
offset
);
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM.
do
{
// just accumulate results in registers!
GridwiseGemm
::
template
RunGEMM
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
static_cast
<
void
*>
(
p_shared
),
a_element_op
,
b_element_op
,
M
,
N
,
K
,
StrideA
,
StrideB
,
k_batch
,
b2c_tile_map
,
results_buffer
);
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
// do
// {
auto
k_tiles
=
work_scheduler
.
GetNextKTiles
(
k_batch
,
b2c_tile_map
.
GetTileKIdx
());
// if (blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>())
// {
// printf("bid: %d, k_tiles: %d\n",
// static_cast<index_t>(blockIdx.x),
// k_tiles);
// }
// just accumulate results in registers!
GridwiseGemm
::
template
RunGEMM
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
static_cast
<
void
*>
(
p_shared
),
a_element_op
,
b_element_op
,
M
,
N
,
K
,
StrideA
,
StrideB
,
k_batch
,
b2c_tile_map
,
results_buffer
,
k_tiles
);
// Move to the last processed k-tile
b2c_tile_map
.
AdvanceTileKIdx
(
k_tiles
-
1
);
// } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
// if (changed group_id || next [M,N] tile)
// With cshuffle at store partials all workgroups have to store
...
...
@@ -164,7 +174,7 @@ __global__ void
// do CShuffle in flight with loading partials products of other peer workgroups.
GridwiseGemm
::
StorePartials
(
p_workspace
,
static_cast
<
void
*>
(
p_shared
),
results_buffer
);
#if
0
#if
1
// make sure all writes to gmem has finished.
__builtin_amdgcn_s_waitcnt
(
0x0f70
);
// s_waitcnt vmcnt(0)
// __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0)
...
...
@@ -212,6 +222,11 @@ __global__ void
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
});
// if (threadIdx.x == 0)
// {
// p_e_grid[blockIdx.x] = 0;
// }
GridwiseGemm
::
template
RunWrite
(
p_ds_grid
,
p_e_grid
,
acc_buff
,
...
...
@@ -497,29 +512,29 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
{
bool
all_have_main_k_block_loop
;
{
const
auto
a_grid_desc_
kbatch_
ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_
KBatch_
AK0_M_AK1
(
gemm_kernel_args_
[
0
].
M
,
gemm_kernel_args_
[
0
].
K
,
gemm_kernel_args_
[
0
].
StrideA
,
K_BATCH
);
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
gemm_kernel_args_
[
0
].
M
,
gemm_kernel_args_
[
0
].
K
,
gemm_kernel_args_
[
0
].
StrideA
,
K_BATCH
);
all_have_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_
kbatch
_ak0_m_ak1
.
GetLength
(
I
1
)
*
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I3
)
);
a_grid_desc_
ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc
_ak0_m_ak1
.
GetLength
(
I
2
)
/
K_BATCH
);
}
for
(
std
::
size_t
i
=
0
;
i
<
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
gemm_arg
=
gemm_kernel_args_
[
i
];
auto
kbatch
=
K_BATCH
;
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
gemm_arg
.
M
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
kbatch
);
const
auto
&
gemm_arg
=
gemm_kernel_args_
[
i
];
auto
kbatch
=
K_BATCH
;
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
gemm_arg
.
M
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
kbatch
);
bool
not_all_have_main_k_block_loop_same
=
all_have_main_k_block_loop
xor
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
*
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I3
));
all_have_main_k_block_loop
xor
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)
/
K_BATCH
);
if
(
not_all_have_main_k_block_loop_same
)
{
...
...
@@ -616,7 +631,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
auto
[
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
]
=
[[
maybe_unused
]]
auto
[
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
]
=
CheckArgument
(
arg
,
stream_config
);
if
(
dev_gemm_args
==
nullptr
)
...
...
@@ -698,17 +713,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
bool
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
;
{
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
arg
.
gemm_kernel_args_
[
0
].
M
,
arg
.
gemm_kernel_args_
[
0
].
K
,
arg
.
gemm_kernel_args_
[
0
].
StrideA
,
arg
.
K_BATCH
);
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
arg
.
gemm_kernel_args_
[
0
].
M
,
arg
.
gemm_kernel_args_
[
0
].
K
,
arg
.
gemm_kernel_args_
[
0
].
StrideA
,
arg
.
K_BATCH
);
all_have_kbatch_gt_one
=
arg
.
K_BATCH
>
1
;
all_have_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_
kbatch_
ak0_m_ak1
.
GetLength
(
I
1
)
*
a_grid_desc_
kbatch_
ak0_m_ak1
.
GetLength
(
I
3
)
);
a_grid_desc_ak0_m_ak1
.
GetLength
(
I
0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I
2
/
kbatch
);
}
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
...
...
@@ -737,14 +751,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw
std
::
runtime_error
(
err
.
str
());
}
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
gemm_arg
.
M
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
kbatch
);
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
gemm_arg
.
M
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
kbatch
);
bool
not_all_have_main_k_block_loop_same
=
all_have_main_k_block_loop
xor
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
*
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I3
));
all_have_main_k_block_loop
xor
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)
/
kbatch
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
if
(
not_all_have_main_k_block_loop_same
)
...
...
@@ -853,8 +867,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
}
auto
preprocess
=
[
&
]()
{
// std::cout << "[preprocess] p_flags: " << p_flags
// << ", flag count: " << flag_count
// << ", bytes: " << flag_count * sizeof(uint32_t)
// << ", stream id: " << stream_config.stream_id_
// << std::endl;
hip_check_error
(
hipMemsetAsync
(
p_flags
,
0
,
flag_count
*
sizeof
(
uint32_t
),
stream_config
.
stream_id_
));
// TODO: For debug only!
hip_check_error
(
hipMemsetAsync
(
dev_gemm_workspace
,
2
,
acc_workspace_size_bytes
,
stream_config
.
stream_id_
));
};
return
launch_and_time_kernel_with_preprocess
(
...
...
@@ -890,11 +912,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
#if DEBUG_LOG
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
"and kernel args size!"
<<
std
::
endl
;
#endif // DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
"and kernel args size!"
<<
std
::
endl
;
}
return
false
;
}
...
...
@@ -913,11 +936,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
arg
.
K_BATCH
);
if
(
not
group_arg_valid
)
{
#if DEBUG_LOG
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
gemm_arg
.
Print
();
#endif // DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
gemm_arg
.
Print
();
}
}
supported
=
supported
&&
group_arg_valid
;
}
...
...
@@ -1043,6 +1067,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
size_t
size_bytes
=
Block2ETileMapKSplit
::
GetAccWorkspaceSize
(
sizeof
(
CShuffleDataType
),
grid_size
)
+
flag_count
*
sizeof
(
uint32_t
);
std
::
cout
<<
"[GetWorkspaceSize]: "
<<
"occ_grid_size: "
<<
occ_grid_size
<<
", grid_size: "
<<
grid_size
<<
", tiles_per_block: "
<<
tiles_per_block
<<
", flag_count: "
<<
flag_count
<<
", size_bytes: "
<<
size_bytes
<<
std
::
endl
;
return
size_bytes
;
}
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
333176c5
...
...
@@ -1531,6 +1531,8 @@ struct BlockToCTileMap_LinearKSplit
return
false
;
}
__host__
__device__
void
AdvanceTileKIdx
(
index_t
k_tiles
)
{
K0_idx_
+=
k_tiles
;
}
///
/// @brief Determines whether the current workgroup processed first tile in K dimension
///
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
333176c5
...
...
@@ -57,7 +57,7 @@ template <typename ADataType,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_
KBatch_
AK0_M_AK1
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
...
...
@@ -65,7 +65,7 @@ template <typename ADataType,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_
KBatch_
BK0_N_BK1
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
...
...
@@ -81,13 +81,6 @@ template <typename ADataType,
PipelineVersion
PipelineVer
>
class
GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
template
<
index_t
...
Ids
>
__device__
static
bool
is_thread_local_1d_id_idx
()
{
const
auto
tid
=
get_thread_local_1d_id
();
return
((
tid
==
Ids
)
||
...);
}
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
GemmSpecialization
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
...
...
@@ -132,28 +125,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
return
math
::
integer_least_multiple
(
K
,
KPerBlock
*
K_Batch
);
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
AK0PerBlock
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
AK0PerBlock
*
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
I1
,
BK0PerBlock
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
BK0PerBlock
*
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
...
...
@@ -171,7 +142,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
__host__
__device__
static
auto
MakeAGridDescriptor_
KBatch_
AK0_M_AK1
(
index_t
M
,
index_t
K
,
index_t
StrideA
,
index_t
KBatch
)
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
K
,
index_t
StrideA
,
index_t
KBatch
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
...
...
@@ -184,7 +155,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
}();
const
auto
MPad
=
CalculateMPadded
(
M
);
const
auto
KPad
=
CalculateKPadded
(
K
,
KBatch
);
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
...
...
@@ -193,33 +163,34 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
AK0
=
KPad
/
(
KBatch
*
AK1
)
;
const
auto
AK0
=
KPad
/
AK1
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
const
auto
MPad
=
CalculateMPadded
(
M
);
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
AK0
,
AK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
AK0
,
AK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_
KBatch_
BK0_N_BK1
(
index_t
K
,
index_t
N
,
index_t
StrideB
,
index_t
KBatch
)
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
N
,
index_t
StrideB
,
index_t
KBatch
)
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
...
@@ -241,7 +212,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
BK0
=
KPad
/
(
KBatch
*
BK1
)
;
const
auto
BK0
=
KPad
/
BK1
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
...
...
@@ -251,32 +222,30 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
BK0
,
BK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
BK0
,
BK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
private:
using
AGridDesc_KBatch_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
1
,
1
,
1
,
1
))
>
;
using
BGridDesc_KBatch_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_KBatch_BK0_N_BK1
(
1
,
1
,
1
,
1
))
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
,
1
))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
,
1
))
>
;
using
ABlockDesc_
KBatch_
AK0PerB_MPerB_AK1
=
remove_cvref_t
<
decltype
(
GetABlockDescriptor_
KBatch_
AK0PerBlock_MPerBlock_AK1
())
>
;
using
BBlockDesc_
KBatch_
BK0PerB_NPerB_BK1
=
remove_cvref_t
<
decltype
(
GetBBlockDescriptor_
KBatch_
BK0PerBlock_NPerBlock_BK1
())
>
;
using
ABlockDesc_AK0PerB_MPerB_AK1
=
remove_cvref_t
<
decltype
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
())
>
;
using
BBlockDesc_BK0PerB_NPerB_BK1
=
remove_cvref_t
<
decltype
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
())
>
;
public:
__host__
__device__
static
constexpr
auto
...
...
@@ -423,11 +392,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const
index_t
StrideE
,
const
index_t
KBatch
)
{
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
M
,
K
,
StrideA
,
KBatch
);
const
auto
b_grid_desc_kbatch_bk0_n_bk1
=
MakeBGridDescriptor_KBatch_BK0_N_BK1
(
K
,
N
,
StrideB
,
KBatch
);
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideE
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
M
,
K
,
StrideA
,
KBatch
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
K
,
N
,
StrideB
,
KBatch
);
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideE
);
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
...
...
@@ -436,12 +403,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
if
(
!
(
M
%
MPerBlock
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
M
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M value is not a multiple of MPerBlock! M: "
<<
M
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
...
...
@@ -453,12 +420,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
if
(
!
(
N
%
NPerBlock
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
N
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N value is not a multiple of NPerBlock! N: "
<<
N
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
#endif // DEBUG_LOG
return
false
;
}
}
...
...
@@ -471,12 +439,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
auto
K_t
=
KBatch
*
KPerBlock
;
if
(
!
(
K
%
K_t
==
0
))
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K value is not a multiple of ! KBatch * KPerBlock: "
<<
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K value is not a multiple of ! KBatch * KPerBlock: "
<<
K
<<
" "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
...
...
@@ -485,13 +453,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
if
(
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K ("
<<
K
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":
"
<<
__LINE__
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
#endif // DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K ("
<<
K
<<
") value is not a multiple of
ABlockTransferSrcScalarPerVector
(
"
<<
ABlockTransferSrcScalarPerVector
<<
" )!
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
...
...
@@ -499,13 +467,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
if
(
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M ("
<<
M
<<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
<<
ABlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":
"
<<
__LINE__
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
#endif // DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M ("
<<
M
<<
") value is not a multiple of
ABlockTransferSrcScalarPerVector
(
"
<<
ABlockTransferSrcScalarPerVector
<<
" )!
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
...
...
@@ -514,13 +482,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
if
(
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N ("
<<
N
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":
"
<<
__LINE__
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
#endif // DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N ("
<<
N
<<
") value is not a multiple of
BBlockTransferSrcScalarPerVector
(
"
<<
BBlockTransferSrcScalarPerVector
<<
" )!
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
...
...
@@ -528,13 +496,13 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
if
(
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg K ("
<<
K
<<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<<
BBlockTransferSrcScalarPerVector
<<
" )! "
<<
__FILE__
<<
":
"
<<
__LINE__
<<
", in function:
"
<<
__
func
__
<<
std
::
endl
;
#endif // DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg K ("
<<
K
<<
") value is not a multiple of
BBlockTransferSrcScalarPerVector
(
"
<<
BBlockTransferSrcScalarPerVector
<<
" )!
"
<<
__
FILE
__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
...
...
@@ -543,14 +511,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
if
(
N
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg N ("
<<
N
<<
") value is not a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CDEShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg N ("
<<
N
<<
") value is not a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CDEShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
...
...
@@ -558,31 +527,33 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
if
(
M
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
#if DEBUG_LOG
std
::
cout
<<
"Arg M ("
<<
M
<<
") value is not a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CDEShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"Arg M ("
<<
M
<<
") value is not a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock ("
<<
CDEShuffleBlockTransferScalarPerVector_NPerBlock
<<
" )! "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
*
a_grid_desc_kbatch
_ak0_m_ak1
.
GetLength
(
I
3
))
/
KPerBlock
;
const
auto
num_k_loop
=
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc
_ak0_m_ak1
.
GetLength
(
I
2
))
/
(
KPerBlock
*
KBatch
)
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
#if DEBUG_LOG
std
::
cout
<<
"The number of k loops ("
<<
num_k_loop
<<
") value is not supported by GridwiseGemm Pipeline."
<<
" K0Padded: "
<<
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
#endif // DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"The number of k loops ("
<<
num_k_loop
<<
") value is not supported by GridwiseGemm Pipeline."
<<
" K0Padded: "
<<
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
)
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
...
...
@@ -590,8 +561,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// check tensor size: cannot be larger than 2GB each
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_
kbatch_
ak0_m_ak1
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
<=
TwoGB
&&
b_grid_desc_
kbatch_
bk0_n_bk1
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
<=
TwoGB
&&
if
(
!
(
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
<=
TwoGB
&&
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
))
{
return
false
;
...
...
@@ -681,16 +652,17 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
AGridDesc_
KBatch_
AK0_M_AK1
&
a_grid_desc_
kbatch_
ak0_m_ak1
,
const
BGridDesc_
KBatch_
BK0_N_BK1
&
b_grid_desc_
kbatch_
bk0_n_bk1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
Block2ETileMap
&
block_2_etile_map
,
CThreadBuf
&
c_thread_buf
)
CThreadBuf
&
c_thread_buf
,
const
index_t
k_tiles
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_
kbatch_
ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_
kbatch_
bk0_n_bk1
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
// divide block work by [M, N, K]
const
auto
block_work_idx
=
block_2_etile_map
.
GetBottomIndex
();
...
...
@@ -701,33 +673,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_kbatch_ak0_m_ak1
=
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_kbatch_bk0_n_bk1
=
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
using
ABlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
AK0PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_
KBatch_
AK0_M_AK1
,
Sequence
<
AK0PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
ComputeType
,
AGridDesc_
KBatch_
AK0_M_AK1
,
ABlockDesc_
KBatch_
AK0PerB_MPerB_AK1
,
AGridDesc_AK0_M_AK1
,
ABlockDesc_AK0PerB_MPerB_AK1
,
ABlockTransferSrcAccessOrder
,
Sequence
<
2
,
0
,
1
,
3
>
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
3
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
...
...
@@ -741,17 +707,17 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
BK0PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_
KBatch_
BK0_N_BK1
,
Sequence
<
BK0PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
ComputeType
,
BGridDesc_
KBatch_
BK0_N_BK1
,
BBlockDesc_
KBatch_
BK0PerB_NPerB_BK1
,
BGridDesc_BK0_N_BK1
,
BBlockDesc_BK0PerB_NPerB_BK1
,
BBlockTransferSrcAccessOrder
,
Sequence
<
2
,
0
,
1
,
3
>
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
3
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
...
...
@@ -760,30 +726,35 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
true
,
NumGemmKPrefetchStage
>
;
const
index_t
ak0_start_idx
=
kbatch_id
*
AK0PerBlock
;
const
index_t
bk0_start_idx
=
kbatch_id
*
BK0PerBlock
;
if
(
blockIdx
.
x
<
4
&&
ck
::
debug
::
is_thread_local_1d_id_idx
<
0
>
())
{
printf
(
"[RunGEMM] bid: %d, ak0_start_idx: %d, bk0_start_idx: %d
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
ak0_start_idx
,
bk0_start_idx
);
}
// A matrix blockwise copy
auto
a_blockwise_copy
=
ABlockwiseCopy
(
a_grid_desc_
kbatch_
ak0_m_ak1
,
make_multi_index
(
kbatch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
ABlockwiseCopy
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
ak0_start_idx
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_
kbatch_
ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
,
0
),
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
BBlockwiseCopy
(
b_grid_desc_
kbatch_
bk0_n_bk1
,
make_multi_index
(
kbatch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
BBlockwiseCopy
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
bk0_start_idx
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_
kbatch_
bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
,
0
),
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
...
...
@@ -792,6 +763,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// register
// auto& c_thread_buf = blockwise_gemm_.GetCThreadBuffer();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
...
...
@@ -803,19 +777,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
static_cast
<
ComputeType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
/
BK1
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
((
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
*
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I3
))
/
KPerBlock
);
// TODO: what if AK1 != BK1 ???
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
k_tiles
);
// __builtin_amdgcn_readfirstlane((a_grid_desc_ak0_m_ak1.GetLength(I1) *
// a_grid_desc_ak0_m_ak1.GetLength(I3)) /
// KPerBlock);
if
(
blockIdx
.
x
<
4
&&
ck
::
debug
::
is_thread_local_1d_id_idx
<
0
>
())
{
printf
(
"[RunGEMM] bid: %d, num_k_block_main_loop %d
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
num_k_block_main_loop
);
}
bool
clear_c_thread_buf
=
fals
e
;
bool
clear_c_thread_buf
=
tru
e
;
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
...
...
@@ -831,14 +813,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
KPack
,
LoopSched
>
();
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_
kbatch_
ak0_m_ak1
,
a_block_desc_
kbatch_
ak0_m_ak1
,
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_
kbatch_
bk0_n_bk1
,
b_block_desc_
kbatch_
bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
...
...
@@ -862,27 +844,26 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const
index_t
StrideB
,
const
index_t
KBatch
,
const
Block2ETileMap
&
block_2_etile_map
,
CThreadBuf
&
c_thread_buf
)
CThreadBuf
&
c_thread_buf
,
const
index_t
k_tiles
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
// tensor descriptors for block/thread-wise copy
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
M
,
K
,
StrideA
,
KBatch
);
const
auto
b_grid_desc_kbatch_bk0_n_bk1
=
MakeBGridDescriptor_KBatch_BK0_N_BK1
(
K
,
N
,
StrideB
,
KBatch
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
M
,
K
,
StrideA
,
KBatch
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
K
,
N
,
StrideB
,
KBatch
);
RunGEMM
<
HasMainKBlockLoop
>
(
p_a_grid
,
p_b_grid
,
p_shared
,
a_element_op
,
b_element_op
,
a_grid_desc_
kbatch_
ak0_m_ak1
,
b_grid_desc_
kbatch_
bk0_n_bk1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
block_2_etile_map
,
c_thread_buf
);
c_thread_buf
,
k_tiles
);
}
template
<
typename
CThreadBuf
>
...
...
@@ -1247,6 +1228,27 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
acc_load
.
MoveSrcSliceWindow
(
workspace_grid_desc_m0m1_n0n1n2
,
partial_acc_load_step
);
}
// if(is_thread_local_1d_id_idx<0, 1, 8, 39>())
// {
// printf("[bid: %d, tid: %d], {Accumulate Partials} AccBuf v[0, 0, 0, 0, 0-3]: [%f,
// %f,"
// "%f, %f]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// static_cast<float>(acc_buff[Number<0>{}]),
// static_cast<float>(acc_buff[Number<1>{}]),
// static_cast<float>(acc_buff[Number<2>{}]),
// static_cast<float>(acc_buff[Number<3>{}]));
// printf("[bid: %d, tid: %d], {Accumulate Partials} AccBuf v[0, 0, 0, 1, 0-3]: [%f,
// %f,"
// "%f, %f]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// static_cast<float>(acc_buff[Number<8>{}]),
// static_cast<float>(acc_buff[Number<9>{}]),
// static_cast<float>(acc_buff[Number<10>{}]),
// static_cast<float>(acc_buff[Number<11>{}]));
// }
}
template
<
typename
Block2ETileMap
,
typename
AccumulationBuffer
>
...
...
@@ -1411,6 +1413,21 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
unpack2
(
cde_element_op
,
tie
(
aux_vgpr_buf
(
I
)),
src_data_refs
);
});
// if(is_thread_local_1d_id_idx<0, 1, 8, 39>())
// {
// printf("[bid: %d, tid: %d, m_iter: %d, n_iter: %d], {RunWrite} AuxBuf v[0-3]:
// "
// " [%f, %f, %f, %f]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// m_idx.value,
// n_idx.value,
// static_cast<float>(aux_vgpr_buf[Number<0>{}]),
// static_cast<float>(aux_vgpr_buf[Number<1>{}]),
// static_cast<float>(aux_vgpr_buf[Number<2>{}]),
// static_cast<float>(aux_vgpr_buf[Number<3>{}]));
// }
e_grid_store
.
Run
(
workspace_thread_desc_m0m1_n0n1n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
aux_vgpr_buf
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment