Commit a619e3f5 authored by Andriy Roshchenko's avatar Andriy Roshchenko
Browse files

Fix f8f6f4 MFMA instructions

parent 0dda6f18
......@@ -476,24 +476,34 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
// TODO: fix ...f8f6f4 instructions
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x64f8f6f4;
/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6,
/// and f4 data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported on
/// the backend. As per Matthew Arsenault: "Use the scaled versions. It's not a workaround, that is
/// the intended use. There is a backend optimization to select to the unscaled if you use 0
/// scales."
template <>
struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
__device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x64_f8f6f4(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
......@@ -509,20 +519,30 @@ template <>
struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
__device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
0, // cbsz
0, // blgp
0, // { OPSEL_HI[0], OPSEL[0] }?
scale_a,
0, // { OPSEL_HI[1], OPSEL[1] }?
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
......@@ -535,20 +555,30 @@ template <>
struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
__device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
0, // cbsz
0, // blgp
0, // { OPSEL_HI[0], OPSEL[0] }?
scale_a,
0, // { OPSEL_HI[1], OPSEL[1] }?
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
......@@ -557,20 +587,32 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x128f8f6f4;
/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4
/// data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported on
/// the backend. As per Matthew Arsenault: "Use the scaled versions. It's not a workaround, that is
/// the intended use. There is a backend optimization to select to the unscaled if you use 0
/// scales."
template <>
struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
__device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x128_f8f6f4(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment