Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
593c2909
Commit
593c2909
authored
Jun 30, 2023
by
Jing Zhang
Committed by
root
Jun 30, 2023
Browse files
add simple kernel arg
parent
6819fc4c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
87 additions
and
46 deletions
+87
-46
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+87
-46
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
593c2909
...
@@ -31,18 +31,20 @@ __global__ void
...
@@ -31,18 +31,20 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
kernel_grouped_gemm_xdl_splitk
(
const
void
*
gemm_desc_const
,
const
index_t
group_count
)
const
index_t
group_count
,
const
index_t
k_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
gemm_desc_const
);
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
#if 0
index_t left = 0;
index_t left = 0;
index_t right = group_count;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
index_t group_id = index_t((left + right) / 2);
...
@@ -60,18 +62,35 @@ __global__ void
...
@@ -60,18 +62,35 @@ __global__ void
}
}
group_id = index_t((left + right) / 2);
group_id = index_t((left + right) / 2);
}
}
#else
if
(
block_id
>=
gemm_desc_ptr
[
group_count
-
1
].
block_end_
)
return
;
index_t
group_id
=
0
;
for
(;
group_id
<
group_count
;
group_id
++
)
{
if
(
block_id
>=
gemm_desc_ptr
[
group_id
].
block_start_
&&
block_id
<
gemm_desc_ptr
[
group_id
].
block_end_
)
{
break
;
}
}
#endif
const
auto
M
=
gemm_desc_ptr
[
group_id
].
karg_
.
M
;
const
auto
p_a_grid
=
gemm_desc_ptr
[
group_id
].
p_a_grid
;
const
auto
N
=
gemm_desc_ptr
[
group_id
].
karg_
.
N
;
const
auto
p_b_grid
=
gemm_desc_ptr
[
group_id
].
p_b_grid
;
const
auto
K
=
gemm_desc_ptr
[
group_id
].
karg_
.
K
;
const
auto
p_c_grid
=
gemm_desc_ptr
[
group_id
].
p_c_grid
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideA
;
const
auto
M
=
gemm_desc_ptr
[
group_id
].
M
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideB
;
const
auto
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
auto
StrideC
=
gemm_desc_ptr
[
group_id
].
karg_
.
StrideC
;
const
auto
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
auto
MPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
MPadded
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
NPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
NPadded
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
const
auto
KPadded
=
gemm_desc_ptr
[
group_id
].
karg_
.
KPadded
;
const
auto
StrideC
=
gemm_desc_ptr
[
group_id
].
StrideC
;
const
auto
K0
=
gemm_desc_ptr
[
group_id
].
karg_
.
K0
;
const
auto
k_batch
=
gemm_desc_ptr
[
group_id
].
karg_
.
k_batch
;
const
auto
MPadded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
auto
NPadded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
auto
KPadded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
k_batch
);
const
auto
K0
=
GridwiseGemm
::
CalculateK0
(
K
,
k_batch
);
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
...
@@ -88,9 +107,9 @@ __global__ void
...
@@ -88,9 +107,9 @@ __global__ void
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
gemm_desc_ptr
[
group_id
].
block_start_
);
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
gemm_desc_ptr
[
group_id
].
block_start_
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
.
p_a_grid
,
p_a_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_b_grid
,
p_b_grid
,
gemm_desc_ptr
[
group_id
].
karg_
.
p_c_grid
,
p_c_grid
,
M
,
M
,
N
,
N
,
K
,
K
,
...
@@ -277,20 +296,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -277,20 +296,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
gemm_kernel_args_
.
reserve
(
group_count_
);
gemm_kernel_args_
.
reserve
(
group_count_
);
skipped_group_count_
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
{
{
const
index_t
M
=
gemm_descs
[
i
].
M_
;
const
index_t
M
=
gemm_descs
[
i
].
M_
;
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
if
(
M
==
0
)
{
skipped_group_count_
++
;
continue
;
}
const
index_t
stride_a
=
gemm_descs
[
i
].
stride_A_
;
const
index_t
stride_a
=
gemm_descs
[
i
].
stride_A_
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
...
@@ -379,7 +390,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -379,7 +390,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// private:
// private:
index_t
K_BATCH
;
index_t
K_BATCH
;
index_t
group_count_
;
index_t
group_count_
;
index_t
skipped_group_count_
;
std
::
vector
<
GemmTransKernelArg
>
gemm_kernel_args_
;
std
::
vector
<
GemmTransKernelArg
>
gemm_kernel_args_
;
index_t
grid_size_
;
index_t
grid_size_
;
...
@@ -388,8 +398,28 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -388,8 +398,28 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
struct
SimpleGemmArgument
{
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
EDataType
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
block_start_
;
index_t
block_end_
;
};
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
std
::
vector
<
SimpleGemmArgument
>
simple_gemm_kernel_args_
;
simple_gemm_kernel_args_
.
reserve
(
arg
.
gemm_kernel_args_
.
size
());
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K0
;
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K0
;
bool
all_have_kbatch_gt_one
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
>
1
;
bool
all_have_kbatch_gt_one
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
>
1
;
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
...
@@ -434,12 +464,26 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -434,12 +464,26 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
<<
" in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
<<
" in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
simple_gemm_kernel_args_
.
push_back
({
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
M
,
karg
.
N
,
karg
.
K
,
karg
.
StrideA
,
karg
.
StrideB
,
karg
.
StrideC
,
arg
.
gemm_kernel_args_
[
i
].
block_start_
,
arg
.
gemm_kernel_args_
[
i
].
block_end_
});
}
}
using
GemmArgumentType
=
SimpleGemmArgument
;
hip_check_error
(
hip_check_error
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
hipMemcpyWithStream
(
arg
.
p_workspace_
,
arg
.
gemm_kernel_args_
.
data
(),
simple_
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
Gemm
TransKernelArg
),
simple_
gemm_kernel_args_
.
size
()
*
sizeof
(
Gemm
ArgumentType
),
hipMemcpyHostToDevice
,
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
stream_config
.
stream_id_
));
...
@@ -456,14 +500,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -456,14 +500,14 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
}
}
}
ave_time
=
ave_time
=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
arg
.
grid_size_
)
,
dim3
(
arg
.
grid_s
ize
_
),
dim3
(
BlockS
ize
),
dim3
(
BlockSize
)
,
0
,
0
,
arg
.
p_workspace_
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
gemm_kernel_args_
.
size
(
),
arg
.
gemm_kernel_args_
.
size
()
);
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
);
};
};
if
(
all_have_main_k0_block_loop
)
if
(
all_have_main_k0_block_loop
)
...
@@ -472,7 +516,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -472,7 +516,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
Gemm
TransKernelArg
,
Gemm
ArgumentType
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
InMemoryDataOperationEnum
::
AtomicAdd
>
;
...
@@ -482,7 +526,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -482,7 +526,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
Gemm
TransKernelArg
,
Gemm
ArgumentType
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
>
;
InMemoryDataOperationEnum
::
Set
>
;
...
@@ -495,7 +539,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -495,7 +539,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
Gemm
TransKernelArg
,
Gemm
ArgumentType
,
false
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
InMemoryDataOperationEnum
::
AtomicAdd
>
;
...
@@ -505,7 +549,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -505,7 +549,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
Gemm
TransKernelArg
,
Gemm
ArgumentType
,
false
,
false
,
InMemoryDataOperationEnum
::
Set
>
;
InMemoryDataOperationEnum
::
Set
>
;
...
@@ -532,13 +576,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -532,13 +576,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
if
(
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
!=
arg
.
group_count_
)
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
{
#if DEBUG_LOG
#if DEBUG_LOG
std
::
cout
<<
"The group count is not equal to sum of skipped groups "
std
::
cout
<<
"The group count is not equal to kernel args size!"
<<
std
::
endl
;
"and kernel args size!"
<<
std
::
endl
;
#endif // DEBUG_LOG
#endif // DEBUG_LOG
return
false
;
return
false
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment