Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a619e3f5
Commit
a619e3f5
authored
Jan 29, 2025
by
Andriy Roshchenko
Browse files
Fix f8f6f4 MFMA instructions
parent
0dda6f18
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
71 additions
and
29 deletions
+71
-29
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+71
-29
No files found.
include/ck/utility/amd_xdlops.hpp
View file @
a619e3f5
...
...
@@ -476,21 +476,31 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
// TODO: fix ...f8f6f4 instructions
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x64f8f6f4
;
/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6,
/// and f4 data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported on
/// the backend. As per Matthew Arsenault: "Use the scaled versions. It's not a workaround, that is
/// the intended use. There is a backend optimization to select to the unscaled if you use 0
/// scales."
template
<
>
struct
intrin_mfma_f32_32x32x64f8f6f4
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x
8
_t
&
reg_a
,
const
f8x
8
_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
f8x
32
_t
&
reg_a
,
const
f8x
32
_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_32x32x64_f8f6f4
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
0
,
0
,
0
);
...
...
@@ -509,20 +519,30 @@ template <>
struct
intrin_mfma_scale_f32_32x32x64f8f6f4
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
f8x8_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
int32_t
scale_a
,
const
f8x32_t
&
reg_b
,
const
int32_t
scale_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
(
bit_cast
<
long
>
(
reg_a
)
,
bit_cast
<
long
>
(
reg_b
)
,
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
0
,
// cbsz
0
,
// blgp
0
,
// { OPSEL_HI[0], OPSEL[0] }?
scale_a
,
0
,
// { OPSEL_HI[1], OPSEL[1] }?
scale_b
);
#else
ignore
=
reg_a
;
ignore
=
scale_a
;
ignore
=
reg_b
;
ignore
=
scale_b
;
ignore
=
reg_c
;
#endif
}
...
...
@@ -535,20 +555,30 @@ template <>
struct
intrin_mfma_scale_f32_16x16x128f8f6f4
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x8_t
&
reg_a
,
const
f8x8_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
f8x32_t
&
reg_a
,
const
int32_t
scale_a
,
const
f8x32_t
&
reg_b
,
const
int32_t
scale_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4
(
bit_cast
<
long
>
(
reg_a
)
,
bit_cast
<
long
>
(
reg_b
)
,
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
);
0
,
// cbsz
0
,
// blgp
0
,
// { OPSEL_HI[0], OPSEL[0] }?
scale_a
,
0
,
// { OPSEL_HI[1], OPSEL[1] }?
scale_b
);
#else
ignore
=
reg_a
;
ignore
=
scale_a
;
ignore
=
reg_b
;
ignore
=
scale_b
;
ignore
=
reg_c
;
#endif
}
...
...
@@ -557,17 +587,29 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_16x16x128f8f6f4
;
/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4
/// data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported on
/// the backend. As per Matthew Arsenault: "Use the scaled versions. It's not a workaround, that is
/// the intended use. There is a backend optimization to select to the unscaled if you use 0
/// scales."
template
<
>
struct
intrin_mfma_f32_16x16x128f8f6f4
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
f8x
8
_t
&
reg_a
,
const
f8x
8
_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
f8x
32
_t
&
reg_a
,
const
f8x
32
_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx950__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_f32_16x16x128_f8f6f4
(
bit_cast
<
long
>
(
reg_a
),
bit_cast
<
long
>
(
reg_b
),
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
0
,
// cbsz
0
,
// blgp
0
,
0
,
0
,
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment