Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f3111877
Commit
f3111877
authored
Mar 09, 2024
by
Jing Zhang
Browse files
fixed
parent
76bb51f4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
8 deletions
+44
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+43
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+1
-1
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
f3111877
...
@@ -67,11 +67,11 @@ struct BlockwiseGemmWMMA
...
@@ -67,11 +67,11 @@ struct BlockwiseGemmWMMA
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
// permutation
#ifdef __gfx12__
#ifdef __gfx12__
static
constexpr
index_t
A_KRow
=
2
;
static
constexpr
index_t
B_KRow
=
2
;
#else
static
constexpr
index_t
A_KRow
=
1
;
static
constexpr
index_t
A_KRow
=
1
;
static
constexpr
index_t
B_KRow
=
1
;
static
constexpr
index_t
B_KRow
=
1
;
#else
static
constexpr
index_t
A_KRow
=
AEnableLds
?
1
:
2
;
static
constexpr
index_t
B_KRow
=
BEnableLds
?
1
:
2
;
#endif
#endif
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
...
@@ -563,6 +563,7 @@ struct BlockwiseGemmWMMA
...
@@ -563,6 +563,7 @@ struct BlockwiseGemmWMMA
#endif
#endif
protected:
protected:
#ifdef __gfx12__
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
A_K1
>
{}),
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
A_K1
>
{}),
make_tuple
(
Number
<
A_K1
>
{},
make_tuple
(
Number
<
A_K1
>
{},
...
@@ -580,6 +581,35 @@ struct BlockwiseGemmWMMA
...
@@ -580,6 +581,35 @@ struct BlockwiseGemmWMMA
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
Number
<
1
>
{}));
#else
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
I1
,
Number
<
A_KRow
>
{},
I1
,
Number
<
A_K1
>
{}),
make_tuple
(
Number
<
A_K1
*
A_KRow
>
{},
Number
<
KPack
>
{},
Number
<
A_K1
*
A_KRow
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
1
>
{}));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
B_KRow
>
{},
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
Number
<
B_K1
*
B_KRow
>
{},
Number
<
KPack
>
{},
Number
<
B_K1
*
B_KRow
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
#endif
// C[M, N, NumRegWMMA]
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
...
@@ -610,7 +640,7 @@ struct BlockwiseGemmWMMA
...
@@ -610,7 +640,7 @@ struct BlockwiseGemmWMMA
template
<
>
template
<
>
struct
AThreadCopySelector
<
false
>
struct
AThreadCopySelector
<
false
>
{
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic
_InterRow
<
FloatA
,
FloatA
,
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
...
@@ -619,7 +649,10 @@ struct BlockwiseGemmWMMA
...
@@ -619,7 +649,10 @@ struct BlockwiseGemmWMMA
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
5
,
A_K1
>
;
A_K1
,
0x76543210
,
0xfedcba98
,
TransposeC
?
false
:
true
>
;
};
};
template
<
bool
EnableLds
>
template
<
bool
EnableLds
>
...
@@ -647,7 +680,7 @@ struct BlockwiseGemmWMMA
...
@@ -647,7 +680,7 @@ struct BlockwiseGemmWMMA
template
<
>
template
<
>
struct
BThreadCopySelector
<
false
>
struct
BThreadCopySelector
<
false
>
{
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic
_InterRow
<
FloatB
,
FloatB
,
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
...
@@ -656,7 +689,10 @@ struct BlockwiseGemmWMMA
...
@@ -656,7 +689,10 @@ struct BlockwiseGemmWMMA
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
5
,
B_K1
>
;
B_K1
,
0x76543210
,
0xfedcba98
,
TransposeC
?
true
:
false
>
;
};
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
f3111877
...
@@ -135,7 +135,7 @@ struct GridwiseGemm_Wmma
...
@@ -135,7 +135,7 @@ struct GridwiseGemm_Wmma
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
WmmaK
=
K1
==
16
?
32
:
16
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment