Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0dda6f18
Commit
0dda6f18
authored
Jan 24, 2025
by
Andriy Roshchenko
Browse files
Verify f8f6f4 MFMA Instructions
parent
0ef27d53
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
44 deletions
+52
-44
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+52
-44
No files found.
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
0dda6f18
...
...
@@ -784,17 +784,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x64f8f6f4
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
16
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
32
;
// n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
2
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 64 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
...
...
@@ -806,17 +808,19 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_16x16x128f8f6f4
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
1
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
4
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
16
;
// == n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
4
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 128 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
...
...
@@ -828,17 +832,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_scale_f32_32x32x64f8f6f4
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
4
;
static
constexpr
index_t
num_regs_per_blk
=
16
;
static
constexpr
index_t
num_threads_per_blk
=
32
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
2
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
32
;
static
constexpr
index_t
n_per_blk
=
32
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
16
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
32
;
// n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
2
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
32
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 64 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
...
...
@@ -850,17 +856,19 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_scale_f32_16x16x128f8f6f4
>
{
static
constexpr
index_t
group_size
=
4
;
static
constexpr
index_t
num_groups_per_blk
=
1
;
static
constexpr
index_t
num_regs_per_blk
=
4
;
static
constexpr
index_t
num_threads_per_blk
=
16
;
static
constexpr
index_t
wave_size
=
64
;
static
constexpr
index_t
num_input_blks
=
4
;
static
constexpr
index_t
num_output_blks
=
1
;
static
constexpr
index_t
m_per_blk
=
16
;
static
constexpr
index_t
n_per_blk
=
16
;
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
// clang-format off
static
constexpr
index_t
group_size
=
4
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_groups_per_blk
=
1
;
// ??? group_size * num_groups_per_blk == num_regs_per_blk
static
constexpr
index_t
num_regs_per_blk
=
4
;
// m_per_blk * n_per_blk / wave_size
static
constexpr
index_t
num_threads_per_blk
=
16
;
// == n_per_blk
static
constexpr
index_t
wave_size
=
64
;
// fixed
static
constexpr
index_t
num_input_blks
=
4
;
// m_per_blk / num_regs_per_blk
static
constexpr
index_t
num_output_blks
=
1
;
// (is_k_reduction == true) ???
static
constexpr
index_t
m_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
n_per_blk
=
16
;
// from the instruction
static
constexpr
index_t
k_per_blk
=
32
;
// (is_k_reduction == true) ? 128 / num_input_blks
static
constexpr
bool
is_k_reduction
=
true
;
// ???
// clang-format on
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment