Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9adf2e60
Commit
9adf2e60
authored
Nov 30, 2022
by
aska-0096
Browse files
runtime bug, cannot find symbol
parent
b3cc22a3
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
201 additions
and
215 deletions
+201
-215
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+64
-98
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+25
-24
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+90
-69
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+21
-23
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
9adf2e60
...
...
@@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
tru
e
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
tru
e
,
7
,
1
>
;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
128
,
128
,
4
,
8
,
16
,
16
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
fals
e
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
fals
e
,
6
,
1
>
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
9adf2e60
...
...
@@ -10,16 +10,6 @@
namespace
ck
{
enum
struct
LoopScheduler
{
Default
,
};
constexpr
LoopScheduler
make_default_loop_scheduler
()
{
return
LoopScheduler
::
Default
;
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
...
...
@@ -30,18 +20,22 @@ template <index_t BlockSize,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
// MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLanelow
/* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
4
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
WmmaK
=
Number
<
16
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
static
constexpr
index_t
WaveSize
=
get_warp_size
()
;
static
constexpr
index_t
WaveSize
=
32
;
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
...
...
@@ -52,7 +46,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
wmma_gemm
=
W
MMA
Gemm
<
FloatAB
,
MPerWMMA
,
NPerWMMA
,
KPack
>
{};
static
constexpr
auto
wmma_gemm
=
W
mma
Gemm
<
FloatAB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
wmma_gemm
.
K0PerWMMA
;
...
...
@@ -62,7 +56,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
wmma_gemm
.
GetRegSizePerW
MMA
(),
wmma_gemm
.
GetRegSizePerW
mma
(),
true
>
c_thread_buf_
;
...
...
@@ -87,7 +81,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|M
w
ave |MLane |KPack
// |KRepeat |MRepeat|M
W
ave |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
WMMA_a_idx
,
0
);
}
...
...
@@ -131,7 +125,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0
n0
m1n1
m
2m3
m4n2_v1
()
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1
m2n0
n1
n
2m3
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
...
...
@@ -157,76 +151,49 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
MSubGroup
,
Number
<
NRepeat
>
{},
I1
,
NThreadPerSubGroup
,
MAccVgprs
));
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M
0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeCGridDescriptor_M
BlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m
0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
const
auto
c_grid_desc_m
blockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerWMMA
),
MWaves
,
MPerWMMA
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerWMMA
),
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
wmma_gemm
.
MakeCDesc
riptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
return
wmma_gemm
.
MakeCDesc
_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K
Repeat
_M0_M1_M2_K
Pack
()
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K
0
_M0_M1_M2_K
1
()
{
static
constexpr
auto
a_block_desc_temp_km0m1m2
=
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_merge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{}));
return
transform_tensor_descriptor
(
a_b
lock
_d
esc
_temp_km0m1m2
,
AK0MK1B
lock
D
esc
{}
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
*
A_K1
/
KPack
>
{},
Number
<
KPack
>
{})),
make_pass_through_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
,
4
>
{},
Sequence
<
1
,
2
,
3
>
{}));
make_pass_through_transform
(
Number
<
A_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K
Repeat
_N0_N1_N2_K
Pack
()
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K
0
_N0_N1_N2_K
1
()
{
static
constexpr
auto
b_block_desc_temp_kn0n1n2
=
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_merge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
Number
<
B_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{}));
return
transform_tensor_descriptor
(
b_b
lock
_d
esc
_temp_kn0n1n2
,
BK0NK1B
lock
D
esc
{}
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
*
B_K1
/
KPack
>
{},
Number
<
KPack
>
{})),
make_pass_through_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
,
4
>
{},
Sequence
<
1
,
2
,
3
>
{}));
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
static
constexpr
auto
a_block_desc_k
repeat
_m0_m1_m2_k
pack
=
MakeABlockDescriptor_K
Repeat
_M0_M1_M2_K
Pack
();
static
constexpr
auto
b_block_desc_k
repeat
_n0_n1_n2_k
pack
=
MakeBBlockDescriptor_K
Repeat
_N0_N1_N2_K
Pack
();
static
constexpr
auto
a_block_desc_k
0
_m0_m1_m2_k
1
=
MakeABlockDescriptor_K
0
_M0_M1_M2_K
1
();
static
constexpr
auto
b_block_desc_k
0
_n0_n1_n2_k
1
=
MakeBBlockDescriptor_K
0
_N0_N1_N2_K
1
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
...
...
@@ -239,9 +206,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
RepeatDiff
=
MRepeat
-
NRepeat
;
constexpr
auto
WmmaK
=
wmma_gemm
.
k_per_wmma
;
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
iWmmaK
){
static_for
<
0
,
KPerBlock
,
WmmaK
>
{}([
&
](
auto
iWmmaK
){
// Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for
<
0
,
RepeatDiff
,
1
>
{}([
&
](
auto
iCut
){
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
...
...
@@ -251,25 +217,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iCut
,
0
,
0
,
iK
))
>
{}];
make_tuple
(
iK
/
A_K1
,
iCut
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iN
,
0
,
0
,
iK
))
>
{}];
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type
=
typename
vector_type
<
FloatAB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iCut
,
iN
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type
>(),
a_thread_vec
.
template
AsType
<
wmma_input_type
>()
(
Number
<
0
>{})
,
b_thread_vec
.
template
AsType
<
wmma_input_type
>()
(
Number
<
0
>
{})
,
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
a_thread_copy_
.
Run
(
a_block_desc_k
repeat
_m0_m1_m2_k
pack
,
make_tuple
(
Number
<
iWmmaK
>
{},
iCut
,
I0
,
I0
,
I0
),
a_thread_copy_
.
Run
(
a_block_desc_k
0
_m0_m1_m2_k
1
,
make_tuple
(
Number
<
iWmmaK
/
A_K1
>
{},
Number
<
iCut
>
{}
,
I0
,
I0
,
Number
<
iWmmaK
%
A_K1
>
{}
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
Number
<
iCut
>
{},
I0
,
I0
,
I0
),
a_thread_buf
);
});
// Run FIFO fashion loopover in Square
...
...
@@ -281,25 +247,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
+
RepeatDiff
,
0
,
0
,
iK
))
>
{}];
make_tuple
(
iK
/
A_K1
,
WmmaInnerloop
+
RepeatDiff
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iN
,
0
,
0
,
iK
))
>
{}];
make_tuple
(
iK
/
B_K1
,
iN
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type
=
typename
vector_type
<
FloatAB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
+
RepeatDiff
,
iN
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type
>(),
a_thread_vec
.
template
AsType
<
wmma_input_type
>()
(
Number
<
0
>{})
,
b_thread_vec
.
template
AsType
<
wmma_input_type
>()
(
Number
<
0
>
{})
,
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
a_thread_copy_
.
Run
(
a_block_desc_k
repeat
_m0_m1_m2_k
pack
,
make_tuple
(
Number
<
iWmmaK
>
{},
WmmaInnerloop
+
RepeatDiff
,
I0
,
I0
,
I0
),
a_thread_copy_
.
Run
(
a_block_desc_k
0
_m0_m1_m2_k
1
,
make_tuple
(
Number
<
iWmmaK
/
A_K1
>
{},
Number
<
WmmaInnerloop
+
RepeatDiff
>
{}
,
I0
,
I0
,
Number
<
iWmmaK
%
A_K1
>
{}
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
Number
<
WmmaInnerloop
+
RepeatDiff
>
{},
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
WmmaInnerloop
+
RepeatDiff
,
MRepeat
,
1
>
{}([
&
](
auto
iM
){
vector_type
<
FloatAB
,
WmmaK
>
a_thread_vec
;
...
...
@@ -308,25 +274,25 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iM
,
0
,
0
,
iK
))
>
{}];
make_tuple
(
iK
/
A_K1
,
iM
,
0
,
0
,
iK
%
A_K1
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
,
0
,
0
,
iK
))
>
{}];
make_tuple
(
iK
/
B_K1
,
WmmaInnerloop
,
0
,
0
,
iK
%
B_K1
))
>
{}];
});
using
wmma_input_type
=
typename
vector_type
<
FloatAB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iM
,
WmmaInnerloop
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type
>(),
a_thread_vec
.
template
AsType
<
wmma_input_type
>()
(
Number
<
0
>{})
,
b_thread_vec
.
template
AsType
<
wmma_input_type
>()
(
Number
<
0
>
{})
,
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
b_thread_copy_
.
Run
(
b_block_desc_k
repeat
_n0_n1_n2_k
pack
,
make_tuple
(
Number
<
iWmmaK
>
{},
WmmaInnerloop
,
I0
,
I0
,
I0
),
b_thread_copy_
.
Run
(
b_block_desc_k
0
_n0_n1_n2_k
1
,
make_tuple
(
Number
<
iWmmaK
/
B_K1
>
{},
Number
<
WmmaInnerloop
>
{}
,
I0
,
I0
,
Number
<
iWmmaK
%
B_K1
>
{}
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
Number
<
WmmaInnerloop
>
{},
I0
,
I0
,
I0
),
b_thread_buf
);
});
});
...
...
@@ -335,33 +301,33 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
protected:
// A[M0, M1, M2, K0 = WmmaK]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
WmmaK
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
WmmaK
/
A_K1
>
{},
Number
<
MRepeat
>
{}
,
I1
,
I1
,
Number
<
A_K1
>
{}));
// B[N0, N1, N2, K0 = WmmaK]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
WmmaK
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
WmmaK
/
B_K1
>
{},
Number
<
MRepeat
>
{}
,
I1
,
I1
,
Number
<
B_K1
>
{}));
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerW
MMA
()));
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerW
mma
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_k
repeat
_m0_m1_m2_k
pack
),
decltype
(
a_block_desc_k
0
_m0_m1_m2_k
1
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
WmmaK
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
Sequence
<
WmmaK
/
A_K
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
3
,
0
,
1
,
2
,
4
>
,
4
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_k
repeat
_n0_n1_n2_k
pack
),
decltype
(
b_block_desc_k
0
_n0_n1_n2_k
1
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
WmmaK
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
Sequence
<
WmmaK
/
B_K
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
3
,
0
,
1
,
2
,
4
>
,
4
,
B_K1
,
B_K1
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
9adf2e60
...
...
@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma
_v1r1
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -38,8 +38,8 @@ template <typename ADataType,
ck
::
index_t
K1
,
ck
::
index_t
MPerWMMA
,
ck
::
index_t
NPerWMMA
,
ck
::
index_t
M
WmmaPerWave
,
ck
::
index_t
N
WmmaPerWave
,
ck
::
index_t
M
Repeat
,
ck
::
index_t
N
Repeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -196,7 +196,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_wmma
_v1
<
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_wmma
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
...
...
@@ -214,8 +214,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
MPerWMMA
,
NPerWMMA
,
K1
,
M
WmmaPerWave
,
N
WmmaPerWave
,
M
Repeat
,
N
Repeat
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
...
...
@@ -232,16 +232,15 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
// CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
#endif
NumPrefetch
,
LoopSched
,
PipelineVer
>
;
// Argument
struct
Argument
:
public
BaseArgument
W
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
...
...
@@ -263,7 +262,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock
_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow
_
{},
c_grid_desc_mblock
xrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs
_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
...
...
@@ -283,8 +282,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock
_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow
_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock
_MWmmaPerWave
_M
w
ave_M
LaneHigh_NBlock_NWmmaPerWave
_N
w
ave_N
Lane_MLaneLow
(
c_grid_desc_m_n_
);
c_grid_desc_mblock
xrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs
_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock
xRepeat
_M
W
ave_M
SubGroup_NBlockxRepeat
_N
W
ave_N
ThreadPerSubGroup_MAccVgprs
(
c_grid_desc_m_n_
);
}
}
...
...
@@ -295,8 +294,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock
_MWmmaPerWave
_M
w
ave_M
LaneHigh_NBlock_NWmmaPerWave
_N
w
ave_N
Lane_MLaneLow
c_grid_desc_mblock
_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow
_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock
xRepeat
_M
W
ave_M
SubGroup_NBlockxRepeat
_N
W
ave_N
ThreadPerSubGroup_MAccVgprs
c_grid_desc_mblock
xrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs
_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -347,19 +346,21 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_wmma
_v1r1
<
const
auto
kernel
=
kernel_gemm_wmma
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmWmma
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock
_MWmmaPerWave
_M
w
ave_M
LaneHigh_NBlock_NWmmaPerWave
_N
w
ave_N
Lane_MLaneLow
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock
xRepeat
_M
W
ave_M
SubGroup_NBlockxRepeat
_N
W
ave_N
ThreadPerSubGroup_MAccVgprs
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
true
>
;
// Last Option is W/O
std
::
cout
<<
"Host kernel type is "
<<
type_name
<
decltype
(
kernel
)
>
()
<<
std
::
endl
;
printf
(
"---------------------Crush before kernel launch-------------------
\n
"
);
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
...
...
@@ -370,7 +371,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock
_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow
_
,
arg
.
c_grid_desc_mblock
xrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs
_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
...
...
@@ -378,13 +379,13 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
}
else
{
const
auto
kernel
=
kernel_gemm_wmma
_v1r1
<
const
auto
kernel
=
kernel_gemm_wmma
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmWmma
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock
_MWmmaPerWave
_M
w
ave_M
LaneHigh_NBlock_NWmmaPerWave
_N
w
ave_N
Lane_MLaneLow
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock
xRepeat
_M
W
ave_M
SubGroup_NBlockxRepeat
_N
W
ave_N
ThreadPerSubGroup_MAccVgprs
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
@@ -401,7 +402,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock
_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow
_
,
arg
.
c_grid_desc_mblock
xrepeat_mwave_msubgroup_nblockxrepeat_nwave_nthreadpersubgroup_maccvgprs
_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
...
...
@@ -540,8 +541,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
<<
K1
<<
", "
<<
MPerWMMA
<<
", "
<<
NPerWMMA
<<
", "
<<
M
WmmaPerWave
<<
", "
<<
N
WmmaPerWave
<<
M
Repeat
<<
", "
<<
N
Repeat
<<
">"
<<
" NumPrefetch: "
<<
NumPrefetch
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma
_v1r1
.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
9adf2e60
...
...
@@ -22,7 +22,7 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDescriptor_MBlock
_M
Repeat_M
w
ave_MSubGroup_NBlock
_N
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
,
typename
CGridDescriptor_MBlock
x
Repeat_M
W
ave_MSubGroup_NBlock
x
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
...
...
@@ -32,14 +32,14 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_wmma
_v1r1
(
kernel_gemm_wmma
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock
_M
Repeat_M
w
ave_MSubGroup_NBlock
_N
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
c_grid_desc_mblock
_m
repeat_mwave_msubgroup_n
_
block
_n
repeat_nwave_nthreadpersubgroup_maccvgprs
,
const
CGridDescriptor_MBlock
x
Repeat_M
W
ave_MSubGroup_NBlock
x
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
c_grid_desc_mblock
x
repeat_mwave_msubgroup_nblock
x
repeat_nwave_nthreadpersubgroup_maccvgprs
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
...
...
@@ -55,7 +55,7 @@ __global__ void
p_shared
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_mblock
_m
repeat_mwave_msubgroup_n
_
block
_n
repeat_nwave_nthreadpersubgroup_maccvgprs
,
c_grid_desc_mblock
x
repeat_mwave_msubgroup_nblock
x
repeat_nwave_nthreadpersubgroup_maccvgprs
,
a_element_op
,
b_element_op
,
c_element_op
,
...
...
@@ -66,7 +66,7 @@ __global__ void
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_mblock
_m
repeat_mwave_msubgroup_n
_
block
_n
repeat_nwave_nthreadpersubgroup_maccvgprs
;
ignore
=
c_grid_desc_mblock
x
repeat_mwave_msubgroup_nblock
x
repeat_nwave_nthreadpersubgroup_maccvgprs
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
...
...
@@ -92,8 +92,8 @@ template <
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
K1Value
,
index_t
M
WmmaPerWave
,
index_t
N
WmmaPerWave
,
index_t
M
Repeat
,
index_t
N
Repeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
...
@@ -114,8 +114,9 @@ template <
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
NumGemmKPrefetchStage
=
1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_wmma
_v1
struct
GridwiseGemm_k0mk1_k0nk1_mn_wmma
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -132,7 +133,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
...
...
@@ -207,8 +208,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerWmma
*
M
WmmaPerWave
)
==
0
)
&&
(
NPerBlock
%
(
N
WmmaPerWave
*
NPerWmma
))
==
0
,
static_assert
((
MPerBlock
%
(
MPerWmma
*
M
Repeat
)
==
0
)
&&
(
NPerBlock
%
(
N
Repeat
*
NPerWmma
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
...
...
@@ -247,35 +248,57 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
template
<
typename
CGridDesc_M_N_
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock
_M
Repeat_M
w
ave_MSubGroup_NBlock
_N
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CGridDesc_M_N
_
&
c_grid_desc_m_n
)
MakeCGridDescriptor_MBlock
x
Repeat_M
W
ave_MSubGroup_NBlock
x
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
constexpr
index_t
MWave
=
MPerBlock
/
(
MWmmaPerWave
*
MPerWmma
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NWmmaPerWave
*
NPerWmma
);
constexpr
index_t
MLaneHigh
=
2
;
constexpr
index_t
MLaneLow
=
NWmmaPerWave
/
MLaneHigh
;
constexpr
index_t
NLane
=
NWmmaPerWave
;
const
auto
c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MWmmaPerWave
>
{},
Number
<
MWave
>
{},
Number
<
MLaneHigh
>
{},
Number
<
MLaneLow
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NWmmaPerWave
>
{},
Number
<
NWave
>
{},
Number
<
NLane
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
8
>
{},
Sequence
<
4
,
5
,
6
,
7
>
{}));
return
c_grid_desc_mblock_mrepeat_mwave_msubgroup_n_block_nrepeat_nwave_nthreadpersubgroup_maccvgprs
;
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0perblock_mperblock_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
using
BlockwiseGemm
=
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0perblock_mperblock_k1
),
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
MPerWmma
,
NPerWmma
,
MRepeat
,
NRepeat
,
KPack
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_m_n
);
}
// return block_id to C matrix tile idx (m0, n0) mapping
...
...
@@ -285,9 +308,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock
_M
Repeat_M
w
ave_MSubGroup_NBlock
_N
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
=
using
CGridDescriptor_MBlock
x
Repeat_M
W
ave_MSubGroup_NBlock
x
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock
_M
Repeat_M
w
ave_MSubGroup_NBlock
_N
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
(
MakeCGridDescriptor_MBlock
x
Repeat_M
W
ave_MSubGroup_NBlock
x
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
...
...
@@ -300,8 +323,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDescriptor_MBlock
_M
Repeat_M
w
ave_MSubGroup_NBlock
_N
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
&
c_grid_desc_mblock
_m
repeat_mwave_msubgroup_n
_
block
_n
repeat_nwave_nthreadpersubgroup_maccvgprs
,
const
CGridDescriptor_MBlock
x
Repeat_M
W
ave_MSubGroup_NBlock
x
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
&
c_grid_desc_mblock
x
repeat_mwave_msubgroup_nblock
x
repeat_nwave_nthreadpersubgroup_maccvgprs
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
...
...
@@ -315,15 +338,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock
_m
repeat_mwave_msubgroup_n
_
block
_n
repeat_nwave_nthreadpersubgroup_maccvgprs
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock
x
repeat_mwave_msubgroup_nblock
x
repeat_nwave_nthreadpersubgroup_maccvgprs
.
GetElementSpaceSize
());
/*******************************************************************************/
// BlockIdx.x -> [BlockId.m, BlockId.n]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock
_m
repeat_mwave_msubgroup_n
_
block
_n
repeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I0
),
c_grid_desc_mblock
_m
repeat_mwave_msubgroup_n
_
block
_n
repeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I4
))))
make_tuple
(
c_grid_desc_mblock
x
repeat_mwave_msubgroup_nblock
x
repeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I0
),
c_grid_desc_mblock
x
repeat_mwave_msubgroup_nblock
x
repeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I4
))))
{
return
;
}
// Store BlockId into SGPR
...
...
@@ -415,8 +438,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
decltype
(
b_block_desc_k0perblock_nperblock_k1
),
MPerWmma
,
NPerWmma
,
M
WmmaPerWave
,
N
WmmaPerWave
,
M
Repeat
,
N
Repeat
,
KPack
>
{};
// Prepare Register for C matrix
...
...
@@ -450,20 +473,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
blockwise_gemm
,
c_thread_buf
,
K0BlockMainLoop
);
// NO C-shuffle, direct write
/*******************************************************************************/
// write out C matrix, c shuffle not implemented
{
constexpr
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLaneLow
();
constexpr
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
constexpr
auto
MRepeat
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I0
);
constexpr
auto
MWave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I1
);
constexpr
auto
MSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I2
);
constexpr
auto
NRepeat
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I3
);
constexpr
auto
Nwave
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I4
);
constexpr
auto
NThreadPerSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I5
);
constexpr
auto
MAccVgprs
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I6
);
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
blockwise_gemm
.
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
();
constexpr
auto
MWave
=
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I1
);
constexpr
auto
MSubGroup
=
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I2
);
constexpr
auto
Nwave
=
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I4
);
constexpr
auto
NThreadPerSubGroup
=
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I5
);
constexpr
auto
MAccVgprs
=
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
.
GetLength
(
I6
);
// Mapping
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
);
...
...
@@ -476,16 +496,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup
=
const
auto
n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup
_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
NRepeat
,
Nwave
,
NThreadPerSubGroup
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
(
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_grid
));
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup
(
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_nrepeat_nwave_nthreadpersubgroup
_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
...
...
@@ -494,8 +514,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
/* typename SrcData */
FloatAcc
,
/* typename DstData */
FloatC
,
/* typename SrcDesc */
decltype
(
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
),
/* typename DstDesc */
decltype
(
c_grid_desc_mblock
_m
repeat_mwave_msubgroup_n
_
block
_n
repeat_nwave_nthreadpersubgroup_maccvgprs
),
/* typename DstDesc */
decltype
(
c_grid_desc_mblock
x
repeat_mwave_msubgroup_nblock
x
repeat_nwave_nthreadpersubgroup_maccvgprs
),
/* typename ElementwiseOperation */
CElementwiseOperation
,
// Thread register Mapping
/* typename SliceLengths */
Sequence
<
MRepeat
,
I1
,
I1
,
NRepeat
,
I1
,
I1
,
MAccVgprs
>
,
/* typename DimAccessOrder */
CThreadTransferSrcDstAccessOrder
,
/* index_t DstVectorDim */
CThreadTransferSrcDstVectorDim
,
...
...
@@ -504,7 +525,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
/* index_t DstScalarStrideInVector */
1
,
/* bool DstResetCoordinateAfterRun */
true
>
{
/* dst_desc */
c_grid_desc_mblock
_m
repeat_mwave_msubgroup_n
_
block
_n
repeat_nwave_nthreadpersubgroup_maccvgprs
,
/* dst_desc */
c_grid_desc_mblock
x
repeat_mwave_msubgroup_nblock
x
repeat_nwave_nthreadpersubgroup_maccvgprs
,
/* dst_slice_origin_idx */
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
...
...
@@ -517,9 +538,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma_v1
c_thread_copy
.
Run
(
/* c_thread_desc */
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
,
/* c_
start point
*/
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
/* c_
buffer
*/
c_thread_buf
,
/* c_grid_desc */
c_grid_desc_mblock
_m
repeat_mwave_msubgroup_n
_
block
_n
repeat_nwave_nthreadpersubgroup_maccvgprs
,
/* c_
register_beginning
*/
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
/* c_
local(register)
*/
c_thread_buf
,
/* c_grid_desc */
c_grid_desc_mblock
x
repeat_mwave_msubgroup_nblock
x
repeat_nwave_nthreadpersubgroup_maccvgprs
,
/* c_grid_buf */
c_grid_buf
);
}
// clang-format on
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
9adf2e60
...
...
@@ -72,12 +72,14 @@ enum struct WmmaInstr
template
<
WmmaInstr
Instr
,
index_t
WaveSize
,
typename
enable_if
<
WaveSize
==
32
||
WaveSize
==
64
,
bool
>
::
=
false
>
struct
wmma_type
;
typename
=
void
>
struct
wmma_type
{}
;
// A-swizzled
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16
,
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
// * Data Pixel
...
...
@@ -172,11 +174,7 @@ struct WmmaSelector
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_acc_vgprs_per_wave
*
selected_wmma
.
acc_data_size
==
selected_wmma
.
m_per_wmma
*
selected_wmma
.
n_per_wmma
*
4
,
"WRONG! Number of Accumulator Register"
);
static_assert
(
selected_wmma
.
lane_size
*
selected_wmma
.
num_srcregs_per_wmma
*
selected_wmma
.
src_data_size
==
selected_wmma
.
m_per_wmma
*
selected_wmma
.
k_per_wmma
*
4
,
"WRONG! Number of Source Register"
);
"WRONG! Invalid Number of Accumulator Register"
);
}
};
...
...
@@ -206,25 +204,25 @@ struct WmmaGemm
static_assert
(
KPack
==
wmma_instr
.
k_per_wmma
,
"KPack should be k_per_wmma"
);
}
//
XDL
output supporting C = A * B
//
WMMA
output supporting C = A * B
// MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template
<
typename
CDesc_MRepeat_M
w
ave_MPerWMMA_NRepeat_NWave_NPerWMMA
>
template
<
typename
CDesc_M
Blockx
Repeat_M
W
ave_MPerWMMA_N
Blockx
Repeat_NWave_NPerWMMA
>
__host__
__device__
static
constexpr
auto
MakeCDesc_MRepeat_M
w
ave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CDesc_MRepeat_M
w
ave_MPerWMMA_NRepeat_NWave_NPerWMMA
&
c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
)
MakeCDesc_M
Blockx
Repeat_M
W
ave_MSubGroup_N
Blockx
Repeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CDesc_M
Blockx
Repeat_M
W
ave_MPerWMMA_N
Blockx
Repeat_NWave_NPerWMMA
&
c_desc_m
blockx
repeat_mwave_mperwmma_n
blockx
repeat_nwave_nperwmma
)
{
const
auto
MRepeat
=
c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
.
GetLength
(
I0
);
const
auto
NRepeat
=
c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
.
GetLength
(
I3
);
const
auto
MWave
=
c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
.
GetLength
(
I1
);
const
auto
NWave
=
c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
.
GetLength
(
I4
);
const
auto
M
Blockx
Repeat
=
c_desc_m
blockx
repeat_mwave_mperwmma_n
blockx
repeat_nwave_nperwmma
.
GetLength
(
I0
);
const
auto
N
Blockx
Repeat
=
c_desc_m
blockx
repeat_mwave_mperwmma_n
blockx
repeat_nwave_nperwmma
.
GetLength
(
I3
);
const
auto
MWave
=
c_desc_m
blockx
repeat_mwave_mperwmma_n
blockx
repeat_nwave_nperwmma
.
GetLength
(
I1
);
const
auto
NWave
=
c_desc_m
blockx
repeat_mwave_mperwmma_n
blockx
repeat_nwave_nperwmma
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
,
make_tuple
(
make_pass_through_transform
(
MRepeat
),
make_pass_through_transform
(
M
w
ave
),
c_desc_m
blockx
repeat_mwave_mperwmma_n
blockx
repeat_nwave_nperwmma
,
make_tuple
(
make_pass_through_transform
(
M
Blockx
Repeat
),
make_pass_through_transform
(
M
W
ave
),
make_unmerge_transform
(
make_tuple
(
Number
<
wmma_instr
.
num_subgroups
>
{},
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{})),
make_pass_through_transform
(
NRepeat
),
make_pass_through_transform
(
N
Blockx
Repeat
),
make_pass_through_transform
(
NWave
),
make_pass_through_transform
(
Number
<
wmma_instr
.
num_thread_per_subgroups
>
{})),
make_tuple
(
Sequence
<
0
>
{},
...
...
@@ -266,12 +264,12 @@ struct WmmaGemm
if
constexpr
(
!
TransposeC
)
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_a_wave
[
0
]
,
p_b_wave
[
0
]
,
p_c_thread
);
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
else
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_b_wave
[
0
]
,
p_a_wave
[
0
]
,
p_c_thread
);
p_b_wave
,
p_a_wave
,
p_c_thread
);
}
}
...
...
@@ -318,7 +316,7 @@ struct WmmaGemm
__host__
__device__
static
constexpr
auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
()
{
return
make_tuple
(
Number
<
I1
,
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{});
I1
,
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{});
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment