Unverified Commit 800cf897 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge from internal (#1857)



* enable batched_gemm_softmax_gemm_perm_wmma for gfx12

* disable instances with blocksize=256 in attention examples

* debuggging

* debug

* fixed lds_enabled

* debugging

* Fix and add limit to skiplds feature

* Enable skipLds feature and fix compilation bugs

* add ck_tile definitions for gfx12

* fix clang format and test/wmma_op

* updage instances cmake for gfx12

* disable the test_wmma_op on gfx12

* fix the builds for gfx950

* add gfx12 and gfx950 to default target list

* clean-up cmake file

* Initial introduction of OFP8 data types.

* Renamed FP8 and BF8 tests into FP8_FNUZ and BF8_FNUZ.

* Implementation of ConvertFP32Nearest in test_fp8_ocp.

* Remove dependence on possibly undeclared alias.

* Implement FP8OCP test for stochastic rounding mode.

* Implement FP8OCP tests for half_t type conversions.

* enable bf16 atomic add on gfx950

* Implement ConvertFP32Nearest test.

* Implement ConvertFP32Stochastic test.

* Implement ConvertFP16Nearest and ConvertFP16Stochastic tests.

* Refactoring. Move FP8 definitions into a separate header file.

* Enable easy switching between architectures.

* Fix compilation error for gfx942 architecture.

* Add fp4 type with constants

* only builf gfx950 branch for gfx950 target by default

* Enable OCP build of example_gemm_xdl_fp8.

* Fix formatting.

* fix the build logic for gfx950

* Improve GEMM example verbosity.

* Add constexpr where applicable.

* fix the logic of enabling XDL and WMMA instances

* Improve GEMM example verbosity.

* Enable build of example_gemm_xdl_fp8_bf8 test.

* Fix tests for gfx1101 architecture.

* Build DPP examples only on gfx103 and gfx11 architectures.

* Optionaly run either CPU or GPU verifications with GEMM examples.

* Extend GeneratorTensor_Sequential to produce values of prescribed data types.

* Add missing constructor.

* Add scale type and mxfp conversions

* Update conversions

* Add conversion tests

* Fix typo

* Improve infrastructure for OFP8 data type support.

* BUGFIX. Should not use FP8 as Compute/Accum data type.

* Add custom target for grouped_convnd_bwd_weight tests.

* Can build `tests` target on gfx950.

* Bugfixes on gfx1101 architecture.

* Fix dependencies.

* Add stochastic rounding tests

* Provide single point of truth for FP8 INF and NAN checks

* Prevent instantiation of operators that are not supported by FP8 data types

* Add FP8 type selection into client_axample CMakeLists.txt

* Prevent sccache server from shutting down during build

* Fix test success reporting logic

* Change default verification method to CPU.

GPU verification takes too much time to complete on the emulator.

* Add scale <-> float conversions

* Add scaled conversions with tests

* Add device conversions

* Make sure all tests and examples are built for gfx950

* Facilitate testing of FP8 data types on the emulator

* Introduce two new tensor generators

* Enable instances built for gfx94 to be built on gfx950

* Verify 35_splitk_gemm on floating point numbers.

splitk gemm appears to be losing precision VS reference implementation when FP numbers are involved.

* Format

* Verify 04_gemm_add_add_fastgelu on floating point numbers

* Verify 20_grouped_conv_bwd_weight on floating point numbers

* Verify 38_grouped_conv_bwd_data_multiple_d on floating point numbers

* Verify more tests on floating point data

* Fix data types and improve testing verbocity.

* Add fp4 vectors

* Add debug tests

* Upgrade to NPI 573 build docker.

* Skip on gemm_universal tests.

The tests take too long to complete on the emulator.
Need to see if it is possible to reduce the scope of the testing to just FP8 data types.

* Add new mfma instructions and examples

* Add preprocessor directives for gfx950 specific code

* Fix gfx1101 build

* Document test availability

* Re-enable fp8 gemms for gfx94/95

* Cherry-pick GEMM Universal tests for FP8 data types

* Cleanup

* Add vector types and tests

* Add check_err function

* Add tensor generators

* CK_USE_GFX94 has already been set on this branch

* Fix

* Address formatting issues and leftovers

* Make fail/pass logic consistent within 01_gemm folder

Removed multiple negations in fail/pass logic to propagate `true` as the success indicator.

* Fix GPU verification reporting logic.

* Update year in copyright notice.

* Cleanup

* Use `enum class` instead of `enum`

* Remove set_property for FP8 tests

* Add vector conversions

* Fix

* Fix linker errror

* Clean up

* Fix gfx950 conversions

* Clean up

* Fix more gfx950 conversions

* Fix even more gfx950 conversions

* Narrowing the scope of PR to OCP FP8 enablement only

* Add tests for OCP FP8 vector_type storage

* Fix client examples build

* Fix typo

* Update e8m0 casting

* Rename E8M0 type

* Update unpack method

* Cleanup merge artifacts

* Enable gemm kernel on all gfx9 architectures (#227)

* clean-up

* Implement `non_native_vector_base` with `ext_vector_type` array. (#232)

* Enable support of 1, 2, 4, and 8-byte custom types in CK.

* Fix pool tests for OCP FP8 data type

* Fix build

* Add ckProfiler gemm instances for new mfma instructions and fix ckProfiler build on MI350

* fix clang format

* Add new mfma instructions and examples

* Add preprocessor directives for gfx950 specific code

* Add ckProfiler gemm instances for new mfma instructions and fix ckProfiler build on MI350

* fix clang format

* Fix clang format for the newly merged files

* Use the existing example instances for fp16 bf16 and int8

* Remove comment on new mfma instructions in MfmaInstr

* Update include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
Co-authored-by: default avatarAndriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>

* merge from public repo

* Fix ck build

* Fix ck build

* Use double for max_abs_in_val

* Move scaled_type_convert functions to a separate header (#251)

* re-enable building mha lib and gemm_universal_f8 instances for gfx950

* Update library/src/tensor_operation_instance/gpu/CMakeLists.txt
Co-authored-by: default avatarAndriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>

* fix typo for CK_USE_OCP_FP8

* fix typo for CK_USE_OCP_FP8

* Add FP6 and BF6 types (#261)

* Add a rounding flag

* Add FP6 and BF6

* Add tests
Co-authored-by: default avatarAndriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>

* Clean up

---------
Co-authored-by: default avatarAndriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>

* fix one more typo

* Refactor E8M0 scale implementation (#262)

* Refactor E8M0 scale implementation

* Add MXFP6 and MXBF6 conversion methods (#270)

* Add conversions

* Add tests

* Add docstrings

* Add scaled conversions

* Add fp6/bf6 tests

* Remove misleading fp4 test case

* Add docstrings

* Clean up

* Address comments

* Set stricter tolerances for RNE tests

* Add missing tests

* Add native conversions to float

* Revert "Add native conversions to float"

This reverts commit 09467111f73b753c8cc3d597533b187940353dab.

* Update copyright years

* replace the fp6 with bf6 convert calls in test_bf6

* fix test_bf6

* enable smfmac test

* [MX FP8] Add Scaled Type Convert Functions for OCP FP8/BF8 data types (#271)

* Move scaled_type_convert functions to a separate header

* Introduce MX data tests

* Build MX tests only on relevant architectures

* Refactor E8M0 scale implementation

* Fix `config.h` typo

* Cleanup deprecated symbols

* Refactor `amd_ck_fp8.hpp`

* `scaled_type_convert` for `f8_ocp_t`

* Implement test for MX FP8 scaled type convert

* Implement test for MX BF8 scaled type convert

* Scaled type convert for vectors of 2 FP8 elements

* Scaled type convert for vectors of 16 FP8 elements

* Implementation of scaled conversion from F32 to F8

* Add tests for scaled conversions from FP32 to FP8

* Add documentation to the test functions

* Implementation of scaled conversion from F32x2 to F8x2

* Implementation of scaled conversion from F32x16 to F8x16

* Implementation of scaled conversion from F32x32 to F8x32

* Implementation of scaled conversion from F8x32 to F32x32

* Verified on the emulator

* MX FP GEMM - Example Template (#277)

Temporarily uses `DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3` kernel and 128x128 scaling matrices.
Must be modified to use MX-native GEMM kernell with 16 or 32 component vectors per scale.

Verified on the emulator.

* Add vector support

* Add tests

* Add missing type aliases

* Fix test naming

* only build mx example for gfx950

* disable CK_USE_AMD_MFMA_GFX950 by default

* fic build for multiple archs

* fix typo

* fix typo

* Update unpack signature

* Fix merge

* Add size checks in pack function

* Add a flag

* Add conversions

* Fix build logic

* Update pack/unpack methods

* Remove unneeded AsType accessors

* Add docstrings

* Add a flag to config file

* Test the functionality of V_MFMA_F32_16X16X128_F8F6F4 and  V_MFMA_F32_32X32X64_F8F6F4 instructions. (#293)

* Introduced MFMA tests

* Verified f8f6f4 MFMA Instructions

* Move flag logic to scaled_type_convert header

* Use pointers instead of array indices

* Fix a typo

* Update tests and pack functions

* Fix gemm gemm on gfx950

* Fix clang format

* restore the default gput target lists

* fix the jenkinsfile

* add missing ifdef

---------
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
Co-authored-by: default avataraska-0096 <haocwang@amd.com>
Co-authored-by: default avatarJun Liu <Liu.Jun@amd.com>
Co-authored-by: default avatarAndriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: default avatarRostyslav Geyyer <rosty.geyyer@amd.com>
Co-authored-by: default avatarRostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
Co-authored-by: default avatarroot <root@banff-cyxtera-s83-2.ctr.dcgpu>
Co-authored-by: default avatarAndriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
Co-authored-by: default avatarjefyang1 <146495389+jefyang1@users.noreply.github.com>
Co-authored-by: default avatarjefyang1 <Jeffreyj.Yang@amd.com>
parent 85d6fcd3
...@@ -607,6 +607,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -607,6 +607,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack = constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
......
...@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle ...@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<A0B0B1DataType*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr index_t Gemm1KPack = math::max( // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
math::lcm( // selected_mfma.k_per_blk <= Gemm1KPack
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size, //
B1K1), // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk); // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.group_size;
auto blockwise_gemm1 = BlockwiseGemmXdlops_v2< auto blockwise_gemm1 = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
......
...@@ -773,6 +773,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle ...@@ -773,6 +773,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack = constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
......
...@@ -628,6 +628,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -628,6 +628,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size // therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack = constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size; MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
......
...@@ -37,7 +37,17 @@ enum struct MfmaInstr ...@@ -37,7 +37,17 @@ enum struct MfmaInstr
mfma_f32_32x32x16f8bf8, mfma_f32_32x32x16f8bf8,
mfma_f32_16x16x32f8bf8, mfma_f32_16x16x32f8bf8,
mfma_f32_32x32x16bf8f8, mfma_f32_32x32x16bf8f8,
mfma_f32_16x16x32bf8f8 mfma_f32_16x16x32bf8f8,
mfma_f32_32x32x16f16,
mfma_f32_16x16x32f16,
mfma_f32_32x32x16bf16,
mfma_f32_16x16x32bf16,
mfma_i32_32x32x32i8,
mfma_i32_16x16x64i8,
mfma_f32_32x32x64f8f6f4,
mfma_f32_16x16x128f8f6f4,
mfma_scale_f32_32x32x64f8f6f4,
mfma_scale_f32_16x16x128f8f6f4
}; };
template <MfmaInstr instr> template <MfmaInstr instr>
...@@ -198,6 +208,50 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16> ...@@ -198,6 +208,50 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8f16>
} }
}; };
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32f16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32f16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x16f16> struct mfma_type<MfmaInstr::mfma_f32_16x16x16f16>
{ {
...@@ -264,6 +318,28 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16> ...@@ -264,6 +318,28 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
} }
}; };
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k> struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
{ {
...@@ -286,6 +362,28 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k> ...@@ -286,6 +362,28 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
} }
}; };
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf16>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <> template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x16bf16_1k> struct mfma_type<MfmaInstr::mfma_f32_16x16x16bf16_1k>
{ {
...@@ -440,6 +538,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x32i8> ...@@ -440,6 +538,50 @@ struct mfma_type<MfmaInstr::mfma_i32_16x16x32i8>
} }
}; };
template <>
struct mfma_type<MfmaInstr::mfma_i32_32x32x32i8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_i32_32x32x32i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_i32_16x16x64i8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_i32_16x16x64i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <> template <>
struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64> struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
{ {
...@@ -638,16 +780,115 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8> ...@@ -638,16 +780,115 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32bf8f8>
} }
}; };
// TODO: fix mfma...f8f6f4 instructions
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
{
// clang-format off
static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
static constexpr index_t num_threads_per_blk = 32; // n_per_blk
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
{
// clang-format off
static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
{
// clang-format off
static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
static constexpr index_t num_threads_per_blk = 32; // n_per_blk
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
{
// clang-format off
static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
static constexpr index_t wave_size = 64; // fixed
static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <typename base_type, template <typename base_type,
index_t MPerXdlops, index_t MPerXdlops,
index_t NPerXdlops, index_t NPerXdlops,
typename additional_type = base_type> typename additional_type = base_type,
bool is_single_rate_mfma = false>
struct MfmaSelector struct MfmaSelector
{ {
template <typename base_type_, template <typename base_type_,
index_t MPerXdlops_, index_t MPerXdlops_,
index_t NPerXdlops_, index_t NPerXdlops_,
typename additional_type_ = base_type_> typename additional_type_ = base_type_,
bool is_single_rate_mfma_ = false>
static constexpr auto GetMfma(); static constexpr auto GetMfma();
template <> template <>
...@@ -711,13 +952,32 @@ struct MfmaSelector ...@@ -711,13 +952,32 @@ struct MfmaSelector
} }
template <> template <>
constexpr auto GetMfma<half_t, 32, 32>() constexpr auto GetMfma<half_t, 32, 32, half_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16f16;
#else
return MfmaInstr::mfma_f32_32x32x8f16;
#endif
}
template <>
constexpr auto GetMfma<half_t, 32, 32, half_t, true>()
{ {
return MfmaInstr::mfma_f32_32x32x8f16; return MfmaInstr::mfma_f32_32x32x8f16;
} }
template <> template <>
constexpr auto GetMfma<half_t, 16, 16>() constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32f16;
#else
return MfmaInstr::mfma_f32_16x16x16f16;
#endif
}
template <>
constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
{ {
return MfmaInstr::mfma_f32_16x16x16f16; return MfmaInstr::mfma_f32_16x16x16f16;
} }
...@@ -741,7 +1001,19 @@ struct MfmaSelector ...@@ -741,7 +1001,19 @@ struct MfmaSelector
} }
template <> template <>
constexpr auto GetMfma<bhalf_t, 32, 32>() constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16bf16;
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
#else
return MfmaInstr::mfma_f32_32x32x4bf16;
#endif
}
template <>
constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
{ {
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k; return MfmaInstr::mfma_f32_32x32x8bf16_1k;
...@@ -751,7 +1023,19 @@ struct MfmaSelector ...@@ -751,7 +1023,19 @@ struct MfmaSelector
} }
template <> template <>
constexpr auto GetMfma<bhalf_t, 16, 16>() constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32bf16;
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
#else
return MfmaInstr::mfma_f32_16x16x8bf16;
#endif
}
template <>
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
{ {
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k; return MfmaInstr::mfma_f32_16x16x16bf16_1k;
...@@ -760,7 +1044,18 @@ struct MfmaSelector ...@@ -760,7 +1044,18 @@ struct MfmaSelector
#endif #endif
} }
#if defined(CK_USE_AMD_MFMA_GFX940) #if defined(__gfx950__)
template <>
constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x32i8;
}
template <>
constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x64i8;
}
#elif defined(__gfx942__)
template <> template <>
constexpr auto GetMfma<int8_t, 32, 32>() constexpr auto GetMfma<int8_t, 32, 32>()
{ {
...@@ -832,8 +1127,8 @@ struct MfmaSelector ...@@ -832,8 +1127,8 @@ struct MfmaSelector
return MfmaInstr::mfma_f32_16x16x32bf8f8; return MfmaInstr::mfma_f32_16x16x32bf8f8;
} }
static constexpr auto selected_mfma = static constexpr auto selected_mfma = mfma_type<
mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type>()>{}; GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type, is_single_rate_mfma>()>{};
__host__ __device__ constexpr MfmaSelector() __host__ __device__ constexpr MfmaSelector()
{ {
...@@ -1135,7 +1430,13 @@ struct XdlopsGemm ...@@ -1135,7 +1430,13 @@ struct XdlopsGemm
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
} }
static constexpr auto mfma = MfmaSelector<base_type, MPerXdlops, NPerXdlops, additional_type>{}; // Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
static constexpr auto
mfma = MfmaSelector < base_type,
MPerXdlops, NPerXdlops, additional_type,
((is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value) && KPack <= 4)
? true
: false > {};
static constexpr auto mfma_instr = mfma.selected_mfma; static constexpr auto mfma_instr = mfma.selected_mfma;
......
...@@ -581,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ ...@@ -581,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
tmp.template AsType<half2_t>()[i]); tmp.template AsType<half2_t>()[i]);
}); });
} }
#if defined(__gfx942__) #if defined(__gfx942__) || defined(__gfx950__)
else if constexpr(is_same<T, bhalf_t>::value) else if constexpr(is_same<T, bhalf_t>::value)
{ {
vector_type<bhalf_t, N> tmp{src_thread_data}; vector_type<bhalf_t, N> tmp{src_thread_data};
......
...@@ -20,39 +20,25 @@ ...@@ -20,39 +20,25 @@
#define CK_USE_OCP_FP8 0 #define CK_USE_OCP_FP8 0
#endif #endif
namespace {
// https://en.cppreference.com/w/cpp/types/conditional
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
} // namespace
namespace ck {
using f8_fnuz_t = _BitInt(8);
using bf8_fnuz_t = unsigned _BitInt(8);
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \ #if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
defined(__gfx1201__)) && \ defined(__gfx1201__) || defined(__gfx950__)) && \
__HIP_DEVICE_COMPILE__ __HIP_DEVICE_COMPILE__
#define CK_FP8_CVT_FAST_PATH 1 #define CK_FP8_CVT_FAST_PATH 1
#else #else
#define CK_FP8_CVT_FAST_PATH 0 #define CK_FP8_CVT_FAST_PATH 0
#endif #endif
#if(defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__ #if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#define CK_OCP_FP8_CVT_FAST_PATH 1 #define CK_OCP_FP8_CVT_FAST_PATH 1
#else #else
#define CK_OCP_FP8_CVT_FAST_PATH 0 #define CK_OCP_FP8_CVT_FAST_PATH 0
#endif #endif
namespace ck {
using f8_fnuz_t = _BitInt(8);
using bf8_fnuz_t = unsigned _BitInt(8);
typedef unsigned char fp8_storage_t; typedef unsigned char fp8_storage_t;
/** /**
...@@ -207,10 +193,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) ...@@ -207,10 +193,11 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
} }
} }
typename conditional< typename std::conditional<
sizeof(T) == 2, sizeof(T) == 2,
unsigned short int, unsigned short int,
typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type retval; typename std::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type
retval;
if constexpr(we == 5 && is_half && !is_fnuz) if constexpr(we == 5 && is_half && !is_fnuz)
{ {
...@@ -303,7 +290,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v) ...@@ -303,7 +290,6 @@ static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false); return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false);
} }
} }
#endif #endif
} // namespace fp8_impl } // namespace fp8_impl
...@@ -378,7 +364,7 @@ struct bf8_ocp_t ...@@ -378,7 +364,7 @@ struct bf8_ocp_t
__host__ explicit operator float() const __host__ explicit operator float() const
#endif #endif
{ {
#if defined(__gfx1200__) || defined(__gfx1201__) #if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data); return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else #else
return fp8_impl::cast_from_f8<float, wm, we, false>( return fp8_impl::cast_from_f8<float, wm, we, false>(
...@@ -392,7 +378,7 @@ struct bf8_ocp_t ...@@ -392,7 +378,7 @@ struct bf8_ocp_t
__host__ explicit operator _Float16() const __host__ explicit operator _Float16() const
#endif #endif
{ {
#if defined(__gfx1200__) || defined(__gfx1201__) #if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data)); return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else #else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>( return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
...@@ -553,10 +539,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -553,10 +539,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
using T_bitwise = typename conditional< using T_bitwise = typename std::conditional<
sizeof(T) == 2, sizeof(T) == 2,
unsigned short int, unsigned short int,
typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type; typename std::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type;
T_bitwise x_bitwise = bit_cast<T_bitwise>(_x); T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
unsigned long long x{x_bitwise}; unsigned long long x{x_bitwise};
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace ck { namespace ck {
// Define the common macro for MI300 models // Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__ #define __gfx94__
#endif #endif
...@@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64> ...@@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64>
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f16;
template <>
struct intrin_mfma_f32_32x32x16f16<32, 32>
{
template <class FloatC>
__device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32f16;
template <>
struct intrin_mfma_f32_16x16x32f16<16, 16>
{
template <class FloatC>
__device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8f16; struct intrin_mfma_f32_32x32x8f16;
...@@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64> ...@@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
}; };
// bfp16 // bfp16
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf16;
template <>
struct intrin_mfma_f32_32x32x16bf16<32, 32>
{
template <class FloatC>
__device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32bf16;
template <>
struct intrin_mfma_f32_16x16x32bf16<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8bf16_1k; struct intrin_mfma_f32_32x32x8bf16_1k;
...@@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> ...@@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x32i8;
template <>
struct intrin_mfma_i32_32x32x32i8<32, 32>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<int32x16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8(
reg_a, reg_b, reg_c.template AsType<int32x16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x64i8;
template <>
struct intrin_mfma_i32_16x16x64i8<16, 16>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<int32x4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8(
reg_a, reg_b, reg_c.template AsType<int32x4_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x16i8; struct intrin_mfma_i32_32x32x16i8;
...@@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
} }
}; };
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x64f8f6f4;
/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6,
/// and f4 data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template <>
struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_scale_f32_32x32x64f8f6f4;
template <>
struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0, // { OPSEL_HI[0], OPSEL[0] }?
scale_a,
0, // { OPSEL_HI[1], OPSEL[1] }?
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_scale_f32_16x16x128f8f6f4;
template <>
struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0, // { OPSEL_HI[0], OPSEL[0] }?
scale_a,
0, // { OPSEL_HI[1], OPSEL[1] }?
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x128f8f6f4;
/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4
/// data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template <>
struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8; struct intrin_mfma_f32_32x32x16f8f8;
......
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type.hpp"
namespace ck {
/**
* @brief Unsigned representation of a conventional biased Float32 exponent.
*
* bias = 127;
*
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
* E8M0_MIN = 0b00000000; => 2^-127
* E8M0_MAX = 0b11111110; => 2^127
* E8M0_NAN = 0b11111111; => NaN
*/
struct e8m0_bexp_t
{
using type = uint8_t;
type data;
constexpr static type bias = 127;
constexpr static type nan_mask = 0xFF;
__host__ __device__ constexpr e8m0_bexp_t() : data{type{}} {}
__host__ __device__ constexpr e8m0_bexp_t(type init) : data{init} {}
__host__ __device__ constexpr e8m0_bexp_t(int init) : data{static_cast<type>(init & nan_mask)}
{
}
__host__ __device__ explicit constexpr e8m0_bexp_t(float scale)
: data{static_cast<type>((bit_cast<uint32_t>(scale) & (nan_mask << 23)) >> 23)}
{
}
__host__ __device__ explicit constexpr operator float() const
{
if(data == nan_mask || data == 0)
{
uint32_t bits = data << 1;
bits |= 1;
bits <<= 22;
return bit_cast<float>(bits);
}
else
{
uint32_t bits = data << 23;
return bit_cast<float>(bits);
}
}
__host__ __device__ constexpr bool operator==(const e8m0_bexp_t& other) const
{
// strict IEEE compliance for NaN
return data == other.data && data != nan_mask;
}
__host__ __device__ constexpr bool is_nan() const { return data == nan_mask; }
};
namespace utils {
template <typename T>
__host__ __device__ inline int get_exponent_value(T x);
template <>
__host__ __device__ inline int get_exponent_value<e8m0_bexp_t>(e8m0_bexp_t x)
{
return x.data;
}
} // namespace utils
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
template <>
__host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
f4_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale == NumericLimits<e8m0_bexp_t>::QuietNaN();
}
// no infinity representation in ocp_e2m1_mxfp4 will always return false
template <>
__host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
f4_t const data [[maybe_unused]])
{
// no inf representation for ocp_e2m1_mxfp4
return false;
}
template <>
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;
return result == 0b0;
}
template <>
__host__ __device__ inline float to_float<f4_t>(e8m0_bexp_t const scale, f4_t const data)
{
if(is_nan<f4_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<f4_t>(scale, data))
return 0.0f;
f4_t prepared_data = data & 0b00001111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<f4_t>(prepared_data, scale_exp);
}
template <>
__host__ __device__ inline f4_t sat_convert_to_type<f4_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
f4_t res = convert_to_type<f4_t>(value);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
template <>
__host__ __device__ inline f4_t sat_convert_to_type_sr<f4_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f4_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f4_t>::data_max_negative_normal_mask
: NumericUtils<f4_t>::data_max_positive_normal_mask;
f4_t res = convert_to_type_sr<f4_t>(value, seed);
if(std::abs(to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f4_t>::DataMinSubnorm())
return value < 0 ? NumericUtils<f4_t>::negative_zero_mask
: NumericUtils<f4_t>::positive_zero_mask;
return res;
}
} // namespace ck::utils
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
namespace ck::utils {
/**
* @brief Checks if an f6_t value is NaN based on the provided scale.
*
* For f6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param dataBytes The f6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template <>
__host__ __device__ inline bool is_nan<f6_t>(e8m0_bexp_t const scale,
f6_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale.is_nan();
}
/**
* @brief Checks if an bf6_t value is NaN based on the provided scale.
*
* For bf6_t data, NaN cannot be represented directly. Instead, this function
* determines NaN by checking if the scale is set to a quiet NaN.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param dataBytes The bf6_t value to check (unused in this implementation).
* @return true if the scale indicates a NaN value, false otherwise.
*/
template <>
__host__ __device__ inline bool is_nan<bf6_t>(e8m0_bexp_t const scale,
bf6_t const dataBytes [[maybe_unused]])
{
// no need to check for data as it does not have NaN representation
return scale.is_nan();
}
/**
* @brief Checks if an f6_t value is infinite.
*
* Because f6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return Always false, as infinity is not represented in f6_t.
*/
template <>
__host__ __device__ inline bool is_inf<f6_t>(e8m0_bexp_t const scale [[maybe_unused]],
f6_t const data [[maybe_unused]])
{
// no inf representation for fp6
return false;
}
/**
* @brief Checks if an bf6_t value is infinite.
*
* Because bf6_t does not support infinite values, this function always returns false.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return Always false, as infinity is not represented in bf6_t.
*/
template <>
__host__ __device__ inline bool is_inf<bf6_t>(e8m0_bexp_t const scale [[maybe_unused]],
bf6_t const data [[maybe_unused]])
{
// no inf representation for bf6
return false;
}
/**
* @brief Checks whether an f6_t value is zero.
*
* If the specified f6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template <>
__host__ __device__ inline bool is_zero<f6_t>(e8m0_bexp_t const scale, f6_t const data)
{
if(is_nan<f6_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
f6_t result = (data & 0b00111111) & NumericUtils<f6_t>::set_sign_mask;
return result == 0b0;
}
/**
* @brief Checks whether an bf6_t value is zero.
*
* If the specified bf6_t is NaN, this function returns false.
* Otherwise, it masks out the sign bits and checks if the remaining bits
* are zero.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to check.
* @return true if the value is zero; otherwise false.
*/
template <>
__host__ __device__ inline bool is_zero<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
{
if(is_nan<bf6_t>(scale, data))
return false;
// no need to check for scale as it does not have a 0 representation
bf6_t result = (data & 0b00111111) & NumericUtils<bf6_t>::set_sign_mask;
return result == 0b0;
}
/**
* @brief Converts an f6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the f6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for f6_t.
* @param data The f6_t value to convert.
* @return The converted float value.
*/
template <>
__host__ __device__ inline float to_float<f6_t>(e8m0_bexp_t const scale, f6_t const data)
{
if(is_nan<f6_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<f6_t>(scale, data))
return 0.0f;
f6_t prepared_data = data & 0b00111111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<f6_t>(prepared_data, scale_exp);
}
/**
* @brief Converts an bf6_t value to a float based on an e8m0_bexp_t scale factor.
*
* Checks if the bf6_t value is NaN or zero before performing the conversion.
* Applies the exponent from the scale to compute the final float result.
*
* @param scale The exponent scale factor (e8m0_bexp_t) used for bf6_t.
* @param data The bf6_t value to convert.
* @return The converted float value.
*/
template <>
__host__ __device__ inline float to_float<bf6_t>(e8m0_bexp_t const scale, bf6_t const data)
{
if(is_nan<bf6_t>(scale, data))
return std::numeric_limits<float>::quiet_NaN();
if(is_zero<bf6_t>(scale, data))
return 0.0f;
bf6_t prepared_data = data & 0b00111111;
int scale_exp = get_exponent_value<e8m0_bexp_t>(scale);
return convert_to_float<bf6_t>(prepared_data, scale_exp);
}
/**
* @brief Converts a float to f6_t with saturation.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template <>
__host__ __device__ inline f6_t sat_convert_to_type<f6_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
f6_t res = convert_to_type<f6_t>(value);
if(std::abs(to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f6_t>::DataMinSubnorm())
return sign ? NumericUtils<f6_t>::negative_zero_mask
: NumericUtils<f6_t>::positive_zero_mask;
return res;
}
/**
* @brief Converts a float to bf6_t with saturation.
*
* If the input is NaN or exceeds the representable range for bf6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated bf6_t value.
*/
template <>
__host__ __device__ inline bf6_t sat_convert_to_type<bf6_t>(float value)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
{
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
}
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
bf6_t res = convert_to_type<bf6_t>(value);
if(std::abs(to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<bf6_t>::DataMinSubnorm())
return sign ? NumericUtils<bf6_t>::negative_zero_mask
: NumericUtils<bf6_t>::positive_zero_mask;
return res;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template <>
__host__ __device__ inline f6_t sat_convert_to_type_sr<f6_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<f6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<f6_t>::data_max_negative_normal_mask
: NumericUtils<f6_t>::data_max_positive_normal_mask;
f6_t res = convert_to_type_sr<f6_t>(value, seed);
if(std::abs(to_float<f6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<f6_t>::DataMinSubnorm())
return sign ? NumericUtils<f6_t>::negative_zero_mask
: NumericUtils<f6_t>::positive_zero_mask;
return res;
}
/**
* @brief Converts a float to f6_t with saturation and stochastic rounding.
*
* If the input is NaN or exceeds the representable range for f6_t, returns
* the corresponding max normal mask. Handles subnormal cases by returning
* zero with the appropriate sign.
*
* @param value The float value to be converted.
* @return The saturated f6_t value.
*/
template <>
__host__ __device__ inline bf6_t sat_convert_to_type_sr<bf6_t>(float value, uint32_t seed)
{
cvt t;
t.value_float = value;
uint32_t sign = t.value_bitwise >> 31;
if(std::isnan(value))
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
if(std::abs(value) > NumericLimits<bf6_t>::Max()) // covers inf case as well
return sign ? NumericUtils<bf6_t>::data_max_negative_normal_mask
: NumericUtils<bf6_t>::data_max_positive_normal_mask;
bf6_t res = convert_to_type_sr<bf6_t>(value, seed);
if(std::abs(to_float<bf6_t>(NumericLimits<e8m0_bexp_t>::Binary_1(), res)) <
NumericLimits<bf6_t>::DataMinSubnorm())
return sign ? NumericUtils<bf6_t>::negative_zero_mask
: NumericUtils<bf6_t>::positive_zero_mask;
return res;
}
} // namespace ck::utils
#include "ck/utility/data_type.hpp"
#include "ck/utility/mxfp_utils.hpp"
#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
#define CK_MX_FP8_CVT_FAST_PATH 1
#else
#define CK_MX_FP8_CVT_FAST_PATH 0
#endif
namespace ck {
namespace fp8_impl {
#if CK_MX_FP8_CVT_FAST_PATH
template <ck_fp8_interpretation_t interpret>
static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v)
{
union
{
unsigned int i32val;
unsigned char i8val[4];
} val;
val.i8val[0] = v;
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
}
else
{
return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
}
}
template <ck_fp8_interpretation_t interpret>
static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v)
{
const auto i16val = bit_cast<uint16_t>(v);
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
}
else
{
return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
}
}
template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
static __device__ fp8_storage_t cast_to_f8_from_f32_scaled(float v,
unsigned int rng = 0,
float scale = 1.0f)
{
fp8_storage_t i8data;
union
{
float fval;
unsigned int i32val;
} val;
union
{
uint32_t ival;
vector_type<int16_t, 2>::type v2i16;
fp8_storage_t v4i8[4];
} ret{};
// unsigned int ival = 0;
val.fval = v;
if constexpr(stochastic_rounding)
{
ret.ival =
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
: __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);
i8data = ret.v4i8[0];
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
// If fval / scale > max fp8, returns Nan
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
val.fval,
val.fval,
scale,
/*dst_lo_hi_sel*/ false);
}
else
{
// If fval / scale > max bf8, returns Inf
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
val.fval,
val.fval,
scale,
/*dst_lo_hi_sel*/ false);
}
i8data = ret.v4i8[0];
}
return i8data;
}
template <ck_fp8_interpretation_t interpret, bool stochastic_rounding = false>
static __device__ fp8x2_storage_t cast_to_f8_from_f32_scaled(float2_t v,
unsigned int rng = 0,
float scale = 1.0f)
{
union
{
uint32_t ival;
vector_type<int16_t, 2>::type v2i16;
StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
} ret{};
if constexpr(stochastic_rounding)
{
fp8x2_storage_t f8x2;
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
f8x2[0] = ret.v2f8x2(Number<0>{})[0];
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
f8x2[1] = ret.v2f8x2(Number<0>{})[0];
}
else
{
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
f8x2[0] = ret.v2f8x2(Number<0>{})[0];
ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
f8x2[1] = ret.v2f8x2(Number<0>{})[0];
}
return f8x2;
}
else
{
// RNE CVT
// llvm.amdgcn.cvt.scalef32.pk.fp8.f32
// v2i16 old_vdst, float srcA, float srcB, float scale, bool dst_lo_hi_sel
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
// If fval / scale > max fp8, returns Nan
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32(/*old_vdst*/ ret.v2i16,
v[0],
v[1],
scale,
/*dst_lo_hi_sel*/ false);
}
else
{
// If fval / scale > max bf8, returns Inf
ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32(/*old_vdst*/ ret.v2i16,
v[0],
v[1],
scale,
/*dst_lo_hi_sel*/ false);
}
return ret.v2f8x2(Number<0>{});
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
#if CK_MX_FP8_CVT_FAST_PATH
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}
/**
* \brief convert 2xfloat to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
float scale)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
}
return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
}
#else
/**
* \brief convert float to @p fp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f float number
* \param scale scaling factor
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8_scaled(const float f, float scale)
{
static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interp == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{
return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
}
else
{
__hip_assert(false && "FP8 type is not supported by current target device");
return 0;
}
}
/**
* \brief convert two float to @p 2xfp8_storage_t with scaling
*
* This version is used when the fast path (MX FP8 hardware) is not available
*
* \tparam interp interpretation of fp8
* \param f 2xfloat
* \param scale scaling factor
* \return 2xfp8_storage_t
*/
template <ck_fp8_interpretation_t interp, bool stochastic_rounding = false>
__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const float2_t f,
float scale)
{
static_assert(interp == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interp == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only OCP interpretations are supported");
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f[0]);
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{
return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
}
else
{
__hip_assert(false && "FP8 type is not supported by current target device");
return 0;
}
}
#endif // CK_MX_FP8_CVT_FAST_PATH
} // namespace fp8_impl
// Declare a template function for fp8 conversion using SR
template <typename Y, typename X>
__host__ __device__ constexpr Y mxf8_convert_sr(X x, float scale);
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y mxf8_convert_rne(X x, float scale);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t mxf8_convert_rne<f8_ocp_t, float>(float x, float scale)
{
return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_ocp_t mxf8_convert_rne<bf8_ocp_t, float>(float x, float scale)
{
return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x2 to fp8x2 with rounding to nearest even
template <>
inline __host__ __device__ f8x2_ocp_t mxf8_convert_rne<f8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x2 to bf8x2 with rounding to nearest even
template <>
inline __host__ __device__ bf8x2_ocp_t mxf8_convert_rne<bf8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
}
// convert fp32x16 to fp8x16 with rounding to nearest even
template <>
inline __host__ __device__ f8x16_ocp_t mxf8_convert_rne<f8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
f8x16_ocp_t fp8_1x16;
f8x2_ocp_t fp8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.fp8_2x8[i] = mxf8_convert_rne<f8x2_ocp_t>(in.float_2x8[i], scale); });
return out.fp8_1x16;
}
// convert fp32x16 to bf8x16 with rounding to nearest even
template <>
inline __host__ __device__ bf8x16_ocp_t mxf8_convert_rne<bf8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.bf8_2x8[i] = mxf8_convert_rne<bf8x2_ocp_t>(in.float_2x8[i], scale); });
return out.bf8_1x16;
}
// convert fp32x32 to fp8x32 with rounding to nearest even
template <>
inline __host__ __device__ f8x32_ocp_t mxf8_convert_rne<f8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
f8x32_ocp_t fp8_1x32;
f8x16_ocp_t fp8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.fp8_16x2[i] = mxf8_convert_rne<f8x16_ocp_t>(in.float_16x2[i], scale); });
return out.fp8_1x32;
}
// convert fp32x32 to bf8x32 with rounding to nearest even
template <>
inline __host__ __device__ bf8x32_ocp_t mxf8_convert_rne<bf8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.bf8_16x2[i] = mxf8_convert_rne<bf8x16_ocp_t>(in.float_16x2[i], scale); });
return out.bf8_1x32;
}
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t mxf8_convert_sr<f8_ocp_t, float>(float x, float scale)
{
return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t mxf8_convert_sr<bf8_ocp_t, float>(float x, float scale)
{
return bf8_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x2 to fp8x2 with stochastic rounding
template <>
inline __host__ __device__ f8x2_ocp_t mxf8_convert_sr<f8x2_ocp_t, float2_t>(float2_t x, float scale)
{
return f8x2_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x2 to bf8x2 with stochastic rounding
template <>
inline __host__ __device__ bf8x2_ocp_t mxf8_convert_sr<bf8x2_ocp_t, float2_t>(float2_t x,
float scale)
{
return bf8x2_ocp_t{
fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
}
// convert fp32x16 to fp8x16 with stochastic rounding
template <>
inline __host__ __device__ f8x16_ocp_t mxf8_convert_sr<f8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
f8x16_ocp_t fp8_1x16;
f8x2_ocp_t fp8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.fp8_2x8[i] = mxf8_convert_sr<f8x2_ocp_t>(in.float_2x8[i], scale); });
return out.fp8_1x16;
}
// convert fp32x16 to bf8x16 with stochastic rounding
template <>
inline __host__ __device__ bf8x16_ocp_t mxf8_convert_sr<bf8x16_ocp_t, float16_t>(float16_t x,
float scale)
{
union
{
float16_t float_1x16;
float2_t float_2x8[8];
} in{x};
union
{
bf8x16_ocp_t bf8_1x16;
bf8x2_ocp_t bf8_2x8[8];
} out{};
ck::static_for<0, 8, 1>{}(
[&](auto i) { out.bf8_2x8[i] = mxf8_convert_sr<bf8x2_ocp_t>(in.float_2x8[i], scale); });
return out.bf8_1x16;
}
// convert fp32x32 to fp8x32 with stochastic rounding
template <>
inline __host__ __device__ f8x32_ocp_t mxf8_convert_sr<f8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
f8x32_ocp_t fp8_1x32;
f8x16_ocp_t fp8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.fp8_16x2[i] = mxf8_convert_sr<f8x16_ocp_t>(in.float_16x2[i], scale); });
return out.fp8_1x32;
}
// convert fp32x32 to bf8x32 with stochastic rounding
template <>
inline __host__ __device__ bf8x32_ocp_t mxf8_convert_sr<bf8x32_ocp_t, float32_t>(float32_t x,
float scale)
{
union
{
float32_t float_1x32;
float16_t float_16x2[2];
} in{x};
union
{
bf8x32_ocp_t bf8_1x32;
bf8x16_ocp_t bf8_16x2[2];
} out{};
ck::static_for<0, 2, 1>{}(
[&](auto i) { out.bf8_16x2[i] = mxf8_convert_sr<bf8x16_ocp_t>(in.float_16x2[i], scale); });
return out.bf8_1x32;
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck::utils {
union cvt
{
float value_float;
uint32_t value_bitwise;
};
template <typename DTYPE>
inline bool getDataHasInf()
{
return DTYPE::dataInfo.hasInf;
}
template <typename T>
__host__ __device__ inline bool is_zero(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_nan(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ inline int get_exponent_value(T x)
{
x >>= NumericUtils<T>::mant;
x &= ((1 << NumericUtils<T>::exp) - 1);
return static_cast<int>(x);
}
template <typename T>
__host__ __device__ inline bool is_subnormal(T x)
{
return get_exponent_value<T>(x) == 0;
}
template <typename T>
__host__ __device__ inline double get_mantissa_value(T x)
{
double mantissa = is_subnormal<T>(x) ? 0.0f : 1.0f;
for(uint i = 0; i < NumericUtils<T>::mant; i++)
{
mantissa += std::pow(2, -int32_t((NumericUtils<T>::mant - i))) * (x & 0b1);
x >>= 1;
}
return mantissa;
}
template <typename T>
__host__ __device__ inline bool get_data_has_inf()
{
return NumericUtils<T>::has_inf;
}
template <typename T>
__host__ __device__ float convert_to_float(T data, int scale_exp)
{
float d_sign =
std::pow(-1, static_cast<float>(data >> (NumericUtils<T>::exp + NumericUtils<T>::mant)));
float d_exp;
if(is_subnormal<T>(data))
d_exp = std::pow(2, 1 - static_cast<int>(NumericUtils<T>::bias));
else
d_exp = std::pow(2, get_exponent_value<T>(data) - static_cast<int>(NumericUtils<T>::bias));
float d_mant = get_mantissa_value<T>(data);
float data_value = d_sign * d_exp * d_mant;
float scale_value = std::pow(
2, static_cast<float>((scale_exp - static_cast<int>(NumericUtils<e8m0_bexp_t>::bias))));
return data_value * scale_value;
}
template <typename T>
__host__ __device__ inline float to_float(e8m0_bexp_t const scale, T const data);
template <typename T>
__host__ __device__ T sat_convert_to_type(float value);
template <typename T>
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);
template <typename T>
inline T convert_to_type(float value)
{
using bitwise_type = typename NumericUtils<T>::bitwise_type;
if(std::abs(value) > NumericLimits<T>::Max())
{
float max_value = NumericLimits<T>::Max();
cvt t;
// cppcheck-suppress redundantAssignment
t.value_float = max_value;
uint32_t max_bitwise = t.value_bitwise;
// cppcheck-suppress redundantAssignment
t.value_float = value;
bitwise_type sign =
t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
bitwise_type exp =
((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
(NumericUtils<float>::bias - NumericUtils<T>::bias);
bitwise_type mantissa = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant_prev &= ((1 << NumericUtils<T>::mant) - 1);
mant_prev--;
mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t prev_bit =
((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
t.value_bitwise = prev_bit;
float prev_val = t.value_float;
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(exp << NumericUtils<T>::mant) | mantissa;
}
else
{
if(!get_data_has_inf<T>())
{
return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
}
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(exp << NumericUtils<T>::mant);
}
}
}
const int mfmt = NumericUtils<float>::mant;
uint32_t x;
x = bit_cast<uint32_t>(value);
uint32_t head, mantissa;
int32_t exponent, bias;
uint32_t sign;
head = x & NumericUtils<float>::head_mask;
mantissa = x & NumericUtils<float>::mant_mask;
exponent = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
sign = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
bias = NumericUtils<float>::bias;
if(x == 0)
{
return 0b0;
}
const int mini_bias = NumericUtils<T>::bias;
const int mini_denormal_act_exponent = 1 - mini_bias;
int act_exponent, out_exponent, exponent_diff;
bool is_subnorm = false;
if(exponent == 0)
{
act_exponent = exponent - bias + 1;
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
act_exponent = exponent - bias;
if(act_exponent <= mini_denormal_act_exponent)
{
exponent_diff = mini_denormal_act_exponent - act_exponent;
is_subnorm = true;
}
else
{
exponent_diff = 0;
}
mantissa += (1UL << mfmt);
}
auto shift_amount = (mfmt - NumericUtils<T>::mant + exponent_diff);
shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
float min_subnorm = NumericLimits<T>::DataMinSubnorm() * (sign ? -1 : 1);
if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
{
// closer to 0
if(std::abs(value) <= std::abs(min_subnorm - value))
return 0;
else
return 1 | (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant));
}
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << mfmt);
out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
uint32_t drop_mask = (1UL << (mfmt - NumericUtils<T>::mant)) - 1;
bool odd = mantissa & (1UL << (mfmt - NumericUtils<T>::mant));
mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
if(out_exponent == 0)
{
if((1UL << mfmt) & mantissa)
{
out_exponent = 1;
}
}
else
{
if((1UL << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
}
}
mantissa >>= (mfmt - NumericUtils<T>::mant);
if(out_exponent == 0 && mantissa == 0)
{
return 0;
}
mantissa &= (1UL << NumericUtils<T>::mant) - 1;
return (sign << (NumericUtils<T>::exp + NumericUtils<T>::mant)) |
(out_exponent << NumericUtils<T>::mant) | mantissa;
}
template <typename T>
inline T convert_to_type_sr(float value, uint32_t seed)
{
if(std::abs(value) > NumericLimits<T>::Max())
{
float max_value = NumericLimits<T>::Max();
cvt t;
// cppcheck-suppress redundantAssignment
t.value_float = max_value;
uint max_bitwise = t.value_bitwise;
// cppcheck-suppress redundantAssignment
t.value_float = value;
T sign = t.value_bitwise >> (NumericUtils<float>::exp + NumericUtils<float>::mant);
T exp = ((max_bitwise >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask) -
(NumericUtils<float>::bias - NumericUtils<T>::bias);
uint32_t mant_prev = max_bitwise >> (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant_prev &= ((1UL << NumericUtils<T>::mant) - 1);
mant_prev--;
mant_prev <<= (NumericUtils<float>::mant - NumericUtils<T>::mant);
uint32_t prev_bit =
((max_bitwise >> NumericUtils<float>::mant) << NumericUtils<float>::mant) | mant_prev;
t.value_bitwise = prev_bit;
float prev_val = t.value_float;
float diff = max_value - prev_val;
float actual_max = max_value + (diff / 2);
if(std::abs(value) < actual_max)
{
double d_max_value = static_cast<double>(max_value);
double d_actual_max = static_cast<double>(actual_max);
double d_value = static_cast<double>(value);
double d_is = std::abs(d_max_value - d_actual_max);
double d_seed = static_cast<double>(seed);
double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is); // prob to round down
double thresh = UINT_MAX * d_prob;
if(!get_data_has_inf<T>() || d_seed <= thresh)
// return static_cast<T>(satConvertToType(getDataMax<DTYPE>())); //round down time
return sign == 0 ? NumericUtils<f4_t>::data_max_positive_normal_mask
: NumericUtils<f4_t>::data_max_negative_normal_mask;
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
| (exp << NumericUtils<T>::mant);
}
}
else
{
if(!get_data_has_inf<T>())
return (1 << (NumericUtils<T>::mant + NumericUtils<T>::exp)) - 1;
else
{
exp++;
return sign << ((NumericUtils<T>::exp + NumericUtils<T>::mant)) // inf
| (exp << NumericUtils<T>::mant);
}
}
}
uint32_t f32 = bit_cast<uint32_t>(value);
auto f32_mant = f32 & NumericUtils<float>::mant_mask;
auto head = f32 & NumericUtils<float>::head_mask;
auto f32_exp = (head >> NumericUtils<float>::mant) & NumericUtils<float>::exp_mask;
auto sign_bit = head >> (NumericUtils<float>::mant + NumericUtils<float>::exp);
auto sign = sign_bit << (NumericUtils<T>::exp + NumericUtils<T>::mant);
f32_exp = static_cast<int32_t>(f32_exp) - NumericUtils<float>::bias;
int32_t exp = f32_exp;
auto mant = f32_mant;
bool subnorm = false;
if(f32 == 0)
return 0b0;
if(exp >= NumericUtils<T>::unbiased_exp_min)
{
mant = f32_mant;
}
// if the exponent bit is 8, then the subnormal is exactly the same as f32
else if(exp < NumericUtils<T>::unbiased_exp_min &&
NumericUtils<T>::exp < NumericUtils<float>::exp)
{
subnorm = true;
auto diff = static_cast<uint32_t>(NumericUtils<T>::unbiased_exp_min - exp);
if(diff >= 32)
{
mant = 0;
f32_mant = 0;
}
else
{
f32_mant |= static_cast<uint32_t>(1) << NumericUtils<float>::mant;
f32_mant >>= diff;
}
exp = 0;
mant = f32_mant;
}
uint32_t sr_shift = NumericUtils<T>::sr_shift;
// For stochastic-rounding we add the aligned random value to the
// mantissa and then truncate (RTZ).
mant += seed >> sr_shift;
// Increment exponent when mantissa overflows due to rounding
if(mant >= static_cast<uint32_t>(1) << NumericUtils<float>::mant)
++exp;
mant >>= (NumericUtils<float>::mant - NumericUtils<T>::mant);
mant &= ((1 << NumericUtils<T>::mant) - 1);
auto biased_exp = static_cast<uint32_t>(exp);
if(!subnorm)
biased_exp = static_cast<uint32_t>(exp + NumericUtils<T>::bias);
biased_exp &= ((1 << NumericUtils<T>::exp) - 1);
auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
return val;
}
} // namespace ck::utils
This diff is collapsed.
This diff is collapsed.
...@@ -824,4 +824,4 @@ ...@@ -824,4 +824,4 @@
#undef _UK_PK_CVT_ #undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_ #undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA #undef CK_TILE_FLATMM_UK_MFMA
// clang-format on // clang-format on
...@@ -722,4 +722,4 @@ ...@@ -722,4 +722,4 @@
#undef _UK_PK_CVT_ #undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_ #undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA #undef CK_TILE_FLATMM_UK_MFMA
// clang-format on // clang-format on
...@@ -771,4 +771,4 @@ ...@@ -771,4 +771,4 @@
#undef _UK_MFMA_ #undef _UK_MFMA_
#undef CK_TILE_FLATMM_UK_2B #undef CK_TILE_FLATMM_UK_2B
#undef CK_TILE_FLATMM_UK_MFMA #undef CK_TILE_FLATMM_UK_MFMA
// clang-format on // clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment