Commit dec32dc6 authored by ThomasNing's avatar ThomasNing
Browse files

Finish the feature and merge with develop on the computeV2

parents 71352c44 c5fff071
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
......@@ -376,6 +376,16 @@ struct numeric<bfloat16_t>
}
};
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<bfloat16_t>
{
static constexpr int exp = 8;
static constexpr int mant = 7;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
#endif
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
......@@ -14,6 +14,12 @@
#pragma once
#if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
#define CK_TILE_FP8_CVT_DEVICE 1
#else
#define CK_TILE_FP8_CVT_DEVICE 0
#endif
namespace ck_tile {
// fp8 rounding modes
......@@ -25,15 +31,26 @@ enum class fp8_rounding_mode
stochastic
};
/**
* \brief FP8 interpretation used in conversion algorithms
*/
enum class fp8_interpretation
{
E4M3_OCP = 0, // OCP FP8 E4M3
E5M2_OCP = 1, // OCP BF8 E5M2
E4M3_FNUZ = 2, // FNUZ FP8 E4M3
E5M2_FNUZ = 3, // FNUZ BF8 E5M2
};
/*
* ______________NANOO_________________ | ______________IEEE________________
* ______________FNUZ_________________ | ______________OCP________________
* e4m3 e5m2 | e4m3 e5m2
* bias : 8 16 | 7 15
* inf : 1.0000.000 1.00000.00 | N/A s.11111.00
* Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
* zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
* Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344)
* Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
* 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
* Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
* 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
......@@ -55,10 +72,10 @@ struct alignas(1) float8_e4m3_t
{
static constexpr int exponent = 4;
static constexpr int mantissa = 3;
#if defined(__gfx94__)
static constexpr int bias = 1 << (exponent - 1); // NANOO
#if CK_TILE_USE_OCP_FP8
static constexpr int bias = 7; // OCP
#else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
static constexpr int bias = 8; // FNUZ
#endif
using raw_type = uint8_t;
raw_type data;
......@@ -113,10 +130,10 @@ struct alignas(1) float8_e5m2_t
{
static constexpr int exponent = 5;
static constexpr int mantissa = 2;
#if defined(__gfx94__)
static constexpr int bias = 1 << (exponent - 1); // NANOO
#if CK_TILE_USE_OCP_FP8
static constexpr int bias = 15; // OCP
#else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
static constexpr int bias = 16; // FNUZ
#endif
using raw_type = uint8_t;
raw_type data;
......@@ -183,501 +200,727 @@ struct native_t<bf8_t>
};
#else
using fp8_t = _BitInt(8);
using fp8_raw_t = uint8_t;
using bf8_t = unsigned _BitInt(8);
using bf8_raw_t = uint8_t;
#endif
// below is sw fp8 conversion, not utilizing hw instruction
namespace impl {
template <typename T>
struct numeric_traits;
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
template <>
struct numeric_traits<fp8_t>
{
// fp8/bf8 exponent/mantissa layout
constexpr int out_exp = numeric_traits<Y>::exp;
constexpr int out_mant = numeric_traits<Y>::mant;
using bitwise_type = fp8_raw_t;
static constexpr int exp = 4;
static constexpr int mant = 3;
#if CK_TILE_USE_OCP_FP8
static constexpr int bias = 7;
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_OCP;
#else
static constexpr int bias = 8;
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ;
#endif
static constexpr uint8_t abs_mask = 0x7F;
};
// original type exponent/mantissa layout
constexpr int in_exp = numeric_traits<X>::exp;
constexpr int in_mant = numeric_traits<X>::mant;
template <>
struct numeric_traits<bf8_t>
{
using bitwise_type = bf8_raw_t;
int exponent, bias;
uint32_t head, mantissa, sign;
// nan code is same for float and half
#if CK_TILE_USE_CUSTOM_DATA_TYPE
constexpr Y nan_code =
numeric<Y>::quiet_NaN(); // __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
static constexpr int exp = 5;
static constexpr int mant = 2;
#if CK_TILE_USE_OCP_FP8
static constexpr int bias = 15;
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_OCP;
#else
constexpr Y nan_code = 0x80;
static constexpr int bias = 16;
static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ;
#endif
static constexpr uint8_t abs_mask = 0x7F;
};
// below is sw fp8 conversion, not utilizing hw instruction
namespace impl {
template <typename SrcT, typename DstT, bool clip = true, bool stoch = false>
CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
{
static_assert(std::is_same<DstT, fp8_t>::value || std::is_same<DstT, bf8_t>::value,
"DstT type must be fp8 or bf8.");
constexpr uint32_t nan_mask = numeric_traits<X>::nan_mask;
constexpr bool is_half = std::is_same<SrcT, half_t>::value;
constexpr bool is_float = std::is_same<SrcT, float>::value;
static_assert(is_half || is_float, "Only half and float can be cast to f8");
// convert to bitwise
using T_bitwise = typename numeric_traits<X>::bitwise_type;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// fp8/bf8 type exponent/mantissa layout
constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
constexpr bool is_fnuz =
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
(numeric_traits<DstT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
// unpack the input, depends on datatype
head = x_bitwise & numeric_traits<X>::head_mask;
mantissa = x_bitwise & numeric_traits<X>::mant_mask;
exponent = (head >> in_mant) & numeric_traits<X>::exp_mask;
sign = head >> (in_exp + in_mant);
bias = numeric_traits<X>::bias;
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
if constexpr(negative_zero_nan)
unsigned long long head, mantissa;
int exponent, bias;
unsigned int sign;
unsigned long long fInf, abs_mask;
head = src_bitwise & numeric_traits<SrcT>::head_mask;
mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
sign = head >> (SrcT_exp + SrcT_mant);
bias = numeric_traits<SrcT>::bias;
fInf = numeric_traits<SrcT>::Inf;
abs_mask = numeric_traits<SrcT>::abs_mask;
unsigned int signed_inf = 0;
unsigned int nan = 0;
if constexpr(is_fnuz)
{
if((x_bitwise & nan_mask) == nan_mask)
return nan_code;
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
nan = 0x80;
}
else
{
if((x_bitwise & nan_mask) == nan_mask)
return signed_inf + (mantissa != 0 ? 1 : 0);
if constexpr(DstT_exp == 4)
{ // e4m3
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
}
else
{ // e5m2
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
}
nan = (sign << 7) + 0x7f;
}
// Max values
unsigned long long ifmax = 0;
if constexpr(is_float)
{
if constexpr(DstT_exp == 5)
{
ifmax = 0x47600000;
}
else
{
if constexpr(is_fnuz)
{
ifmax = 0x43700000;
}
else
{
ifmax = 0x43E00000;
}
}
}
else if constexpr(is_half)
{
if constexpr(DstT_exp == 5)
{
ifmax = 0x7B00;
}
else
{
if constexpr(is_fnuz)
{
ifmax = 0x5B80;
}
else
{
ifmax = 0x5F00;
}
}
}
// check if x is 0.0
if(x_bitwise == 0)
return __builtin_bit_cast(Y, static_cast<uint8_t>(0));
// Deal with inf and NaNs
if((src_bitwise & fInf) == fInf)
{
if constexpr(is_fnuz)
return signed_inf;
return mantissa != 0 ? nan : signed_inf;
}
if((src_bitwise & abs_mask) > ifmax)
{
return signed_inf;
}
if(src_bitwise == 0)
{
return 0;
}
// First need to check if it is normal or denorm as there is a difference of implict 1
// Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
// The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
// RNE, no need to add rng. Then probably need to check whether there is carry and adjust
// exponent and mantissa again3
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits
const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0);
const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// out_exponent is the converted f8 exponent with bias encoding
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, out_exponent, exponent_diff;
int act_exponent, f8_exponent, exponent_diff;
if(exponent == 0)
{ // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal.
In this case, the fp16 mantissa should be shift left by 1 */
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
exponent bias 16. It means that there are some numbers in fp16 denormal but they
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff = out_denormal_act_exponent -
exponent_diff = f8_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else
{ // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if(act_exponent <= out_denormal_act_exponent)
if(act_exponent <= f8_denormal_act_exponent)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = out_denormal_act_exponent - act_exponent;
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = f8_denormal_act_exponent - act_exponent;
}
else
{ // both fp32/fp16 and f8 are in normal range
exponent_diff =
0; // exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
{ // both fp32/fp16 and f8 are in normal range
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
// for this case, act_exponent could be larger. Just
// that it does not need shift mantissa
}
mantissa += (1 << in_mant); // Add the implicit 1 into mantissa
mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) ==
(1 << (in_mant - out_mant + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint. */
bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
(1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part and
make something not midpoint look like midpoint. For example, the fp16 number
0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
by 4 bits, it would look like midpoint.
*/
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1 << in_mant);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
out_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1);
bool implicit_one = mantissa & (1ull << SrcT_mant);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
bool odd =
mantissa &
(1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1
mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
mantissa & (1ull << (SrcT_mant -
DstT_mant)); // if the least significant bit that is not truncated is 1
mantissa +=
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
// Now we deal with overflow
if(out_exponent == 0)
if(f8_exponent == 0)
{
if((1 << in_mant) & mantissa)
if((1ull << SrcT_mant) & mantissa)
{
out_exponent = 1; // denormal overflow to become normal, promote exponent
// No need to make 1 implicit now as it will be addressed later
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
}
else
{
if((1 << (in_mant + 1)) & mantissa)
if((1ull << (SrcT_mant + 1)) & mantissa)
{
mantissa >>= 1;
out_exponent++;
// No need to make 1 implicit now as it will be addressed later
f8_exponent++;
}
}
mantissa >>= (in_mant - out_mant);
mantissa >>= (SrcT_mant - DstT_mant);
if(out_exponent > max_exp)
// above range: quantize to maximum possible float of the same sign
const int max_exp = (1 << DstT_exp) - 1;
if(f8_exponent > max_exp)
{
if(clip)
if constexpr(clip)
{
mantissa = (1 << out_mant) - 1;
out_exponent = max_exp;
mantissa = (1 << DstT_mant) - 1;
f8_exponent = max_exp;
}
else
{
return __builtin_bit_cast(Y, static_cast<uint8_t>(signed_inf));
return signed_inf;
}
}
// check if x is 0.0 or -0.0
if(out_exponent == 0 && mantissa == 0)
return __builtin_bit_cast(
Y, static_cast<uint8_t>(negative_zero_nan ? 0 : (sign << (out_exp + out_mant))));
mantissa &= (1 << out_mant) - 1;
return __builtin_bit_cast(Y,
static_cast<uint8_t>((sign << (out_exp + out_mant)) |
(out_exponent << out_mant) | mantissa));
if(f8_exponent == 0 && mantissa == 0)
return is_fnuz ? 0 : (sign << 7);
mantissa &= (1 << DstT_mant) - 1;
return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
}
template <typename X, typename Y, bool negative_zero_nan>
CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
template <typename SrcT, typename DstT, bool clip = true>
CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
{
// fp8/bf8 exponent/mantissa layout
constexpr int in_exp = numeric_traits<X>::exp;
constexpr int in_mant = numeric_traits<X>::mant;
// resulting type exponent/mantissa layout
constexpr int out_exp = numeric_traits<Y>::exp;
constexpr int out_mant = numeric_traits<Y>::mant;
uint8_t x_raw = __builtin_bit_cast(uint8_t, x);
// prepare the codes
constexpr uint8_t nan_code = 0x80;
Y Inf, NegInf, NaN, Neg0;
using T_bitwise = typename numeric_traits<Y>::bitwise_type;
constexpr T_bitwise Inf_bitwise = numeric_traits<Y>::Inf;
constexpr T_bitwise NegInf_bitwise = numeric_traits<Y>::NegInf;
constexpr T_bitwise NaN_bitwise = numeric_traits<Y>::NaN;
constexpr T_bitwise Neg0_bitwise = numeric_traits<Y>::Neg0;
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
NaN = *(reinterpret_cast<const Y*>(&NaN_bitwise));
Neg0 = *(reinterpret_cast<const Y*>(&Neg0_bitwise));
// check if x is 0.0
if(x_raw == 0)
return static_cast<Y>(0);
// unpack the input
uint32_t sign = x_raw >> (in_exp + in_mant);
uint32_t mantissa = x_raw & ((1 << in_mant) - 1);
int exponent = (x_raw & 0x7F) >> in_mant;
static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
"SrcT type must be fp8 or bf8.");
constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
constexpr bool is_fnuz =
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E4M3_FNUZ) ||
(numeric_traits<SrcT>::f8_interpret == fp8_interpretation::E5M2_FNUZ);
constexpr bool is_half = std::is_same<DstT, half_t>::value;
constexpr bool is_float = std::is_same<DstT, float>::value;
static_assert(is_half || is_float, "DstT type must be half_t or float.");
// destination type exponent/mantissa layout
constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
constexpr DstT fInf = bit_cast<DstT>(numeric_traits<DstT>::Inf);
constexpr DstT fNegInf = bit_cast<DstT>(numeric_traits<DstT>::NegInf);
constexpr DstT fNaN = bit_cast<DstT>(numeric_traits<DstT>::NaN);
constexpr DstT fNeg0 = bit_cast<DstT>(numeric_traits<DstT>::Neg0);
DstT fmax{0}, fmin{0};
// Max number in e5m2 57344
if constexpr(is_half)
{
fmax = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0x7B00));
fmin = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0xFB00));
}
else if constexpr(is_float)
{
fmax = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0x47600000));
fmin = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0xC7600000));
}
constexpr int exp_low_cutoff =
(1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
T_bitwise retval;
if(x == 0)
{
return 0;
}
if constexpr(negative_zero_nan)
unsigned long long sign = x >> 7;
unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
int exponent = (x & 0x7F) >> SrcT_mant;
if constexpr(is_fnuz)
{
if(x_raw == nan_code)
return NaN;
if(x == 0x80)
{
return fNaN;
}
}
else
{
if(x_raw == nan_code)
return Neg0;
if(exponent == ((1 << in_exp) - 1))
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
if(x == 0x80)
{
return fNeg0;
}
if constexpr(SrcT_exp == 4)
{ // e4m3
if((x & 0x7F) == 0x7F)
{
return fNaN;
}
}
else if((x & 0x7C) == 0x7C)
{ // e5m2
if((x & 0x3) == 0)
{
if constexpr(clip)
{
return sign ? fmin : fmax;
}
return sign ? fNegInf : fInf;
}
return fNaN;
}
}
if((numeric_traits<Y>::mant == 10) && (numeric_traits<X>::mant == 2) && !negative_zero_nan)
typename numeric_traits<DstT>::bitwise_type retval;
if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
{
retval = x_raw;
retval <<= 8;
return *(reinterpret_cast<const Y*>(&retval));
retval = x << 8;
return bit_cast<DstT>(retval);
}
const int exp_low_cutoff =
(1 << (DstT_exp - 1)) - (1 << (SrcT_exp - 1)) + 1 - (is_fnuz ? 1 : 0);
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + clz(mantissa) - (32 - in_mant);
int sh = 1 + clz(mantissa) - (32 - SrcT_mant);
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1 << in_mant) - 1);
mantissa &= ((1ull << SrcT_mant) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= out_mant - in_mant;
mantissa <<= DstT_mant - SrcT_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
// subnormal output (occurs when DstT is half_t, we=5, is_fnuz=true)
if(exponent <= 0)
{
mantissa |= 1 << out_mant;
mantissa |= 1 << DstT_mant;
mantissa >>= 1 - exponent;
exponent = 0;
}
retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
return *(reinterpret_cast<const Y*>(&retval));
}
template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
{
// check datatypes
constexpr bool is_half = std::is_same<X, half_t>::value;
constexpr bool is_float = std::is_same<X, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted.");
retval = (sign << (DstT_exp + DstT_mant)) | (exponent << DstT_mant) | mantissa;
return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
return bit_cast<DstT>(retval);
}
template <typename X, typename Y, bool negative_zero_nan>
CK_TILE_HOST_DEVICE Y cast_from_f8(X x)
template <typename X, typename Y, bool clip, bool stoch>
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<Y, half_t>::value;
constexpr bool is_float = std::is_same<Y, float>::value;
static_assert(is_half || is_float, "only half and float are supported.");
return run_cast_from_f8<X, Y, negative_zero_nan>(x);
return bit_cast<Y>(run_cast_to_f8<X, Y, clip, stoch>(x, rng));
}
} // namespace impl
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x)
#if CK_TILE_FP8_CVT_DEVICE
/**
* @brief Cast float to fp8/bf8 using device conversion instructions
*/
template <fp8_interpretation interpret, bool saturate, bool stochastic_rounding = false>
CK_TILE_DEVICE uint8_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
{
constexpr int seed = 42;
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx94__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
uint8_t i8data;
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
unsigned int i32val;
unsigned char i8val[4]; // NOTE: not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
return bit_cast<fp8_raw_t>(impl::cast_to_f8<float,
fp8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif
}
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator_t<float, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx94__)
union
unsigned int ival = 0;
val.fval = v;
if constexpr(saturate)
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic;
return bit_cast<bf8_raw_t>(impl::cast_to_f8<float,
bf8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
#endif
if constexpr(interpret == fp8_interpretation::E4M3_FNUZ)
{
if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
}
}
else if constexpr(interpret == fp8_interpretation::E4M3_OCP)
{ // OCP type
if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
}
}
else
{
if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
}
}
}
if constexpr(stochastic_rounding)
{
ival = (interpret == fp8_interpretation::E4M3_FNUZ) ||
(interpret == fp8_interpretation::E4M3_OCP)
? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
: __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
i8data = val.i8val[0]; // little endian
}
else
{ // RNE CVT
ival = (interpret == fp8_interpretation::E4M3_FNUZ) ||
(interpret == fp8_interpretation::E4M3_OCP)
? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
: __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
val.fval,
ival,
false); // false -> WORD0
val.i32val = ival;
i8data = val.i8val[0];
}
return i8data;
}
#endif // CK_TILE_FP8_CVT_DEVICE
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x)
} // namespace impl
/**
* @brief Converts a floating-point value to an 8-bit floating-point representation with stochastic
* rounding.
*
* This function converts a floating-point value (float or half_t) to an 8-bit floating-point
* representation of type fp8_t or bf8_t. The conversion process may
* involve clipping and uses a pseudo-random number generator for the stochastic rounding.
*
* @tparam DstT The destination type (fp8_t or bf8_t).
* @tparam SrcT The source type (float or half_t) to be converted.
* @param x The floating-point value to be converted.
* @return The 8-bit floating-point representation of the input value.
*/
template <typename SrcT, typename DstT>
CK_TILE_HOST_DEVICE typename numeric_traits<DstT>::bitwise_type float_to_fp8_sr_raw(SrcT x)
{
#if defined(__gfx94__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
constexpr bool clip = true;
constexpr int seed = 42;
uint32_t rng = prand_generator_t<SrcT, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
#if CK_TILE_FP8_CVT_DEVICE
return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip, true>(x, rng);
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return bit_cast<fp8_raw_t>(impl::cast_to_f8<float,
fp8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
impl::cast_to_f8<SrcT, DstT, clip, true>(x, rng));
#endif
}
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
/**
* @brief Converts a floating-point value to an 8-bit floating-point representation with rounding to
* nearest even.
*
* This function converts a floating-point value (float or half_t) to an 8-bit floating-point
* representation of type fp8_t or bf8_t. The conversion process may involve clipping.
*
* @tparam DstT The destination type (fp8_t or bf8_t).
* @tparam SrcT The source type (float or half_t) to be converted.
* @param x The floating-point value to be converted.
* @return The 8-bit floating-point representation of the input value.
*/
template <typename SrcT, typename DstT>
CK_TILE_HOST_DEVICE typename numeric_traits<DstT>::bitwise_type float_to_fp8_rtn_raw(SrcT x)
{
#if defined(__gfx94__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
constexpr bool clip = true;
#if CK_TILE_FP8_CVT_DEVICE
return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip, false>(x, 0);
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return bit_cast<bf8_raw_t>(impl::cast_to_f8<float,
bf8_t,
negative_zero_nan,
clip,
(rm == fp8_rounding_mode::stochastic)>(x, rng));
return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
impl::cast_to_f8<SrcT, DstT, clip, false>(x, 0));
#endif
}
// clang-format off
template<fp8_rounding_mode rounding>
template <fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
else return fp8_raw_t{0};
if constexpr(rounding == fp8_rounding_mode::standard)
{
return float_to_fp8_rtn_raw<float, fp8_t>(x);
}
else if constexpr(rounding == fp8_rounding_mode::stochastic)
{
return float_to_fp8_sr_raw<float, fp8_t>(x);
}
else
{
return fp8_raw_t{0};
}
}
template<fp8_rounding_mode rounding>
template <fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
else return bf8_raw_t{0};
if constexpr(rounding == fp8_rounding_mode::standard)
{
return float_to_fp8_rtn_raw<float, bf8_t>(x);
}
else if constexpr(rounding == fp8_rounding_mode::stochastic)
{
return float_to_fp8_sr_raw<float, bf8_t>(x);
}
else
{
return bf8_raw_t{0};
}
}
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
{
#if defined(__gfx94__)
#if CK_TILE_FP8_CVT_DEVICE
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(bit_cast<fp8_t>(x));
return impl::run_cast_from_f8<fp8_t, float>(bit_cast<fp8_t>(x));
#endif
}
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
{
#if defined(__gfx94__)
#if CK_TILE_FP8_CVT_DEVICE
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true;
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(bit_cast<bf8_t>(x));
return impl::run_cast_from_f8<bf8_t, float>(bit_cast<bf8_t>(x));
#endif
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant<rounding> = {})
{
return bit_cast<fp8_t>(float_to_fp8_raw(x, constant<rounding>{}));
}
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant<rounding> = {})
{
return bit_cast<bf8_t>(float_to_bf8_raw(x, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
{
return fp8_to_float_raw(bit_cast<fp8_raw_t>(x));
}
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x) { return fp8_to_float_raw(bit_cast<fp8_raw_t>(x)); }
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
{
return bf8_to_float_raw(bit_cast<bf8_raw_t>(x));
}
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x) { return bf8_to_float_raw(bit_cast<bf8_raw_t>(x)); }
// clang-format on
template <typename T>
struct numeric_traits;
template <class T>
struct numeric;
#if CK_TILE_USE_OCP_FP8
template <>
struct numeric_traits<fp8_t>
struct numeric<fp8_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
#if defined(__gfx94__)
static constexpr int bias = 8;
#else
static constexpr int bias = 7;
#endif
// minimum finite value, or minimum positive normal value
CK_TILE_HOST_DEVICE static constexpr fp8_t min()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08)); // 0b00001000 = 2^-6
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xfe)); // 0b11111110 = -448
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t max()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7e)); // 0b01111110 = 448
}
// difference between 1.0 and next representable f8 value (1.125)
// returns fp8_t(0.125)
CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20)); // 0.125
}
// rounding error (0.0625)
// half of epsilon
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x18)); // 0.0625
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7F)); // 0b01111111
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xFF)); // 0b11111111
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
}
CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
}
};
template <>
struct numeric_traits<bf8_t>
struct numeric<bf8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
#if defined(__gfx94__)
static constexpr int bias = 16;
#else
static constexpr int bias = 15; // IEEE
#endif
};
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr bf8_t min()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04)); // 0b00000100 = 2^-14
}
template <class T>
struct numeric;
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xfb)); // 0b11111011 = -57344
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t max()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7b)); // 0b01111011 = 57344
}
// difference between 1.0 and next representable bf8 value (1.25)
CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34)); // 0.25
}
// rounding error (0.125)
// half of epsilon
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x30)); // 0.125
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7c)); // 0b01111100
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7F)); // 0b01111111
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xFF));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
}
CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
}
};
#else
template <>
struct numeric<fp8_t>
{
......@@ -811,6 +1054,7 @@ struct numeric<bf8_t>
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
}
};
#endif
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t)
......@@ -818,19 +1062,26 @@ CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
#endif
// math
CK_TILE_HOST_DEVICE
fp8_t abs(const fp8_t& x)
template <typename T>
CK_TILE_HOST_DEVICE T abs(const T& x)
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(bit_cast<fp8_raw_t>(x) & 0x7f));
static_assert(std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>,
"Only fp8_t and bf8_t are supported");
return bit_cast<T>(static_cast<uint8_t>(bit_cast<uint8_t>(x) & numeric_traits<T>::abs_mask));
}
CK_TILE_HOST_DEVICE
bool isnan(const fp8_t& x)
{
uint8_t xx = bit_cast<fp8_raw_t>(x);
return xx == 0x80; // TODO: NANOO
}
#if CK_TILE_USE_OCP_FP8
return (xx & 0x7f) == 0x7f;
#else
return xx == 0x80;
#endif
}
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_DEVICE
fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
......@@ -842,20 +1093,21 @@ fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); }
CK_TILE_DEVICE
fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
CK_TILE_HOST_DEVICE
bf8_t abs(const bf8_t& x)
{
return bit_cast<bf8_t>(static_cast<fp8_raw_t>(bit_cast<bf8_raw_t>(x) & 0x7f));
}
#endif
CK_TILE_HOST_DEVICE
bool isnan(const bf8_t& x)
{
uint8_t xx = bit_cast<bf8_raw_t>(x);
return xx == 0x80; // TODO: NANOO
#if CK_TILE_USE_OCP_FP8
return (xx & 0x7f) > 0x7c;
#else
return xx == 0x80;
#endif
}
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_DEVICE
bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
......@@ -867,5 +1119,6 @@ bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); }
CK_TILE_DEVICE
bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };
#endif
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
......@@ -236,10 +236,11 @@ struct numeric_traits<half_t>
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
static constexpr uint16_t abs_mask = 0x7FFF;
static constexpr uint16_t Inf = 0x7C00;
static constexpr uint16_t NegInf = 0xFC00;
static constexpr uint16_t NaN = 0x7C01;
static constexpr uint16_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -89,6 +89,7 @@ struct numeric_traits<float>
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t abs_mask = 0x7FFFFFFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -18,8 +18,17 @@
namespace ck_tile {
// Note: this tile window do not support single issue
// you need to use tile_window_linear structure for this purpose
/**
* @brief This class provides tile (windowed) view and access to the device memory.
*
* @note This tile window does not support single issue you need to use tile_window_linear
* structure for this purpose
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
* @tparam NumCoord TBD
*/
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
......@@ -1009,6 +1018,14 @@ CK_TILE_DEVICE void move_tile_window(
window.move(step);
}
/**
* @brief This class provides description of tile windowed view on the device memory.
*
* @note This class does not provide any functions to read or modify device memory.
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
*/
template <typename BottomTensorView_, typename WindowLengths_>
struct tile_window_with_static_lengths
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
namespace ck_tile {
namespace detail {
template <typename OutTensor, typename InTensor>
CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
const InTensor& in_tensor)
{
constexpr auto I0 = number<0>{};
static_assert(std::is_same_v<typename InTensor::DataType, typename OutTensor::DataType>,
"Data type for InTensor and OutTensor must be the same!");
using DataType = typename InTensor::DataType;
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
// y_dim_out_to_in
// For swapped Hs tile case I need only get_rh_minor_to_y
// since rh_major are already swapped due to swapped Hs.
constexpr auto get_rh_minor_to_y = [](auto dstr_tensor) {
using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
map<index_t, index_t> rh_minor_to_y_;
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
rh_minor_to_y_(rh_minor) = i;
});
return rh_minor_to_y_;
};
// In swapped Hs case <Y,X> -> <X,Y> tile
// we have same rh_major, but reversed rh_minor!
constexpr auto rh_minor_to_y_in = get_rh_minor_to_y(InTensor{});
constexpr auto rh_minor_to_y_out = get_rh_minor_to_y(OutTensor{});
// Is this really needed?? Should we have simple reverse here??
constexpr auto y_dim_out_to_in = [&] {
map<index_t, index_t> y_dim_out_to_in_;
for(const auto& [rh_minor, y_out] : rh_minor_to_y_out)
{
y_dim_out_to_in_(y_out) = rh_minor_to_y_in[rh_minor];
}
return y_dim_out_to_in_;
}();
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
// input and output vector dim in the order of input Y dims
constexpr index_t y_dim_vec_in = NDimY - 1;
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
// vector lengths
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];
constexpr index_t vec_length_out = y_lengths[y_dim_vec_out];
// # of vectors
constexpr index_t num_vec_in = vec_length_out;
constexpr index_t num_vec_out = vec_length_in;
using InVec = array<DataType, vec_length_in>;
using OutVec = array<DataType, vec_length_out>;
// SFC
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; },
number<NDimY>{});
constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY);
using SFC_Y = space_filling_curve<decltype(y_lengths),
typename arithmetic_sequence_gen<0, NDimY, 1>::type,
decltype(scalars_per_access)>;
constexpr index_t num_access = SFC_Y::get_num_of_access();
static_assert(num_access > 0, "wrong! num_access should be larger than 0");
// in/out vectors to be transposed
thread_buffer<InVec, num_vec_in> in_vectors;
thread_buffer<OutVec, num_vec_out> out_vectors;
// loop over SFC and do transpose
static_for<0, num_access, 1>{}([&](auto iAccess) {
// data index [y0, y1, ...] in the order of input tensor
constexpr auto idx_y_start = SFC_Y::get_index(iAccess);
// get input vectors
static_for<0, num_vec_in, 1>{}([&](auto i) {
constexpr auto idx_y_in = generate_tuple(
[&](auto ii) {
return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii];
},
number<NDimY>{});
constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in);
static_assert(in_offset % vec_length_in == 0);
in_vectors(i).template get_as<InVec>()(I0) =
in_tensor.get_thread_buffer()
.template get_as<InVec>()[number<in_offset / vec_length_in>{}];
});
// transpose
transpose_vectors<DataType, num_vec_in, num_vec_out>{}(in_vectors, out_vectors);
// set output vectors
static_for<0, num_vec_out, 1>{}([&](auto i) {
constexpr auto idx_y_out_tmp = generate_array(
[&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; },
number<NDimY>{});
constexpr auto idx_y_out =
container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in);
constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out);
static_assert(out_offset % vec_length_out == 0);
out_tensor.get_thread_buffer().template set_as<OutVec>(
number<out_offset / vec_length_out>{},
out_vectors[i].template get_as<OutVec>()[I0]);
});
});
}
} // namespace detail
template <typename OutTensor, typename InTensor>
CK_TILE_DEVICE void transpose_tile2d(OutTensor& out, const InTensor& in)
{
using InDataType = typename InTensor::DataType;
using OutDataType = typename OutTensor::DataType;
using InTileDistr = typename InTensor::StaticTileDistribution;
using OutTileDistr = typename OutTensor::StaticTileDistribution;
using InDstrEncode = typename InTileDistr::DstrEncode;
using OutDstrEncode = typename OutTileDistr::DstrEncode;
using InThreadTensorDesc = typename InTensor::ThreadTensorDesc;
using OutThreadTensorDesc = typename OutTensor::ThreadTensorDesc;
// Ys:
constexpr auto in_thread_desc_lengths = InThreadTensorDesc{}.get_lengths();
constexpr auto out_thread_desc_lengths = OutThreadTensorDesc{}.get_lengths();
// type convert
const auto in_tmp = [&]() {
if constexpr(std::is_same_v<OutDataType, InDataType>)
{
return in;
}
else
{
return tile_elementwise_in(type_convert<OutDataType, InDataType>, in);
}
}();
// Scenario where we switch from tile <Y, X> -> <X, Y> - only 2D tiles!
// we preserve Ps but swap Ys: <Y1, Y0> -> <Y0, Y1>
if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ &&
InDstrEncode::hs_lengthss_ == tuple_reverse(OutDstrEncode::hs_lengthss_) &&
InDstrEncode::NDimY == OutDstrEncode::NDimY && InDstrEncode::NDimY == 2 &&
in_thread_desc_lengths == tuple_reverse(out_thread_desc_lengths))
// Any condition on Ps ??
// InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ &&
// InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ &&
{
detail::transpose_tile2d_impl_in_thread(out, in_tmp);
}
else
{
static_assert(false, "Provided tensors could not be transposed!");
}
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace ck_tile {
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
template <typename T>
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template <typename T>
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
} // namespace ck_tile
......@@ -109,4 +109,22 @@ CK_TILE_HOST_DEVICE PY c_style_pointer_cast(PX p_x)
#pragma clang diagnostic pop
}
template <typename CompareTo, typename... Rest>
struct is_any_of : std::false_type
{
};
template <typename CompareTo, typename FirstType>
struct is_any_of<CompareTo, FirstType> : std::is_same<CompareTo, FirstType>
{
};
template <typename CompareTo, typename FirstType, typename... Rest>
struct is_any_of<CompareTo, FirstType, Rest...>
: std::integral_constant<bool,
std::is_same<CompareTo, FirstType>::value ||
is_any_of<CompareTo, Rest...>::value>
{
};
} // namespace ck_tile
......@@ -51,16 +51,18 @@ struct composes<F>
template <typename... Ts>
__host__ __device__ composes(Ts&&...)->composes<remove_cvref_t<Ts>...>;
template <typename To>
template <typename SaturateType>
struct saturates
{
template <typename From>
CK_TILE_HOST_DEVICE constexpr auto operator()(const From& from) const
-> std::enable_if_t<std::is_arithmetic_v<From>, From>
// NOTE: this function does not return SaturateType value
// it is user's responsiblity to do further cast or not
template <typename AccType>
CK_TILE_HOST_DEVICE constexpr auto operator()(const AccType& a_) const
-> std::enable_if_t<std::is_arithmetic_v<AccType>, AccType>
{
return clamp(from,
type_convert<From>(numeric<To>::lowest()),
type_convert<From>(numeric<To>::max()));
return clamp(a_,
type_convert<AccType>(numeric<SaturateType>::lowest()),
type_convert<AccType>(numeric<SaturateType>::max()));
}
};
......
......@@ -34,3 +34,4 @@
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -18,6 +18,112 @@
namespace ck_tile {
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_relative_threshold(const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0;
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
{
return 0;
}
else
{
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0;
if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
{
return 0;
}
else
{
output_error = std::pow(2, -numeric_traits<OutDataType>::mant) * 0.5;
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0;
if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
{
return 0;
}
else
{
acc_error = std::pow(2, -numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
}
return std::max(acc_error, midway_error);
}
template <typename ComputeDataType, typename OutDataType, typename AccDataType = ComputeDataType>
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{
using F8 = ck_tile::fp8_t;
using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t;
using F32 = float;
using I8 = int8_t;
using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0;
if constexpr(is_any_of<ComputeDataType, I8, I32, int>::value)
{
return 0;
}
else
{
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
}
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0;
if constexpr(is_any_of<OutDataType, I8, I32, int>::value)
{
return 0;
}
else
{
output_error = std::pow(2, expo - numeric_traits<OutDataType>::mant) * 0.5;
}
double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0;
if constexpr(is_any_of<AccDataType, I8, I32, int>::value)
{
return 0;
}
else
{
acc_error =
std::pow(2, expo - numeric_traits<AccDataType>::mant) * 0.5 * number_of_accumulations;
}
return std::max(acc_error, midway_error);
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
......@@ -337,7 +443,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
if(!res)
{
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
}
return res;
}
......
......@@ -14,57 +14,41 @@ namespace detail {
template <typename OldLayout>
CK_TILE_HOST std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKW>)
using namespace ck_tile::tensor_layout::convolution;
if constexpr(is_any_of<OldLayout, GNCW, GKCX, GNKW>::value)
{
return {0, 1, 2, 3};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKHW>)
else if constexpr(is_any_of<OldLayout, GNCHW, GKCYX, GNKHW>::value)
{
return {0, 1, 2, 3, 4};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCDHW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCZYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
else if constexpr(is_any_of<OldLayout, GNCDHW, GKCZYX, GNKDHW>::value)
{
return {0, 1, 2, 3, 4, 5};
}
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWK>)
if constexpr(is_any_of<OldLayout, GNWC, GKXC, GNWK>::value)
{
return {0, 1, 3, 2};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWK>)
else if constexpr(is_any_of<OldLayout, GNHWC, GKYXC, GNHWK>::value)
{
return {0, 1, 4, 2, 3};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKZYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
else if constexpr(is_any_of<OldLayout, GNDHWC, GKZYXC, GNDHWK>::value)
{
return {0, 1, 5, 2, 3, 4};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGK>)
else if constexpr(is_any_of<OldLayout, NWGC, KXGC, NWGK>::value)
{
return {2, 0, 3, 1};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGK>)
else if constexpr(is_any_of<OldLayout, NHWGC, KYXGC, NHWGK>::value)
{
return {3, 0, 4, 1, 2};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KZYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
else if constexpr(is_any_of<OldLayout, NDHWGC, KZYXGC, NDHWGK>::value)
{
return {4, 0, 5, 1, 2, 3};
}
......@@ -83,11 +67,11 @@ template <typename InLayout>
CK_TILE_HOST HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param)
{
using namespace ck_tile::tensor_layout::convolution;
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCDHW>)
if constexpr(is_any_of<InLayout, GNCW, GNCHW, GNCDHW>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
......@@ -97,9 +81,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNDHWC>)
else if constexpr(is_any_of<InLayout, GNWC, GNHWC, GNDHWC>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
......@@ -109,9 +91,7 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvPara
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>)
else if constexpr(is_any_of<InLayout, NWGC, NHWGC, NDHWGC>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
......@@ -139,11 +119,11 @@ template <typename WeiLayout>
CK_TILE_HOST HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param)
{
using namespace ck_tile::tensor_layout::convolution;
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXC>)
if constexpr(is_any_of<WeiLayout, KXC, KYXC, KZYXC>::value)
{
if(param.G_ != 1)
{
......@@ -157,9 +137,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCZYX>)
else if constexpr(is_any_of<WeiLayout, GKCX, GKCYX, GKCZYX>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
......@@ -169,9 +147,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKZYXC>)
else if constexpr(is_any_of<WeiLayout, GKXC, GKYXC, GKZYXC>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
......@@ -181,9 +157,7 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvPara
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXGC>)
else if constexpr(is_any_of<WeiLayout, KXGC, KYXGC, KZYXGC>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.G_),
......@@ -211,11 +185,11 @@ template <typename OutLayout>
CK_TILE_HOST HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param)
{
using namespace ck_tile::tensor_layout::convolution;
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKHW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
if constexpr(is_any_of<OutLayout, GNKW, GNKHW, GNKDHW>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
......@@ -226,9 +200,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNHWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
else if constexpr(is_any_of<OutLayout, GNWK, GNHWK, GNDHWK>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
......@@ -238,9 +210,7 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvPar
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NHWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
else if constexpr(is_any_of<OutLayout, NWGK, NHWGK, NDHWGK>::value)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -678,4 +678,43 @@ struct HostTensor
Descriptor mDesc;
Data mData;
};
template <bool is_row_major>
auto host_tensor_descriptor(std::size_t row,
std::size_t col,
std::size_t stride,
bool_constant<is_row_major>)
{
using namespace ck_tile::literals;
if constexpr(is_row_major)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
}
template <bool is_row_major>
auto get_default_stride(std::size_t row,
std::size_t col,
std::size_t stride,
bool_constant<is_row_major>)
{
if(stride == 0)
{
if constexpr(is_row_major)
{
return col;
}
else
{
return row;
}
}
else
return stride;
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename Type>
CK_TILE_HOST void reference_batched_transpose(const HostTensor<Type>& x,
HostTensor<Type>& y,
std::string layout_in = "NCHW",
std::string layout_out = "NHWC")
{
const int N = x.mDesc.get_lengths()[0];
auto f = [&](auto batch) {
if(layout_in == "NCHW" && layout_out == "NHWC")
{
const int C = x.mDesc.get_lengths()[1];
const int H = x.mDesc.get_lengths()[2];
const int W = x.mDesc.get_lengths()[3];
for(int c = 0; c < C; ++c)
{
for(int h = 0; h < H; ++h)
{
for(int w = 0; w < W; ++w)
{
Type v_x = x(batch, c, h, w);
y(batch, h, w, c) = v_x;
}
}
}
}
else if(layout_in == "NHWC" && layout_out == "NCHW")
{
const int H = x.mDesc.get_lengths()[1];
const int W = x.mDesc.get_lengths()[2];
const int C = x.mDesc.get_lengths()[3];
for(int h = 0; h < H; ++h)
{
for(int w = 0; w < W; ++w)
{
for(int c = 0; c < C; ++c)
{
Type v_x = x(batch, h, w, c);
y(batch, c, h, w) = v_x;
}
}
}
}
};
make_ParallelTensorFunctor(f, N)(std::thread::hardware_concurrency());
}
} // namespace ck_tile
......@@ -73,7 +73,7 @@ void reference_fused_moe(
ck_tile::index_t tokens,
ck_tile::index_t experts,
ck_tile::index_t hidden_size,
ck_tile::index_t intermediate_size, // this size is for gate/up
ck_tile::index_t intermediate_size, // this size is for gate/up/down
ck_tile::index_t topk,
ck_tile::index_t gate_only)
{
......@@ -82,19 +82,8 @@ void reference_fused_moe(
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
assert(num_sorted_tiles_host.get_element_size() == 1);
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
ck_tile::index_t intermediate_size_0 = intermediate_size;
ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2);
// TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
{
if(token_ids_host(token_id_, i_) == expert_id_)
return i_;
}
throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!!
};
ck_tile::index_t intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2);
ck_tile::index_t intermediate_size_1 = intermediate_size;
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
......@@ -105,11 +94,31 @@ void reference_fused_moe(
if(i_tile >= num_sorted_tiles)
return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
ck_tile::index_t i_topk = i_token >> 24;
i_token &= 0xffffff;
if(i_token >= tokens)
return;
(void)token_ids_host;
#else
// TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
{
if(token_ids_host(token_id_, i_) == expert_id_)
return i_;
}
throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!!
};
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
if(i_token >= tokens)
return;
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
auto weight = sorted_weight_host.mData[i_flatten];
#endif
auto weight = sorted_weight_host.mData[i_flatten];
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
// first gemm
......
......@@ -8,16 +8,40 @@
namespace ck_tile {
// Note: for simplicity, each functor only care about single M
struct reference_rmsnorm2d_default_epilogue
{
template <typename OutDataType, typename AccDataType>
void operator()(int m, HostTensor<OutDataType>& o, const HostTensor<AccDataType>& acc)
{
const int N = acc.mDesc.get_lengths()[1];
for(int n = 0; n < N; ++n)
{
o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
}
}
template <typename OutDataType, typename AccDataType>
auto operator()(int m, const HostTensor<AccDataType>& acc)
{
HostTensor<OutDataType> o(acc.get_lengths(), acc.get_strides());
operator()(m, o, acc);
return o;
}
};
template <typename XDataType,
typename GammaDataType,
typename ComputeDataType,
typename YDataType,
typename InvRmsDataType>
typename InvRmsDataType,
typename Epilogue = reference_rmsnorm2d_default_epilogue>
void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
const HostTensor<GammaDataType>& gamma_n,
HostTensor<YDataType>& y_m_n,
HostTensor<InvRmsDataType>& invRms_m,
ComputeDataType epsilon)
ComputeDataType epsilon,
Epilogue epilogue_functor = {})
{
auto rmsnorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
......@@ -37,13 +61,15 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
for(int n = 0; n < N; ++n)
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
auto y = x * divisor * gamma;
y_m_n(m, n) = ck_tile::type_convert<YDataType>(y);
acc(m, n) = x * divisor * gamma;
}
epilogue_functor(m, y_m_n, acc);
};
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
......
......@@ -22,7 +22,7 @@ CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor<XDataType>&
// scale = amax / 127 for int8
auto v_scale = type_convert<XDataType>(scale_m(m));
auto v_qx = v_x / v_scale;
qx_m_n(m, n) = saturates<QXDataType>{}(v_qx);
qx_m_n(m, n) = type_convert<QXDataType>(saturates<QXDataType>{}(v_qx));
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
struct BatchedTransposeHostArgs
{
const void* p_input;
void* p_output;
index_t batch;
index_t height;
index_t width;
// index_t dim_blocks;
index_t dim_stride;
index_t dim_block_h;
index_t dim_block_w;
};
template <typename Pipeline_>
struct BatchedTransposeKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using Type = typename Problem::InputType;
struct BatchedTransposeKargs
{
const void* p_input;
void* p_output;
index_t batch;
index_t height;
index_t width;
index_t dim_stride;
};
using Kargs = BatchedTransposeKargs;
using Hargs = BatchedTransposeHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w;
size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h;
size_t grid_size_z = h.batch;
return dim3(grid_size_x, grid_size_y, grid_size_z);
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_input = h.p_input;
k.p_output = h.p_output;
k.batch = h.batch;
k.height = h.height;
k.width = h.width;
k.dim_stride = h.dim_stride;
return k;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr ck_tile::index_t kMPerThread = Problem::kMPerThread;
static constexpr ck_tile::index_t kNPerThread = Problem::kNPerThread;
static_assert(kMPerThread == 1 && kNPerThread == 1);
const auto iDim = blockIdx.z;
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const Type*>(kargs.p_input) + iDim * kargs.dim_stride,
make_tuple(kargs.height, kargs.width),
make_tuple(kargs.width, 1),
number<kNPerThread>{}, // TODO thread load value
number<1>{});
return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock);
const auto y_n_m = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<Type*>(kargs.p_output) + iDim * kargs.dim_stride,
make_tuple(kargs.width, kargs.height),
make_tuple(kargs.height, 1),
number<kMPerThread>{},
number<1>{});
return pad_tensor_view(y_dram_naive,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
sequence<kPadN, kPadM>{});
}();
auto x_block_window =
make_tile_window(x_m_n,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{static_cast<ck_tile::index_t>(iM * kMPerBlock),
static_cast<ck_tile::index_t>(iN * kNPerBlock)});
auto y_block_window =
make_tile_window(y_n_m,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
{static_cast<ck_tile::index_t>(iN * kNPerBlock),
static_cast<ck_tile::index_t>(iM * kMPerBlock)});
Pipeline{}(x_block_window, y_block_window);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = BatchedTransposePolicy>
struct BatchedTransposePipeline
{
// TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using InputType = ck_tile::remove_cvref_t<typename Problem::InputType>;
static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t AlignmentM = Problem::AlignmentM;
static constexpr index_t AlignmentN = Problem::AlignmentN;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
template <typename InputWindow, typename OutputWindow>
CK_TILE_DEVICE auto operator()(const InputWindow& input_window, OutputWindow& out_window)
{
auto inp_win =
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
auto out_win =
make_tile_window(out_window, Policy::template MakeOutputDistribution<Problem>());
auto x = load_tile(inp_win); // x->thread input_win->block
auto y = make_static_distributed_tensor<InputType>(
Policy::template MakeOutputDistribution<Problem>());
constexpr auto span_2d_x = decltype(x)::get_distributed_spans();
sweep_tile_span(span_2d_x[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d_x[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx1, idx0);
y(i_j_idx) = x(i_j_idx);
});
});
store_tile(out_win, y);
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment