Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
333176c5
Commit
333176c5
authored
May 21, 2024
by
Adam Osewski
Browse files
Draft changes to run gridwise gemm through multiple SplitK tiles
parent
be48abdb
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
320 additions
and
258 deletions
+320
-258
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_single_kernel_fp16.cpp
...grouped_gemm_multiple_d_splitk_xdl_single_kernel_fp16.cpp
+23
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+90
-60
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+205
-188
No files found.
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_single_kernel_fp16.cpp
View file @
333176c5
...
@@ -37,10 +37,11 @@ using BDataType = F16;
...
@@ -37,10 +37,11 @@ using BDataType = F16;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F
32
;
using
EDataType
=
F
16
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
BLayout
=
Col
;
// using BLayout = Row;
using
DsLayout
=
ck
::
Tuple
<>
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
ELayout
=
Row
;
using
ELayout
=
Row
;
...
@@ -56,7 +57,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultip
...
@@ -56,7 +57,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmMultip
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
>
;
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
8
,
ck
::
PipelineVersion
::
v1
>
;
// < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// clang-format on
// clang-format on
struct
ProblemSize
final
struct
ProblemSize
final
...
@@ -76,8 +79,9 @@ struct ExecutionConfig final
...
@@ -76,8 +79,9 @@ struct ExecutionConfig final
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
int
k_batch
=
128
;
// int k_batch = 128;
bool
time_kernel
=
false
;
int
k_batch
=
1
;
bool
time_kernel
=
false
;
};
};
bool
run_grouped_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
bool
run_grouped_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
...
@@ -158,9 +162,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -158,9 +162,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
break
;
break
;
case
3
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
1
}(
a_tensors
[
i
]);
ck
::
utils
::
FillConstant
<
BDataType
>
{
1
}(
b_tensors
[
i
]);
break
;
default:
default:
a_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
0
>
{});
// a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_tensors
[
i
].
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
// b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
ck
::
utils
::
FillMonotonicSeq
<
ADataType
>
{
0
,
1
}(
a_tensors
[
i
]);
ck
::
utils
::
FillMonotonicSeq
<
BDataType
>
{
1
,
1
}(
b_tensors
[
i
]);
}
}
}
}
...
@@ -309,17 +319,20 @@ int main(int argc, char* argv[])
...
@@ -309,17 +319,20 @@ int main(int argc, char* argv[])
if
(
argc
<
11
)
if
(
argc
<
11
)
{
{
std
::
vector
<
ck
::
index_t
>
Ms
{
64
,
127
,
255
,
129
,
260
,
190
,
77
};
// std::vector<ck::index_t> Ms{64, 127, 255, 129, 260, 190, 77};
std
::
vector
<
ck
::
index_t
>
Ms
{
64
};
problem_size
.
group_count
=
Ms
.
size
();
problem_size
.
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
{
problem_size
.
Ms
.
push_back
(
Ms
[
i
]);
problem_size
.
Ms
.
push_back
(
Ms
[
i
]);
problem_size
.
Ns
.
push_back
(
252
);
// problem_size.Ns.push_back(252);
problem_size
.
Ns
.
push_back
(
256
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
// problem_size.stride_Bs.push_back(problem_size.Ns[i]);
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
333176c5
...
@@ -131,30 +131,40 @@ __global__ void
...
@@ -131,30 +131,40 @@ __global__ void
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
results_buffer
.
Clear
();
//
results_buffer.Clear();
b2c_tile_map
.
CalculateBottomIndex
(
work_scheduler
.
tile_id_
-
offset
);
b2c_tile_map
.
CalculateBottomIndex
(
work_scheduler
.
tile_id_
-
offset
);
// Iterate over K dimension for this [M,N] tile
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
// still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM.
// TODO: change desc so that few K-tiles will be done in single GEMM.
do
// do
{
// {
// just accumulate results in registers!
auto
k_tiles
=
work_scheduler
.
GetNextKTiles
(
k_batch
,
b2c_tile_map
.
GetTileKIdx
());
GridwiseGemm
::
template
RunGEMM
<
HasMainKBlockLoop
>(
p_a_grid
,
// if (blockIdx.x < 4 && ck::debug::is_thread_local_1d_id_idx<0>())
p_b_grid
,
// {
static_cast
<
void
*>
(
p_shared
),
// printf("bid: %d, k_tiles: %d\n",
a_element_op
,
// static_cast<index_t>(blockIdx.x),
b_element_op
,
// k_tiles);
M
,
// }
N
,
// just accumulate results in registers!
K
,
GridwiseGemm
::
template
RunGEMM
<
HasMainKBlockLoop
>(
p_a_grid
,
StrideA
,
p_b_grid
,
StrideB
,
static_cast
<
void
*>
(
p_shared
),
k_batch
,
a_element_op
,
b2c_tile_map
,
b_element_op
,
results_buffer
);
M
,
N
,
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
K
,
StrideA
,
StrideB
,
k_batch
,
b2c_tile_map
,
results_buffer
,
k_tiles
);
// Move to the last processed k-tile
b2c_tile_map
.
AdvanceTileKIdx
(
k_tiles
-
1
);
// } while(work_scheduler.GetNextTile() && b2c_tile_map.GetNextKTileIdx());
// if (changed group_id || next [M,N] tile)
// if (changed group_id || next [M,N] tile)
// With cshuffle at store partials all workgroups have to store
// With cshuffle at store partials all workgroups have to store
...
@@ -164,7 +174,7 @@ __global__ void
...
@@ -164,7 +174,7 @@ __global__ void
// do CShuffle in flight with loading partials products of other peer workgroups.
// do CShuffle in flight with loading partials products of other peer workgroups.
GridwiseGemm
::
StorePartials
(
p_workspace
,
static_cast
<
void
*>
(
p_shared
),
results_buffer
);
GridwiseGemm
::
StorePartials
(
p_workspace
,
static_cast
<
void
*>
(
p_shared
),
results_buffer
);
#if
0
#if
1
// make sure all writes to gmem has finished.
// make sure all writes to gmem has finished.
__builtin_amdgcn_s_waitcnt
(
0x0f70
);
// s_waitcnt vmcnt(0)
__builtin_amdgcn_s_waitcnt
(
0x0f70
);
// s_waitcnt vmcnt(0)
// __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0)
// __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0)
...
@@ -212,6 +222,11 @@ __global__ void
...
@@ -212,6 +222,11 @@ __global__ void
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
});
});
// if (threadIdx.x == 0)
// {
// p_e_grid[blockIdx.x] = 0;
// }
GridwiseGemm
::
template
RunWrite
(
p_ds_grid
,
GridwiseGemm
::
template
RunWrite
(
p_ds_grid
,
p_e_grid
,
p_e_grid
,
acc_buff
,
acc_buff
,
...
@@ -497,29 +512,29 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -497,29 +512,29 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
{
{
bool
all_have_main_k_block_loop
;
bool
all_have_main_k_block_loop
;
{
{
const
auto
a_grid_desc_
kbatch_
ak0_m_ak1
=
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_
KBatch_
AK0_M_AK1
(
gemm_kernel_args_
[
0
].
M
,
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
gemm_kernel_args_
[
0
].
M
,
gemm_kernel_args_
[
0
].
K
,
gemm_kernel_args_
[
0
].
K
,
gemm_kernel_args_
[
0
].
StrideA
,
gemm_kernel_args_
[
0
].
StrideA
,
K_BATCH
);
K_BATCH
);
all_have_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
all_have_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_
kbatch
_ak0_m_ak1
.
GetLength
(
I
1
)
*
a_grid_desc_
ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc
_ak0_m_ak1
.
GetLength
(
I
2
)
/
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I3
)
);
K_BATCH
);
}
}
for
(
std
::
size_t
i
=
0
;
i
<
gemm_kernel_args_
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_kernel_args_
.
size
();
++
i
)
{
{
const
auto
&
gemm_arg
=
gemm_kernel_args_
[
i
];
const
auto
&
gemm_arg
=
gemm_kernel_args_
[
i
];
auto
kbatch
=
K_BATCH
;
auto
kbatch
=
K_BATCH
;
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
GridwiseGemm
::
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
gemm_arg
.
M
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
kbatch
);
gemm_arg
.
M
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
kbatch
);
bool
not_all_have_main_k_block_loop_same
=
bool
not_all_have_main_k_block_loop_same
=
all_have_main_k_block_loop
xor
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
all_have_main_k_block_loop
xor
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
*
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I3
));
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)
/
K_BATCH
);
if
(
not_all_have_main_k_block_loop_same
)
if
(
not_all_have_main_k_block_loop_same
)
{
{
...
@@ -616,7 +631,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -616,7 +631,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void
*
dev_gemm_workspace
,
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
auto
[
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
]
=
[[
maybe_unused
]]
auto
[
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
]
=
CheckArgument
(
arg
,
stream_config
);
CheckArgument
(
arg
,
stream_config
);
if
(
dev_gemm_args
==
nullptr
)
if
(
dev_gemm_args
==
nullptr
)
...
@@ -698,17 +713,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -698,17 +713,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
bool
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
;
bool
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
;
{
{
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
arg
.
gemm_kernel_args_
[
0
].
M
,
arg
.
gemm_kernel_args_
[
0
].
M
,
arg
.
gemm_kernel_args_
[
0
].
K
,
arg
.
gemm_kernel_args_
[
0
].
K
,
arg
.
gemm_kernel_args_
[
0
].
StrideA
,
arg
.
gemm_kernel_args_
[
0
].
StrideA
,
arg
.
K_BATCH
);
arg
.
K_BATCH
);
all_have_kbatch_gt_one
=
arg
.
K_BATCH
>
1
;
all_have_kbatch_gt_one
=
arg
.
K_BATCH
>
1
;
all_have_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
all_have_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_
kbatch_
ak0_m_ak1
.
GetLength
(
I
1
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I
0
)
*
a_grid_desc_
kbatch_
ak0_m_ak1
.
GetLength
(
I
3
)
);
a_grid_desc_ak0_m_ak1
.
GetLength
(
I
2
/
kbatch
);
}
}
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
...
@@ -737,14 +751,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -737,14 +751,14 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
GridwiseGemm
::
MakeAGridDescriptor_KBatch_AK0_M_AK1
(
gemm_arg
.
M
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
kbatch
);
gemm_arg
.
M
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
kbatch
);
bool
not_all_have_main_k_block_loop_same
=
bool
not_all_have_main_k_block_loop_same
=
all_have_main_k_block_loop
xor
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
all_have_main_k_block_loop
xor
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I1
)
*
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_kbatch_ak0_m_ak1
.
GetLength
(
I3
));
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)
/
kbatch
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
if
(
not_all_have_main_k_block_loop_same
)
if
(
not_all_have_main_k_block_loop_same
)
...
@@ -853,8 +867,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -853,8 +867,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
}
}
auto
preprocess
=
[
&
]()
{
auto
preprocess
=
[
&
]()
{
// std::cout << "[preprocess] p_flags: " << p_flags
// << ", flag count: " << flag_count
// << ", bytes: " << flag_count * sizeof(uint32_t)
// << ", stream id: " << stream_config.stream_id_
// << std::endl;
hip_check_error
(
hipMemsetAsync
(
hip_check_error
(
hipMemsetAsync
(
p_flags
,
0
,
flag_count
*
sizeof
(
uint32_t
),
stream_config
.
stream_id_
));
p_flags
,
0
,
flag_count
*
sizeof
(
uint32_t
),
stream_config
.
stream_id_
));
// TODO: For debug only!
hip_check_error
(
hipMemsetAsync
(
dev_gemm_workspace
,
2
,
acc_workspace_size_bytes
,
stream_config
.
stream_id_
));
};
};
return
launch_and_time_kernel_with_preprocess
(
return
launch_and_time_kernel_with_preprocess
(
...
@@ -890,11 +912,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -890,11 +912,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
{
"and kernel args size!"
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
<<
std
::
endl
;
"and kernel args size!"
#endif // DEBUG_LOG
<<
std
::
endl
;
}
return
false
;
return
false
;
}
}
...
@@ -913,11 +936,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -913,11 +936,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
arg
.
K_BATCH
);
arg
.
K_BATCH
);
if
(
not
group_arg_valid
)
if
(
not
group_arg_valid
)
{
{
#if DEBUG_LOG
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
{
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
gemm_arg
.
Print
();
<<
" has invalid GridwiseGemm settings!"
<<
std
::
endl
;
#endif // DEBUG_LOG
gemm_arg
.
Print
();
}
}
}
supported
=
supported
&&
group_arg_valid
;
supported
=
supported
&&
group_arg_valid
;
}
}
...
@@ -1043,6 +1067,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -1043,6 +1067,12 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
size_t
size_bytes
=
size_t
size_bytes
=
Block2ETileMapKSplit
::
GetAccWorkspaceSize
(
sizeof
(
CShuffleDataType
),
grid_size
)
+
Block2ETileMapKSplit
::
GetAccWorkspaceSize
(
sizeof
(
CShuffleDataType
),
grid_size
)
+
flag_count
*
sizeof
(
uint32_t
);
flag_count
*
sizeof
(
uint32_t
);
std
::
cout
<<
"[GetWorkspaceSize]: "
<<
"occ_grid_size: "
<<
occ_grid_size
<<
", grid_size: "
<<
grid_size
<<
", tiles_per_block: "
<<
tiles_per_block
<<
", flag_count: "
<<
flag_count
<<
", size_bytes: "
<<
size_bytes
<<
std
::
endl
;
return
size_bytes
;
return
size_bytes
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
333176c5
...
@@ -1531,6 +1531,8 @@ struct BlockToCTileMap_LinearKSplit
...
@@ -1531,6 +1531,8 @@ struct BlockToCTileMap_LinearKSplit
return
false
;
return
false
;
}
}
__host__
__device__
void
AdvanceTileKIdx
(
index_t
k_tiles
)
{
K0_idx_
+=
k_tiles
;
}
///
///
/// @brief Determines whether the current workgroup processed first tile in K dimension
/// @brief Determines whether the current workgroup processed first tile in K dimension
///
///
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
333176c5
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment