Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9177a207
Commit
9177a207
authored
Jun 04, 2024
by
Adam Osewski
Browse files
Fix joining kbatch-tiles.
parent
eaa68635
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
33 additions
and
134 deletions
+33
-134
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+33
-134
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
9177a207
...
@@ -43,8 +43,6 @@ namespace device {
...
@@ -43,8 +43,6 @@ namespace device {
/// @tparam FloatC Input tensor C elements' data type.
/// @tparam FloatC Input tensor C elements' data type.
/// @tparam Block2ETileMapKSplit The structure providing mapping between workgroup ids,
/// @tparam Block2ETileMapKSplit The structure providing mapping between workgroup ids,
/// the data tiles to process and the output tiles.
/// the data tiles to process and the output tiles.
/// @tparam HasMainKBlockLoop Flag indicating whether all GEMM problem configurations
/// need to loop over tiles in K dimension.
///
///
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
typename
GemmDesc
,
...
@@ -55,8 +53,7 @@ template <typename GridwiseGemm,
...
@@ -55,8 +53,7 @@ template <typename GridwiseGemm,
typename
Block2ETileMapKSplit
,
typename
Block2ETileMapKSplit
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
CDEElementwiseOperation
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -147,20 +144,20 @@ __global__ void
...
@@ -147,20 +144,20 @@ __global__ void
// k_tiles);
// k_tiles);
// }
// }
// just accumulate results in registers!
// just accumulate results in registers!
GridwiseGemm
::
template
RunGEMM
<
HasMainKBlockLoop
>
(
p_a_grid
,
GridwiseGemm
::
template
RunGEMM
(
p_a_grid
,
p_b_grid
,
p_b_grid
,
static_cast
<
void
*>
(
p_shared
),
static_cast
<
void
*
>(
p_shared
),
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
M
,
M
,
N
,
N
,
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
k_batch
,
k_batch
,
b2c_tile_map
,
b2c_tile_map
,
results_buffer
,
results_buffer
,
k_tiles
);
k_tiles
);
// Move to the last processed k-tile
// Move to the last processed k-tile
b2c_tile_map
.
AdvanceTileKIdx
(
k_tiles
-
1
);
b2c_tile_map
.
AdvanceTileKIdx
(
k_tiles
-
1
);
...
@@ -175,6 +172,7 @@ __global__ void
...
@@ -175,6 +172,7 @@ __global__ void
GridwiseGemm
::
StorePartials
(
p_workspace
,
static_cast
<
void
*>
(
p_shared
),
results_buffer
);
GridwiseGemm
::
StorePartials
(
p_workspace
,
static_cast
<
void
*>
(
p_shared
),
results_buffer
);
#if 1
#if 1
__builtin_amdgcn_sched_barrier
(
0
);
// make sure all writes to gmem has finished.
// make sure all writes to gmem has finished.
__builtin_amdgcn_s_waitcnt
(
0x0f70
);
// s_waitcnt vmcnt(0)
__builtin_amdgcn_s_waitcnt
(
0x0f70
);
// s_waitcnt vmcnt(0)
// __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0)
// __builtin_amdgcn_s_waitcnt(0x0070); // s_waitcnt vmcnt(0) lgkmcnt(0)
...
@@ -510,73 +508,18 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -510,73 +508,18 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void
UpdateOccupancy
()
void
UpdateOccupancy
()
{
{
bool
all_have_main_k_block_loop
;
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk_v2
<
GridwiseGemm
,
{
KernelArguments
,
const
auto
a_grid_desc_ak0_m_ak1
=
ADataType
,
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
gemm_kernel_args_
[
0
].
M
,
BDataType
,
gemm_kernel_args_
[
0
].
K
,
EDataType
,
gemm_kernel_args_
[
0
].
StrideA
,
DsDataType
,
K_BATCH
);
Block2ETileMapKSplit
,
AElementwiseOperation
,
all_have_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
BElementwiseOperation
,
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)
/
CDEElementwiseOperation
>
;
K_BATCH
);
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
}
&
occupancy_num_blocks_
,
kernel
,
BlockSize
,
0
));
for
(
std
::
size_t
i
=
0
;
i
<
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
gemm_arg
=
gemm_kernel_args_
[
i
];
auto
kbatch
=
K_BATCH
;
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
gemm_arg
.
M
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
kbatch
);
bool
not_all_have_main_k_block_loop_same
=
all_have_main_k_block_loop
xor
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)
/
K_BATCH
);
if
(
not_all_have_main_k_block_loop_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same value for main_k0_block_loop! in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
if
(
all_have_main_k_block_loop
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk_v2
<
GridwiseGemm
,
KernelArguments
,
ADataType
,
BDataType
,
EDataType
,
DsDataType
,
Block2ETileMapKSplit
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
true
>
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy_num_blocks_
,
kernel
,
BlockSize
,
0
));
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk_v2
<
GridwiseGemm
,
KernelArguments
,
ADataType
,
BDataType
,
EDataType
,
DsDataType
,
Block2ETileMapKSplit
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
false
>
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy_num_blocks_
,
kernel
,
BlockSize
,
0
));
}
}
}
// private:
// private:
...
@@ -631,8 +574,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -631,8 +574,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
void
*
dev_gemm_workspace
,
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
[[
maybe_unused
]]
auto
[
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
]
=
CheckArgument
(
arg
,
stream_config
);
CheckArgument
(
arg
,
stream_config
);
if
(
dev_gemm_args
==
nullptr
)
if
(
dev_gemm_args
==
nullptr
)
{
{
...
@@ -650,18 +592,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -650,18 +592,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
float
ave_time
=
0
;
float
ave_time
=
DispatchKernel
(
arg
,
dev_gemm_args
,
dev_gemm_workspace
,
stream_config
);
if
(
all_have_main_k_block_loop
)
{
ave_time
=
DispatchKernel
<
true
>
(
arg
,
dev_gemm_args
,
dev_gemm_workspace
,
stream_config
);
}
else
{
ave_time
=
DispatchKernel
<
false
>
(
arg
,
dev_gemm_args
,
dev_gemm_workspace
,
stream_config
);
}
return
ave_time
;
return
ave_time
;
}
}
...
@@ -708,22 +639,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -708,22 +639,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
}
}
private:
private:
auto
CheckArgument
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
)
const
void
CheckArgument
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
)
const
{
{
bool
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
;
bool
all_have_kbatch_gt_one
=
arg
.
K_BATCH
>
1
;
{
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
arg
.
gemm_kernel_args_
[
0
].
M
,
arg
.
gemm_kernel_args_
[
0
].
K
,
arg
.
gemm_kernel_args_
[
0
].
StrideA
,
arg
.
K_BATCH
);
all_have_kbatch_gt_one
=
arg
.
K_BATCH
>
1
;
all_have_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
/
kbatch
);
}
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
{
...
@@ -751,24 +669,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -751,24 +669,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
gemm_arg
.
M
,
gemm_arg
.
K
,
gemm_arg
.
StrideA
,
kbatch
);
bool
not_all_have_main_k_block_loop_same
=
all_have_main_k_block_loop
xor
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)
/
kbatch
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
if
(
not_all_have_main_k_block_loop_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same value for main_k0_block_loop! in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
not_all_have_kbatch_value_same
)
if
(
not_all_have_kbatch_value_same
)
{
{
std
::
ostringstream
err
;
std
::
ostringstream
err
;
...
@@ -779,10 +681,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -779,10 +681,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
}
}
return
std
::
make_tuple
(
all_have_kbatch_gt_one
,
all_have_main_k_block_loop
);
}
}
template
<
bool
HasMainKBlockLoop
>
float
DispatchKernel
(
const
Argument
&
arg
,
float
DispatchKernel
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_workspace
,
void
*
dev_gemm_workspace
,
...
@@ -797,8 +697,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
...
@@ -797,8 +697,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
Block2ETileMapKSplit
,
Block2ETileMapKSplit
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
CDEElementwiseOperation
>
;
HasMainKBlockLoop
>
;
return
LaunchKernel
(
kernel
,
arg
,
dev_gemm_args
,
dev_gemm_workspace
,
stream_config
);
return
LaunchKernel
(
kernel
,
arg
,
dev_gemm_args
,
dev_gemm_workspace
,
stream_config
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment