Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e43df26a
Commit
e43df26a
authored
Dec 13, 2022
by
aska-0096
Browse files
temp save, reproduce the v_bfi_b32 issue
parent
9739ede0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
146 additions
and
92 deletions
+146
-92
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+97
-81
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+3
-3
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+1
-3
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+5
-2
test/wmma_op/wmma_op_util.hpp
test/wmma_op/wmma_op_util.hpp
+40
-3
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
e43df26a
...
@@ -280,24 +280,24 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
...
@@ -280,24 +280,24 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
const
BBlockBuffer
&
b_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
CThreadBuffer
&
c_thread_buf
)
const
{
{
//
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
//
a_thread_desc_.GetElementSpaceSize());
a_thread_desc_
.
GetElementSpaceSize
());
//
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
//
b_thread_desc_.GetElementSpaceSize());
b_thread_desc_
.
GetElementSpaceSize
());
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
//
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAB
,
//
FloatAB,
MRepeat
,
//
MRepeat,
WmmaK
,
//
WmmaK,
true
>
//
true>
a_thread_buf
;
//
a_thread_buf;
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
//
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAB
,
//
FloatAB,
NRepeat
,
//
NRepeat,
WmmaK
,
//
WmmaK,
true
>
//
true>
b_thread_buf
;
//
b_thread_buf;
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
@@ -306,8 +306,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
...
@@ -306,8 +306,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
I
0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
m
0
,
I0
,
I0
,
I0
),
a_thread_buf
.
GetVectorTypeReference
(
Number
<
m0
*
WmmaK
>
{}).
template
AsType
<
FloatAB
>()
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
// read B
...
@@ -315,28 +315,28 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
...
@@ -315,28 +315,28 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
I
0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
n
0
,
I0
,
I0
,
I0
),
b_thread_buf
.
GetVectorTypeReference
(
Number
<
n0
*
WmmaK
>
{}).
template
AsType
<
FloatAB
>()
);
b_thread_buf
);
//
vector_type<FloatAB, WmmaK> a_thread_vec;
vector_type
<
FloatAB
,
WmmaK
>
a_thread_vec
;
//
vector_type<FloatAB, WmmaK> b_thread_vec;
vector_type
<
FloatAB
,
WmmaK
>
b_thread_vec
;
//
static_for<0, WmmaK, 1>{}([&](auto i) {
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
//
a_thread_vec.template AsType<FloatAB>()(i) =
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
//
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
//
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
i
%
A_K1
))
>
{}];
//
b_thread_vec.template AsType<FloatAB>()(i) =
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
//
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
//
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
i
%
B_K1
))
>
{}];
//
});
});
//
using wmma_input_type = typename vector_type<FloatAB, WmmaK>::type;
using
wmma_input_type
=
typename
vector_type
<
FloatAB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
wmma_gemm
.
template
Run
(
a_thread_
buf
.
GetVectorTypeReference
(
Number
<
m0
*
WmmaK
>{}),
a_thread_
vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>{}),
b_thread_
buf
.
GetVectorTypeReference
(
Number
<
n0
*
WmmaK
>
{}),
b_thread_
vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
...
@@ -346,11 +346,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
...
@@ -346,11 +346,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
protected:
protected:
// A[M0, M1, M2, K0 = WmmaK]
// A[M0, M1, M2, K0 = WmmaK]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
WmmaK
/
A_K1
>
{},
I1
,
I1
,
I1
,
Number
<
A_K1
>
{}));
make_tuple
(
Number
<
WmmaK
/
A_K1
>
{},
Number
<
MRepeat
>
{}
,
I1
,
I1
,
Number
<
A_K1
>
{}));
// B[N0, N1, N2, K0 = WmmaK]
// B[N0, N1, N2, K0 = WmmaK]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
WmmaK
/
B_K1
>
{},
I1
,
I1
,
I1
,
Number
<
B_K1
>
{}));
make_tuple
(
Number
<
WmmaK
/
B_K1
>
{},
Number
<
NRepeat
>
{}
,
I1
,
I1
,
Number
<
B_K1
>
{}));
// C[M, N, NumRegWMMA]
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
...
@@ -659,7 +659,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
...
@@ -659,7 +659,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
Number
<
m0
>
{}
,
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
...
@@ -668,7 +668,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
...
@@ -668,7 +668,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
Number
<
n0
>
{}
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
...
@@ -678,10 +678,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
...
@@ -678,10 +678,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
((
k
*
WmmaK
+
i
)
/
A_K1
,
0
,
0
,
0
,
(
k
*
WmmaK
+
i
)
%
A_K1
))
>
{}];
make_tuple
((
k
*
WmmaK
+
i
)
/
A_K1
,
m
0
,
0
,
0
,
(
k
*
WmmaK
+
i
)
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
((
k
*
WmmaK
+
i
)
/
B_K1
,
0
,
0
,
0
,
(
k
*
WmmaK
+
i
)
%
B_K1
))
>
{}];
make_tuple
((
k
*
WmmaK
+
i
)
/
B_K1
,
n
0
,
0
,
0
,
(
k
*
WmmaK
+
i
)
%
B_K1
))
>
{}];
});
});
using
wmma_input_type
=
typename
vector_type
<
FloatAB
,
WmmaK
>::
type
;
using
wmma_input_type
=
typename
vector_type
<
FloatAB
,
WmmaK
>::
type
;
...
@@ -701,11 +701,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
...
@@ -701,11 +701,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
protected:
protected:
// A[M0, M1, M2, K0 = WmmaK]
// A[M0, M1, M2, K0 = WmmaK]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerBlock
/
A_K1
>
{},
I1
,
I1
,
I1
,
Number
<
A_K1
>
{}));
make_tuple
(
Number
<
KPerBlock
/
A_K1
>
{},
Number
<
MRepeat
>
{}
,
I1
,
I1
,
Number
<
A_K1
>
{}));
// B[N0, N1, N2, K0 = WmmaK]
// B[N0, N1, N2, K0 = WmmaK]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerBlock
/
B_K1
>
{},
I1
,
I1
,
I1
,
Number
<
B_K1
>
{}));
make_tuple
(
Number
<
KPerBlock
/
B_K1
>
{},
Number
<
NRepeat
>
{}
,
I1
,
I1
,
Number
<
B_K1
>
{}));
// C[M, N, NumRegWMMA]
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
...
@@ -716,7 +716,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
...
@@ -716,7 +716,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
KPerBlock
/
A_K1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
KPerBlock
/
A_K1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
3
,
0
,
1
,
2
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
A_K1
,
A_K1
,
A_K1
>
;
A_K1
>
;
...
@@ -726,7 +726,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
...
@@ -726,7 +726,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_MNKloop
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
KPerBlock
/
B_K1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
KPerBlock
/
B_K1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
3
,
0
,
1
,
2
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
B_K1
,
B_K1
,
B_K1
>
;
B_K1
>
;
...
@@ -1009,9 +1009,17 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
...
@@ -1009,9 +1009,17 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
b_thread_desc_
.
GetElementSpaceSize
());
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
RepeatDiff
=
MRepeat
-
NRepeat
;
constexpr
auto
RepeatDiff
=
MRepeat
-
NRepeat
;
static_for
<
0
,
KPerBlock
,
WmmaK
>
{}([
&
](
auto
iWmmaK
){
static_for
<
0
,
KPerBlock
,
WmmaK
>
{}([
&
](
auto
iWmmaK
){
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
iWmmaK
/
B_K1
>
{},
Number
<
iN
>
{},
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
Number
<
iN
>
{},
I0
,
I0
,
I0
),
b_thread_buf
);
});
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for
<
0
,
RepeatDiff
,
1
>
{}([
&
](
auto
iCut
){
static_for
<
0
,
RepeatDiff
,
1
>
{}([
&
](
auto
iCut
){
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
...
@@ -1021,12 +1029,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
...
@@ -1021,12 +1029,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
make_tuple
(
I0
,
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
make_tuple
(
I0
,
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
//
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple
(
Number
<
iWmmaK
/
B_K1
>
{},
Number
<
iN
>
{},
I0
,
I0
,
I0
),
//
make_tuple(Number<iWmmaK/B_K1>{}, Number<iN>{}, I0, I0, I0),
b_block_buf
,
//
b_block_buf,
b_thread_desc_
,
//
b_thread_desc_,
make_tuple
(
I0
,
Number
<
iN
>
{},
I0
,
I0
,
I0
),
//
make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_buf
);
//
b_thread_buf);
vector_type
<
FloatAB
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
b_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
b_thread_vec
;
...
@@ -1042,30 +1050,34 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
...
@@ -1042,30 +1050,34 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
using
wmma_input_type
=
typename
vector_type
<
FloatAB
,
WmmaK
>::
type
;
using
wmma_input_type
=
typename
vector_type
<
FloatAB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iCut
,
iN
,
0
));
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iCut
,
iN
,
0
));
s_nop
();
wmma_gemm
.
template
Run
(
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>{}),
a_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>
{}),
b_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
s_nop
();
});
});
});
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
WmmaInnerloop
){
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
iWmmaK
/
A_K1
>
{},
Number
<
WmmaInnerloop
+
RepeatDiff
>
{},
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
Number
<
WmmaInnerloop
+
RepeatDiff
>
{},
I0
,
I0
,
I0
),
a_thread_buf
);
});
// Stage 2: Run FIFO fashion loopover in Square
// Stage 2: Run FIFO fashion loopover in Square
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
WmmaInnerloop
){
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
WmmaInnerloop
){
// Row Repeatation
// Row Repeatation
static_for
<
WmmaInnerloop
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
static_for
<
WmmaInnerloop
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
iWmmaK
/
A_K1
>
{},
Number
<
WmmaInnerloop
+
RepeatDiff
>
{},
I0
,
I0
,
I0
),
// b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
a_block_buf
,
// make_tuple(Number<iWmmaK/B_K1>{}, Number<iN>{}, I0, I0, I0),
a_thread_desc_
,
// b_block_buf,
make_tuple
(
I0
,
Number
<
WmmaInnerloop
+
RepeatDiff
>
{},
I0
,
I0
,
I0
),
// b_thread_desc_,
a_thread_buf
);
// make_tuple(I0, Number<iN>{}, I0, I0, I0),
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
// b_thread_buf);
make_tuple
(
Number
<
iWmmaK
/
B_K1
>
{},
Number
<
iN
>
{},
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
Number
<
iN
>
{},
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
FloatAB
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
b_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
b_thread_vec
;
...
@@ -1081,27 +1093,29 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
...
@@ -1081,27 +1093,29 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
+
RepeatDiff
,
iN
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
+
RepeatDiff
,
iN
,
0
));
s_nop
();
wmma_gemm
.
template
Run
(
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>{}),
a_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>
{}),
b_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
s_nop
();
});
});
// WmmaInnerloop++
// WmmaInnerloop++
// Col Repeatation
// Col Repeatation
static_for
<
WmmaInnerloop
+
1
+
RepeatDiff
,
MRepeat
,
1
>
{}([
&
](
auto
iM
){
static_for
<
WmmaInnerloop
+
1
+
RepeatDiff
,
MRepeat
,
1
>
{}([
&
](
auto
iM
){
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
//
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple
(
Number
<
iWmmaK
/
A_K1
>
{},
Number
<
iM
>
{},
I0
,
I0
,
I0
),
//
make_tuple(Number<iWmmaK/A_K1>{}, Number<iM>{}, I0, I0, I0),
a_block_buf
,
//
a_block_buf,
a_thread_desc_
,
//
a_thread_desc_,
make_tuple
(
I0
,
Number
<
iM
>
{},
I0
,
I0
,
I0
),
//
make_tuple(I0, Number<iM>{}, I0, I0, I0),
a_thread_buf
);
//
a_thread_buf);
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
//
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple
(
Number
<
iWmmaK
/
B_K1
>
{},
Number
<
WmmaInnerloop
>
{},
I0
,
I0
,
I0
),
//
make_tuple(Number<iWmmaK/B_K1>{}, Number<WmmaInnerloop>{}, I0, I0, I0),
b_block_buf
,
//
b_block_buf,
b_thread_desc_
,
//
b_thread_desc_,
make_tuple
(
I0
,
Number
<
WmmaInnerloop
>
{},
I0
,
I0
,
I0
),
//
make_tuple(I0, Number<WmmaInnerloop>{}, I0, I0, I0),
b_thread_buf
);
//
b_thread_buf);
vector_type
<
FloatAB
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
b_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
b_thread_vec
;
...
@@ -1117,10 +1131,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
...
@@ -1117,10 +1131,12 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iM
,
WmmaInnerloop
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iM
,
WmmaInnerloop
,
0
));
s_nop
();
wmma_gemm
.
template
Run
(
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>{}),
a_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>
{}),
b_thread_vec
.
template
AsType
<
wmma_input_type
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
s_nop
();
});
});
});
});
});
});
...
@@ -1144,7 +1160,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
...
@@ -1144,7 +1160,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
WmmaK
/
A_K1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
WmmaK
/
A_K1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
3
,
0
,
1
,
2
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
A_K1
,
A_K1
,
A_K1
>
;
A_K1
>
;
...
@@ -1154,7 +1170,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
...
@@ -1154,7 +1170,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
WmmaK
/
B_K1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
WmmaK
/
B_K1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
3
,
0
,
1
,
2
,
4
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
4
,
B_K1
,
B_K1
,
B_K1
>
;
B_K1
>
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
e43df26a
...
@@ -310,7 +310,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -310,7 +310,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
WmmaK
=
16
;
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
using
BlockwiseGemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
<
using
BlockwiseGemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
_FIFO
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
...
@@ -367,7 +367,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -367,7 +367,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
WmmaK
=
16
;
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
using
BlockwiseGemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
<
using
BlockwiseGemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
_FIFO
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
...
@@ -540,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -540,7 +540,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
<
BlockSize
,
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
_FIFO
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
...
...
include/ck/utility/amd_inline_asm.hpp
View file @
e43df26a
...
@@ -360,9 +360,7 @@ __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a,
...
@@ -360,9 +360,7 @@ __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a,
half16_t
b
,
half16_t
b
,
float8_t
&
c
)
float8_t
&
c
)
{
{
asm
volatile
(
"
\n
\
asm
volatile
(
"v_wmma_f32_16x16x16_f16 %0, %1, %2, %0"
v_wmma_f32_16x16x16_f16_w32 %0, %1, %2, %0
\n
\
"
:
"=v"
(
c
)
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
}
}
...
...
include/ck/utility/amd_wmma.hpp
View file @
e43df26a
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#ifndef CK_AMD_WMMA_HPP
#ifndef CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP
#define CK_AMD_WMMA_HPP
#include "ck/utility/amd_inline_asm.hpp"
#include "data_type.hpp"
#include "data_type.hpp"
// TODO: Add arch limitation
// TODO: Add arch limitation
namespace
ck
{
namespace
ck
{
...
@@ -20,8 +21,10 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
...
@@ -20,8 +21,10 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
// * Inline assembly need to elimate the duplicated data load, compiler won't help you delete them.
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
amd_assembly_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
}
}
};
};
...
...
test/wmma_op/wmma_op_util.hpp
View file @
e43df26a
...
@@ -97,6 +97,7 @@ builtin_wmma_naive_selector<int4x16_t,
...
@@ -97,6 +97,7 @@ builtin_wmma_naive_selector<int4x16_t,
template
<
typename
src_t
,
typename
dst_t
,
typename
acc_t
,
index_t
acc_num
>
template
<
typename
src_t
,
typename
dst_t
,
typename
acc_t
,
index_t
acc_num
>
__global__
void
matmul
(
const
src_t
*
a
,
const
src_t
*
b
,
dst_t
*
c
)
__global__
void
matmul
(
const
src_t
*
a
,
const
src_t
*
b
,
dst_t
*
c
)
{
{
__shared__
src_t
p_shared
[
16
*
16
*
2
];
const
int
lIdx
=
threadIdx
.
x
;
const
int
lIdx
=
threadIdx
.
x
;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
...
@@ -104,6 +105,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
...
@@ -104,6 +105,9 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
using
src_vec
=
typename
vector_type
<
src_t
,
16
>::
type
;
using
src_vec
=
typename
vector_type
<
src_t
,
16
>::
type
;
src_vec
a_frag
=
{};
src_vec
a_frag
=
{};
src_vec
b_frag
=
{};
src_vec
b_frag
=
{};
src_vec
a_temp
=
{};
src_vec
b_temp
=
{};
// initialize c fragment to 0
// initialize c fragment to 0
using
acc_vec
=
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
acc_t
,
1
,
acc_num
,
true
>
;
using
acc_vec
=
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
acc_t
,
1
,
acc_num
,
true
>
;
acc_vec
c_thread_buf_
;
acc_vec
c_thread_buf_
;
...
@@ -112,19 +116,52 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
...
@@ -112,19 +116,52 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c)
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
const
int
lane
=
lIdx
%
16
;
const
int
lane
=
lIdx
%
16
;
const
int
lane_lo
=
lIdx
/
2
;
const
int
lane_hi
=
lIdx
%
2
;
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
{
a_temp
[
ele
]
=
a
[
8
*
lane_hi
+
16
*
lane_lo
+
ele
];
}
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
{
b_temp
[
ele
]
=
b
[
8
*
lane_hi
+
16
*
lane_lo
+
ele
];
}
__syncthreads
();
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
{
p_shared
[
8
*
16
*
lane_hi
+
8
*
lane_lo
+
ele
]
=
a_temp
[
ele
];
}
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
{
p_shared
[
8
*
16
*
lane_hi
+
8
*
lane_lo
+
ele
+
16
*
16
]
=
b_temp
[
ele
];
}
asm
volatile
(
"\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
{
b_frag
[
ele
]
=
b
[
16
*
lane
+
ele
];
b_frag
[
ele
]
=
p_shared
[(
ele
/
8
)
*
16
*
8
+
8
*
lane
+
ele
%
8
+
16
*
16
];
}
}
// follow origin design
// follow origin design
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
{
a_frag
[
ele
]
=
a
[
16
*
lane
+
ele
];
a_frag
[
ele
]
=
p_shared
[(
ele
/
8
)
*
16
*
8
+
8
*
lane
+
ele
%
8
];
}
}
asm
volatile
(
"\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
// sync threads, similar to mma_sync
// sync threads, similar to mma_sync
__syncthreads
();
//
__syncthreads();
builtin_wmma_naive_selector
<
src_vec
,
acc_vec
>
(
a_frag
,
b_frag
,
c_thread_buf_
);
builtin_wmma_naive_selector
<
src_vec
,
acc_vec
>
(
a_frag
,
b_frag
,
c_thread_buf_
);
__syncthreads
();
__syncthreads
();
// wait for results, similar to mma_sync
// wait for results, similar to mma_sync
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment