Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
55ab4687
Commit
55ab4687
authored
Mar 05, 2022
by
Jing Zhang
Browse files
perf test
parent
bbe5c0c7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
156 additions
and
91 deletions
+156
-91
composable_kernel/include/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
...de/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
+120
-58
device_operation/include/device_grouped_gemm_xdl.hpp
device_operation/include/device_grouped_gemm_xdl.hpp
+17
-28
example/12_grouped_gemm_xdl/grouped_gemm_xdl.cpp
example/12_grouped_gemm_xdl/grouped_gemm_xdl.cpp
+19
-5
No files found.
composable_kernel/include/tensor_operation/gridwise_grouped_gemm_xdlops_v2r3.hpp
View file @
55ab4687
...
...
@@ -29,19 +29,20 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_v2r3
(
kernel_
grouped_
gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
GemmDesc
gemm_shapes
,
const
StaticallyIndexedArray
<
AGridDesc_K0_M_K1
,
MaxGroupCount
>
a_grid_desc_k0_m_k1
,
const
StaticallyIndexedArray
<
BGridDesc_K0_N_K1
,
MaxGroupCount
>
b_grid_desc_k0_n_k1
,
const
StaticallyIndexedArray
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
MaxGroupCount
>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_shapes
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
StaticallyIndexedArray
<
Block2CTileMap
,
MaxGroupCount
>
block_2_ctile_map
)
{
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -54,60 +55,122 @@ __global__ void
index_t
c_offset_grp
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
i
<
group_count
)
if
(
block_id
>=
gemm_shapes
[
i
].
BlockStart
&&
block_id
<
(
gemm_shapes
[
i
].
BlockStart
+
gemm_shapes
[
i
].
BlockSize
))
{
if
(
block_id
>=
gemm_shapes
[
i
].
BlockStart
&&
block_id
<
(
gemm_shapes
[
i
].
BlockStart
+
gemm_shapes
[
i
].
BlockSize
))
{
group_id
=
i
;
block_id_grp
=
block_id
-
gemm_shapes
[
i
].
BlockStart
;
a_offset_grp
=
gemm_shapes
[
i
].
OffsetA
;
b_offset_grp
=
gemm_shapes
[
i
].
OffsetB
;
c_offset_grp
=
gemm_shapes
[
i
].
OffsetC
;
// if(get_thread_local_1d_id() == 0)
// printf("%d %d %d %d %d %d\n",
// block_id,
// group_id,
// block_id_grp,
// a_offset_grp,
// b_offset_grp,
// c_offset_grp);
}
group_id
=
i
;
block_id_grp
=
block_id
-
gemm_shapes
[
i
].
BlockStart
;
a_offset_grp
=
gemm_shapes
[
i
].
OffsetA
;
b_offset_grp
=
gemm_shapes
[
i
].
OffsetB
;
c_offset_grp
=
gemm_shapes
[
i
].
OffsetC
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_offset_grp
,
p_b_grid
+
b_offset_grp
,
p_c_grid
+
c_offset_grp
,
p_shared
,
a_grid_desc_k0_m_k1
[
i
],
b_grid_desc_k0_n_k1
[
i
],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
[
i
],
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
[
i
],
block_id_grp
);
// if(get_thread_local_1d_id() == 0)
// printf("%d %d %d %d %d %d\n",
// block_id,
// group_id,
// block_id_grp,
// a_offset_grp,
// b_offset_grp,
// c_offset_grp);
}
});
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
GemmDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
bool
HasMainK0BlockLoop
,
index_t
MaxGroupCount
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdlops_v2r4
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
StaticallyIndexedArray
<
AGridDesc_K0_M_K1
,
MaxGroupCount
>
a_grid_desc_k0_m_k1
,
const
StaticallyIndexedArray
<
BGridDesc_K0_N_K1
,
MaxGroupCount
>
b_grid_desc_k0_n_k1
,
const
StaticallyIndexedArray
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
MaxGroupCount
>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_shapes
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
StaticallyIndexedArray
<
Block2CTileMap
,
MaxGroupCount
>
block_2_ctile_map
)
{
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
__shared__
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
[
MaxGroupCount
];
__shared__
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
[
MaxGroupCount
];
__shared__
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
[
MaxGroupCount
];
__shared__
Block2CTileMap
block_2_ctile_map_
[
MaxGroupCount
];
__shared__
GemmDesc
gemm_shapes_
[
MaxGroupCount
];
if
(
get_thread_local_1d_id
())
{
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
a_grid_desc_k0_m_k1_
[
i
]
=
a_grid_desc_k0_m_k1
[
i
];
b_grid_desc_k0_n_k1_
[
i
]
=
b_grid_desc_k0_n_k1
[
i
];
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
[
i
]
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
[
i
];
block_2_ctile_map_
[
i
]
=
block_2_ctile_map
[
i
];
gemm_shapes_
[
i
]
=
gemm_shapes
[
i
];
});
}
block_sync_lds
();
index_t
group_id
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
group_id
=
(
block_id
>=
gemm_shapes
[
i
].
BlockStart
&&
block_id
<
(
gemm_shapes
[
i
].
BlockStart
+
gemm_shapes
[
i
].
BlockSize
))
?
i
:
group_id
;
});
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
if
(
group_id
==
0
)
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_offset_grp
,
p_b_grid
+
b_offset_grp
,
p_c_grid
+
c_offset_grp
,
p_shared
,
a_grid_desc_k0_m_k1
[
I0
],
b_grid_desc_k0_n_k1
[
I0
],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
[
I0
],
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
[
I0
],
block_id_grp
,
group_id
);
else
if
(
group_id
==
1
)
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_offset_grp
,
p_b_grid
+
b_offset_grp
,
p_c_grid
+
c_offset_grp
,
p_shared
,
a_grid_desc_k0_m_k1
[
I1
],
b_grid_desc_k0_n_k1
[
I1
],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
[
I1
],
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
[
I1
],
block_id_grp
,
group_id
);
const
index_t
block_id_grp
=
block_id
-
gemm_shapes_
[
group_id
].
BlockStart
;
const
index_t
a_offset_grp
=
gemm_shapes_
[
group_id
].
OffsetA
;
const
index_t
b_offset_grp
=
gemm_shapes_
[
group_id
].
OffsetB
;
const
index_t
c_offset_grp
=
gemm_shapes_
[
group_id
].
OffsetC
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_offset_grp
,
p_b_grid
+
b_offset_grp
,
p_c_grid
+
c_offset_grp
,
p_shared
,
a_grid_desc_k0_m_k1_
[
group_id
],
b_grid_desc_k0_n_k1_
[
group_id
],
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
[
group_id
],
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map_
[
group_id
],
block_id_grp
);
}
template
<
index_t
BlockSize
,
...
...
@@ -407,8 +470,7 @@ struct GridwiseGroupedGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
index_t
block_id
,
const
index_t
group_id
)
const
index_t
block_id
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
...
...
device_operation/include/device_grouped_gemm_xdl.hpp
View file @
55ab4687
...
...
@@ -350,6 +350,11 @@ struct DeviceGroupedGemmXdl
std
::
cout
<<
"group_id "
<<
i
<<
" BlockStart "
<<
gemm_shapes
(
i
).
BlockStart
<<
" BlockSize "
<<
gemm_shapes
(
i
).
BlockSize
<<
std
::
endl
;
}
else
{
gemm_shapes
(
i
).
BlockStart
=
-
1
;
gemm_shapes
(
i
).
BlockSize
=
-
1
;
}
});
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
[
Number
<
0
>
{}].
GetLength
(
I0
);
...
...
@@ -361,26 +366,18 @@ struct DeviceGroupedGemmXdl
#if 1
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
const
auto
kernel
=
kernel_
grouped_
gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
StaticallyIndexedArray
<
DeviceGroupedGemmXdl
::
AGridDesc_K0_M_K1
,
MaxGroupCount
>>
,
remove_reference_t
<
StaticallyIndexedArray
<
DeviceGroupedGemmXdl
::
BGridDesc_K0_N_K1
,
MaxGroupCount
>>
,
remove_reference_t
<
StaticallyIndexedArray
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
MaxGroupCount
>>
,
remove_reference_t
<
StaticallyIndexedArray
<
gemm_desc
,
MaxGroupCount
>>
,
remove_reference_t
<
DeviceGroupedGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGroupedGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
gemm_desc
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
StaticallyIndexedArray
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
MaxGroupCount
>>
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
true
,
MaxGroupCount
>
;
...
...
@@ -404,26 +401,18 @@ struct DeviceGroupedGemmXdl
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
const
auto
kernel
=
kernel_
grouped_
gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
StaticallyIndexedArray
<
DeviceGroupedGemmXdl
::
AGridDesc_K0_M_K1
,
MaxGroupCount
>>
,
remove_reference_t
<
StaticallyIndexedArray
<
DeviceGroupedGemmXdl
::
BGridDesc_K0_N_K1
,
MaxGroupCount
>>
,
remove_reference_t
<
StaticallyIndexedArray
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
MaxGroupCount
>>
,
remove_reference_t
<
StaticallyIndexedArray
<
gemm_desc
,
MaxGroupCount
>>
,
remove_reference_t
<
DeviceGroupedGemmXdl
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGroupedGemmXdl
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
gemm_desc
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
StaticallyIndexedArray
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
MaxGroupCount
>>
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
false
,
MaxGroupCount
>
;
...
...
example/12_grouped_gemm_xdl/grouped_gemm_xdl.cpp
View file @
55ab4687
...
...
@@ -76,7 +76,7 @@ int main(int argc, char* argv[])
exit
(
0
);
}
int
group_count
=
2
;
int
group_count
=
3
;
// GEMM shape
std
::
vector
<
ck
::
gemm_desc
>
gemm_shapes
;
...
...
@@ -85,11 +85,11 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
2
56
*
(
i
+
1
)
;
int
N
=
5
12
*
(
i
+
1
)
;
int
K
=
1024
*
(
i
+
1
)
;
int
M
=
2
048
+
256
*
i
;
int
N
=
2048
+
12
8
*
i
;
int
K
=
256
+
128
*
i
;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
,
A_size
,
B_size
,
C_size
,
0
,
0
});
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
,
A_size
,
B_size
,
C_size
,
-
1
,
-
1
});
A_size
+=
M
*
K
;
B_size
+=
N
*
K
;
...
...
@@ -115,6 +115,8 @@ int main(int argc, char* argv[])
std
::
vector
<
Tensor
<
CDataType
>>
c_host_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_device_tensors
;
std
::
size_t
flop
=
0
,
num_btype
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
a_tensors
.
push_back
(
Tensor
<
ADataType
>
(
f_host_tensor_descriptor
(
...
...
@@ -129,6 +131,11 @@ int main(int argc, char* argv[])
std
::
cout
<<
"gemm["
<<
i
<<
"] a_m_k: "
<<
a_tensors
[
i
].
mDesc
<<
" b_k_n: "
<<
b_tensors
[
i
].
mDesc
<<
" c_m_n: "
<<
c_device_tensors
[
i
].
mDesc
<<
std
::
endl
;
flop
+=
std
::
size_t
(
2
)
*
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
*
gemm_shapes
[
i
].
N
;
num_btype
+=
sizeof
(
ADataType
)
*
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
+
sizeof
(
BDataType
)
*
gemm_shapes
[
i
].
K
*
gemm_shapes
[
i
].
N
+
sizeof
(
CDataType
)
*
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
N
;
}
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
...
...
@@ -192,6 +199,13 @@ int main(int argc, char* argv[])
float
ave_time
=
invoker
.
Run
(
argument
,
nrepeat
);
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
c_tensors_data
.
resize
(
C_size
);
c_tensors_device_buf
.
FromDevice
(
c_tensors_data
.
data
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment