Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
af469e6b
Commit
af469e6b
authored
Mar 19, 2024
by
Adam Osewski
Browse files
Allocate CThreadBuffer on global function level.
* Drop support for MI100. * Make GridwiseGEMM static without members.
parent
9205784f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
101 additions
and
101 deletions
+101
-101
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+12
-11
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+89
-90
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
af469e6b
...
...
@@ -12,7 +12,6 @@
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/host_utility/stream_utility.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/tuple.hpp"
#include <ck/utility/work_scheduling.hpp>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
...
@@ -70,8 +69,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
...
...
@@ -101,7 +99,12 @@ __global__ void
index_t
gemm_tile_id_start
=
0
;
index_t
gemm_tile_id_end
=
grid_size_grp
;
auto
gridwise_gemm
=
GridwiseGemm
();
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
typename
GridwiseGemm
::
AccType
,
GridwiseGemm
::
GetMPerXdl
()
*
GridwiseGemm
::
GetNPerXdl
(),
GridwiseGemm
::
GetCThreadBufferVectorSize
(),
true
>
results_buffer
;
do
{
...
...
@@ -128,10 +131,8 @@ __global__ void
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
using
VGPRBufferT
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
GetCThreadBuffer
())
>
;
auto
results_buffer
=
VGPRBufferT
{};
b2c_tile_map
.
CalculateBottomIndex
(
work_scheduler
.
tile_id_
-
offset
);
results_buffer
.
Clear
();
b2c_tile_map
.
CalculateBottomIndex
(
work_scheduler
.
tile_id_
-
offset
);
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
...
...
@@ -139,7 +140,7 @@ __global__ void
do
{
// just accumulate results in registers!
g
ridwise
_g
emm
.
template
RunGEMM
<
HasMainKBlockLoop
>(
p_a_grid
,
G
ridwise
G
emm
::
template
RunGEMM
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
static_cast
<
void
*>
(
p_shared
),
a_element_op
,
...
...
@@ -162,7 +163,7 @@ __global__ void
// if (changed group_id || next [M,N] tile)
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
{
g
ridwise
_g
emm
.
StorePartials
(
p_workspace
,
results_buffer
);
G
ridwise
G
emm
::
StorePartials
(
p_workspace
,
results_buffer
);
}
work_scheduler
.
FlagFinished
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
...
...
@@ -177,7 +178,7 @@ __global__ void
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
if
(
neighbour_count
>
1
)
g
ridwise
_g
emm
.
AccumulatePartials
(
p_workspace
,
results_buffer
,
neighbour_count
);
G
ridwise
G
emm
::
AccumulatePartials
(
p_workspace
,
results_buffer
,
neighbour_count
);
// Signal waiting blocks that they can start use their workspace.
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
...
...
@@ -196,7 +197,7 @@ __global__ void
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
});
g
ridwise
_g
emm
.
template
RunWrite
(
p_ds_grid
,
G
ridwise
G
emm
::
template
RunWrite
(
p_ds_grid
,
p_e_grid
,
static_cast
<
void
*
>(
p_shared
),
M
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
af469e6b
...
...
@@ -269,54 +269,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using
BBlockDesc_KBatch_BK0PerB_NPerB_BK1
=
remove_cvref_t
<
decltype
(
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1
())
>
;
using
ABlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
AK0PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
ComputeType
,
AGridDesc_KBatch_AK0_M_AK1
,
ABlockDesc_KBatch_AK0PerB_MPerB_AK1
,
ABlockTransferSrcAccessOrder
,
Sequence
<
2
,
0
,
1
,
3
>
,
ABlockTransferSrcVectorDim
,
3
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
;
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
BK0PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
ComputeType
,
BGridDesc_KBatch_BK0_N_BK1
,
BBlockDesc_KBatch_BK0PerB_NPerB_BK1
,
BBlockTransferSrcAccessOrder
,
Sequence
<
2
,
0
,
1
,
3
>
,
BBlockTransferSrcVectorDim
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
;
public:
__host__
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
...
...
@@ -664,13 +616,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
}
// TODO: we should refactor out all those common Make... descriptors to sth like
// gridwise_gemm_utils.hpp
__device__
__host__
static
constexpr
auto
GetMPerBlock
()
{
return
MPerBlock
;
}
__device__
__host__
static
constexpr
auto
GetNPerBlock
()
{
return
NPerBlock
;
}
__device__
__host__
static
constexpr
auto
GetMPerXdl
()
{
return
MPerXdl
;
}
__device__
__host__
static
constexpr
auto
GetNPerXdl
()
{
return
NPerXdl
;
}
__device__
__host__
static
constexpr
auto
&
GetCThreadBuffer
()
__device__
static
constexpr
auto
GetCThreadBuffer
VectorSize
()
{
using
BlockwiseGemmT
=
remove_cvref_t
<
decltype
(
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
...
...
@@ -686,20 +637,19 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
NXdlPerWave
,
KPack
,
LoopSched
>
())
>
;
BlockwiseGemmT
blockwise_gemm
;
return
blockwise_gemm
.
GetCThreadBuffer
();
return
BlockwiseGemmT
::
xdlops_gemm
.
GetRegSizePerXdlops
();
}
template
<
bool
HasMainKBlockLoop
,
typename
Block2ETileMap
,
typename
CThreadBuf
>
__device__
void
RunGEMM
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
AGridDesc_KBatch_AK0_M_AK1
&
a_grid_desc_kbatch_ak0_m_ak1
,
const
BGridDesc_KBatch_BK0_N_BK1
&
b_grid_desc_kbatch_bk0_n_bk1
,
const
Block2ETileMap
&
block_2_etile_map
,
CThreadBuf
&
c_thread_buf
)
__device__
static
void
RunGEMM
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
AGridDesc_KBatch_AK0_M_AK1
&
a_grid_desc_kbatch_ak0_m_ak1
,
const
BGridDesc_KBatch_BK0_N_BK1
&
b_grid_desc_kbatch_bk0_n_bk1
,
const
Block2ETileMap
&
block_2_etile_map
,
CThreadBuf
&
c_thread_buf
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_kbatch_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
@@ -727,6 +677,54 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr
auto
b_block_desc_kbatch_bk0_n_bk1
=
GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1
();
using
ABlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
AK0PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
ComputeType
,
AGridDesc_KBatch_AK0_M_AK1
,
ABlockDesc_KBatch_AK0PerB_MPerB_AK1
,
ABlockTransferSrcAccessOrder
,
Sequence
<
2
,
0
,
1
,
3
>
,
ABlockTransferSrcVectorDim
,
3
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
;
using
BBlockwiseCopy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
BK0PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
ComputeType
,
BGridDesc_KBatch_BK0_N_BK1
,
BBlockDesc_KBatch_BK0PerB_NPerB_BK1
,
BBlockTransferSrcAccessOrder
,
Sequence
<
2
,
0
,
1
,
3
>
,
BBlockTransferSrcVectorDim
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
;
// A matrix blockwise copy
auto
a_blockwise_copy
=
ABlockwiseCopy
(
a_grid_desc_kbatch_ak0_m_ak1
,
...
...
@@ -817,19 +815,19 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
template
<
bool
HasMainKBlockLoop
,
typename
Block2ETileMap
,
typename
CThreadBuf
>
__device__
void
RunGEMM
(
const
void
*
__restrict__
p_a_grid_
,
const
void
*
__restrict__
p_b_grid_
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
index_t
KBatch
,
const
Block2ETileMap
&
block_2_etile_map
,
CThreadBuf
&
c_thread_buf
)
__device__
static
void
RunGEMM
(
const
void
*
__restrict__
p_a_grid_
,
const
void
*
__restrict__
p_b_grid_
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
index_t
KBatch
,
const
Block2ETileMap
&
block_2_etile_map
,
CThreadBuf
&
c_thread_buf
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
...
...
@@ -854,7 +852,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// TODO Need to do CShuffle already here:
template
<
typename
CThreadBuf
>
__device__
void
StorePartials
(
void
*
__restrict__
p_workspace
,
const
CThreadBuf
&
c_thread_buf
)
__device__
static
void
StorePartials
(
void
*
__restrict__
p_workspace
,
const
CThreadBuf
&
c_thread_buf
)
{
// M0 = grid_size
// N0 = 1
...
...
@@ -999,9 +998,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
template
<
typename
CThreadBuf
>
__device__
void
AccumulatePartials
(
void
*
__restrict__
p_workspace
,
CThreadBuf
&
c_thread_buf
,
uint32_t
reduce_count
)
__device__
static
void
AccumulatePartials
(
void
*
__restrict__
p_workspace
,
CThreadBuf
&
c_thread_buf
,
uint32_t
reduce_count
)
{
using
BlockwiseGemmT
=
remove_cvref_t
<
decltype
(
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
...
...
@@ -1167,16 +1166,16 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
template
<
typename
Block2ETileMap
,
typename
CThreadBuf
>
__device__
void
RunWrite
(
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
const
index_t
M
,
const
index_t
N
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
Block2ETileMap
&
block_2_etile_map
,
const
CThreadBuf
&
c_thread_buf
)
__device__
static
void
RunWrite
(
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
const
index_t
M
,
const
index_t
N
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
Block2ETileMap
&
block_2_etile_map
,
const
CThreadBuf
&
c_thread_buf
)
{
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment