Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
12a4ea69
Commit
12a4ea69
authored
May 19, 2023
by
aska-0096
Browse files
(3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds
parent
3ccfb0ae
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
104 deletions
+36
-104
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
..._bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
+10
-10
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+1
-79
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+25
-15
No files found.
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
View file @
12a4ea69
...
@@ -67,33 +67,33 @@ using DeviceOpInstanceKKNN =
...
@@ -67,33 +67,33 @@ using DeviceOpInstanceKKNN =
ASpec
,
ASpec
,
BSpec
,
BSpec
,
DESpec
,
DESpec
,
2
,
1
,
128
,
128
,
64
,
64
,
128
,
64
,
32
,
64
,
8
,
4
,
16
,
16
,
16
,
16
,
2
,
1
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
4
,
8
,
4
,
true
,
true
,
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
4
,
8
,
4
,
true
,
true
,
1
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
S
<
1
,
64
,
1
,
2
>
,
8
>
;
8
>
;
using
DeviceOpInstance
=
DeviceOpInstanceKKNN
;
using
DeviceOpInstance
=
DeviceOpInstanceKKNN
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
12a4ea69
...
@@ -305,7 +305,7 @@ struct BlockwiseGemmWMMA
...
@@ -305,7 +305,7 @@ struct BlockwiseGemmWMMA
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
// read A
a_thread_copy_
.
Run
(
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
...
@@ -346,84 +346,6 @@ struct BlockwiseGemmWMMA
...
@@ -346,84 +346,6 @@ struct BlockwiseGemmWMMA
i
%
B_K1
))
>
{}];
i
%
B_K1
))
>
{}];
});
});
#if 0
if (get_thread_local_1d_id() == 0){
printf("repeat: m,n,k:(%02d, %02d, %02d) a_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x\n",
m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(0 / A_K1, m0, 0, 0, 0 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(1 / A_K1, m0, 0, 0, 1 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(2 / A_K1, m0, 0, 0, 2 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(3 / A_K1, m0, 0, 0, 3 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(4 / A_K1, m0, 0, 0, 4 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(5 / A_K1, m0, 0, 0, 5% A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(6 / A_K1, m0, 0, 0, 6 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(7 / A_K1, m0, 0, 0, 7 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(8 / A_K1, m0, 0, 0, 8 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(9 / A_K1, m0, 0, 0, 9% A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(10 / A_K1, m0, 0, 0, 10 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(11 / A_K1, m0, 0, 0, 11 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(12 / A_K1, m0, 0, 0, 12 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(13 / A_K1, m0, 0, 0, 13 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(14 / A_K1, m0, 0, 0, 14 % A_K1))>{}]))),
*(reinterpret_cast<const uint16_t*>(&(a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(15 / A_K1, m0, 0, 0, 15% A_K1))>{}])))
);
}
// if (get_thread_local_1d_id() == 0){
// printf("repeat: m,n,k:(%02d, %02d, %02d) b_thread_buf: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x\n",
// m0.value, n0.value, k.value,
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(0 / B_K1, n0, 0, 0, 0 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(1 / B_K1, n0, 0, 0, 1 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(2 / B_K1, n0, 0, 0, 2 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(3 / B_K1, n0, 0, 0, 3 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(4 / B_K1, n0, 0, 0, 4 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(5 / B_K1, n0, 0, 0, 5% B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(6 / B_K1, n0, 0, 0, 6 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(7 / B_K1, n0, 0, 0, 7 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(8 / B_K1, n0, 0, 0, 8 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(9 / B_K1, n0, 0, 0, 9% B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(10 / B_K1, n0, 0, 0, 10 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(11 / B_K1, n0, 0, 0, 11 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(12 / B_K1, n0, 0, 0, 12 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(13 / B_K1, n0, 0, 0, 13 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(14 / B_K1, n0, 0, 0, 14 % B_K1))>{}]))),
// *(reinterpret_cast<const uint16_t*>(&(b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(15 / B_K1, n0, 0, 0, 15% B_K1))>{}])))
// );
// }
#endif
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
12a4ea69
...
@@ -125,6 +125,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -125,6 +125,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
// K1 = Max Vector Access Pixels
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
...
@@ -136,9 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -136,9 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
// If true, LDS is used unconditionally
// If true, LDS is used unconditionally
// Bug, MNK vector load check not implemented correctly
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds_manu
=
true
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
true
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
...
@@ -220,18 +220,21 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -220,18 +220,21 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}
}
else
else
{
{
constexpr
auto
A_KRow
=
WmmaK
/
K1
;
constexpr
auto
A_KRow
=
2
;
const
auto
A_KWmma
=
K
/
WmmaK
;
constexpr
auto
A_K0PerWmma
=
WmmaK
/
A_KRow
/
K1Number
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
M0
=
M
/
MPerBlock
;
const
auto
M0
=
M
/
MPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_KRow
>
{},
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_K0PerWmma
>
{},
Number
<
A_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
}
}
...
@@ -309,18 +312,21 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -309,18 +312,21 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}
}
else
else
{
{
constexpr
auto
B_KRow
=
WmmaK
/
K1
;
constexpr
auto
B_KRow
=
2
;
const
auto
B_KWmma
=
K
/
WmmaK
;
constexpr
auto
B_K0PerWmma
=
WmmaK
/
B_KRow
/
K1Number
;
const
auto
B_KWmma
=
K
/
WmmaK
;
const
auto
N0
=
N
/
NPerBlock
;
const
auto
N0
=
N
/
NPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_KRow
>
{},
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_K0PerWmma
>
{},
Number
<
B_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
}
}
...
@@ -752,7 +758,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -752,7 +758,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
else
else
{
{
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I3
)
*
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I3
)
*
arg
.
a_grid_desc_
.
GetLength
(
I
5
);
arg
.
a_grid_desc_
.
GetLength
(
I
4
)
*
arg
.
a_grid_desc_
.
GetLength
(
I6
);
}
}
}();
}();
...
@@ -1036,7 +1042,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -1036,7 +1042,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
<<
MRepeat
<<
", "
<<
MRepeat
<<
", "
<<
NRepeat
<<
NRepeat
<<
">"
<<
">"
<<
" NumPrefetch: "
<<
" AEnableLds: "
<<
AEnableLds
<<
", "
<<
"BEnableLds: "
<<
BEnableLds
<<
", "
<<
"NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment