Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1b462ab5
Commit
1b462ab5
authored
Jan 29, 2024
by
Adam Osewski
Browse files
Clean up debug code and reuse new neighbour count func.
parent
e954c206
Changes
10
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
318 additions
and
468 deletions
+318
-468
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+56
-152
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+1
-56
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+150
-150
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multiple_d.hpp
...tensor_operation_instance/gpu/grouped_gemm_multiple_d.hpp
+14
-14
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
...tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
+9
-9
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/CMakeLists.txt
...ation_instance/gpu/grouped_gemm_multiple_d/CMakeLists.txt
+1
-1
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
..._xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
+61
-61
profiler/include/profiler/profile_grouped_gemm_multiple_d_splitk_impl.hpp
.../profiler/profile_grouped_gemm_multiple_d_splitk_impl.hpp
+6
-10
profiler/src/profile_grouped_gemm_multiple_d_splitk.cpp
profiler/src/profile_grouped_gemm_multiple_d_splitk.cpp
+14
-6
test/work_scheduling/test_strided_reduction_tile_loop.cpp
test/work_scheduling/test_strided_reduction_tile_loop.cpp
+6
-9
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
1b462ab5
...
@@ -63,20 +63,19 @@ __global__ void
...
@@ -63,20 +63,19 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_gemm_xdl_splitk_v2
(
kernel_grouped_gemm_xdl_splitk_v2
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
void
*
const
__restrict__
p_workspace
,
void
*
const
__restrict__
p_workspace
,
const
index_t
tile_count
,
const
index_t
tile_count
,
const
index_t
k_batch
,
const
index_t
k_batch
,
[[
maybe_unused
]]
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
[[
maybe_unused
]]
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
[[
maybe_unused
]]
const
CDEElementwiseOperation
cde_element_op
)
const
CDEElementwiseOperation
cde_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
[[
maybe_unused
]]
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
const
auto
gemm_desc_ptr
=
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
...
@@ -105,12 +104,6 @@ __global__ void
...
@@ -105,12 +104,6 @@ __global__ void
index_t
gemm_tile_id_end
=
grid_size_grp
;
index_t
gemm_tile_id_end
=
grid_size_grp
;
auto
gridwise_gemm
=
GridwiseGemm
();
auto
gridwise_gemm
=
GridwiseGemm
();
[[
maybe_unused
]]
auto
is_thread_local_1d_id_idx
=
[](
auto
...
Ids
)
->
bool
{
const
auto
tid
=
get_thread_local_1d_id
();
return
((
tid
==
Ids
)
||
...
);
};
do
do
{
{
// Find corresponding GEMM group for our tile
// Find corresponding GEMM group for our tile
...
@@ -129,12 +122,12 @@ __global__ void
...
@@ -129,12 +122,12 @@ __global__ void
gemm_tile_id_end
=
offset
+
grid_size_grp
;
gemm_tile_id_end
=
offset
+
grid_size_grp
;
}
}
[[
maybe_unused
]]
const
auto
p_a_grid
=
reinterpret_cast
<
const
FloatA
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
);
const
auto
p_a_grid
=
reinterpret_cast
<
const
FloatA
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
);
[[
maybe_unused
]]
const
auto
p_b_grid
=
reinterpret_cast
<
const
FloatB
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
FloatB
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
);
[[
maybe_unused
]]
const
auto
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
auto
K
=
gemm_desc_ptr
[
group_id
].
K
;
[[
maybe_unused
]]
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
[[
maybe_unused
]]
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
auto
&
results_buffer
=
gridwise_gemm
.
GetCThreadBuffer
();
auto
&
results_buffer
=
gridwise_gemm
.
GetCThreadBuffer
();
b2c_tile_map
.
CalculateBottomIndex
(
work_scheduler
.
tile_id_
-
offset
);
b2c_tile_map
.
CalculateBottomIndex
(
work_scheduler
.
tile_id_
-
offset
);
...
@@ -143,32 +136,21 @@ __global__ void
...
@@ -143,32 +136,21 @@ __global__ void
// Iterate over K dimension for this [M,N] tile
// Iterate over K dimension for this [M,N] tile
// still in the same GEMM && the same [M,N] tile
// still in the same GEMM && the same [M,N] tile
// TODO: change desc so that few K-tiles will be done in single GEMM.
// TODO: change desc so that few K-tiles will be done in single GEMM.
// {
// if (is_thread_local_1d_id_idx(0))
// {
// printf("bid: %d, group: %d, accumulate tile id (M,N,K): [%d, %d, %d] \n",
// static_cast<index_t>(blockIdx.x),
// group_id,
// b2c_tile_map.GetTileMIdx(),
// b2c_tile_map.GetTileNIdx(),
// b2c_tile_map.GetTileKIdx());
// }
// }
do
do
{
{
// just accumulate results in registers!
// just accumulate results in registers!
//
gridwise_gemm.template RunGEMM<HasMainKBlockLoop>(p_a_grid,
gridwise_gemm
.
template
RunGEMM
<
HasMainKBlockLoop
>(
p_a_grid
,
//
p_b_grid,
p_b_grid
,
//
static_cast<void*>(p_shared),
static_cast
<
void
*>
(
p_shared
),
//
a_element_op,
a_element_op
,
//
b_element_op,
b_element_op
,
//
M,
M
,
//
N,
N
,
//
K,
K
,
//
StrideA,
StrideA
,
//
StrideB,
StrideB
,
//
k_batch,
k_batch
,
//
b2c_tile_map);
b2c_tile_map
);
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
}
while
(
work_scheduler
.
GetNextTile
()
&&
b2c_tile_map
.
GetNextKTileIdx
());
...
@@ -184,122 +166,47 @@ __global__ void
...
@@ -184,122 +166,47 @@ __global__ void
work_scheduler
.
FlagFinished
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
work_scheduler
.
FlagFinished
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
// {
// // const uint32_t flag_v2 = __builtin_amdgcn_readfirstlane(
// // work_scheduler.GetFlagValue(k_batch, output_tile_idx, output_tile_idx_offset));
if
(
is_thread_local_1d_id_idx
(
0
))
{
printf
(
"bid: %d, group: %d, FlagFInished
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
group_id
);
// printf("bid: %d, group: %d, FlagFInished flag_v[%u]: %u\n",
// static_cast<index_t>(blockIdx.x),
// group_id)
// work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// flag_v2);
}
// }
// The workgroup which processed first K tile accumulates results and stores to GMEM
// The workgroup which processed first K tile accumulates results and stores to GMEM
if
(
b2c_tile_map
.
IsFirstKSplitBlock
())
if
(
b2c_tile_map
.
IsFirstKSplitBlock
())
{
{
if
(
is_thread_local_1d_id_idx
(
0
))
{
printf
(
"bid: %d, group: %d, Will wait for neighbours...
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
group_id
);
}
// Wait untill all other blocks for this [M,N] tile store their results.
// Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler
.
WaitForNeighbours
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
index_t
neighbour_count
=
work_scheduler
.
WaitForNeighbours
(
k_batch
,
b2c_tile_map
.
GetTileKIdx
(),
output_tile_idx
,
output_tile_idx_offset
);
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
[[
maybe_unused
]]
const
uint32_t
flag_v
=
__builtin_amdgcn_readfirstlane
(
work_scheduler
.
GetFlagValue
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
));
// {
// if (is_thread_local_1d_id_idx(0))
// {
// printf("bid: %d, group: %d, WaitForNeighbours flag_v[%u]: %u\n",
// static_cast<index_t>(blockIdx.x),
// group_id,
// work_scheduler.GetWorkgroupFlagIdx(k_batch, output_tile_idx, output_tile_idx_offset),
// static_cast<index_t>(blockIdx.x));
// // flag_v);
// }
// }
// using CThreadBuffer = remove_cvref_t<decltype(results_buffer)>;
// constexpr index_t n_v = CThreadBuffer::num_of_v_.value;
// constexpr index_t s_per_v = CThreadBuffer::s_per_v.value;
// static_for<0, n_v, 1>{}([&](auto v) {
// static_for<0, s_per_v, 1>{}([&](auto s) {
// // printf("bid: %d; tid: %d; [Partial results] c_thread_buff[%d, %d]:
// // %f\n",
// // static_cast<index_t>(blockIdx.x),
// // static_cast<index_t>(threadIdx.x),
// // v.value,
// // s.value,
// // static_cast<float>(results_buffer[v * Number<s_per_v>{} + s])
// // );
// results_buffer(v * Number<s_per_v>{} + s) = threadIdx.x * v + s;
// });
// });
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// Accumulate only when there is at least two workgroups processing splitk data-tiles
// across same MN-output tile.
// across same MN-output tile.
// if(flag_v > 1)
if
(
neighbour_count
>
1
)
// gridwise_gemm.AccumulatePartials(p_workspace, flag_v);
gridwise_gemm
.
AccumulatePartials
(
p_workspace
,
neighbour_count
);
if
(
is_thread_local_1d_id_idx
(
0
))
{
printf
(
"bid: %d, group: %d, Reset flag
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
group_id
);
}
// Signal waiting blocks that they can start use their workspace.
// Signal waiting blocks that they can start use their workspace.
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
//
const auto p_e_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_e_grid);
const
auto
p_e_grid
=
reinterpret_cast
<
FloatC
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
);
//
const auto stride_e = gemm_desc_ptr[group_id].StrideE;
const
auto
stride_e
=
gemm_desc_ptr
[
group_id
].
StrideE
;
//
const auto stride_ds = gemm_desc_ptr[group_id].StrideDs;
const
auto
stride_ds
=
gemm_desc_ptr
[
group_id
].
StrideDs
;
//
constexpr auto NumDTensor = DsDataType::Size();
constexpr
auto
NumDTensor
=
DsDataType
::
Size
();
//
using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer());
using
DsGridPointer
=
decltype
(
GridwiseGemm
::
MakeDsGridPointer
());
//
DsGridPointer p_ds_grid;
DsGridPointer
p_ds_grid
;
//
static_for<0, NumDTensor, 1>{}([&](auto i) {
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
//
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
//
p_ds_grid(i) = static_cast<const DDataType*>(gemm_desc_ptr[group_id].p_ds_grid[i]);
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
//
});
});
//
gridwise_gemm.template RunWrite(p_ds_grid,
gridwise_gemm
.
template
RunWrite
(
p_ds_grid
,
//
p_e_grid,
p_e_grid
,
//
static_cast<void*>(p_shared),
static_cast
<
void
*
>(
p_shared
),
//
M,
M
,
//
N,
N
,
//
stride_ds,
stride_ds
,
//
stride_e,
stride_e
,
//
cde_element_op,
cde_element_op
,
//
b2c_tile_map);
b2c_tile_map
);
}
}
else
if
(
work_scheduler
.
HasTile
())
else
if
(
work_scheduler
.
HasTile
())
{
{
{
// const uint32_t flag_v2 = __builtin_amdgcn_readfirstlane(
const
uint32_t
flag_v2
=
work_scheduler
.
GetFlagValue
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
if
(
is_thread_local_1d_id_idx
(
0
))
{
printf
(
"bid: %d, group: %d, Waiting for Reduction flag_v[%u]: %u
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
group_id
,
work_scheduler
.
GetWorkgroupFlagIdx
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
),
// static_cast<index_t>(blockIdx.x));
flag_v2
);
}
}
work_scheduler
.
WaitForReduction
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
work_scheduler
.
WaitForReduction
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
}
}
}
while
(
work_scheduler
.
HasTile
());
}
while
(
work_scheduler
.
HasTile
());
...
@@ -839,8 +746,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -839,8 +746,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void
*
p_flags
=
reinterpret_cast
<
char
*>
(
dev_gemm_workspace
)
+
void
*
p_flags
=
reinterpret_cast
<
char
*>
(
dev_gemm_workspace
)
+
Block2ETileMapKSplit
::
GetAccWorkspaceSize
(
Block2ETileMapKSplit
::
GetAccWorkspaceSize
(
sizeof
(
typename
GridwiseGemm
::
AccType
),
grid_size
);
sizeof
(
typename
GridwiseGemm
::
AccType
),
grid_size
);
// std::size_t flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
std
::
size_t
flag_count
=
(
grid_size
*
tiles_per_block
+
arg
.
K_BATCH
-
1
)
/
arg
.
K_BATCH
;
std
::
size_t
flag_count
=
arg
.
tile_count_
/
arg
.
K_BATCH
;
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
...
@@ -1077,13 +983,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -1077,13 +983,11 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
int
grid_size
=
std
::
min
(
arg
.
tile_count_
,
occ_grid_size
);
int
grid_size
=
std
::
min
(
arg
.
tile_count_
,
occ_grid_size
);
int
tiles_per_block
=
(
arg
.
tile_count_
+
grid_size
-
1
)
/
grid_size
;
int
tiles_per_block
=
(
arg
.
tile_count_
+
grid_size
-
1
)
/
grid_size
;
if
(
arg
.
tile_count_
>
occ_grid_size
&&
if
(
arg
.
tile_count_
>
occ_grid_size
&&
grid_size
*
tiles_per_block
>
arg
.
tile_count_
)
grid_size
*
tiles_per_block
>
arg
.
tile_count_
)
{
{
grid_size
=
(
arg
.
tile_count_
+
tiles_per_block
-
1
)
/
tiles_per_block
;
grid_size
=
(
arg
.
tile_count_
+
tiles_per_block
-
1
)
/
tiles_per_block
;
}
}
// int flag_count = (grid_size * tiles_per_block + arg.K_BATCH - 1) / arg.K_BATCH;
int
flag_count
=
(
grid_size
*
tiles_per_block
+
arg
.
K_BATCH
-
1
)
/
arg
.
K_BATCH
;
int
flag_count
=
arg
.
tile_count_
/
arg
.
K_BATCH
;
// This would be the maximum needed workspace size. Since actual grid size, which determines
// This would be the maximum needed workspace size. Since actual grid size, which determines
// the amount of workspace bytes needed, may be less due to the number of available CUs in
// the amount of workspace bytes needed, may be less due to the number of available CUs in
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
1b462ab5
...
@@ -106,13 +106,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -106,13 +106,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
using
GridwiseGemmPipe
=
remove_cvref_t
<
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
template
<
index_t
...
Ids
>
__device__
static
bool
is_thread_local_1d_id_idx
()
{
const
auto
tid
=
get_thread_local_1d_id
();
return
((
tid
==
Ids
)
||
...);
}
public:
public:
using
AccType
=
AccDataType
;
using
AccType
=
AccDataType
;
...
@@ -913,32 +906,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -913,32 +906,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
Sequence
<
6
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
Sequence
<
7
>
{}));
// if (is_thread_local_1d_id_idx<0>())
// {
// // printf("bid: %d; tid: %d; [Store Partials] c_block_desc:[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// // static_cast<index_t>(blockIdx.x),
// // static_cast<index_t>(threadIdx.x),
// // M0.value,
// // N0.value,
// // M1.value,
// // N1.value,
// // M2.value,
// // M3.value,
// // M4.value,
// // N2.value);
// printf("bid: %d; tid: %d; [Store Partials] wrkspace_desc:[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1),
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6).value,
// workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7).value);
// }
auto
p_workspace_grid
=
reinterpret_cast
<
AccDataType
*>
(
p_workspace
);
auto
p_workspace_grid
=
reinterpret_cast
<
AccDataType
*>
(
p_workspace
);
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_workspace_grid
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
p_workspace_grid
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
...
@@ -996,33 +963,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -996,33 +963,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
n_thread_data_on_block_idx
[
I2
]),
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// if (is_thread_local_1d_id_idx<0, 64, 223>())
// {
// printf("[StorePartials] bid: %d, tid: %d: dst origin idx[%d, %d, %d, %d, %d, %d, %d, %d]\n",
// static_cast<index_t>(blockIdx.x),
// static_cast<index_t>(threadIdx.x),
// (static_cast<index_t>(blockIdx.x)) * MXdlPerWave,
// n_thread_data_on_block_idx[I0],
// m_thread_data_on_block_idx[I1],
// n_thread_data_on_block_idx[I1],
// m_thread_data_on_block_idx[I2],
// m_thread_data_on_block_idx[I3],
// m_thread_data_on_block_idx[I4],
// n_thread_data_on_block_idx[I2]);
// }
c_thread_copy_vgpr_to_gmem
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_thread_copy_vgpr_to_gmem
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_thread_buf
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
w_grid_buf
);
w_grid_buf
);
if
(
is_thread_local_1d_id_idx
<
0
>
())
{
printf
(
"[StorePartials] done. bid: %d, tid: %d
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
static_cast
<
index_t
>
(
threadIdx
.
x
));
}
}
}
__device__
void
AccumulatePartials
(
void
*
__restrict__
p_workspace
,
uint32_t
reduce_count
)
__device__
void
AccumulatePartials
(
void
*
__restrict__
p_workspace
,
uint32_t
reduce_count
)
...
@@ -1158,7 +1103,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -1158,7 +1103,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// We do not need to read this workgroup partial results since they're
// We do not need to read this workgroup partial results since they're
// already in c_thread_buff
// already in c_thread_buff
for
(
uint32_t
i_t
=
1
;
i_t
<
reduce_count
;
++
i_t
)
for
(
uint32_t
i_t
=
1
;
i_t
<
=
reduce_count
;
++
i_t
)
{
{
acc_buf
.
Clear
();
acc_buf
.
Clear
();
acc_load
.
Run
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
acc_load
.
Run
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
View file @
1b462ab5
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -16,96 +16,96 @@ namespace tensor_operation {
...
@@ -16,96 +16,96 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
//
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
//
Row,
Row
,
//
Empty_Tuple,
Empty_Tuple
,
//
Row,
Row
,
//
F16,
F16
,
//
F16,
F16
,
//
Empty_Tuple,
Empty_Tuple
,
//
F16,
F16
,
//
PassThrough,
PassThrough
,
//
PassThrough,
PassThrough
,
//
PassThrough>>>& instances);
PassThrough
>>>&
instances
);
//
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
void
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
//
Col,
Col
,
//
Empty_Tuple,
Empty_Tuple
,
//
Row,
Row
,
//
F16,
F16
,
//
F16,
F16
,
//
Empty_Tuple,
Empty_Tuple
,
//
F16,
F16
,
//
PassThrough,
PassThrough
,
//
PassThrough,
PassThrough
,
//
PassThrough>>>& instances);
PassThrough
>>>&
instances
);
//
void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
//
Row,
Row
,
//
Empty_Tuple,
Empty_Tuple
,
//
Row,
Row
,
//
F16,
F16
,
//
F16,
F16
,
//
Empty_Tuple,
Empty_Tuple
,
//
F16,
F16
,
//
PassThrough,
PassThrough
,
//
PassThrough,
PassThrough
,
//
PassThrough>>>& instances);
PassThrough
>>>&
instances
);
//
void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
void
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Col
,
//
Col,
Col
,
//
Empty_Tuple,
Empty_Tuple
,
//
Row,
Row
,
//
F16,
F16
,
//
F16,
F16
,
//
Empty_Tuple,
Empty_Tuple
,
//
F16,
F16
,
//
PassThrough,
PassThrough
,
//
PassThrough,
PassThrough
,
//
PassThrough>>>& instances);
PassThrough
>>>&
instances
);
//
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
//
Col,
Col
,
//
Empty_Tuple,
Empty_Tuple
,
//
Row,
Row
,
//
F16,
F16
,
//
F16,
F16
,
//
Empty_Tuple,
Empty_Tuple
,
//
F16,
F16
,
//
PassThrough,
PassThrough
,
//
PassThrough,
PassThrough
,
//
PassThrough>>>& instances);
PassThrough
>>>&
instances
);
//
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
//
Row,
Row
,
//
Empty_Tuple,
Empty_Tuple
,
//
Row,
Row
,
//
F16,
F16
,
//
F16,
F16
,
//
Empty_Tuple,
Empty_Tuple
,
//
F16,
F16
,
//
PassThrough,
PassThrough
,
//
PassThrough,
PassThrough
,
//
PassThrough>>>& instances);
PassThrough
>>>&
instances
);
//
void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
//
Col,
Col
,
//
Empty_Tuple,
Empty_Tuple
,
//
Row,
Row
,
//
F16,
F16
,
//
F16,
F16
,
//
Empty_Tuple,
Empty_Tuple
,
//
F16,
F16
,
//
PassThrough,
PassThrough
,
//
PassThrough,
PassThrough
,
//
PassThrough>>>& instances);
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
void
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
...
@@ -120,31 +120,31 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
...
@@ -120,31 +120,31 @@ void add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
//
void add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(
void
add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances
(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
//
Row,
Row
,
//
Empty_Tuple,
Empty_Tuple
,
//
Row,
Row
,
//
F16,
F16
,
//
F8,
F8
,
//
Empty_Tuple,
Empty_Tuple
,
//
F16,
F16
,
//
PassThrough,
PassThrough
,
//
PassThrough,
PassThrough
,
//
PassThrough>>>& instances);
PassThrough
>>>&
instances
);
//
void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(
void
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances
(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
//
Row,
Row
,
//
Empty_Tuple,
Empty_Tuple
,
//
Row,
Row
,
//
F8,
F8
,
//
F16,
F16
,
//
Empty_Tuple,
Empty_Tuple
,
//
F16,
F16
,
//
PassThrough,
PassThrough
,
//
PassThrough,
PassThrough
,
//
PassThrough>>>& instances);
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
@@ -186,48 +186,48 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -186,48 +186,48 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
is_same_v
<
ELayout
,
Row
>
)
{
{
//
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
//
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
is_same_v
<
ELayout
,
Row
>
)
{
{
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
// add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances(
add_device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instances
(
// op_ptrs);
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
}
}
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
// }
// else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
// }
}
}
// else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
// is_same_v<EDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instances(op_ptrs);
// }
// }
// else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
// is_same_v<EDataType, half_t>)
// {
// if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
// is_same_v<ELayout, Row>)
// {
// add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
// }
// }
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multiple_d.hpp
View file @
1b462ab5
...
@@ -17,18 +17,18 @@ namespace device {
...
@@ -17,18 +17,18 @@ namespace device {
namespace
instance
{
namespace
instance
{
// MultiD version
// MultiD version
//
void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
void
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances
(
//
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
//
Col,
Col
,
//
Empty_Tuple,
Empty_Tuple
,
//
Row,
Row
,
//
F16,
F16
,
//
F16,
F16
,
//
Empty_Tuple,
Empty_Tuple
,
//
F16,
F16
,
//
PassThrough,
PassThrough
,
//
PassThrough,
PassThrough
,
//
PassThrough>>>& instances);
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances
(
void
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
...
@@ -93,8 +93,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -93,8 +93,8 @@ struct DeviceOperationInstanceFactory<
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
is_same_v
<
ELayout
,
Row
>
)
{
{
//
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances(
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instances
(
//
op_ptrs);
op_ptrs
);
}
}
}
}
return
op_ptrs
;
return
op_ptrs
;
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
View file @
1b462ab5
add_instance_library
(
device_grouped_gemm_instance
add_instance_library
(
device_grouped_gemm_instance
#
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
#
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
#
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp
#
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp
#
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
#
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
#
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
#
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
#
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
)
)
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/CMakeLists.txt
View file @
1b462ab5
add_instance_library
(
device_grouped_gemm_multiple_d_instance
add_instance_library
(
device_grouped_gemm_multiple_d_instance
#
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
)
)
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
View file @
1b462ab5
This diff is collapsed.
Click to expand it.
profiler/include/profiler/profile_grouped_gemm_multiple_d_splitk_impl.hpp
View file @
1b462ab5
...
@@ -39,7 +39,9 @@ bool profile_ggemm_multid_splitk(int do_verification,
...
@@ -39,7 +39,9 @@ bool profile_ggemm_multid_splitk(int do_verification,
const
std
::
vector
<
int
>&
StrideAs
,
const
std
::
vector
<
int
>&
StrideAs
,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideBs
,
const
std
::
vector
<
int
>&
StrideCs
,
const
std
::
vector
<
int
>&
StrideCs
,
int
kbatch
=
1
)
int
kbatch
=
1
,
int
warmup_iter
=
1
,
int
kernel_iter
=
10
)
{
{
bool
pass
=
true
;
bool
pass
=
true
;
...
@@ -250,23 +252,18 @@ bool profile_ggemm_multid_splitk(int do_verification,
...
@@ -250,23 +252,18 @@ bool profile_ggemm_multid_splitk(int do_verification,
for
(
std
::
size_t
j
=
0
;
j
<
kbatch_list
.
size
();
j
++
)
for
(
std
::
size_t
j
=
0
;
j
<
kbatch_list
.
size
();
j
++
)
{
{
auto
kbatch_curr
=
kbatch_list
[
j
];
auto
kbatch_curr
=
kbatch_list
[
j
];
// std::cout << ">>> kbatch: " << kbatch_curr << std::endl;
gptr
->
SetKBatchSize
(
argument_ptr
.
get
(),
kbatch_curr
);
gptr
->
SetKBatchSize
(
argument_ptr
.
get
(),
kbatch_curr
);
DeviceMem
gemm_desc_workspace
(
gemm_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
DeviceMem
gemm_desc_workspace
(
gemm_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
()));
gemm_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
gemm_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
gemm_desc_workspace
.
GetDeviceBuffer
());
gemm_desc_workspace
.
GetDeviceBuffer
());
// std::cout << "WorkspacePointer set!" << std::endl;
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
gemm_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
c_device_buf
[
i
]
->
SetZero
();
c_device_buf
[
i
]
->
SetZero
();
// invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false, 1});
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
// std::cout << ">>>>>GPU Run end!" << std::endl;
if
(
do_verification
)
if
(
do_verification
)
{
{
...
@@ -313,13 +310,12 @@ bool profile_ggemm_multid_splitk(int do_verification,
...
@@ -313,13 +310,12 @@ bool profile_ggemm_multid_splitk(int do_verification,
std
::
cout
<<
">>>>>CPU verification end!"
<<
std
::
endl
;
std
::
cout
<<
">>>>>CPU verification end!"
<<
std
::
endl
;
}
}
if
(
time_kernel
)
if
(
time_kernel
)
{
{
std
::
cout
<<
">>>>>GPU time profiling start!"
<<
std
::
endl
;
std
::
cout
<<
">>>>>GPU time profiling start!"
<<
std
::
endl
;
float
avg_time
=
invoker_ptr
->
Run
(
float
avg_time
=
invoker_ptr
->
Run
(
//
argument_ptr.get(),
StreamConfig{nullptr, time_kernel, 1, 5, 30});
argument_ptr
.
get
(),
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
,
1
,
0
,
1
});
StreamConfig
{
nullptr
,
time_kernel
,
0
,
warmup_iter
,
kernel_iter
});
std
::
size_t
flop
=
0
,
num_btype
=
0
;
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
i
++
)
{
{
...
...
profiler/src/profile_grouped_gemm_multiple_d_splitk.cpp
View file @
1b462ab5
...
@@ -70,6 +70,8 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
...
@@ -70,6 +70,8 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
<<
"arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
<<
"arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)
\n
"
"64,64 64,64 128,128)
\n
"
<<
"arg15: kbatch value (default 4)
\n
"
<<
"arg15: kbatch value (default 4)
\n
"
<<
"arg16: warm-up iterations (default 1)
\n
"
<<
"arg17: kernel repeat iterations (default 10)
\n
"
<<
std
::
endl
;
<<
std
::
endl
;
exit
(
1
);
exit
(
1
);
...
@@ -90,6 +92,8 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
...
@@ -90,6 +92,8 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
const
auto
StrideBs
=
argToIntArray
(
argv
[
12
]);
const
auto
StrideBs
=
argToIntArray
(
argv
[
12
]);
const
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
int
kbatch
=
argc
==
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
const
int
kbatch
=
argc
==
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
const
int
warmup_iter
=
argc
==
16
?
std
::
stoi
(
argv
[
15
])
:
1
;
const
int
kernel_iter
=
argc
==
17
?
std
::
stoi
(
argv
[
16
])
:
10
;
#ifdef CK_ENABLE_FP16
#ifdef CK_ENABLE_FP16
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
...
@@ -110,7 +114,9 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
...
@@ -110,7 +114,9 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
StrideAs
,
StrideAs
,
StrideBs
,
StrideBs
,
StrideCs
,
StrideCs
,
kbatch
);
kbatch
,
warmup_iter
,
kernel_iter
);
}
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
{
...
@@ -131,7 +137,9 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
...
@@ -131,7 +137,9 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
StrideAs
,
StrideAs
,
StrideBs
,
StrideBs
,
StrideCs
,
StrideCs
,
kbatch
);
kbatch
,
warmup_iter
,
kernel_iter
);
}
}
else
else
{
{
...
...
test/work_scheduling/test_strided_reduction_tile_loop.cpp
View file @
1b462ab5
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <memory>
#include <memory>
#include <vector>
#include <vector>
...
@@ -150,18 +150,16 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
...
@@ -150,18 +150,16 @@ __global__ void grouped_gemm_naive_strided_tile_loop_reduce(const GemmArgDesc* p
if
(
b2c_tile_map
.
IsFirstKSplitBlock
())
if
(
b2c_tile_map
.
IsFirstKSplitBlock
())
{
{
// Wait untill all other blocks for this [M,N] tile store their results.
// Wait untill all other blocks for this [M,N] tile store their results.
work_scheduler
.
WaitForNeighbours
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
index_t
neighbour_count
=
work_scheduler
.
WaitForNeighbours
(
k_batch
,
b2c_tile_map
.
GetTileKIdx
(),
output_tile_idx
,
output_tile_idx_offset
);
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// Accumulate partial results. We can have different # of workgroups to reduce, thus we
// read actual flag value.
// read actual flag value.
const
uint32_t
flag_v
=
__builtin_amdgcn_readfirstlane
(
for
(
index_t
i
=
1
;
i
<=
neighbour_count
;
++
i
)
work_scheduler
.
GetFlagValue
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
));
for
(
uint32_t
i
=
1
;
i
<
flag_v
;
++
i
)
{
{
partial_result
+=
p_workspace
[(
get_block_1d_id
())
*
MPerBlock
*
NPerBlock
+
partial_result
+=
p_workspace
[(
get_block_1d_id
())
*
MPerBlock
*
NPerBlock
+
i
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()];
i
*
MPerBlock
*
NPerBlock
+
get_thread_local_1d_id
()];
}
}
// Signal waiting blocks that they can start use their workspace.
// Signal waiting blocks that they can start use their workspace.
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
work_scheduler
.
Reset
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
...
@@ -284,11 +282,10 @@ struct GroupedGemmStridedTileLoopReduce
...
@@ -284,11 +282,10 @@ struct GroupedGemmStridedTileLoopReduce
DeviceMem
gemm_workspace
,
gemm_flags
;
DeviceMem
gemm_workspace
,
gemm_flags
;
//
const index_t tiles_per_block = (tile_count + grid_size - 1) / grid_size;
const
index_t
tiles_per_block
=
(
tile_count
+
grid_size
-
1
)
/
grid_size
;
// This is the number of MN-output tiles which we cover with workgroups.
// This is the number of MN-output tiles which we cover with workgroups.
// We launch k_batch / tiles_per_block workgroups for each output tile.
// We launch k_batch / tiles_per_block workgroups for each output tile.
// const index_t flag_count = (grid_size * tiles_per_block + k_batch - 1) / k_batch;
const
index_t
flag_count
=
(
grid_size
*
tiles_per_block
+
k_batch
-
1
)
/
k_batch
;
const
index_t
flag_count
=
tile_count
/
k_batch
;
gemm_workspace
.
Realloc
(
grid_size
*
MPerBlock
*
NPerBlock
*
sizeof
(
float
));
gemm_workspace
.
Realloc
(
grid_size
*
MPerBlock
*
NPerBlock
*
sizeof
(
float
));
gemm_flags
.
Realloc
(
flag_count
*
sizeof
(
uint32_t
));
gemm_flags
.
Realloc
(
flag_count
*
sizeof
(
uint32_t
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment