Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
69add6ff
"cacheflow/vscode:/vscode.git/clone" did not exist on "376725ce74d2d75490eed1840b41de00c0e4acd6"
Commit
69add6ff
authored
May 11, 2022
by
Jing Zhang
Browse files
moved gemm_descs_args into const buff
parent
76764d8c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
74 additions
and
91 deletions
+74
-91
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+74
-91
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
69add6ff
...
...
@@ -17,6 +17,10 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
#define CK_GEMM_DESCS_CONSTANT_BUFF_SIZE 1048576 // 1MB for 1000 gemm_descs
__constant__
static
char
gemm_descs_const_
[
CK_GEMM_DESCS_CONSTANT_BUFF_SIZE
];
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
...
...
@@ -24,15 +28,12 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainKBlockLoop
,
index_t
MaxGroupCount
>
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdlops_v2r3
(
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_descs
,
const
index_t
group_count
,
kernel_grouped_gemm_xdlops_v2r3
(
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
...
...
@@ -42,39 +43,16 @@ __global__ void
const
index_t
block_id
=
get_block_1d_id
();
#if 1
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
block_id
>=
gemm_descs
[
i
].
BlockStart_
&&
block_id
<
gemm_descs
[
i
].
BlockEnd_
&&
i
<
group_count
)
{
auto
group_id
=
i
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
gemm_descs
[
group_id
].
a_ptr
,
gemm_descs
[
group_id
].
b_ptr
,
gemm_descs
[
group_id
].
c_ptr
,
p_shared
,
gemm_descs
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_descs
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_descs
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_descs
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
}
});
#else
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_descs
);
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
gemm_descs_const_
);
index_t
group_id
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
group_id
=
(
block_id
>=
gemm_descs
[
i
].
BlockStart
&&
block_id
<
gemm_descs
[
i
].
BlockEnd
&&
i
<
group_count
)
for
(
index_t
i
=
0
;
i
<
group_count
;
i
++
)
{
group_id
=
(
block_id
>=
gemm_desc_ptr
[
i
].
BlockStart_
&&
block_id
<
gemm_desc_ptr
[
i
].
BlockEnd_
)
?
i
:
group_id
;
});
const
index_t
block_id_grp
=
block_id
-
gemm_desc_ptr
[
group_id
].
BlockStart
;
}
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
gemm_desc_ptr
[
group_id
].
a_ptr
,
...
...
@@ -87,9 +65,7 @@ __global__ void
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
block_id_grp
);
#endif
gemm_desc_ptr
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
#else
ignore
=
gemm_descs
;
ignore
=
group_count
;
...
...
@@ -451,33 +427,28 @@ struct DeviceGroupedGemmXdl
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
StaticallyIndexedArray
<
GemmDescKernelArg
,
MaxGroupCount
>
gemm_desc_kernel_args
;
bool
has_main_k_block_loop
=
true
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
i
<
arg
.
gemm_desc_kernel_arg_
.
size
())
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
gemm_desc_kernel_args
(
i
)
=
arg
.
gemm_desc_kernel_arg_
[
i
];
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
<<
gemm_desc_kernel_arg
s
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg
s
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_arg
s
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
;
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.b_grid_desc_k0_n_k1_{"
<<
gemm_desc_kernel_arg
s
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg
s
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_arg
s
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
;
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.c_grid_desc_m_n_{ "
<<
gemm_desc_kernel_arg
s
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg
s
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_kernel_arg
_
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
gemm_desc_kernel_arg
s
[
i
].
a_grid_desc_k0_m_k1_
,
gemm_desc_kernel_arg
s
[
i
].
b_grid_desc_k0_n_k1_
,
gemm_desc_kernel_arg
s
[
i
].
c_grid_desc_m_n_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_kernel_arg
_
[
i
].
a_grid_desc_k0_m_k1_
,
arg
.
gemm_desc_kernel_arg
_
[
i
].
b_grid_desc_k0_n_k1_
,
arg
.
gemm_desc_kernel_arg
_
[
i
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
))
{
...
...
@@ -485,15 +456,27 @@ struct DeviceGroupedGemmXdl
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
const
auto
K
=
gemm_desc_kernel_arg
s
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
gemm_desc_kernel_arg
s
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
auto
K
=
arg
.
gemm_desc_kernel_arg
_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
gemm_desc_kernel_arg
_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
)
!=
has_main_k_block_loop
)
{
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
}
}
});
if
(
sizeof
(
GemmDescKernelArg
)
*
arg
.
gemm_desc_kernel_arg_
.
size
()
>
CK_GEMM_DESCS_CONSTANT_BUFF_SIZE
)
{
throw
std
::
runtime_error
(
"wrong! too many gemms"
);
}
hipGetErrorString
(
hipMemcpyToSymbol
(
HIP_SYMBOL
(
gemm_descs_const_
),
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
),
0
,
hipMemcpyHostToDevice
));
float
ave_time
=
0
;
...
...
@@ -503,19 +486,17 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
GemmDescKernelArg
>
,
GemmDescKernelArg
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
true
,
MaxGroupCount
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
gemm_desc_kernel_args
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
...
...
@@ -527,19 +508,17 @@ struct DeviceGroupedGemmXdl
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
GemmDescKernelArg
>
,
GemmDescKernelArg
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
false
,
MaxGroupCount
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
gemm_desc_kernel_args
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
...
...
@@ -566,7 +545,11 @@ struct DeviceGroupedGemmXdl
{
if
(
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_desc_kernel_arg_
.
size
())
!=
arg
.
group_count_
)
return
false
;
else
if
(
sizeof
(
GemmDescKernelArg
)
*
arg
.
gemm_desc_kernel_arg_
.
size
()
>
CK_GEMM_DESCS_CONSTANT_BUFF_SIZE
)
return
false
;
return
true
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment