Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
370c9245
Commit
370c9245
authored
Aug 18, 2021
by
Jing Zhang
Browse files
change mfma_info
parent
c982e753
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
38 deletions
+37
-38
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+3
-7
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+34
-31
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
370c9245
...
...
@@ -35,8 +35,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
auto
CXdlopsLayout
=
xdlops_gemm
.
GetCXdlopsLayout
();
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
...
...
@@ -116,15 +114,13 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
constexpr
index_t
NumBlks
=
CXdlopsLayout
.
GetNumBlks
();
constexpr
index_t
NumXdlops
=
CXdlopsLayout
.
GetNumXdlops
();
static_assert
(
NumBlks
==
1
&&
NumXdlops
==
1
,
"K Reduction Mfma only"
);
}
__host__
__device__
static
constexpr
auto
GetCM0N0M1N1M2M3M4N2ThreadDescriptor
()
{
///\to-do: hide xdl clayout into xdlops-gemm
constexpr
auto
CXdlopsLayout
=
xdlops_gemm
.
GetCXdlopsLayout
();
constexpr
auto
M0
=
Number
<
CXdlopsLayout
.
M1
()
>
{};
constexpr
auto
M2
=
Number
<
CXdlopsLayout
.
M0
()
>
{};
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
370c9245
...
...
@@ -34,10 +34,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x1xf32>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
...
...
@@ -61,10 +61,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2xf32>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
...
...
@@ -88,10 +88,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4xf32>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
...
...
@@ -115,10 +115,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x1xf32>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
...
...
@@ -143,7 +143,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x1xf32>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
...
...
@@ -170,10 +170,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
2
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
...
...
@@ -197,10 +197,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
...
...
@@ -224,10 +224,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
...
...
@@ -251,10 +251,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
wave_size
/
num_threads_per_blk
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
4
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
...
...
@@ -278,7 +278,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
group_size
*
num_groups_per_blk
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
64
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
1
;
...
...
@@ -306,10 +306,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk =
group_size * num_groups_per_blk
;
static constexpr index_t num_regs_per_blk =
16
;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks =
wave_size / num_threads_per_blk
;
static constexpr index_t num_input_blks =
2
;
static constexpr index_t num_output_blks = 2;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
...
...
@@ -338,10 +338,10 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk =
group_size * num_groups_per_blk
;
static constexpr index_t num_regs_per_blk =
16
;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks =
wave_size / num_threads_per_blk
;
static constexpr index_t num_input_blks =
2
;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
...
...
@@ -369,10 +369,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x8bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk =
group_size * num_groups_per_blk
;
static constexpr index_t num_regs_per_blk =
4
;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks =
wave_size / num_threads_per_blk
;
static constexpr index_t num_input_blks =
4
;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
...
...
@@ -400,10 +400,10 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk =
group_size * num_groups_per_blk
;
static constexpr index_t num_regs_per_blk =
4
;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks =
wave_size / num_threads_per_blk
;
static constexpr index_t num_input_blks =
4
;
static constexpr index_t num_output_blks = 4;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
...
...
@@ -431,7 +431,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk =
group_size * num_groups_per_blk
;
static constexpr index_t num_regs_per_blk =
4
;
static constexpr index_t num_threads_per_blk = 64;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 1;
...
...
@@ -659,6 +659,8 @@ struct XdlopsGemm
__host__
__device__
static
void
mfma_info_check
()
{
static_assert
(
mfma_type
.
group_size
*
mfma_type
.
num_groups_per_blk
==
mfma_type
.
num_regs_per_blk
,
"wrong! num_regs_per_blk"
);
static_assert
(
mfma_type
.
num_threads_per_blk
==
mfma_type
.
n_per_blk
,
"n_per_blk != num_threads_per_blk"
);
...
...
@@ -745,8 +747,9 @@ struct XdlopsGemm
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
mfma_type
.
wave_size
;
}
__device__
static
auto
GetBlkIdx
(
const
index_t
laneId
)
__device__
static
auto
GetBlkIdx
()
{
const
auto
laneId
=
GetLaneId
();
const
auto
threadidx_to_blk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
mfma_type
.
num_input_blks
,
mfma_type
.
num_threads_per_blk
))),
...
...
@@ -765,7 +768,7 @@ struct XdlopsGemm
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
(
laneId
);
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
...
...
@@ -783,7 +786,7 @@ struct XdlopsGemm
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
(
laneId
);
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
...
...
@@ -801,7 +804,7 @@ struct XdlopsGemm
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
{
const
auto
laneId
=
GetLaneId
();
const
auto
blk_idx
=
GetBlkIdx
(
laneId
);
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment