Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b3cc22a3
Commit
b3cc22a3
authored
Nov 24, 2022
by
aska-0096
Browse files
tempsave
parent
d16063db
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
425 additions
and
736 deletions
+425
-736
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+54
-87
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+33
-38
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp
.../ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp
+181
-453
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+143
-157
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+14
-1
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
b3cc22a3
...
@@ -30,12 +30,14 @@ template <index_t BlockSize,
...
@@ -30,12 +30,14 @@ template <index_t BlockSize,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
>
index_t
KPack
>
struct
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
// MRepeat_MWave_MLaneHigh_NRepeat_NWave_NLane_MLanelow
struct
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
4
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -85,8 +87,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
...
@@ -85,8 +87,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|Mwave |MLane |KPack
return
make_tuple
(
0
,
waveId_m
,
WMMA_a_idx
[
I1
],
KPerThread
*
WMMA_a_idx
[
I0
]
);
return
make_tuple
(
0
,
0
,
waveId_m
,
WMMA_a_idx
,
0
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__device__
static
auto
CalculateBThreadOriginDataIndex
()
...
@@ -96,20 +98,20 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
...
@@ -96,20 +98,20 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return
make_tuple
(
0
,
waveId_n
,
WMMA_b_idx
[
I1
],
KPerThread
*
WMMA_b_idx
[
I0
]
);
return
make_tuple
(
0
,
0
,
waveId_n
,
WMMA_b_idx
,
0
);
}
}
template
<
index_t
m0
,
index_t
n0
,
index_t
WMMA_i
,
index_t
blk_i
>
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
WMMA_i
>
,
Number
<
blk_i
>
)
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
)
{
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk
(
WMMA_i
,
blk_i
);
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk
();
constexpr
auto
mrepeat_mwave_mperWMMA_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
constexpr
auto
mrepeat_mwave_mperWMMA_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerWMMA
))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerWMMA
))),
...
@@ -129,27 +131,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
...
@@ -129,27 +131,6 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
return
make_tuple
(
c_thread_m
,
c_thread_n
);
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
}
template
<
index_t
m0
,
index_t
n0
,
index_t
WMMA_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
WMMA_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk4D
(
WMMA_i
,
blk_i
);
return
make_tuple
(
Number
<
m0
>
{},
Number
<
n0
>
{},
waveId_m
,
waveId_n
,
blk_idx
[
I0
],
blk_idx
[
I1
],
blk_idx
[
I2
],
blk_idx
[
I3
]);
}
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
()
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
()
{
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
...
@@ -162,59 +143,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
...
@@ -162,59 +143,31 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
"wrong!"
);
}
}
// Thread level, register decriptor.
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
wmma_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
{
constexpr
auto
c_m
0_m1_m2_n_tblk_lens
=
wmma_gemm
.
GetCM0M1M2N
ThreadBlkLengths
();
constexpr
auto
c_m
subgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprs
ThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
MSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
NThreadPerSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
// |MRepeat |MWave |MSubGroup |NRepeat |NWave |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
MSubGroup
,
Number
<
NRepeat
>
{},
I1
,
NThreadPerSubGroup
,
MAccVgprs
));
}
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M
0_N0_M1_N1_M2_M3_M4_N2
()
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M
Repeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
{
constexpr
auto
c_block_desc_m
0_n0_m1_n1_m2_n2
=
constexpr
auto
c_block_desc_m
repeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NPerWMMA
>
{}));
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
return
wmma_gemm
.
MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
...
@@ -234,32 +187,46 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
...
@@ -234,32 +187,46 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
return
wmma_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
return
wmma_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K
0
_M0_M1_M2_K
1
()
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K
Repeat
_M0_M1_M2_K
Pack
()
{
{
return
transform_tensor_descriptor
(
static
constexpr
auto
a_block_desc_temp_km0m1m2
=
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
AK0MK1BlockDesc
{},
make_tuple
(
make_tuple
(
make_pass_through_transform
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_merge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{}))),
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{}));
return
transform_tensor_descriptor
(
a_block_desc_temp_km0m1m2
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
*
A_K1
/
KPack
>
{},
Number
<
KPack
>
{})),
make_pass_through_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
,
4
>
{},
Sequence
<
1
,
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
4
>
{},
Sequence
<
1
,
2
,
3
>
{}));
}
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K
0
_N0_N1_N2_K
1
()
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K
Repeat
_N0_N1_N2_K
Pack
()
{
{
return
transform_tensor_descriptor
(
static
constexpr
auto
b_block_desc_temp_kn0n1n2
=
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
BK0NK1BlockDesc
{},
make_tuple
(
make_tuple
(
make_pass_through_transform
(
make_tuple
(
Number
<
B_K0
>
{},
Number
<
B_K1
>
{})),
make_merge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
Number
<
B_K1
>
{})),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}))),
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{}));
return
transform_tensor_descriptor
(
b_block_desc_temp_kn0n1n2
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
*
B_K1
/
KPack
>
{},
Number
<
KPack
>
{})),
make_pass_through_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
,
4
>
{},
Sequence
<
1
,
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
4
>
{},
Sequence
<
1
,
2
,
3
>
{}));
}
}
static
constexpr
auto
a_block_desc_k
0
_m0_m1_m2_k
1
=
MakeABlockDescriptor_K
0
_M0_M1_M2_K
1
();
static
constexpr
auto
a_block_desc_k
repeat
_m0_m1_m2_k
pack
=
MakeABlockDescriptor_K
Repeat
_M0_M1_M2_K
Pack
();
static
constexpr
auto
b_block_desc_k
0
_n0_n1_n2_k
1
=
MakeBBlockDescriptor_K
0
_N0_N1_N2_K
1
();
static
constexpr
auto
b_block_desc_k
repeat
_n0_n1_n2_k
pack
=
MakeBBlockDescriptor_K
Repeat
_N0_N1_N2_K
Pack
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
...
@@ -298,7 +265,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
...
@@ -298,7 +265,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
b_thread_vec
.
template
AsType
<
wmma_input_type
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
a_thread_copy_
.
Run
(
a_block_desc_k
0
_m0_m1_m2_k
1
,
a_thread_copy_
.
Run
(
a_block_desc_k
repeat
_m0_m1_m2_k
pack
,
make_tuple
(
Number
<
iWmmaK
>
{},
iCut
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
iWmmaK
>
{},
iCut
,
I0
,
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
...
@@ -328,7 +295,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
...
@@ -328,7 +295,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
b_thread_vec
.
template
AsType
<
wmma_input_type
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
a_thread_copy_
.
Run
(
a_block_desc_k
0
_m0_m1_m2_k
1
,
a_thread_copy_
.
Run
(
a_block_desc_k
repeat
_m0_m1_m2_k
pack
,
make_tuple
(
Number
<
iWmmaK
>
{},
WmmaInnerloop
+
RepeatDiff
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
iWmmaK
>
{},
WmmaInnerloop
+
RepeatDiff
,
I0
,
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
...
@@ -355,7 +322,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
...
@@ -355,7 +322,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
b_thread_vec
.
template
AsType
<
wmma_input_type
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
b_thread_copy_
.
Run
(
b_block_desc_k
0
_n0_n1_n2_k
1
,
b_thread_copy_
.
Run
(
b_block_desc_k
repeat
_n0_n1_n2_k
pack
,
make_tuple
(
Number
<
iWmmaK
>
{},
WmmaInnerloop
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
iWmmaK
>
{},
WmmaInnerloop
,
I0
,
I0
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
...
@@ -380,7 +347,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
...
@@ -380,7 +347,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_k
0
_m0_m1_m2_k
1
),
decltype
(
a_block_desc_k
repeat
_m0_m1_m2_k
pack
),
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
WmmaK
>
,
Sequence
<
1
,
1
,
1
,
WmmaK
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
...
@@ -390,7 +357,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
...
@@ -390,7 +357,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_k
0
_n0_n1_n2_k
1
),
decltype
(
b_block_desc_k
repeat
_n0_n1_n2_k
pack
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
WmmaK
>
,
Sequence
<
1
,
1
,
1
,
WmmaK
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
...
@@ -413,11 +380,11 @@ template <index_t BlockSize,
...
@@ -413,11 +380,11 @@ template <index_t BlockSize,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
,
index_t
KPack
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
>
constexpr
auto
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m
2
_Selector
()
constexpr
auto
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1
m2
n0n1n2m
3
_Selector
()
{
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
{
return
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m
2
<
BlockSize
,
return
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1
m2
n0n1n2m
3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
AK0MK1BlockDesc
,
AK0MK1BlockDesc
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
b3cc22a3
...
@@ -72,9 +72,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -72,9 +72,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
M1Number
=
Number
<
M1
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
{
...
@@ -87,10 +86,12 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -87,10 +86,12 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
}
#ifdef ENABLE_COLMAJOR
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}
#endif
}();
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
...
@@ -154,12 +155,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -154,12 +155,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
}
}
}
}
static
auto
MakeCGridDescriptor_M
0
_N
_M1
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
{
assert
(
M
%
M1
==
0
);
const
index_t
M0
=
M
/
M1
;
const
auto
c_grid_desc_m_n
=
[
&
]()
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
{
...
@@ -173,8 +170,6 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -173,8 +170,6 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
{
static_assert
(
false
,
"Padding Gemm Not implemented"
);
/* Not implemented yet.
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
...
@@ -183,26 +178,25 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -183,26 +178,25 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
*/
}
}
else
else
{
{
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1Number
)),
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
}
}
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M
0
_N
_M1
=
decltype
(
MakeCGridDescriptor_M
0
_N
_M1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_m
0nm1
_wmma_v1
r1
<
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_m
n
_wmma_v1
<
BlockSize
,
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
...
@@ -210,7 +204,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -210,7 +204,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M
0
_N
_M1
,
CGridDesc_M_N
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
...
@@ -238,15 +232,16 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -238,15 +232,16 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
BBlockLdsAddExtraN
,
#if 0
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim,
CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector,
CThreadTransferDstScalarPerVector,
#endif
NumPrefetch
,
NumPrefetch
,
LoopSched
,
PipelineVer
>
;
PipelineVer
>
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
W
{
{
Argument
(
const
ADataType
*
p_a_grid
,
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
...
@@ -267,8 +262,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -267,8 +262,8 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_k0_m_k1_
{},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m
0
_n_
m1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m
0_n0_m1_n1_m2_m3_m4_n2
_
{},
c_grid_desc_m
block_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow
_
{},
block_2_ctile_map_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
},
N01_
{
N01
},
...
@@ -278,18 +273,18 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -278,18 +273,18 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
{
{
a_grid_desc_k0_m_k1_
=
DeviceGemmWmma
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
a_grid_desc_k0_m_k1_
=
DeviceGemmWmma
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmWmma
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
b_grid_desc_k0_n_k1_
=
DeviceGemmWmma
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m
0
_n_
m1_
=
DeviceGemmWmma
::
MakeCGridDescriptor_M
0
_N
_M1
(
M
,
N
,
StrideC
);
c_grid_desc_m_n_
=
DeviceGemmWmma
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
block_2_ctile_map_
=
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m
0
_n_
m1_
,
M01
,
N01
);
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m
0
_n_
m1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
block_2_ctile_map_
))
{
{
c_grid_desc_m
0_n0_m1_n1_m2_m3_m4_n2
_
=
c_grid_desc_m
block_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow
_
=
GridwiseGemm
::
MakeCGridDescriptor_M
0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m
0
_n_
m1_
);
GridwiseGemm
::
MakeCGridDescriptor_M
Block_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow
(
c_grid_desc_m_n_
);
}
}
}
}
...
@@ -299,9 +294,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -299,9 +294,9 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M
0
_N
_M1
c_grid_desc_m
0
_n_
m1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDesc
_M0_N0_M1_N1_M2_M3_M4_N2
typename
GridwiseGemm
::
CGridDesc
riptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow
c_grid_desc_m
0_n0_m1_n1_m2_m3_m4_n2
_
;
c_grid_desc_m
block_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow
_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
M01_
;
index_t
N01_
;
index_t
N01_
;
...
@@ -327,15 +322,15 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -327,15 +322,15 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m
0
_n_
m1_
{ " << arg.c_grid_desc_m
0
_n_
m1_
.GetLength(I0)
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m
0
_n_
m1_
.GetLength(I1) << ", "
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
<< arg.c_grid_desc_m
0
_n_
m1_
.GetLength(I2) << "}" << std::endl;
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
}
}
#endif
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m
0
_n_
m1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
arg
.
block_2_ctile_map_
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
...
@@ -343,7 +338,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -343,7 +338,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m
0
_n_
m1_
);
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K
=
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
...
@@ -358,7 +353,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -358,7 +353,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmWmma
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmWmma
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc
_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc
riptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow
>
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
...
@@ -375,7 +370,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -375,7 +370,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m
0_n0_m1_n1_m2_m3_m4_n2
_
,
arg
.
c_grid_desc_m
block_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow
_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
...
@@ -389,7 +384,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -389,7 +384,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmWmma
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmWmma
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc
_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc
riptor_MBlock_MWmmaPerWave_Mwave_MLaneHigh_NBlock_NWmmaPerWave_Nwave_NLane_MLaneLow
>
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
...
@@ -406,7 +401,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -406,7 +401,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m
0_n0_m1_n1_m2_m3_m4_n2
_
,
arg
.
c_grid_desc_m
block_mwmmaperwave_mwave_mlanehigh_nblock_nwmmaperwave_nwave_nlane_mlanelow
_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
...
@@ -447,7 +442,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -447,7 +442,7 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m
0
_n_
m1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp
View file @
b3cc22a3
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
b3cc22a3
...
@@ -11,35 +11,107 @@ namespace ck {
...
@@ -11,35 +11,107 @@ namespace ck {
enum
struct
WmmaInstr
enum
struct
WmmaInstr
{
{
wmma_f32_16x16x16_f16
_w32
=
0
,
wmma_f32_16x16x16_f16
=
0
,
wmma_f32_16x16x16_bf16
_w32
=
0
,
wmma_f32_16x16x16_bf16
=
0
,
wmma_f16_16x16x16_f16
_w32
=
0
,
wmma_f16_16x16x16_f16
=
0
,
wmma_bf16_16x16x16_bf16
_w32
=
0
,
wmma_bf16_16x16x16_bf16
=
0
,
wmma_i32_16x16x16_iu8
_w32
=
0
,
wmma_i32_16x16x16_iu8
=
0
,
wmma_i32_16x16x16_iu4
_w32
=
0
wmma_i32_16x16x16_iu4
=
0
};
};
template
<
WmmaInstr
instr
>
/*
* WMMA Wave Tile Always MxNxK = 16x16x16
* WAVE32
-----------------------------------
|RC0| | | | | | | | | | | | | | | | SubGroup 0
|RC1| | | | | | | | | | | | | | | |
|RC2| | | | | | | | | | | | | | | |
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
|RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC6| | | | | | | | | | | | | | | |
|RC7| | | | | | | | | | | | | | | |
-----------------------------------
| | | | | | | | | | | | | | | | | SubGroup 1
| | | | | | | | | | | | | | | | |
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
| | | | | | | | | | | | | | | | |
-----------------------------------
* WAVE64
-----------------------------------
|RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0
|RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
|RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
|RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1
| 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
| 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2
| 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4|
| 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7|
| | | | | | | | | | | | | | | | |
-----------------------------------
| T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3
| 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6|
| 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3|
| | | | | | | | | | | | | | | | |
-----------------------------------
* RC = Register for storing accumalted result
* T = Thread ID
*/
template
<
WmmaInstr
Instr
,
index_t
WaveSize
,
typename
enable_if
<
WaveSize
==
32
||
WaveSize
==
64
,
bool
>
::
=
false
>
struct
wmma_type
;
struct
wmma_type
;
template
<
>
// A-swizzled
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16_w32
>
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16
,
WaveSize
>
{
{
// Absolute fixing property
// * Data Pixel
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
lane_size
=
16
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
src_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
num_srcregs_per_wmma
=
8
;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static
constexpr
index_t
num_accregs_per_wmma
=
8
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
{
{
intrin_wmma_f32_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
intrin_wmma_f32_16x16x16_f16_w32
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
else
if
constexpr
(
wave_size
==
64
)
{
intrin_wmma_f32_16x16x16_f16_w64
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
template
<
typename
src_type
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
NPerWmma
>
template
<
typename
src_type
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
NPerWmma
>
...
@@ -51,54 +123,54 @@ struct WmmaSelector
...
@@ -51,54 +123,54 @@ struct WmmaSelector
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
float
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
half_t
,
float
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_f32_16x16x16_f16
_w32
;
return
WmmaInstr
::
wmma_f32_16x16x16_f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
float
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
bhalf_t
,
float
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_f32_16x16x16_bf16
_w32
;
return
WmmaInstr
::
wmma_f32_16x16x16_bf16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_f16_16x16x16_f16
_w32
;
return
WmmaInstr
::
wmma_f16_16x16x16_f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
_w32
;
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
float
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
int8_t
,
float
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu8
_w32
;
return
WmmaInstr
::
wmma_i32_16x16x16_iu8
;
}
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
float
,
16
,
16
>
()
static
constexpr
auto
GetWmma
<
int4_t
,
float
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
_w32
;
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
}
}
#endif
#endif
static
constexpr
auto
selected_wmma
=
wmma_type
<
GetWmma
<
src_type
,
dst_type
,
MPerWmma
,
NPerWmma
>
()
>
{};
static
constexpr
auto
selected_wmma
=
wmma_type
<
GetWmma
<
src_type
,
dst_type
,
MPerWmma
,
NPerWmma
>
()
,
get_warp_size
()
>
{};
__host__
__device__
constexpr
WmmaSelector
()
__host__
__device__
constexpr
WmmaSelector
()
{
{
static_assert
(
selected_wmma
.
m_per_wmma
==
selected_wmma
.
n_per_wmma
,
static_assert
(
selected_wmma
.
m_per_wmma
==
16
,
"WRONG! WMMA_M must equal to
WMMA_N
"
);
"WRONG! WMMA_M must equal to
16
"
);
static_assert
(
selected_wmma
.
m_per_wmma
==
selected_wmma
.
k_per_wmma
,
static_assert
(
selected_wmma
.
m_per_wmma
==
16
,
"WRONG! WMMA_M must equal to
WMMA_K
"
);
"WRONG! WMMA_M must equal to
16
"
);
static_assert
(
selected_wmma
.
k_per_wmma
==
16
,
static_assert
(
selected_wmma
.
k_per_wmma
==
16
,
"WRONG! WMMA_M must equal to
WMMA_N
"
);
"WRONG! WMMA_M must equal to
16
"
);
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_acc
reg
s_per_w
mma
*
selected_wmma
.
acc_data_size
==
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_acc
_vgpr
s_per_w
ave
*
selected_wmma
.
acc_data_size
==
selected_wmma
.
m_per_wmma
*
selected_wmma
.
n_per_wmma
*
4
,
selected_wmma
.
m_per_wmma
*
selected_wmma
.
n_per_wmma
*
4
,
"WRONG! Number of Accumulator Register"
);
"WRONG! Number of Accumulator Register"
);
...
@@ -135,26 +207,26 @@ struct WmmaGemm
...
@@ -135,26 +207,26 @@ struct WmmaGemm
}
}
// XDL output supporting C = A * B
// XDL output supporting C = A * B
// M
2_N2 -> M2_M3_M4_N2
// M
PerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
template
<
typename
CDesc_M
0_N0_M1_N1_M2_N2
>
template
<
typename
CDesc_M
Repeat_Mwave_MPerWMMA_NRepeat_NWave_NPerWMMA
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
MakeCDesc_MRepeat_Mwave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CDesc_MRepeat_Mwave_MPerWMMA_NRepeat_NWave_NPerWMMA
&
c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
)
{
{
const
auto
M
0
=
c_desc_m
0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
M
Repeat
=
c_desc_m
repeat_mwave_mperwmma_nrepeat_nwave_nperwmma
.
GetLength
(
I0
);
const
auto
N
0
=
c_desc_m
0_n0_m1_n1_m2_n2
.
GetLength
(
I
1
);
const
auto
N
Repeat
=
c_desc_m
repeat_mwave_mperwmma_nrepeat_nwave_nperwmma
.
GetLength
(
I
3
);
const
auto
M
1
=
c_desc_m
0_n0_m1_n1_m2_n2
.
GetLength
(
I
2
);
const
auto
M
Wave
=
c_desc_m
repeat_mwave_mperwmma_nrepeat_nwave_nperwmma
.
GetLength
(
I
1
);
const
auto
N
1
=
c_desc_m
0_n0_m1_n1_m2_n2
.
GetLength
(
I
3
);
const
auto
N
Wave
=
c_desc_m
repeat_mwave_mperwmma_nrepeat_nwave_nperwmma
.
GetLength
(
I
4
);
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
c_desc_m0_n0_m1_n1_m2_n2
,
c_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_tuple
(
make_pass_through_transform
(
MRepeat
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
Mwave
),
make_pass_through_transform
(
M1
),
make_unmerge_transform
(
make_tuple
(
Number
<
wmma_instr
.
num_subgroups
>
{},
make_pass_through_transform
(
N1
),
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
wmma_instr
.
num_groups_per_blk
>
{},
make_pass_through_transform
(
NRepeat
),
Number
<
wmma_instr
.
num_input_blks
>
{},
make_pass_through_transform
(
NWave
),
Number
<
wmma_instr
.
group_size
>
{})),
make_pass_through_transform
(
Number
<
wmma_instr
.
num_thread_per_subgroups
>
{})),
make_pass_through_transform
(
Number
<
wmma_instr
.
num_threads_per_blk
>
{})),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -163,91 +235,22 @@ struct WmmaGemm
...
@@ -163,91 +235,22 @@ struct WmmaGemm
Sequence
<
5
>
{}),
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
,
6
>
{},
Sequence
<
3
>
{},
Sequence
<
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template
<
typename
CDesc_M0_N0_M1_N1_M2_N2
>
__host__
__device__
static
constexpr
auto
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
const
CDesc_M0_N0_M1_N1_M2_N2
&
c_desc_m0_n0_m1_n1_m2_n2
)
{
const
auto
M0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
const
auto
N0
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
M1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
N1
=
c_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
return
transform_tensor_descriptor
(
c_desc_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_pass_through_transform
(
Number
<
wmma_instr
.
num_threads_per_blk
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
wmma_instr
.
num_groups_per_blk
>
{},
Number
<
wmma_instr
.
num_input_blks
>
{},
Number
<
wmma_instr
.
group_size
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{}));
}
}
template
<
typename
CDesc_G_M0_N0_M1_N1_M2_N2
>
__device__
static
constexpr
index_t
GetRegSizePerWmma
()
__host__
__device__
static
constexpr
auto
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CDesc_G_M0_N0_M1_N1_M2_N2
&
c_desc_g_m0_n0_m1_n1_m2_n2
)
{
{
const
auto
G
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
return
wmma_instr
.
num_acc_vgprs_per_wave
;
const
auto
M0
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
const
auto
N0
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
const
auto
M1
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
const
auto
N1
=
c_desc_g_m0_n0_m1_n1_m2_n2
.
GetLength
(
I4
);
return
transform_tensor_descriptor
(
c_desc_g_m0_n0_m1_n1_m2_n2
,
make_tuple
(
make_pass_through_transform
(
G
),
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
M1
),
make_pass_through_transform
(
N1
),
make_unmerge_transform
(
make_tuple
(
wmma_instr
.
num_groups_per_blk
,
wmma_instr
.
num_input_blks
,
wmma_instr
.
group_size
)),
make_pass_through_transform
(
wmma_instr
.
num_threads_per_blk
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{},
Sequence
<
8
>
{}));
}
}
__device__
static
constexpr
index_t
Get
RegSizePerXdlops
()
__device__
static
constexpr
index_t
Get
WaveSize
()
{
{
return
MPerWmma
*
NPerWmma
/
wmma_instr
.
wave_size
;
return
wmma_instr
.
wave_size
;
}
}
__device__
static
constexpr
index_t
GetWaveSize
()
{
return
wmma_instr
.
wave_size
;
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
__device__
void
Run
(
const
FloatA
&
p_a_wave
,
const
FloatB
&
p_b_wave
,
FloatC
&
p_c_thread
)
const
{
{
...
@@ -272,67 +275,50 @@ struct WmmaGemm
...
@@ -272,67 +275,50 @@ struct WmmaGemm
}
}
}
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
wmma_instr
.
wave_size
;
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
wmma_instr
.
wave_size
;
}
__device__
static
auto
Get
LaneIdHigh
()
__device__
static
auto
Get
SubGroupId
()
{
{
return
GetLaneId
()
/
16
;
return
(
GetLaneId
()
/
wmma_instr
.
num_thread_per_subgroups
)
%
wmma_instr
.
num_subgroups
;
}
}
__device__
static
auto
GetLaneId
Low
()
__device__
static
auto
GetLaneId
UnderSubGroup
()
{
{
return
GetLaneId
()
%
16
;
return
GetLaneId
()
%
wmma_instr
.
num_thread_per_subgroups
;
}
}
__device__
static
auto
GetSwizzledLaneIdLow
()
__device__
static
auto
GetSwizzledLaneIdLow
()
{
{
return
((
GetLaneId
Low
()
&
1
)
<<
3
)
|
(
GetLaneId
Low
()
>>
1
);
return
((
GetLaneId
UnderSubGroup
()
&
1
)
<<
3
)
|
(
GetLaneId
UnderSubGroup
()
>>
1
);
}
}
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
{
return
make_tuple
(
0
,
GetSwizzledLaneIdLow
()
)
;
return
GetSwizzledLaneIdLow
();
}
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
{
return
make_tuple
(
0
,
GetLaneIdLow
()
);
return
GetLaneIdUnderSubGroup
(
);
}
}
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
__device__
static
CIndex
GetBeginOfThreadBlk
()
{
{
const
auto
blk_idx
=
GetBlkIdx
();
index_t
n_offset
=
GetLaneIdUnderSubGroup
();
index_t
m_offset
=
GetSubGroupId
()
*
wmma_instr
.
num_acc_vgprs_per_wave
;
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
index_t
n_offset
=
blk_i
*
wmma_instr
.
n_per_blk
+
blk_td
;
index_t
m_offset
=
xdlops_i
*
wmma_instr
.
m_per_blk
+
blk_id
*
wmma_instr
.
group_size
;
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
return
TransposeC
?
CIndex
{
n_offset
,
m_offset
}
:
CIndex
{
m_offset
,
n_offset
};
}
}
__device__
static
CIndex4D
GetBeginOfThreadBlk4D
(
index_t
/* xdlops_i */
,
index_t
/* blk_i */
)
{
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
static
constexpr
auto
wmma
=
WmmaSelector
<
src_type
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
wmma
=
WmmaSelector
<
src_type
,
dst_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
wmma_instr
=
wmma
.
selected_wmma
;
static
constexpr
auto
wmma_instr
=
wmma
.
selected_wmma
;
static
constexpr
auto
KPerXdlops
=
wmma
.
GetKPerXdlops
();
__host__
__device__
static
constexpr
auto
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
()
static
constexpr
auto
K1PerXdlops
=
wmma
.
GetK1PerXdlops
();
static
constexpr
auto
K0PerXdlops
=
KPerXdlops
/
K1PerXdlops
;
__host__
__device__
static
constexpr
auto
GetCM0M1M2NThreadBlkLengths
()
{
{
return
make_tuple
(
return
make_tuple
(
Number
<
wmma_instr
.
num_groups_per_blk
>
{},
I1
,
Number
<
wmma_instr
.
group_size
>
{},
I1
);
Number
<
I1
,
I1
,
Number
<
wmma_instr
.
num_acc_vgprs_per_wave
>
{}
);
}
}
};
};
...
...
include/ck/utility/amd_wmma.hpp
View file @
b3cc22a3
...
@@ -8,7 +8,6 @@
...
@@ -8,7 +8,6 @@
// TODO: Add arch limitation
// TODO: Add arch limitation
namespace
ck
{
namespace
ck
{
// wave32 only
// src: fp16, dst: fp32
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32
;
struct
intrin_wmma_f32_16x16x16_f16_w32
;
...
@@ -24,6 +23,20 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
...
@@ -24,6 +23,20 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
}
}
};
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w64
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w64
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w64
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}]);
}
};
// src: bf16, dst: fp32
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w32
;
struct
intrin_wmma_f32_16x16x16_bf16_w32
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment