Unverified Commit a199815a authored by Muhammed  Emin Ozturk's avatar Muhammed Emin Ozturk Committed by GitHub
Browse files

Merge branch 'develop' into muozturk_sk_padding

parents a9a3a3e2 2bef5501
......@@ -38,8 +38,7 @@ __global__ void
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
......
......@@ -50,8 +50,7 @@ __global__ void
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
......
......@@ -40,8 +40,7 @@ __global__ void
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
__shared__ uint8_t p_shared[shared_size];
......@@ -80,7 +79,7 @@ __global__ void
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx9__))
}
template <typename ALayout,
......
......@@ -56,8 +56,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch =
......
......@@ -16,7 +16,8 @@ namespace ck {
// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production]
// (https://arxiv.org/abs/2211.10017) and implementation:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
__host__ __device__ inline half4_t pki4_to_half4(int q)
// Convert lower part of packed int4 -> int4 to half
__device__ inline half4_t i4_to_half4(int q)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
......@@ -44,7 +45,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
return res.template AsType<half4_t>()[Number<0>{}];
}
__host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t& scale)
__device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale)
{
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
......@@ -78,34 +79,7 @@ __host__ __device__ inline half4_t pki4_to_half4_scale(int q, const ck::half2_t&
return res.template AsType<half4_t>()[Number<0>{}];
}
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
{
#if 1
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
const int EX = 0x64006400;
const int SUB = 0xE408E408; //-8
int lo = i4s | EX;
return amd_assembly_pk_add_f16(bit_cast<half2_t>(lo), bit_cast<half2_t>(SUB));
#else
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
vector_type<half_t, 2> res;
half_t x_h = (x_u8 & 0x0f) - 8;
half_t x_l = ((x_u8 & 0xf0) >> 4) - 8;
res.template AsType<half_t>()(Number<0>{}) = x_l;
res.template AsType<half_t>()(Number<1>{}) = x_h;
return res.template AsType<half2_t>()[Number<0>{}];
#endif
}
__host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
__device__ inline bhalf4_t i4_to_bhalf4(int q)
{
uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12);
......@@ -134,21 +108,6 @@ __host__ __device__ inline bhalf4_t pki4_to_bhalf4(int q)
return res.template AsType<bhalf4_t>()[Number<0>{}];
}
__host__ __device__ inline bhalf2_t pki4_to_bhalf2(pk_i4_t q)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
float x_h = ((x_u8 & 0x0f) >> 0) - 8.f;
float x_l = ((x_u8 & 0xf0) >> 4) - 8.f;
vector_type<bhalf_t, 2> res;
res.template AsType<bhalf_t>()(Number<0>{}) = type_convert<bhalf_t>(x_l);
res.template AsType<bhalf_t>()(Number<1>{}) = type_convert<bhalf_t>(x_h);
return res.template AsType<bhalf2_t>()[Number<0>{}];
}
namespace tensor_operation {
namespace element_wise {
......@@ -159,11 +118,11 @@ struct PassThroughPack8
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
{
#if 1
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = pki4_to_half4(bit_cast<int>(x) >> 8);
result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4(bit_cast<int>(x));
result.template AsType<half4_t>()(Number<1>{}) = i4_to_half4(bit_cast<int>(x) >> 8);
y = result.template AsType<half8_t>()[Number<0>{}];
#else
......@@ -171,13 +130,13 @@ struct PassThroughPack8
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}];
#endif
......@@ -185,11 +144,11 @@ struct PassThroughPack8
__host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const
{
#if 1
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<bhalf_t, 8> result;
result.template AsType<bhalf4_t>()(Number<0>{}) = pki4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<1>{}) = pki4_to_bhalf4(bit_cast<int>(x) >> 16);
result.template AsType<bhalf4_t>()(Number<0>{}) = i4_to_bhalf4(bit_cast<int>(x));
result.template AsType<bhalf4_t>()(Number<1>{}) = i4_to_bhalf4(bit_cast<int>(x) >> 16);
y = result.template AsType<bhalf8_t>()[Number<0>{}];
#else
......@@ -197,13 +156,13 @@ struct PassThroughPack8
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<bhalf2_t>()(Number<0>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<0>{}]);
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<bhalf2_t>()(Number<1>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<1>{}]);
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<bhalf2_t>()(Number<2>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<2>{}]);
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<bhalf2_t>()(Number<3>{}) =
pki4_to_bhalf2(src.template AsType<pk_i4_t>()[Number<3>{}]);
type_convert<bhalf2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<bhalf8_t>()[Number<0>{}];
#endif
......@@ -219,12 +178,12 @@ struct DequantPack8
__host__ __device__ constexpr void
operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const
{
#if 1
#if CK_USE_PK4_LAYOUT_SHUFFLE
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4_scale(bit_cast<int>(x), z);
result.template AsType<half4_t>()(Number<0>{}) = i4_to_half4_scale(bit_cast<int>(x), z);
result.template AsType<half4_t>()(Number<1>{}) =
pki4_to_half4_scale(bit_cast<int>(x) >> 8, z);
i4_to_half4_scale(bit_cast<int>(x) >> 8, z);
y = result.template AsType<half8_t>()[Number<0>{}];
#else
......@@ -232,13 +191,13 @@ struct DequantPack8
vector_type<pk_i4_t, 4> src{x};
dst.template AsType<half2_t>()(Number<0>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<0>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<0>{}]);
dst.template AsType<half2_t>()(Number<1>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<1>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<1>{}]);
dst.template AsType<half2_t>()(Number<2>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<2>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<2>{}]);
dst.template AsType<half2_t>()(Number<3>{}) =
pki4_to_half2(src.template AsType<pk_i4_t>()[Number<3>{}]);
type_convert<half2_t>(src.template AsType<pk_i4_t>()[Number<3>{}]);
y = dst.template AsType<half8_t>()[Number<0>{}];
#endif
......@@ -260,7 +219,7 @@ struct PassThroughPack2
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
{
#if 1
#if CK_USE_PK4_LAYOUT_SHUFFLE
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
......
......@@ -607,6 +607,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
......
......@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max(
math::lcm(
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size,
B1K1),
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk);
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
// selected_mfma.k_per_blk <= Gemm1KPack
//
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size;
auto blockwise_gemm1 = BlockwiseGemmXdlops_v2<
BlockSize,
......
......@@ -773,6 +773,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
......
......@@ -628,6 +628,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
......
......@@ -37,7 +37,17 @@ enum struct MfmaInstr
mfma_f32_32x32x16f8bf8,
mfma_f32_16x16x32f8bf8,
mfma_f32_32x32x16bf8f8,
mfma_f32_16x16x32bf8f8
mfma_f32_16x16x32bf8f8,
mfma_f32_32x32x16f16,
mfma_f32_16x16x32f16,
mfma_f32_32x32x16bf16,
mfma_f32_16x16x32bf16,
mfma_i32_32x32x32i8,
mfma_i32_16x16x64i8,
mfma_f32_32x32x64f8f6f4,
mfma_f32_16x16x128f8f6f4,
mfma_scale_f32_32x32x64f8f6f4,
mfma_scale_f32_16x16x128f8f6f4
};
template <MfmaInstr instr>
......@@ -198,6 +208,50 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x16f16>
{
......@@ -264,6 +318,28 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
{
......@@ -286,6 +362,28 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x16bf16_1k>
{
......@@ -440,6 +538,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x32i8>
}
};
template <>
struct mfma_type<MfmaInstr::mfma_i32_32x32x32i8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_i32_32x32x32i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_i32_16x16x64i8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_i32_16x16x64i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
{
......@@ -638,16 +780,115 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
}
};
// TODO: fix mfma...f8f6f4 instructions
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
{
// clang-format off
static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
static constexpr index_t num_threads_per_blk = 32; // n_per_blk
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
{
// clang-format off
static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
{
// clang-format off
static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
static constexpr index_t num_threads_per_blk = 32; // n_per_blk
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
{
// clang-format off
static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
typename additional_type = base_type>
typename additional_type = base_type,
bool is_single_rate_mfma = false>
struct MfmaSelector
{
template <typename base_type_,
index_t MPerXdlops_,
index_t NPerXdlops_,
typename additional_type_ = base_type_>
typename additional_type_ = base_type_,
bool is_single_rate_mfma_ = false>
static constexpr auto GetMfma();
template <>
......@@ -711,13 +952,32 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<half_t, 32, 32>()
constexpr auto GetMfma<half_t, 32, 32, half_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16f16;
#else
return MfmaInstr::mfma_f32_32x32x8f16;
#endif
}
template <>
constexpr auto GetMfma<half_t, 32, 32, half_t, true>()
{
return MfmaInstr::mfma_f32_32x32x8f16;
}
template <>
constexpr auto GetMfma<half_t, 16, 16>()
constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32f16;
#else
return MfmaInstr::mfma_f32_16x16x16f16;
#endif
}
template <>
constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
{
return MfmaInstr::mfma_f32_16x16x16f16;
}
......@@ -741,7 +1001,19 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<bhalf_t, 32, 32>()
constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16bf16;
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
#else
return MfmaInstr::mfma_f32_32x32x4bf16;
#endif
}
template <>
constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
......@@ -751,7 +1023,19 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<bhalf_t, 16, 16>()
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32bf16;
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
#else
return MfmaInstr::mfma_f32_16x16x8bf16;
#endif
}
template <>
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
......@@ -760,7 +1044,18 @@ struct MfmaSelector
#endif
}
#if defined(CK_USE_AMD_MFMA_GFX940)
#if defined(__gfx950__)
template <>
constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x32i8;
}
template <>
constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x64i8;
}
#elif defined(__gfx942__)
template <>
constexpr auto GetMfma<int8_t, 32, 32>()
{
......@@ -832,8 +1127,8 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_16x16x32bf8f8;
}
static constexpr auto selected_mfma =
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{};
static constexpr auto selected_mfma = mfma_type<
GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type, is_single_rate_mfma>()>{};
__host__ __device__ constexpr MfmaSelector()
{
......@@ -1135,7 +1430,13 @@ struct XdlopsGemm
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
}
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops, additional_type>{};
// Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
static constexpr auto
mfma = MfmaSelector < base_type,
MPerXdlops, NPerXdlops, additional_type,
((is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value) && KPack <= 4)
? true
: false > {};
static constexpr auto mfma_instr = mfma.selected_mfma;
......
......@@ -581,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
tmp.template AsType<half2_t>()[i]);
});
}
#if defined(__gfx942__)
#if defined(__gfx942__) || defined(__gfx950__)
else if constexpr(is_same<T, bhalf_t>::value)
{
vector_type<bhalf_t, N> tmp{src_thread_data};
......
......@@ -20,39 +20,25 @@
#define CK_USE_OCP_FP8 0
#endif
namespace {
// https://en.cppreference.com/w/cpp/types/conditional
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
} // namespace
namespace ck {
using f8_fnuz_t = _BitInt(8);
using bf8_fnuz_t = unsigned _BitInt(8);
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
defined(__gfx1201__)) && \
defined(__gfx1201__) || defined(__gfx950__)) && \
__HIP_DEVICE_COMPILE__
#define CK_FP8_CVT_FAST_PATH 1
#else
#define CK_FP8_CVT_FAST_PATH 0
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#define CK_OCP_FP8_CVT_FAST_PATH 1
#else
#define CK_OCP_FP8_CVT_FAST_PATH 0
#endif
namespace ck {
using f8_fnuz_t = _BitInt(8);
using bf8_fnuz_t = unsigned _BitInt(8);
typedef unsigned char fp8_storage_t;
/**
......@@ -207,10 +193,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
}
}
typename conditional<
typename std::conditional<
sizeof(T) == 2,
unsigned short int,
typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type retval;
typename std::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type
retval;
if constexpr(we == 5 && is_half && !is_fnuz)
{
......@@ -303,7 +290,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false);
}
}
#endif
} // namespace fp8_impl
......@@ -378,7 +364,7 @@ struct bf8_ocp_t
__host__ explicit operator float() const
#endif
{
#if defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
......@@ -392,7 +378,7 @@ struct bf8_ocp_t
__host__ explicit operator _Float16() const
#endif
{
#if defined(__gfx1200__) || defined(__gfx1201__)
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
......@@ -553,10 +539,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
using T_bitwise = typename conditional<
using T_bitwise = typename std::conditional<
sizeof(T) == 2,
unsigned short int,
typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type;
typename std::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type;
T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
unsigned long long x{x_bitwise};
......
......@@ -5,7 +5,7 @@
namespace ck {
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__
#endif
......@@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64>
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f16;
template <>
struct intrin_mfma_f32_32x32x16f16<32, 32>
{
template <class FloatC>
__device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32f16;
template <>
struct intrin_mfma_f32_16x16x32f16<16, 16>
{
template <class FloatC>
__device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8f16;
......@@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
};
// bfp16
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf16;
template <>
struct intrin_mfma_f32_32x32x16bf16<32, 32>
{
template <class FloatC>
__device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32bf16;
template <>
struct intrin_mfma_f32_16x16x32bf16<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8bf16_1k;
......@@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x32i8;
template <>
struct intrin_mfma_i32_32x32x32i8<32, 32>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<int32x16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8(
reg_a, reg_b, reg_c.template AsType<int32x16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x64i8;
template <>
struct intrin_mfma_i32_16x16x64i8<16, 16>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<int32x4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8(
reg_a, reg_b, reg_c.template AsType<int32x4_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x16i8;
......@@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x64f8f6f4;
/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6,
/// and f4 data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template <>
struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_scale_f32_32x32x64f8f6f4;
template <>
struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0, // { OPSEL_HI[0], OPSEL[0] }?
scale_a,
0, // { OPSEL_HI[1], OPSEL[1] }?
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_scale_f32_16x16x128f8f6f4;
template <>
struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0, // { OPSEL_HI[0], OPSEL[0] }?
scale_a,
0, // { OPSEL_HI[1], OPSEL[1] }?
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x128f8f6f4;
/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4
/// data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template <>
struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8;
......
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type.hpp"
namespace ck {
/**
* @brief Unsigned representation of a conventional biased Float32 exponent.
*
* bias = 127;
*
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
* E8M0_MIN = 0b00000000; => 2^-127
* E8M0_MAX = 0b11111110; => 2^127
* E8M0_NAN = 0b11111111; => NaN
*/
struct e8m0_bexp_t
{
using type = uint8_t;
type data;
constexpr static type bias = 127;
constexpr static type nan_mask = 0xFF;
__host__ __device__ constexpr e8m0_bexp_t() : data{type{}} {}
__host__ __device__ constexpr e8m0_bexp_t(type init) : data{init} {}
__host__ __device__ constexpr e8m0_bexp_t(int init) : data{static_cast<type>(init & nan_mask)}
{
}
__host__ __device__ explicit constexpr e8m0_bexp_t(float scale)
: data{static_cast<type>((bit_cast<uint32_t>(scale) & (nan_mask << 23)) >> 23)}
{
}
__host__ __device__ explicit constexpr operator float() const
{
if(data == nan_mask || data == 0)
{
uint32_t bits = data << 1;
bits |= 1;
bits <<= 22;
return bit_cast<float>(bits);
}
else
{
uint32_t bits = data << 23;
return bit_cast<float>(bits);
}
}
__host__ __device__ constexpr bool operator==(const e8m0_bexp_t& other) const
{
// strict IEEE compliance for NaN
return data == other.data && data != nan_mask;
}
__host__ __device__ constexpr bool is_nan() const { return data == nan_mask; }
};
namespace utils {
template <typename T>
__host__ __device__ inline int get_exponent_value(T x);
template <>
__host__ __device__ inline int get_exponent_value<e8m0_bexp_t>(e8m0_bexp_t x)
{
return x.data;
}
} // namespace utils
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
template <>
__host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
f4_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale == NumericLimits<e8m0_bexp_t>::QuietNaN();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template <>
__host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
f4_t const data [[maybe_unused]])
{
// no inf representation for ocp_e2m1_mxfp4
return false;
}
template <>
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;
return result == 0b0;
}
template <>
__host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<f4_t>(scale, data))
return 0.0f;
f4_t prepared_data = data & 0b00001111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<f4_t>(prepared_data, scale_exp);
}
template <>
__host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
f4_t res = convert_to_type<f4_t>(value);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
template <>
__host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
f4_t res = convert_to_type_sr<f4_t>(value, seed);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
} // namespace ck::utils
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
/**
* @brief Checks if an f6_t value is NaN based on the provided scale.
*
* For f6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param dataBytes The f6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template <>
__host__ __device__ inline bool is_nan<f6_t>(e8m0_bexp_t const scale,
f6_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale.is_nan();
}
/**
* @brief Checks if an bf6_t value is NaN based on the provided scale.
*
* For bf6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param dataBytes The bf6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template <>
__host__ __device__ inline bool is_nan<bf6_t>(e8m0_bexp_t const scale,
bf6_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale.is_nan();
}
/**
* @brief Checks if an f6_t value is infinite.
*
* Because f6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return Always false, as infinity is not represented in f6_t.
*/
template <>
__host__ __device__ inline bool is_inf<f6_t>(e8m0_bexp_t const scale [[maybe_unused]],
f6_t const data [[maybe_unused]])
{
// no inf representation for fp6
return false;
}
/**
* @brief Checks if an bf6_t value is infinite.
*
* Because bf6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return Always false, as infinity is not represented in bf6_t.
*/
template <>
__host__ __device__ inline bool is_inf<bf6_t>(e8m0_bexp_t const scale [[maybe_unused]],
bf6_t const data [[maybe_unused]])
{
// no inf representation for bf6
return false;
}
/**
* @brief Checks whether an f6_t value is zero.
*
* If the specified f6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template <>
__host__ __device__ inline bool is_zero<f6_t>(e8m0_bexp_t const scale, f6_t const data)
{
if(is_nan<f6_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
f6_t result = (data & 0b00111111) & NumericUtils<f6_t>::set_sign_mask;
return result == 0b0;
}
/**
* @brief Checks whether an bf6_t value is zero.
*
* If the specified bf6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template <>
__host__ __device__ inline bool is_zero<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
{
if(is_nan<bf6_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
bf6_t result = (data & 0b00111111) & NumericUtils<bf6_t>::set_sign_mask;
return result == 0b0;
}
/**
* @brief Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the f6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to convert.
* @return The converted float value.
*/
template <>
__host__ __device__ inline float to_float<f6_t>(e8m0_bexp_t const scale, f6_t const data)
{
if(is_nan<f6_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<f6_t>(scale, data))
return 0.0f;
f6_t prepared_data = data & 0b00111111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<f6_t>(prepared_data, scale_exp);
}
/**
* @brief Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the bf6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to convert.
* @return The converted float value.
*/
template <>
__host__ __device__ inline float to_float<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
{
if(is_nan<bf6_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<bf6_t>(scale, data))
return 0.0f;
bf6_t prepared_data = data & 0b00111111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<bf6_t>(prepared_data, scale_exp);
}
/**
* @brief Converts a float to f6_t with saturation.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template <>
__host__ __device__ inline f6_t sat_convert_to_type<f6_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
f6_t res = convert_to_type<f6_t>(value);
if(std::abs(to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f6_t>::DataMinSubnorm())
return sign ? NumericUtils<f6_t>::negative_zero_mask
: NumericUtils<f6_t>::positive_zero_mask;
return res;
}
/**
* @brief Converts a float to bf6_t with saturation.
*
* If the input is NaN or exceeds the representable range for bf6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated bf6_t value.
*/
template <>
__host__ __device__ inline bf6_t sat_convert_to_type<bf6_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
bf6_t res = convert_to_type<bf6_t>(value);
if(std::abs(to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<bf6_t>::DataMinSubnorm())
return sign ? NumericUtils<bf6_t>::negative_zero_mask
: NumericUtils<bf6_t>::positive_zero_mask;
return res;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template <>
__host__ __device__ inline f6_t sat_convert_to_type_sr<f6_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
f6_t res = convert_to_type_sr<f6_t>(value, seed);
if(std::abs(to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f6_t>::DataMinSubnorm())
return sign ? NumericUtils<f6_t>::negative_zero_mask
: NumericUtils<f6_t>::positive_zero_mask;
return res;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template <>
__host__ __device__ inline bf6_t sat_convert_to_type_sr<bf6_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
bf6_t res = convert_to_type_sr<bf6_t>(value, seed);
if(std::abs(to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<bf6_t>::DataMinSubnorm())
return sign ? NumericUtils<bf6_t>::negative_zero_mask
: NumericUtils<bf6_t>::positive_zero_mask;
return res;
}
} // namespace ck::utils
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
#define CK_MX_FP8_CVT_FAST_PATH 1
#else
#define CK_MX_FP8_CVT_FAST_PATH 0
#endif
namespace ck {
namespace fp8_impl {
#if CK_MX_FP8_CVT_FAST_PATH
template <ck_fp8_interpretation_t interpret>
static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v)
{
union
{
unsigned int i32val;
unsigned char i8val[4];
} val;
val.i8val[0] = v;
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
}
else
{
return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
}
}
template <ck_fp8_interpretation_t interpret>
static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v)
{
const auto i16val = bit_cast<uint16_t>(v);
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
}
else
{
return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
}
}
template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
static __device__ fp8_storage_t cast_to_f8_from_f32_scaled(float v,
unsigned int rng = 0,
float scale = 1.0f)
{
fp8_storage_t i8data;
union
{
float fval;
unsigned int i32val;
} val;
union
{
uint32_t ival;
vector_type<int16_t, 2>::type v2i16;
fp8_storage_t v4i8[4];
} ret{};
// unsigned int ival = 0;
val.fval = v;
if constexpr(stochastic_rounding)
{
ret.ival =
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
: __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);
i8data = ret.v4i8[0];
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
// If fval / scale > max fp8, returns Nan
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
val.fval,
val.fval,
scale,
/*dst_lo_hi_sel*/ false);
}
else
{
// If fval / scale > max bf8, returns Inf
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
val.fval,
val.fval,
scale,
/*dst_lo_hi_sel*/ false);
}
i8data = ret.v4i8[0];
}
return i8data;
}
template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
static __device__ fp8x2_storage_t cast_to_f8_from_f32_scaled(float2_t v,
unsigned int rng = 0,
float scale = 1.0f)
{
union
{
uint32_t ival;
vector_type<int16_t, 2>::type v2i16;
StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
} ret{};
if constexpr(stochastic_rounding)
{
fp8x2_storage_t f8x2;
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
f8x2[0] = ret.v2f8x2(Number<0>{})[0];
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
f8x2[1] = ret.v2f8x2(Number<0>{})[0];
}
else
{
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
f8x2[0] = ret.v2f8x2(Number<0>{})[0];
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
f8x2[1] = ret.v2f8x2(Number<0>{})[0];
}
return f8x2;
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
// If fval / scale > max fp8, returns Nan
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
v[0],
v[1],
scale,
/*dst_lo_hi_sel*/ false);
}
else
{
// If fval / scale > max bf8, returns Inf
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
v[0],
v[1],
scale,
/*dst_lo_hi_sel*/ false);
}
return ret.v2f8x2(Number<0>{});
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
#if CK_MX_FP8_CVT_FAST_PATH
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}
/**
* \brief convert 2xfloat to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
float scale)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}
#else
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
{
static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interp == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{
return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
}
else
{
__hip_assert(false && "FP8 type is not supported by current target device");
return 0;
}
}
/**
* \brief convert two float to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
float scale)
{
static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interp == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{
return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
}
else
{
__hip_assert(false && "FP8 type is not supported by current target device");
return 0;
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
} // namespace fp8_impl
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y mxf8_convert_sr(X x, float scale);
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y mxf8_convert_rne(X x, float scale);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t mxf8_convert_rne<f8_ocp_t, float>(float x, float scale)
{
return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_ocp_t mxf8_convert_rne<bf8_ocp_t, float>(float x, float scale)
{
return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x2 to fp8x2 with rounding to nearest even
template <>
inline __host__ __device__ f8x2_ocp_t mxf8_convert_rne<f8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x2 to bf8x2 with rounding to nearest even
template <>
inline __host__ __device__ bf8x2_ocp_t mxf8_convert_rne<bf8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x16 to fp8x16 with rounding to nearest even
template <>
inline __host__ __device__ f8x16_ocp_t mxf8_convert_rne<f8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
f8x16_ocp_t fp8_1x16;
f8x2_ocp_t fp8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.fp8_2x8[i] = mxf8_convert_rne<f8x2_ocp_t>(in.float_2x8[i], scale); });
return out.fp8_1x16;
}
// convert fp32x16 to bf8x16 with rounding to nearest even
template <>
inline __host__ __device__ bf8x16_ocp_t mxf8_convert_rne<bf8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.bf8_2x8[i] = mxf8_convert_rne<bf8x2_ocp_t>(in.float_2x8[i], scale); });
return out.bf8_1x16;
}
// convert fp32x32 to fp8x32 with rounding to nearest even
template <>
inline __host__ __device__ f8x32_ocp_t mxf8_convert_rne<f8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
f8x32_ocp_t fp8_1x32;
f8x16_ocp_t fp8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.fp8_16x2[i] = mxf8_convert_rne<f8x16_ocp_t>(in.float_16x2[i], scale); });
return out.fp8_1x32;
}
// convert fp32x32 to bf8x32 with rounding to nearest even
template <>
inline __host__ __device__ bf8x32_ocp_t mxf8_convert_rne<bf8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.bf8_16x2[i] = mxf8_convert_rne<bf8x16_ocp_t>(in.float_16x2[i], scale); });
return out.bf8_1x32;
}
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t mxf8_convert_sr<f8_ocp_t, float>(float x, float scale)
{
return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t mxf8_convert_sr<bf8_ocp_t, float>(float x, float scale)
{
return bf8_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x2 to fp8x2 with stochastic rounding
template <>
inline __host__ __device__ f8x2_ocp_t mxf8_convert_sr<f8x2_ocp_t, float2_t>(float2_t x, float scale)
{
return f8x2_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x2 to bf8x2 with stochastic rounding
template <>
inline __host__ __device__ bf8x2_ocp_t mxf8_convert_sr<bf8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return bf8x2_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x16 to fp8x16 with stochastic rounding
template <>
inline __host__ __device__ f8x16_ocp_t mxf8_convert_sr<f8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
f8x16_ocp_t fp8_1x16;
f8x2_ocp_t fp8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.fp8_2x8[i] = mxf8_convert_sr<f8x2_ocp_t>(in.float_2x8[i], scale); });
return out.fp8_1x16;
}
// convert fp32x16 to bf8x16 with stochastic rounding
template <>
inline __host__ __device__ bf8x16_ocp_t mxf8_convert_sr<bf8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.bf8_2x8[i] = mxf8_convert_sr<bf8x2_ocp_t>(in.float_2x8[i], scale); });
return out.bf8_1x16;
}
// convert fp32x32 to fp8x32 with stochastic rounding
template <>
inline __host__ __device__ f8x32_ocp_t mxf8_convert_sr<f8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
f8x32_ocp_t fp8_1x32;
f8x16_ocp_t fp8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.fp8_16x2[i] = mxf8_convert_sr<f8x16_ocp_t>(in.float_16x2[i], scale); });
return out.fp8_1x32;
}
// convert fp32x32 to bf8x32 with stochastic rounding
template <>
inline __host__ __device__ bf8x32_ocp_t mxf8_convert_sr<bf8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.bf8_16x2[i] = mxf8_convert_sr<bf8x16_ocp_t>(in.float_16x2[i], scale); });
return out.bf8_1x32;
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck::utils {
union cvt
{
float value_float;
uint32_t value_bitwise;
};
template <typename DTYPE>
inline bool getDataHasInf()
{
return DTYPE::dataInfo.hasInf;
}
template <typename T>
__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline int get_exponent_value(T x)
{
x >>= NumericUtils<T>::mant;
x &= ((1 << NumericUtils<T>::exp) - 1);
return static_cast<int>(x);
}
template <typename T>
__host__ __device__ inline bool is_subnormal(T x)
{
return get_exponent_value<T>(x) == 0;
}
template <typename T>
__host__ __device__ inline double get_mantissa_value(T x)
{
double mantissa = is_subnormal<T>(x) ? 0.0f : 1.0f;
for(uint i = 0; i < NumericUtils<T>::mant; i++)
{
mantissa += std::pow(2, -int32_t((NumericUtils<T>::mant - i))) * (x & 0b1);
x >>= 1;
}
return mantissa;
}
template <typename T>
__host__ __device__ inline bool get_data_has_inf()
{
return NumericUtils<T>::has_inf;
}
template <typename T>
__host__ __device__ float convert_to_float(T data, int scale_exp)
{
float d_sign =
std::pow(-1, static_cast<float>(data >> (NumericUtils<T>::exp + NumericUtils<T>::mant)));
float d_exp;
if(is_subnormal<T>(data))
d_exp = std::pow(2, 1 - static_cast<int>(NumericUtils<T>::bias));
else
d_exp = std::pow(2, get_exponent_value<T>(data) - static_cast<int>(NumericUtils<T>::bias));
float d_mant = get_mantissa_value<T>(data);
float data_value = d_sign * d_exp * d_mant;
float scale_value = std::pow(
2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_bexp_t>::bias))));
return data_value * scale_value;
}
template <typename T>
__host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ T sat_convert_to_type(float value);
template <typename T>
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);
template <typename T>
inline T convert_to_type(float value)
{
using bitwise_type = typename NumericUtils<T>::bitwise_type;
if(std::abs(value) > NumericLimits<T>::Max())
{
float max_value = NumericLimits<T>::Max();
cvt t;
// cppcheck-suppress redundantAssignment
t.value_float = max_value;
uint32_t max_bitwise = t.value_bitwise;
// cppcheck-suppress redundantAssignment
t.value_float = value;
bitwise_type sign =
t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
bitwise_type exp =
((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
(NumericUtils<float>::bias - NumericUtils<T>::bias);
bitwise_type mantissa = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant_prev &= ((1 << NumericUtils<T>::mant) - 1);
mant_prev--;
mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t prev_bit =
((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
t.value_bitwise = prev_bit;
float prev_val = t.value_float;
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(exp << NumericUtils<T>::mant) | mantissa;
}
else
{
if(!get_data_has_inf<T>())
{
return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
}
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(exp << NumericUtils<T>::mant);
}
}
}
const int mfmt = NumericUtils<float>::mant;
uint32_t x;
x = bit_cast<uint32_t>(value);
uint32_t head, mantissa;
int32_t exponent, bias;
uint32_t sign;
head = x & NumericUtils<float>::head_mask;
mantissa = x & NumericUtils<float>::mant_mask;
exponent = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
sign = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
bias = NumericUtils<float>::bias;
if(x == 0)
{
return 0b0;
}
const int mini_bias = NumericUtils<T>::bias;
const int mini_denormal_act_exponent = 1 - mini_bias;
int act_exponent, out_exponent, exponent_diff;
bool is_subnorm = false;
if(exponent == 0)
{
act_exponent = exponent - bias + 1;
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
act_exponent = exponent - bias;
if(act_exponent <= mini_denormal_act_exponent)
{
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
exponent_diff = 0;
}
mantissa += (1UL << mfmt);
}
auto shift_amount = (mfmt - NumericUtils<T>::mant + exponent_diff);
shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
float min_subnorm = NumericLimits<T>::DataMinSubnorm() * (sign ? -1 : 1);
if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
{
// closer to 0
if(std::abs(value) <= std::abs(min_subnorm - value))
return 0;
else
return 1 | (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant));
}
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
uint32_t drop_mask = (1UL << (mfmt - NumericUtils<T>::mant)) - 1;
bool odd = mantissa & (1UL << (mfmt - NumericUtils<T>::mant));
mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
if(out_exponent == 0)
{
if((1UL << mfmt) & mantissa)
{
out_exponent = 1;
}
}
else
{
if((1UL << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
}
}
mantissa >>= (mfmt - NumericUtils<T>::mant);
if(out_exponent == 0 && mantissa == 0)
{
return 0;
}
mantissa &= (1UL << NumericUtils<T>::mant) - 1;
return (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(out_exponent << NumericUtils<T>::mant) | mantissa;
}
template <typename T>
inline T convert_to_type_sr(float value, uint32_t seed)
{
if(std::abs(value) > NumericLimits<T>::Max())
{
float max_value = NumericLimits<T>::Max();
cvt t;
// cppcheck-suppress redundantAssignment
t.value_float = max_value;
uint max_bitwise = t.value_bitwise;
// cppcheck-suppress redundantAssignment
t.value_float = value;
T sign = t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
T exp = ((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
(NumericUtils<float>::bias - NumericUtils<T>::bias);
uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant_prev &= ((1UL << NumericUtils<T>::mant) - 1);
mant_prev--;
mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t prev_bit =
((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
t.value_bitwise = prev_bit;
float prev_val = t.value_float;
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
double d_max_value = static_cast<double>(max_value);
double d_actual_max = static_cast<double>(actual_max);
double d_value = static_cast<double>(value);
double d_is = std::abs(d_max_value - d_actual_max);
double d_seed = static_cast<double>(seed);
double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is); // prob to round down
double thresh = UINT_MAX * d_prob;
if(!get_data_has_inf<T>() || d_seed <= thresh)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return sign == 0 ? NumericUtils<f4_t>::data_max_positive_normal_mask
: NumericUtils<f4_t>::data_max_negative_normal_mask;
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
| (exp << NumericUtils<T>::mant);
}
}
else
{
if(!get_data_has_inf<T>())
return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
| (exp << NumericUtils<T>::mant);
}
}
}
uint32_t f32 = bit_cast<uint32_t>(value);
auto f32_mant = f32 & NumericUtils<float>::mant_mask;
auto head = f32 & NumericUtils<float>::head_mask;
auto f32_exp = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
auto sign_bit = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
auto sign = sign_bit << (NumericUtils<T>::exp + NumericUtils<T>::mant);
f32_exp = static_cast<int32_t>(f32_exp) - NumericUtils<float>::bias;
int32_t exp = f32_exp;
auto mant = f32_mant;
bool subnorm = false;
if(f32 == 0)
return 0b0;
if(exp >= NumericUtils<T>::unbiased_exp_min)
{
mant = f32_mant;
}
// if the exponent bit is 8, then the subnormal is exactly the same as f32
else if(exp < NumericUtils<T>::unbiased_exp_min &&
NumericUtils<T>::exp < NumericUtils<float>::exp)
{
subnorm = true;
auto diff = static_cast<uint32_t>(NumericUtils<T>::unbiased_exp_min - exp);
if(diff >= 32)
{
mant = 0;
f32_mant = 0;
}
else
{
f32_mant |= static_cast<uint32_t>(1) << NumericUtils<float>::mant;
f32_mant >>= diff;
}
exp = 0;
mant = f32_mant;
}
uint32_t sr_shift = NumericUtils<T>::sr_shift;
// For stochastic-rounding we add the aligned random value to the
// mantissa and then truncate (RTZ).
mant += seed >> sr_shift;
// Increment exponent when mantissa overflows due to rounding
if(mant >= static_cast<uint32_t>(1) << NumericUtils<float>::mant)
++exp;
mant >>= (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant &= ((1 << NumericUtils<T>::mant) - 1);
auto biased_exp = static_cast<uint32_t>(exp);
if(!subnorm)
biased_exp = static_cast<uint32_t>(exp + NumericUtils<T>::bias);
biased_exp &= ((1 << NumericUtils<T>::exp) - 1);
auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
return val;
}
} // namespace ck::utils
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment