Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5d5891b0
Commit
5d5891b0
authored
Dec 15, 2022
by
aska-0096
Browse files
clang format
parent
cfb397b1
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
144 additions
and
118 deletions
+144
-118
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+0
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+84
-76
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+13
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+14
-14
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+24
-9
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+2
-6
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+7
-4
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
5d5891b0
...
...
@@ -30,7 +30,6 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
128
,
128
,
8
,
8
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
5d5891b0
...
...
@@ -52,7 +52,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
>
{};
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
...
...
@@ -279,7 +280,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
};
// block wise level pipe designed for inline asm
template
<
index_t
BlockSize
,
typename
FloatA
,
...
...
@@ -321,7 +321,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
>
{};
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
...
...
@@ -512,7 +513,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
constexpr
auto
RepeatDiff
=
MRepeat
-
NRepeat
;
// Read all Mrepeat, Nrepeat
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
)
{
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
I0
,
Number
<
iN
>
{},
I0
,
I0
,
I0
),
b_block_buf
,
...
...
@@ -521,7 +522,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
b_thread_buf
);
});
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
){
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
I0
,
Number
<
iM
>
{},
I0
,
I0
,
I0
),
a_block_buf
,
...
...
@@ -531,24 +532,24 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
});
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for
<
0
,
RepeatDiff
,
1
>
{}([
&
](
auto
iCut
){
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
static_for
<
0
,
RepeatDiff
,
1
>
{}([
&
](
auto
iCut
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
iCut
,
0
,
0
,
iK
%
A_K1
))
>
{}];
make_tuple
(
iK
/
A_K1
,
iCut
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iCut
,
iN
,
0
));
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iCut
,
iN
,
0
));
s_nop
();
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
...
...
@@ -556,10 +557,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
s_nop
();
});
if
constexpr
(
KPerBlock
>
WmmaK
){
if
constexpr
(
KPerBlock
>
WmmaK
)
{
// Read Consumed Next inner loop A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
WmmaK
/
A_K1
>
{},
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
make_tuple
(
Number
<
WmmaK
/
A_K1
>
{},
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
...
...
@@ -567,27 +569,27 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
}
});
static_for
<
WmmaK
,
KPerBlock
,
WmmaK
>
{}([
&
](
auto
iWmmaK
){
static_for
<
WmmaK
,
KPerBlock
,
WmmaK
>
{}([
&
](
auto
iWmmaK
)
{
// Stage 2: Run FIFO fashion loopover in Square
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
WmmaInnerloop
){
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
WmmaInnerloop
)
{
// Row Repeatation
static_for
<
WmmaInnerloop
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
static_for
<
WmmaInnerloop
,
NRepeat
,
1
>
{}([
&
](
auto
iN
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
WmmaInnerloop
+
RepeatDiff
,
0
,
0
,
iK
%
A_K1
))
>
{}];
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
WmmaInnerloop
+
RepeatDiff
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
+
RepeatDiff
,
iN
,
0
));
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
+
RepeatDiff
,
iN
,
0
));
s_nop
();
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
...
...
@@ -597,25 +599,27 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
});
// Read Consumed Next inner loop A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
iWmmaK
/
A_K1
>
{},
Number
<
WmmaInnerloop
+
RepeatDiff
>
{},
I0
,
I0
,
I0
),
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
iWmmaK
/
A_K1
>
{},
Number
<
WmmaInnerloop
+
RepeatDiff
>
{},
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
Number
<
WmmaInnerloop
+
RepeatDiff
>
{},
I0
,
I0
,
I0
),
make_tuple
(
I0
,
Number
<
WmmaInnerloop
+
RepeatDiff
>
{},
I0
,
I0
,
I0
),
a_thread_buf
);
// Col Repeatation
static_for
<
WmmaInnerloop
+
1
+
RepeatDiff
,
MRepeat
,
1
>
{}([
&
](
auto
iM
){
static_for
<
WmmaInnerloop
+
1
+
RepeatDiff
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
iM
,
0
,
0
,
iK
%
A_K1
))
>
{}];
make_tuple
(
iK
/
A_K1
,
iM
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
B_K1
,
WmmaInnerloop
,
0
,
0
,
iK
%
B_K1
))
>
{}];
make_tuple
(
iK
/
B_K1
,
WmmaInnerloop
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
...
...
@@ -630,8 +634,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
s_nop
();
});
// Read Consumed Next inner loop B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
iWmmaK
/
B_K1
>
{},
Number
<
WmmaInnerloop
>
{},
I0
,
I0
,
I0
),
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
iWmmaK
/
B_K1
>
{},
Number
<
WmmaInnerloop
>
{},
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
Number
<
WmmaInnerloop
>
{},
I0
,
I0
,
I0
),
...
...
@@ -639,23 +644,24 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
});
// Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for
<
0
,
RepeatDiff
,
1
>
{}([
&
](
auto
iCut
){
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
static_for
<
0
,
RepeatDiff
,
1
>
{}([
&
](
auto
iCut
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
iCut
,
0
,
0
,
iK
%
A_K1
))
>
{}];
make_tuple
(
iK
/
A_K1
,
iCut
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iCut
,
iN
,
0
));
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iCut
,
iN
,
0
));
s_nop
();
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
...
...
@@ -663,9 +669,11 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
s_nop
();
});
if
constexpr
(
KPerBlock
>
WmmaK
){
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
(
iWmmaK
+
WmmaK
)
/
A_K1
>
{},
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
if
constexpr
(
KPerBlock
>
WmmaK
)
{
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
(
iWmmaK
+
WmmaK
)
/
A_K1
>
{},
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
...
...
@@ -675,25 +683,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
});
// Stage 2: Run FIFO fashion loopover in Square
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
WmmaInnerloop
){
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
WmmaInnerloop
)
{
// Row Repeatation
static_for
<
WmmaInnerloop
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
static_for
<
WmmaInnerloop
,
NRepeat
,
1
>
{}([
&
](
auto
iN
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
WmmaInnerloop
+
RepeatDiff
,
0
,
0
,
iK
%
A_K1
))
>
{}];
make_tuple
(
iK
/
A_K1
,
WmmaInnerloop
+
RepeatDiff
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
+
RepeatDiff
,
iN
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
+
RepeatDiff
,
iN
,
0
));
s_nop
();
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
...
...
@@ -703,17 +711,17 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
});
// Col Repeatation
static_for
<
WmmaInnerloop
+
1
+
RepeatDiff
,
MRepeat
,
1
>
{}([
&
](
auto
iM
){
static_for
<
WmmaInnerloop
+
1
+
RepeatDiff
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
A_K1
,
iM
,
0
,
0
,
iK
%
A_K1
))
>
{}];
make_tuple
(
iK
/
A_K1
,
iM
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iK
/
B_K1
,
WmmaInnerloop
,
0
,
0
,
iK
%
B_K1
))
>
{}];
make_tuple
(
iK
/
B_K1
,
WmmaInnerloop
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
5d5891b0
...
...
@@ -276,8 +276,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
a_grid_desc_k0_m_k1_
=
DeviceGemmWmma_CShuffle
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmWmma_CShuffle
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
a_grid_desc_k0_m_k1_
=
DeviceGemmWmma_CShuffle
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmWmma_CShuffle
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmWmma_CShuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
block_2_ctile_map_
=
...
...
@@ -289,7 +291,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
...
...
@@ -359,7 +362,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CDataType
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
@@ -391,7 +395,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CDataType
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
5d5891b0
...
...
@@ -218,7 +218,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
b_block_space_size_aligned
*
sizeof
(
FloatB
));
return
(
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
b_block_space_size_aligned
*
sizeof
(
FloatB
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
...
@@ -305,8 +306,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
5d5891b0
...
...
@@ -283,10 +283,18 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
}
};
template
<
typename
src_type_a
,
typename
src_type_b
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
NPerWmma
>
template
<
typename
src_type_a
,
typename
src_type_b
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
NPerWmma
>
struct
WmmaSelector
{
template
<
typename
src_type_a_
,
typename
src_type_b_
,
typename
dst_type_
,
index_t
MPerWmma_
,
index_t
NPerWmma_
>
template
<
typename
src_type_a_
,
typename
src_type_b_
,
typename
dst_type_
,
index_t
MPerWmma_
,
index_t
NPerWmma_
>
static
constexpr
auto
GetWmma
();
template
<
>
...
...
@@ -424,13 +432,19 @@ struct WmmaGemm
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
static_assert
(
(
is_same
<
src_type_a
,
half_t
>::
value
&&
is_same
<
src_type_b
,
half_t
>::
value
&&
is_same
<
dst_type
,
float
>::
value
)
||
(
is_same
<
src_type_a
,
bhalf_t
>::
value
&&
is_same
<
src_type_b
,
bhalf_t
>::
value
&&
is_same
<
dst_type
,
float
>::
value
)
||
(
is_same
<
src_type_a
,
half_t
>::
value
&&
is_same
<
src_type_b
,
half_t
>::
value
&&
is_same
<
dst_type
,
half_t
>::
value
)
||
(
is_same
<
src_type_a
,
bhalf_t
>::
value
&&
is_same
<
src_type_b
,
bhalf_t
>::
value
&&
is_same
<
dst_type
,
bhalf_t
>::
value
)
||
(
is_same
<
src_type_a
,
int8_t
>::
value
&&
is_same
<
src_type_b
,
int8_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
(
is_same
<
src_type_a
,
half_t
>::
value
&&
is_same
<
src_type_b
,
half_t
>::
value
&&
is_same
<
dst_type
,
float
>::
value
)
||
(
is_same
<
src_type_a
,
bhalf_t
>::
value
&&
is_same
<
src_type_b
,
bhalf_t
>::
value
&&
is_same
<
dst_type
,
float
>::
value
)
||
(
is_same
<
src_type_a
,
half_t
>::
value
&&
is_same
<
src_type_b
,
half_t
>::
value
&&
is_same
<
dst_type
,
half_t
>::
value
)
||
(
is_same
<
src_type_a
,
bhalf_t
>::
value
&&
is_same
<
src_type_b
,
bhalf_t
>::
value
&&
is_same
<
dst_type
,
bhalf_t
>::
value
)
||
(
is_same
<
src_type_a
,
int8_t
>::
value
&&
is_same
<
src_type_b
,
int8_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
(
is_same
<
src_type_a
,
int4_t
>::
value
&&
is_same
<
src_type_b
,
int4_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
||
(
is_same
<
src_type_a
,
int4_t
>::
value
&&
is_same
<
src_type_b
,
int4_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
#endif
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
...
...
@@ -479,7 +493,8 @@ struct WmmaGemm
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
}
static
constexpr
auto
wmma
=
WmmaSelector
<
src_type_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
wmma
=
WmmaSelector
<
src_type_a
,
src_type_b
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
wmma_instr
=
wmma
.
selected_wmma
;
__host__
__device__
static
constexpr
auto
...
...
include/ck/utility/amd_inline_asm.hpp
View file @
5d5891b0
...
...
@@ -356,13 +356,9 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a,
}
// Ranged input operand
__device__
void
amd_assembly_wmma_f32_16x16x16_f16_w32
(
half16_t
a
,
half16_t
b
,
float8_t
&
c
)
__device__
void
amd_assembly_wmma_f32_16x16x16_f16_w32
(
half16_t
a
,
half16_t
b
,
float8_t
&
c
)
{
asm
volatile
(
"v_wmma_f32_16x16x16_f16 %0, %1, %2, %0"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
asm
volatile
(
"v_wmma_f32_16x16x16_f16 %0, %1, %2, %0"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
}
}
// namespace ck
...
...
include/ck/utility/amd_wmma.hpp
View file @
5d5891b0
...
...
@@ -21,10 +21,13 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
amd_assembly_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{}));
// reg_c.template AsType<float8_t>()(Number<0>{}) =
// __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template
// AsType<float8_t>()[Number<0>{}]);
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment