"vscode:/vscode.git/clone" did not exist on "cb68807ffcaa157b9ee0826fecd61fb439efe9b3"
Unverified Commit 3c5717df authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into gemm_elementwise_gemm

parents 171b9030 d9f1ead3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "data_type.hpp"
......@@ -429,7 +429,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
using r_t = typename vector_type<T, N>::type;
......@@ -549,8 +550,10 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_fnuz_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, fp8_storage_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
......@@ -578,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
tmp.template AsType<half2_t>()[i]);
});
}
#if defined(__gfx942__)
#if defined(__gfx942__) || defined(__gfx950__)
else if constexpr(is_same<T, bhalf_t>::value)
{
vector_type<bhalf_t, N> tmp{src_thread_data};
......@@ -843,8 +846,8 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#else
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0)};
return src_thread_element_valid ? tmp : vector_t(0);
#endif
}
......@@ -873,8 +876,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
constexpr index_t vector_size = scalar_type<vector_t>::vector_size;
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
vector_t tmp{amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0)};
return src_thread_element_valid ? tmp : vector_t(customized_value);
}
......@@ -1018,15 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes);
#ifndef CK_CODE_GEN_RTC
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
#else
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<size_t>(global_base_ptr));
#endif
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T* lds_ptr = lds_base_ptr + lds_offset;
#ifndef CK_CODE_GEN_RTC
auto const lds_ptr_sgpr =
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
#else
auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<size_t>(lds_ptr)));
#endif
asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(global_offset_bytes),
......@@ -1035,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr =
#ifndef CK_CODE_GEN_RTC
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
#else
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<size_t>(lds_base_ptr + lds_offset));
#endif
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/type.hpp"
#ifdef CK_USE_FNUZ_FP8
#define CK_USE_FNUZ_FP8 1
#else
#define CK_USE_FNUZ_FP8 0
#endif
#ifdef CK_USE_OCP_FP8
#define CK_USE_OCP_FP8 1
#else
#define CK_USE_OCP_FP8 0
#endif
#if(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
defined(__gfx1201__) || defined(__gfx950__)) && \
__HIP_DEVICE_COMPILE__
#define CK_FP8_CVT_FAST_PATH 1
#else
#define CK_FP8_CVT_FAST_PATH 0
#endif
#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__
#define CK_OCP_FP8_CVT_FAST_PATH 1
#else
#define CK_OCP_FP8_CVT_FAST_PATH 0
#endif
namespace ck {
using f8_fnuz_t = _BitInt(8);
using bf8_fnuz_t = unsigned _BitInt(8);
typedef unsigned char fp8_storage_t;
/**
* \brief Describes FP8 interpretation
*/
enum class ck_fp8_interpretation_t
{
CK_E4M3_OCP = 0, // OCP E4M3
CK_E5M2_OCP = 1, // OCP E5M2
CK_E4M3_FNUZ = 2, // FP8
CK_E5M2_FNUZ = 3, // BF8
};
/**
* \brief Describes saturation behavior
*/
enum class ck_saturation_t
{
CK_NOSAT = 0, // No saturation - replace with NaN or Inf
CK_SATFINITE = 1, // Saturate to finite
};
namespace fp8_impl {
typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2)));
typedef float float2_t __attribute__((ext_vector_type(2)));
__host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a)
{
return static_cast<unsigned char>(a) == 0x80;
}
__host__ __device__ static inline constexpr bool fnuz_bf8_is_nan(bf8_fnuz_t a)
{
return static_cast<unsigned char>(a) == 0x80;
}
__host__ __device__ static inline constexpr bool ocp_f8_is_nan(fp8_storage_t a)
{
return (a & 0x7f) == 0x7f;
}
__host__ __device__ static inline constexpr bool ocp_bf8_is_nan(fp8_storage_t a)
{
return (a & 0x7f) > 0x7c;
}
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
// This has been modified to handle double types as well
template <typename T, int wm, int we, bool is_fnuz, bool clip = false>
__host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
{
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
static_assert(is_half || is_float || is_double, "only half, float and double are supported");
constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
if constexpr(is_half)
{
const unsigned short int ihInf = 0x7C00;
const unsigned short int ihNegInf = 0xFC00;
const unsigned short int ihNaN = 0x7C01;
const unsigned short int ihNeg0 = 0x8000;
/* Max number in e5m2 57344*/
const unsigned short int ifmax = 0x7B00;
const unsigned short int ifmin = 0xFB00;
fInf = bit_cast<_Float16>(ihInf);
fNegInf = bit_cast<_Float16>(ihNegInf);
fNaN = bit_cast<_Float16>(ihNaN);
fNeg0 = bit_cast<_Float16>(ihNeg0);
fmax = bit_cast<_Float16>(ifmax);
fmin = bit_cast<_Float16>(ifmin);
}
else if constexpr(is_float)
{
const unsigned int ifInf = 0x7F800000;
const unsigned int ifNegInf = 0xFF800000;
const unsigned int ifNaN = 0x7F800001;
const unsigned int ifNeg0 = 0x80000000;
/* Max number in e5m2 57344*/
const unsigned int ifmax = 0x47600000;
const unsigned int ifmin = 0xC7600000;
fInf = bit_cast<float>(ifInf);
fNegInf = bit_cast<float>(ifNegInf);
fNaN = bit_cast<float>(ifNaN);
fNeg0 = bit_cast<float>(ifNeg0);
fmax = bit_cast<float>(ifmax);
fmin = bit_cast<float>(ifmin);
}
else if constexpr(is_double)
{
const unsigned long long ifInf = 0x7FF0000000000000ull;
const unsigned long long ifNegInf = 0xFFF0000000000000ull;
const unsigned long long ifNaN = 0x7FF0000000000001ull;
const unsigned long long ifNeg0 = 0x8000000000000000ull;
/* Max number in e5m2 57344*/
const unsigned long long ifmax = 0x40EC000000000000ull;
const unsigned long long ifmin = 0xC0EC000000000000ull;
fInf = bit_cast<double>(ifInf);
fNegInf = bit_cast<double>(ifNegInf);
fNaN = bit_cast<double>(ifNaN);
fNeg0 = bit_cast<double>(ifNeg0);
fmax = bit_cast<double>(ifmax);
fmin = bit_cast<double>(ifmin);
}
if(x == 0)
{
return 0;
}
unsigned long long sign = x >> 7;
unsigned long long mantissa = x & ((1 << wm) - 1);
int exponent = (x & 0x7F) >> wm;
if constexpr(is_fnuz)
{
if(x == 0x80)
{
return fNaN;
}
}
else
{
if(x == 0x80)
{
return fNeg0;
}
if constexpr(we == 4)
{ // e4m3
if((x & 0x7F) == 0x7F)
{
return fNaN;
}
}
else if((x & 0x7C) == 0x7C)
{ // e5m2
if((x & 0x3) == 0)
{
if constexpr(clip)
{
return sign ? fmin : fmax;
}
return sign ? fNegInf : fInf;
}
return fNaN;
}
}
typename std::conditional<
sizeof(T) == 2,
unsigned short int,
typename std::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type
retval;
if constexpr(we == 5 && is_half && !is_fnuz)
{
retval = x << 8;
return bit_cast<T>(retval);
}
const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
// subnormal input
if(exponent == 0)
{
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __clz(mantissa) - (32 - wm);
#else
int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
#endif
mantissa <<= sh;
exponent += 1 - sh;
mantissa &= ((1ull << wm) - 1);
}
exponent += exp_low_cutoff - 1;
mantissa <<= wmo - wm;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << wmo;
mantissa >>= 1 - exponent;
exponent = 0;
}
if constexpr(sizeof(T) == 2)
retval = (sign << 15) | (exponent << 10) | mantissa;
else if constexpr(sizeof(T) == 4)
retval = (sign << 31) | (exponent << 23) | mantissa;
else
retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
return bit_cast<T>(retval);
}
#if CK_FP8_CVT_FAST_PATH
template <ck_fp8_interpretation_t interpret>
static __device__ float cast_to_f32_from_f8(fp8_storage_t v)
{
union
{
unsigned int i32val;
unsigned char i8val[4];
} val;
val.i8val[0] = v;
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only FNUZ and OCP interpretations are supported");
if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP))
{
return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
}
else
{
return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
}
}
template <ck_fp8_interpretation_t interpret>
static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v)
{
const auto i16val = bit_cast<uint16_t>(v);
static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ ||
interpret == ck_fp8_interpretation_t::CK_E4M3_OCP ||
interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ ||
interpret == ck_fp8_interpretation_t::CK_E5M2_OCP,
"Only FNUZ and OCP interpretations are supported");
if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP))
{
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false);
}
else
{
return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false);
}
}
#endif
} // namespace fp8_impl
struct f8_ocp_t
{
using data_type = fp8_storage_t;
data_type data;
static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
static constexpr ck_fp8_interpretation_t default_interpret =
ck_fp8_interpretation_t::CK_E4M3_OCP;
static constexpr unsigned int we = 4; // exponent width
static constexpr unsigned int wm = 3; // mantissa width
__host__ __device__ constexpr bool operator==(const f8_ocp_t& other) const
{
return (data == other.data) && (fp8_impl::ocp_f8_is_nan(data) == false); // NaN != NaN
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator float() const
#else
__host__ explicit operator float() const
#endif
{
#if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator _Float16
#endif
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator _Float16() const
#else
__host__ explicit operator _Float16() const
#endif
{
#if CK_OCP_FP8_CVT_FAST_PATH
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator float
#endif
}
};
struct bf8_ocp_t
{
using data_type = fp8_storage_t;
data_type data;
static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE;
static constexpr ck_fp8_interpretation_t default_interpret =
ck_fp8_interpretation_t::CK_E5M2_OCP;
static constexpr unsigned int we = 5; // exponent width
static constexpr unsigned int wm = 2; // mantissa width
__host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const
{
return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator float() const
#else
__host__ explicit operator float() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
#else
return fp8_impl::cast_from_f8<float, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator _Float16
#endif
}
#if CK_USE_OCP_FP8
__host__ __device__ explicit operator _Float16() const
#else
__host__ explicit operator _Float16() const
#endif
{
#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__)
return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
#else
return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
this->data); // XXX: clip==false must be consistent with operator float
#endif
}
};
template <typename T>
__host__ __device__ static inline constexpr bool fp8_is_nan(T);
template <>
__host__ __device__ inline constexpr bool fp8_is_nan(f8_ocp_t a)
{
return fp8_impl::ocp_f8_is_nan(a.data);
}
template <>
__host__ __device__ inline constexpr bool fp8_is_nan(bf8_ocp_t a)
{
return fp8_impl::ocp_bf8_is_nan(a.data);
}
template <>
__host__ __device__ inline constexpr bool fp8_is_nan(f8_fnuz_t a)
{
return fp8_impl::fnuz_f8_is_nan(a);
}
template <>
__host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a)
{
return fp8_impl::fnuz_bf8_is_nan(a);
}
template <typename T,
ck::enable_if_t<is_same_v<T, bf8_ocp_t> || is_same_v<T, f8_ocp_t> ||
is_same_v<T, bf8_fnuz_t> || is_same_v<T, f8_fnuz_t>,
bool> = true>
__host__ __device__ static inline constexpr bool fp8_is_inf(T)
{
return false;
}
template <>
__host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a)
{
return (a.data & 0x7f) == 0x7c;
}
namespace fp8_impl {
// Assertions to check for supported conversion types
#define __assert_ocp_support(interp) \
{ \
if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \
interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \
{ \
__hip_assert(false && "type is unsupported by current target device"); \
} \
}
#define __assert_fnuz_support(interp) \
{ \
if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \
interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \
{ \
__hip_assert(false && "type is unsupported by current target device"); \
} \
}
__host__ __device__ static inline void
__is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp)
{
#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__
#if CK_USE_OCP_FP8
__assert_ocp_support(interp);
#endif
#if CK_USE_FNUZ_FP8
__assert_fnuz_support(interp);
#endif
#endif
}
#if CK_FP8_CVT_FAST_PATH
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
template <ck_fp8_interpretation_t interpret, bool saturate, bool stochastic_rounding = false>
static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
{
fp8_storage_t i8data;
union
{
float fval;
unsigned int i32val;
unsigned char i8val[4]; // NOTE: not endian independent
} val;
unsigned int ival = 0;
val.fval = v;
if constexpr(saturate)
{
if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
{
if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
}
}
else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
{ // OCP type
if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
}
}
else
{
if((val.i32val & 0x7F800000) != 0x7F800000)
{ /// propagate NAN/INF, no clipping
val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
}
}
}
if constexpr(stochastic_rounding)
{
ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
: __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
i8data = val.i8val[0]; // little endian
}
else
{ // RNE CVT
ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) ||
(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)
? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
: __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
val.fval,
ival,
false); // false -> WORD0
val.i32val = ival;
i8data = val.i8val[0];
}
return i8data;
}
#endif // CK_FP8_CVT_FAST_PATH
// The conversion function is from rocblas
// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
// This has been modified to add double types conversion as well
template <typename T, int wm, int we, bool is_fnuz, bool clip = false, bool stoch = false>
__host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rng = 0)
{
constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
constexpr bool is_float = __hip_internal::is_same<T, float>::value;
constexpr bool is_double = __hip_internal::is_same<T, double>::value;
static_assert(is_half || is_float || is_double,
"Only half, float and double can be cast to f8");
constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
using T_bitwise = typename std::conditional<
sizeof(T) == 2,
unsigned short int,
typename std::conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type;
T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
unsigned long long x{x_bitwise};
unsigned long long head, mantissa;
int exponent, bias;
unsigned int sign;
unsigned long long fInf, mask;
if constexpr(sizeof(T) == 8)
{
head = x & 0xFFF0000000000000ull;
mantissa = x & 0xFFFFFFFFFFFFFull;
exponent = (head >> 52) & 0x7FF;
sign = head >> 63;
bias = 1023;
fInf = 0x7FF0000000000000ull;
mask = 0x7FFFFFFFFFFFFFFFull;
}
else if constexpr(sizeof(T) == 4)
{
head = x & 0xFF800000;
mantissa = x & 0x7FFFFF;
exponent = (head >> 23) & 0xFF;
sign = head >> 31;
bias = 127;
fInf = 0x7F800000;
mask = 0x7FFFFFFF;
}
else
{
head = x & 0xFC00;
mantissa = x & 0x3FF;
exponent = (head >> 10) & 0x1F;
sign = head >> 15;
bias = 15;
fInf = 0x7C00;
mask = 0x7FFF;
}
unsigned int signed_inf = 0;
unsigned int nan = 0;
if constexpr(is_fnuz)
{
signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
nan = 0x80;
}
else
{
if constexpr(we == 4)
{ // e4m3
signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
}
else
{ // e5m2
signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
}
nan = (sign << 7) + 0x7f;
}
// Max values
unsigned long long ifmax = 0;
if constexpr(sizeof(T) == 8)
{
if constexpr(we == 5)
{ // 57344
ifmax = 0x40EC000000000000ull;
}
else
{
if constexpr(is_fnuz)
{ // 240
ifmax = 0x406E000000000000ull;
}
else
{ // 448
ifmax = 0x407C000000000000ull;
}
}
}
else if(sizeof(T) == 4)
{
if constexpr(we == 5)
{
ifmax = 0x47600000;
}
else
{
if constexpr(is_fnuz)
{
ifmax = 0x43700000;
}
else
{
ifmax = 0x43E00000;
}
}
}
else
{
if constexpr(we == 5)
{
ifmax = 0x7B00;
}
else
{
if constexpr(is_fnuz)
{
ifmax = 0x5B80;
}
else
{
ifmax = 0x5F00;
}
}
}
// Deal with inf and NaNs
if((x & fInf) == fInf)
{
if constexpr(is_fnuz)
return signed_inf;
return mantissa != 0 ? nan : signed_inf;
}
if((x & mask) > ifmax)
{
return signed_inf;
}
if(x == 0)
{
return 0;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int act_exponent, f8_exponent, exponent_diff;
if(exponent == 0)
{ // fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
exponent bias 16. It means that there are some numbers in fp16 denormal but they
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent = exponent - bias + 1;
exponent_diff = f8_denormal_act_exponent -
act_exponent; // actual exponent is exponent-bias+1 as it is denormal
}
else
{ // fp32/fp16 is normal with implicit 1
act_exponent = exponent - bias;
if(act_exponent <= f8_denormal_act_exponent)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff = f8_denormal_act_exponent - act_exponent;
}
else
{ // both fp32/fp16 and f8 are in normal range
exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
// for this case, act_exponent could be larger. Just
// that it does not need shift mantissa
}
mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
}
bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
(1ull << (mfmt - wm + exponent_diff - 1));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part and
make something not midpoint look like midpoint. For example, the fp16 number
0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
by 4 bits, it would look like midpoint.
*/
if(exponent_diff > 0)
mantissa >>= exponent_diff;
else if(exponent_diff == -1)
mantissa <<= -exponent_diff;
bool implicit_one = mantissa & (1ull << mfmt);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent =
(act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
// Now we have the exponent and mantissa adjusted
unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
bool odd =
mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
mantissa +=
(stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
// Now we deal with overflow
if(f8_exponent == 0)
{
if((1ull << mfmt) & mantissa)
{
f8_exponent = 1; // denormal overflow to become normal, promote exponent
}
}
else
{
if((1ull << (mfmt + 1)) & mantissa)
{
mantissa >>= 1;
f8_exponent++;
}
}
mantissa >>= (mfmt - wm);
// above range: quantize to maximum possible float of the same sign
const int max_exp = (1 << we) - 1;
if(f8_exponent > max_exp)
{
if constexpr(clip)
{
mantissa = (1 << wm) - 1;
f8_exponent = max_exp;
}
else
{
return signed_inf;
}
}
if(f8_exponent == 0 && mantissa == 0)
return is_fnuz ? 0 : (sign << 7);
mantissa &= (1 << wm) - 1;
return (sign << 7) | (f8_exponent << wm) | mantissa;
}
/**
* \brief convert float to @p fp8_storage_t
*
* \tparam interp interpretation of fp8
* \tparam sat saturation of fp8
* \param f float number
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
{
__is_interpret_supported(interp);
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
#endif
}
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
f, rng);
#else
#if CK_USE_OCP_FP8
__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
{
#else
__host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
{
#endif
uint32_t rng = 0;
if constexpr(stochastic_rounding)
{
constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
#endif
}
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
{
return cast_to_f8<float,
3,
4,
true,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(f, rng);
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_FNUZ)
{
return cast_to_f8<float,
2,
5,
true,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(f, rng);
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP)
{
return cast_to_f8<float,
3,
4,
false,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(f, rng);
}
else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP)
{
return cast_to_f8<float,
2,
5,
false,
sat == ck_saturation_t::CK_SATFINITE,
stochastic_rounding>(f, rng);
}
else
{
__hip_assert(false && "FP8 type is not supported by current target device");
return 0;
}
#endif // CK_FP8_CVT_FAST_PATH
}
/**
* \brief convert _Float16 to @p fp8_storage_t
*
* \tparam sat saturation of fp8
* \tparam interp interpretation of fp8
* \tparam stochastic_rounding switch between RNE and SR
* \param x _Float16 value
* \return fp8_storage_t
*/
template <ck_fp8_interpretation_t interp,
ck_saturation_t sat = ck_saturation_t::CK_SATFINITE,
bool stochastic_rounding = false>
#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8
__host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
#else
__host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x)
#endif
{
return cvt_float_to_fp8<interp, sat, stochastic_rounding>(static_cast<float>(x));
}
} // namespace fp8_impl
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, float>(float x)
{
return f8_ocp_t{
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
// convert fp32 to bf8 with rounding to nearest even
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, float>(float x)
{
return bf8_ocp_t{
fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(x)};
}
// convert _Float16 to fp8 with rounding to nearest even
template <>
inline __host__ __device__ f8_ocp_t f8_convert_rne<f8_ocp_t, _Float16>(_Float16 x)
{
return f8_ocp_t{
fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation>(x)};
}
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_rne<bf8_ocp_t, _Float16>(_Float16 x)
{
return bf8_ocp_t{
fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret, bf8_ocp_t::default_saturation>(
x)};
}
// Declare a template function for fp8 conversion using RNE
template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, float>(float x)
{
return f8_ocp_t{
fp8_impl::cvt_float_to_fp8<f8_ocp_t::default_interpret, f8_ocp_t::default_saturation, true>(
x)};
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, float>(float x)
{
return bf8_ocp_t{fp8_impl::cvt_float_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
// convert _Float16 to fp8 with stochastic rounding
template <>
inline __host__ __device__ f8_ocp_t f8_convert_sr<f8_ocp_t, _Float16>(_Float16 x)
{
return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8<f8_ocp_t::default_interpret,
f8_ocp_t::default_saturation,
true>(x)};
}
// convert _Float16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_ocp_t f8_convert_sr<bf8_ocp_t, _Float16>(_Float16 x)
{
return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8<bf8_ocp_t::default_interpret,
bf8_ocp_t::default_saturation,
true>(x)};
}
#if CK_USE_OCP_FP8
using f8_t = f8_ocp_t;
using bf8_t = bf8_ocp_t;
#define CK_FP8_TYPE_FNUZ 0
#define CK_FP8_TYPE_OCP 1
#else
using f8_t = f8_fnuz_t;
using bf8_t = bf8_fnuz_t;
#define CK_FP8_TYPE_FNUZ 1
#define CK_FP8_TYPE_OCP 0
#endif
} // namespace ck
......@@ -4,13 +4,34 @@
#ifndef CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP
#include "data_type.hpp"
#include "c_style_pointer_cast.hpp"
#include "data_type.hpp"
// TODO: deprecate all amd_assembly_outer_product_xxx
namespace ck {
inline __device__ int amd_assembly_and_or_b32(int a, int b, int d)
{
int c;
asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(d));
return c;
}
inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c)
{
half2_t d;
asm volatile("v_pk_fma_f16 %0, %1, %2, %3" : "=v"(d) : "v"(a), "v"(b), "v"(c));
return d;
}
inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b)
{
half2_t c;
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b));
return c;
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -7,10 +7,12 @@
#include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp"
#ifndef CK_CODE_GEN_RTC
#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>
#endif
namespace ck {
namespace detail {
......@@ -37,7 +39,7 @@ struct get_carrier<3>
{
using value_type = uint32_t;
std::array<std::byte, 3> bytes;
Array<ck::byte, 3> bytes;
static_assert(sizeof(bytes) <= sizeof(value_type));
// replacement of host std::copy_n()
......@@ -61,22 +63,22 @@ struct get_carrier<3>
// method to trigger template substitution failure
__device__ carrier(const carrier& other) noexcept
{
copy_n(other.bytes.begin(), bytes.size(), bytes.begin());
copy_n(other.bytes.begin(), bytes.Size(), bytes.begin());
}
public:
__device__ carrier& operator=(value_type value) noexcept
{
copy_n(reinterpret_cast<const std::byte*>(&value), bytes.size(), bytes.begin());
copy_n(reinterpret_cast<const ck::byte*>(&value), bytes.Size(), bytes.begin());
return *this;
}
__device__ operator value_type() const noexcept
{
std::byte result[sizeof(value_type)];
ck::byte result[sizeof(value_type)];
copy_n(bytes.begin(), bytes.size(), result);
copy_n(bytes.begin(), bytes.Size(), result);
return *reinterpret_cast<const value_type*>(result);
}
......@@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
{
constexpr unsigned object_size = sizeof(int64_t);
constexpr unsigned second_part_offset = object_size / 2;
auto* const from_obj = reinterpret_cast<const std::byte*>(&value);
alignas(int64_t) std::byte to_obj[object_size];
auto* const from_obj = reinterpret_cast<const ck::byte*>(&value);
alignas(int64_t) ck::byte to_obj[object_size];
using Sgpr = uint32_t;
......@@ -122,17 +124,16 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
return *reinterpret_cast<int64_t*>(to_obj);
}
template <
typename Object,
typename = std::enable_if_t<std::is_class_v<Object> && std::is_trivially_copyable_v<Object>>>
template <typename Object,
typename = ck::enable_if_t<ck::is_class_v<Object> && ck::is_trivially_copyable_v<Object>>>
__device__ auto amd_wave_read_first_lane(const Object& obj)
{
using Size = unsigned;
constexpr Size SgprSize = 4;
constexpr Size ObjectSize = sizeof(Object);
auto* const from_obj = reinterpret_cast<const std::byte*>(&obj);
alignas(Object) std::byte to_obj[ObjectSize];
auto* const from_obj = reinterpret_cast<const ck::byte*>(&obj);
alignas(Object) ck::byte to_obj[ObjectSize];
constexpr Size RemainedSize = ObjectSize % SgprSize;
constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize;
......
......@@ -4,8 +4,8 @@
#pragma once
namespace ck {
// Define the common macro for gfx94x models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__
#endif
......@@ -134,6 +134,46 @@ struct intrin_mfma_f32_32x32x4f16<32, 64>
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f16;
template <>
struct intrin_mfma_f32_32x32x16f16<32, 32>
{
template <class FloatC>
__device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32f16;
template <>
struct intrin_mfma_f32_16x16x32f16<16, 16>
{
template <class FloatC>
__device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8f16;
......@@ -204,6 +244,46 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
};
// bfp16
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf16;
template <>
struct intrin_mfma_f32_32x32x16bf16<32, 32>
{
template <class FloatC>
__device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16(
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32bf16;
template <>
struct intrin_mfma_f32_16x16x32bf16<16, 16>
{
template <class FloatC>
__device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x8bf16_1k;
......@@ -298,6 +378,46 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x32i8;
template <>
struct intrin_mfma_i32_32x32x32i8<32, 32>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<int32x16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8(
reg_a, reg_b, reg_c.template AsType<int32x16_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_16x16x64i8;
template <>
struct intrin_mfma_i32_16x16x64i8<16, 16>
{
template <class FloatC>
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<int32x4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8(
reg_a, reg_b, reg_c.template AsType<int32x4_t>()[Number<0>{}], 0, 0, 0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif // defined(__gfx950__)
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_i32_32x32x16i8;
......@@ -356,6 +476,149 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x64f8f6f4;
/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6,
/// and f4 data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template <>
struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_scale_f32_32x32x64f8f6f4;
template <>
struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0, // { OPSEL_HI[0], OPSEL[0] }?
scale_a,
0, // { OPSEL_HI[1], OPSEL[1] }?
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_scale_f32_16x16x128f8f6f4;
template <>
struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a,
const int32_t scale_a,
const f8x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0, // { OPSEL_HI[0], OPSEL[0] }?
scale_a,
0, // { OPSEL_HI[1], OPSEL[1] }?
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x128f8f6f4;
/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4
/// data types.
///
/// @note Calls scaled version of the instruction as the original instruction is not supported in
/// the backend. That is the intended use. There is a backend optimization to select the unscaled
/// operation if the scale is 0.
template <>
struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0, // cbsz
0, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP
......@@ -38,6 +38,8 @@ struct Array
}
__host__ __device__ constexpr const TData* begin() const { return &mData[0]; }
__host__ __device__ constexpr const TData* end() const { return &mData[NSize]; }
__host__ __device__ constexpr TData* begin() { return &mData[0]; }
__host__ __device__ constexpr TData* end() { return &mData[NSize]; }
};
// empty Array
......@@ -54,7 +56,7 @@ template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{
using data_type = remove_cvref_t<X>;
return Array<data_type, sizeof...(Xs) + 1>{std::forward<X>(x), std::forward<Xs>(xs)...};
return Array<data_type, sizeof...(Xs) + 1>{ck::forward<X>(x), ck::forward<Xs>(xs)...};
}
// make empty array
......
......@@ -90,14 +90,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
KPerXDL);
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
"%d, %d\n C MFMA inst: %d\n",
"%d, %d\n C MFMA inst: %d\n"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d/ %d\n",
A_Buffer_Load_Inst_Num,
B_Buffer_Load_Inst_Num,
A_LDS_Write_Inst_Num,
B_LDS_Write_Inst_Num,
A_LDS_Read_Inst_Num,
B_LDS_Read_Inst_Num,
C_MFMA_Inst_Num);
C_MFMA_Inst_Num,
A_LDS_Read_Width,
B_LDS_Read_Width,
ALDSWriteWidth,
BLDSWriteWidth,
ABufferLoadWidth,
BBufferLoadWidth);
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CONTAINER_HELPER_HPP
#define CK_CONTAINER_HELPER_HPP
......@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
{
return unpack2(
[&](auto&&... zs) { return make_array(std::forward<decltype(zs)>(zs)...); }, ax, ay);
[&](auto&&... zs) { return make_array(ck::forward<decltype(zs)>(zs)...); }, ax, ay);
}
template <typename... X, typename... Y>
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{
return unpack2(
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty);
[&](auto&&... zs) { return make_tuple(ck::forward<decltype(zs)>(zs)...); }, tx, ty);
}
template <typename Container>
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/e8m0.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#ifdef CK_CODE_GEN_RTC
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using float_t = float;
#endif
namespace ck {
#ifdef CK_CODE_GEN_RTC
using byte = unsigned char;
#else
using std::byte;
#endif
using bhalf_t = ushort;
using half_t = _Float16;
using int4_t = _BitInt(4);
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8);
using f4_t = unsigned _BitInt(4);
using f6_t = _BitInt(6); // e2m3 format
using bf6_t = unsigned _BitInt(6); // e3m2 format
struct f4x2_pk_t
{
using type = uint8_t;
type data;
f4x2_pk_t() : data{type{}} {}
f4x2_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline type unpack(Number<I>) const
{
static_assert(I < 2, "Index is out of range.");
if constexpr(I == 0)
return data & 0b00001111;
else
return (data >> 4);
}
__host__ __device__ inline type pack(const type x0, const type x1)
{
return (x1 << 4) | (x0 & 0b00001111);
}
};
struct f6x16_pk_t
{
// store 16 elements of f6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
f6x16_pk_t() : data{type{}} {}
f6x16_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline f6_t unpack(Number<I>)
{
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<f6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 16 f6_t values, place its 6 bits in the correct position
ck::static_for<0, 16, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct f6x32_pk_t
{
// store 32 elements of f6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
f6x32_pk_t() : data{type{}} {}
f6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline f6_t unpack(Number<I>)
{
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<f6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 32 f6_t values, place its 6 bits in the correct position
ck::static_for<0, 32, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct bf6x16_pk_t
{
// store 16 elements of bf6_t in an array of 3 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 3>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(16)));
bf6x16_pk_t() : data{type{}} {}
bf6x16_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack(Number<I>)
{
static_assert(I < 16, "Index out of range for 16 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<bf6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 16 bf6_t values, place its 6 bits in the correct position
ck::static_for<0, 16, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 3;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
struct bf6x32_pk_t
{
// store 32 elements of bf6_t in an array of 6 uint32_t
using element_type = uint32_t;
using type = StaticallyIndexedArray_v2<element_type, 6>;
type data;
typedef int8_t test_vec_t __attribute__((ext_vector_type(32)));
bf6x32_pk_t() : data{type{}} {}
bf6x32_pk_t(type init) : data{init} {}
template <index_t I>
__host__ __device__ inline bf6_t unpack(Number<I>)
{
static_assert(I < 32, "Index out of range for 32 f6_t elements.");
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = I * num_bits_elem;
constexpr int arr_idx = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
uint32_t bits = data.At(Number<arr_idx>{}) >> bit_offset;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
if constexpr(overhang > 0 && (arr_idx + 1) < vector_size)
{
bits |= (data.At(Number<arr_idx + 1>{}) & ((1u << overhang) - 1))
<< (num_bits_elem - overhang);
}
return static_cast<bf6_t>(bits & 0x3F);
}
__host__ __device__ inline type pack(const test_vec_t& x)
{
type packed{};
// for each of the 32 bf6_t values, place its 6 bits in the correct position
ck::static_for<0, 32, 1>{}([&](auto i) {
uint32_t bits = static_cast<uint32_t>(x[static_cast<int>(i)]) & 0x3F;
constexpr int num_bits_elem = 6;
constexpr int num_bits_vec_elem = 32;
constexpr int vector_size = 6;
constexpr int bit_pos = i * num_bits_elem;
constexpr int arr_index = bit_pos / num_bits_vec_elem;
constexpr int bit_offset = bit_pos % num_bits_vec_elem;
constexpr int overhang = bit_offset + num_bits_elem - num_bits_vec_elem;
uint32_t old_value = packed.At(Number<arr_index>{});
// insert bits into the current 32-bit block
old_value |= (bits << bit_offset);
packed.At(Number<arr_index>{}) = old_value;
// if it crosses into the next block, shift the remainder
if constexpr(overhang > 0 && (arr_index + 1) < vector_size)
{
uint32_t next_value = packed.At(Number<arr_index + 1>{});
next_value |= (bits >> (num_bits_elem - overhang));
packed.At(Number<arr_index + 1>{}) = next_value;
}
});
return packed;
}
};
// custom data type - pack int4 data
struct pk_i4_t
{
using type = int8_t;
type data;
__host__ __device__ constexpr pk_i4_t() : data{type{}} {}
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
};
inline constexpr auto next_pow2(uint32_t x)
{
......@@ -19,14 +330,16 @@ inline constexpr auto next_pow2(uint32_t x)
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
}
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t,
// native types: bool, f4_t, f6_t, bf6_t
template <typename T>
inline constexpr bool is_native_type()
{
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
is_same<T, uint8_t>::value || is_same<T, f8_t>::value || is_same<T, bf8_t>::value ||
is_same<T, bool>::value;
is_same<T, uint8_t>::value || is_same<T, f8_fnuz_t>::value ||
is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value || is_same<T, f4_t>::value ||
is_same<T, f6_t>::value || is_same<T, bf6_t>::value;
}
// vector_type
......@@ -166,16 +479,37 @@ struct scalar_type<int4_t>
#endif
template <>
struct scalar_type<f8_t>
struct scalar_type<pk_i4_t>
{
using type = pk_i4_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<f8_fnuz_t>
{
using type = f8_fnuz_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf8_fnuz_t>
{
using type = bf8_fnuz_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<f8_ocp_t>
{
using type = f8_t;
using type = f8_ocp_t::data_type;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bf8_t>
struct scalar_type<bf8_ocp_t>
{
using type = bf8_t;
using type = bf8_ocp_t::data_type;
static constexpr index_t vector_size = 1;
};
......@@ -187,7 +521,7 @@ struct scalar_type<bool>
};
template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
struct vector_type<T, 1, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
using type = d1_t;
......@@ -223,7 +557,7 @@ struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
__device__ int static err = 0;
template <typename T>
struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
struct vector_type<T, 2, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
......@@ -283,20 +617,20 @@ struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
};
template <typename T>
struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
struct vector_type<T, 3, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d3_t __attribute__((ext_vector_type(3)));
using type = d4_t;
using type = d3_t;
union
{
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
d3_t d3_;
StaticallyIndexedArray<d1_t, 3> d1x3_;
StaticallyIndexedArray<d2_t, 1> d2x1_;
StaticallyIndexedArray<d3_t, 1> d3x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -306,20 +640,20 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d3_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
return data_.d1x3_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
return data_.d2x1_;
}
else if constexpr(is_same<X, d4_t>::value)
else if constexpr(is_same<X, d3_t>::value)
{
return data_.d4x1_;
return data_.d3x1_;
}
else
{
......@@ -330,20 +664,20 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d3_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x4_;
return data_.d1x3_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x2_;
return data_.d2x1_;
}
else if constexpr(is_same<X, d4_t>::value)
else if constexpr(is_same<X, d3_t>::value)
{
return data_.d4x1_;
return data_.d3x1_;
}
else
{
......@@ -353,22 +687,20 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
};
template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
struct vector_type<T, 4, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
using type = d8_t;
using type = d4_t;
union
{
d8_t d8_;
StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
d4_t d4_;
StaticallyIndexedArray<d1_t, 4> d1x4_;
StaticallyIndexedArray<d2_t, 2> d2x2_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -378,25 +710,20 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
return data_.d4x1_;
}
else
{
......@@ -407,25 +734,20 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x8_;
return data_.d1x4_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x4_;
return data_.d2x2_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
return data_.d4x1_;
}
else
{
......@@ -435,24 +757,20 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
};
template <typename T>
struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
struct vector_type<T, 5, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d5_t __attribute__((ext_vector_type(5)));
using type = d16_t;
using type = d5_t;
union
{
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
d5_t d5_;
StaticallyIndexedArray<d1_t, 5> d1x5_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
StaticallyIndexedArray<d5_t, 1> d5x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -462,30 +780,20 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value || is_same<X, d5_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
return data_.d1x5_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
return data_.d4x1_;
}
else if constexpr(is_same<X, d16_t>::value)
else if constexpr(is_same<X, d5_t>::value)
{
return data_.d16x1_;
return data_.d5x1_;
}
else
{
......@@ -496,30 +804,20 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value || is_same<X, d5_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x8_;
return data_.d1x5_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x2_;
return data_.d4x1_;
}
else if constexpr(is_same<X, d16_t>::value)
else if constexpr(is_same<X, d5_t>::value)
{
return data_.d16x1_;
return data_.d5x1_;
}
else
{
......@@ -529,26 +827,22 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
};
template <typename T>
struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
struct vector_type<T, 7, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d7_t __attribute__((ext_vector_type(7)));
using type = d32_t;
using type = d7_t;
union
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
d7_t d7_;
StaticallyIndexedArray<d1_t, 7> d1x7_;
StaticallyIndexedArray<d2_t, 3> d2x3_;
StaticallyIndexedArray<d4_t, 1> d4x1_;
StaticallyIndexedArray<d7_t, 1> d7x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -559,33 +853,24 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
is_same<X, d4_t>::value || is_same<X, d7_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
return data_.d1x7_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
return data_.d2x3_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
return data_.d4x1_;
}
else if constexpr(is_same<X, d32_t>::value)
else if constexpr(is_same<X, d7_t>::value)
{
return data_.d32x1_;
return data_.d7x1_;
}
else
{
......@@ -597,33 +882,24 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
is_same<X, d4_t>::value || is_same<X, d7_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
return data_.d1x7_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
return data_.d2x3_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
return data_.d4x1_;
}
else if constexpr(is_same<X, d32_t>::value)
else if constexpr(is_same<X, d7_t>::value)
{
return data_.d32x1_;
return data_.d7x1_;
}
else
{
......@@ -633,28 +909,22 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
};
template <typename T>
struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
struct vector_type<T, 8, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
using type = d64_t;
using type = d8_t;
union
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
d8_t d8_;
StaticallyIndexedArray<d1_t, 8> d1x8_;
StaticallyIndexedArray<d2_t, 4> d2x4_;
StaticallyIndexedArray<d4_t, 2> d4x2_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -665,81 +935,135 @@ struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
return data_.d1x8_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
return data_.d8x1_;
}
else if constexpr(is_same<X, d16_t>::value)
else
{
return data_.d16x4_;
return err;
}
else if constexpr(is_same<X, d32_t>::value)
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d32x2_;
return data_.d1x8_;
}
else if constexpr(is_same<X, d64_t>::value)
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d64x1_;
return data_.d2x4_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x2_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 13, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d13_t __attribute__((ext_vector_type(13)));
using type = d13_t;
union
{
d13_t d13_;
StaticallyIndexedArray<d1_t, 13> d1x13_;
StaticallyIndexedArray<d4_t, 3> d4x3_;
StaticallyIndexedArray<d8_t, 1> d8x1_;
StaticallyIndexedArray<d13_t, 1> d13x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr auto& AsType()
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d13_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
return data_.d1x13_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
return data_.d4x3_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
return data_.d8x1_;
}
else if constexpr(is_same<X, d16_t>::value)
else if constexpr(is_same<X, d13_t>::value)
{
return data_.d16x4_;
return data_.d13x1_;
}
else if constexpr(is_same<X, d32_t>::value)
else
{
return data_.d32x2_;
return err;
}
else if constexpr(is_same<X, d64_t>::value)
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d13_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d64x1_;
return data_.d1x13_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x3_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x1_;
}
else if constexpr(is_same<X, d13_t>::value)
{
return data_.d13x1_;
}
else
{
......@@ -749,30 +1073,24 @@ struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
};
template <typename T>
struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
struct vector_type<T, 16, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
using type = d128_t;
using type = d16_t;
union
{
d128_t d128_;
StaticallyIndexedArray<d1_t, 128> d1x128_;
StaticallyIndexedArray<d2_t, 64> d2x64_;
StaticallyIndexedArray<d4_t, 32> d4x32_;
StaticallyIndexedArray<d8_t, 16> d8x16_;
StaticallyIndexedArray<d16_t, 8> d16x8_;
StaticallyIndexedArray<d32_t, 4> d32x4_;
StaticallyIndexedArray<d64_t, 2> d64x2_;
StaticallyIndexedArray<d128_t, 1> d128x1_;
d16_t d16_;
StaticallyIndexedArray<d1_t, 16> d1x16_;
StaticallyIndexedArray<d2_t, 8> d2x8_;
StaticallyIndexedArray<d4_t, 4> d4x4_;
StaticallyIndexedArray<d8_t, 2> d8x2_;
StaticallyIndexedArray<d16_t, 1> d16x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -784,41 +1102,28 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
return data_.d16x1_;
}
else
{
......@@ -831,41 +1136,28 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
return data_.d1x16_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
return data_.d2x8_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
return data_.d4x4_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
return data_.d8x2_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
return data_.d16x1_;
}
else
{
......@@ -875,7 +1167,7 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
};
template <typename T>
struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
struct vector_type<T, 32, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
......@@ -883,24 +1175,18 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
typedef T d256_t __attribute__((ext_vector_type(256)));
using type = d256_t;
using type = d32_t;
union
{
d256_t d256_;
StaticallyIndexedArray<d1_t, 256> d1x256_;
StaticallyIndexedArray<d2_t, 128> d2x128_;
StaticallyIndexedArray<d4_t, 64> d4x64_;
StaticallyIndexedArray<d8_t, 32> d8x32_;
StaticallyIndexedArray<d16_t, 16> d16x16_;
StaticallyIndexedArray<d32_t, 8> d32x8_;
StaticallyIndexedArray<d64_t, 4> d64x4_;
StaticallyIndexedArray<d128_t, 2> d128x2_;
StaticallyIndexedArray<d256_t, 1> d256x1_;
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
......@@ -910,47 +1196,34 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"Something went wrong, please check src and dst types.");
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
return data_.d32x1_;
}
else
{
......@@ -961,47 +1234,34 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"Something went wrong, please check src and dst types.");
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
return data_.d32x1_;
}
else
{
......@@ -1010,60 +1270,677 @@ struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
}
};
template <typename T, index_t N>
struct non_native_vector_base
{
using type = non_native_vector_base<T, N>;
__host__ __device__ non_native_vector_base() = default;
__host__ __device__ non_native_vector_base(const type&) = default;
__host__ __device__ non_native_vector_base(type&&) = default;
__host__ __device__ ~non_native_vector_base() = default;
T d[N];
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
struct vector_type<T, 64, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
using type = d1_t;
union alignas(next_pow2(1 * sizeof(T)))
{
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
using type = d64_t;
union
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 128, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
using type = d128_t;
union
{
d128_t d128_;
StaticallyIndexedArray<d1_t, 128> d1x128_;
StaticallyIndexedArray<d2_t, 64> d2x64_;
StaticallyIndexedArray<d4_t, 32> d4x32_;
StaticallyIndexedArray<d8_t, 16> d8x16_;
StaticallyIndexedArray<d16_t, 8> d16x8_;
StaticallyIndexedArray<d32_t, 4> d32x4_;
StaticallyIndexedArray<d64_t, 2> d64x2_;
StaticallyIndexedArray<d128_t, 1> d128x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x128_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x64_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x32_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x16_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x8_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x4_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x2_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 256, typename ck::enable_if_t<is_native_type<T>()>>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
typedef T d128_t __attribute__((ext_vector_type(128)));
typedef T d256_t __attribute__((ext_vector_type(256)));
using type = d256_t;
union
{
d256_t d256_;
StaticallyIndexedArray<d1_t, 256> d1x256_;
StaticallyIndexedArray<d2_t, 128> d2x128_;
StaticallyIndexedArray<d4_t, 64> d4x64_;
StaticallyIndexedArray<d8_t, 32> d8x32_;
StaticallyIndexedArray<d16_t, 16> d16x16_;
StaticallyIndexedArray<d32_t, 8> d32x8_;
StaticallyIndexedArray<d64_t, 4> d64x4_;
StaticallyIndexedArray<d128_t, 2> d128x2_;
StaticallyIndexedArray<d256_t, 1> d256x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x256_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x128_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x64_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x32_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x16_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x8_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x4_;
}
else if constexpr(is_same<X, d128_t>::value)
{
return data_.d128x2_;
}
else if constexpr(is_same<X, d256_t>::value)
{
return data_.d256x1_;
}
else
{
return err;
}
}
};
template <typename T, index_t N, typename Enable = void>
struct non_native_vector_base;
template <typename T>
struct nnvb_data_t_selector
{
using type = unsigned _BitInt(8 * sizeof(T));
};
template <>
struct nnvb_data_t_selector<f8_ocp_t>
{
using type = f8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<bf8_ocp_t>
{
using type = bf8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<f6x16_pk_t>
{
using type = f6x16_pk_t::type;
};
template <>
struct nnvb_data_t_selector<f6x32_pk_t>
{
using type = f6x32_pk_t::type;
};
template <>
struct nnvb_data_t_selector<bf6x16_pk_t>
{
using type = bf6x16_pk_t::type;
};
template <>
struct nnvb_data_t_selector<bf6x32_pk_t>
{
using type = bf6x32_pk_t::type;
};
template <>
struct nnvb_data_t_selector<pk_i4_t>
{
using type = pk_i4_t::type;
};
template <typename T, index_t N>
struct non_native_vector_base<
T,
N,
ck::enable_if_t<sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8>>
{
using data_t = typename nnvb_data_t_selector<T>::type; // select data_t based on the size of T
static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch");
using data_v = data_t __attribute__((ext_vector_type(N)));
using type = non_native_vector_base<T, N>;
union alignas(next_pow2(N * sizeof(T)))
{
data_v dN; // storage vector;
StaticallyIndexedArray<data_t, N> dxN;
StaticallyIndexedArray<T, N> dTxN;
StaticallyIndexedArray<data_v, 1> dNx1;
} data_;
__host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {}
__host__ __device__ constexpr non_native_vector_base(T f)
: non_native_vector_base(bit_cast<data_t>(f))
{
}
__host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){};
__host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {}
__host__ __device__ constexpr operator data_v() const { return data_.dN; }
__host__ __device__ constexpr operator data_t() const
{
if constexpr(N == 1)
{
return data_.dxN[Number<0>{}];
}
else
{
return data_.dxN; // XXX this should cause an error
}
}
__host__ __device__ constexpr operator T() const
{
if constexpr(N == 1)
{
return data_.dTxN[Number<0>{}];
}
else
{
return data_.dTxN; // XXX this should cause an error
}
}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same_v<X, data_t> || is_same_v<X, T> || is_same_v<X, data_v>,
"Something went wrong, please check src and dst types.");
if constexpr(is_same_v<X, data_t>)
{
return data_.dxN;
}
else if constexpr(is_same_v<X, T>)
{
return data_.dTxN;
}
else if constexpr(is_same_v<X, data_v>)
{
return data_.dNx1;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same_v<X, data_t> || is_same_v<X, T> || is_same_v<X, data_v>,
"Something went wrong, please check src and dst types.");
if constexpr(is_same_v<X, data_t>)
{
return data_.dxN;
}
else if constexpr(is_same_v<X, T>)
{
return data_.dTxN;
}
else if constexpr(is_same_v<X, data_v>)
{
return data_.dNx1;
}
else
{
return err;
}
}
};
// implementation for f6x16 and f6x32
template <typename T, index_t N>
struct non_native_vector_base<T, N, std::enable_if_t<sizeof(T) == 12 || sizeof(T) == 24>>
{
using data_t =
typename nnvb_data_t_selector<T>::type; // select data_t based on declared base type
using element_t = typename T::element_type; // select element_t based on declared element type
static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch");
static constexpr size_t size_factor =
sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6
using data_v = element_t __attribute__((ext_vector_type(N * size_factor)));
using type = non_native_vector_base<T, N>;
union alignas(next_pow2(N * sizeof(T)))
{
data_v dN; // storage vector;
StaticallyIndexedArray<data_t, N> dxN;
StaticallyIndexedArray<T, N> dTxN;
StaticallyIndexedArray<data_v, 1> dNx1;
} data_;
__host__ __device__ constexpr non_native_vector_base(data_t a)
: data_{data_v(a.At(Number<0>{}))}
{
}
__host__ __device__ constexpr non_native_vector_base(T f)
: non_native_vector_base(bit_cast<data_t>(f))
{
}
__host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){};
__host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {}
__host__ __device__ constexpr operator data_v() const { return data_.dN; }
__host__ __device__ constexpr operator data_t() const
{
if constexpr(N == 1)
{
return data_.dxN[Number<0>{}];
}
else
{
return data_.dxN; // XXX this should cause an error
}
}
__host__ __device__ constexpr operator T() const
{
if constexpr(N == 1)
{
return data_.dTxN[Number<0>{}];
}
else
{
return data_.dTxN; // XXX this should cause an error
}
}
};
template <typename T, index_t N>
struct scalar_type<non_native_vector_base<T, N>>;
template <index_t N>
struct scalar_type<non_native_vector_base<f8_ocp_t, N>>
{
using type = typename non_native_vector_base<f8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
};
template <index_t N>
struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
{
using type = typename non_native_vector_base<bf8_ocp_t, N>::data_t;
static constexpr index_t vector_size = N;
};
template <index_t N>
struct scalar_type<non_native_vector_base<pk_i4_t, N>>
{
using type = typename non_native_vector_base<pk_i4_t, N>::data_t;
static constexpr index_t vector_size = N;
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename ck::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using type = d1_nnv_t;
union alignas(next_pow2(1 * sizeof(T)))
{
d1_t d1_;
StaticallyIndexedArray<d1_t, 1> d1x1_;
d1_nnv_t d1_nnv_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{}} {}
__host__ __device__ constexpr vector_type() : data_{d1_t{}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_;
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x1_;
}
else
{
return err;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value,
"Something went wrong, please check src and dst types.");
return data_.d1x1_;
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x1_;
}
else
{
return err;
}
}
};
template <typename T>
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
struct vector_type<T, 2, typename ck::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using type = d2_t;
......@@ -1081,10 +1958,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x2_;
}
......@@ -1101,10 +1979,11 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x2_;
}
......@@ -1120,11 +1999,12 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
};
template <typename T>
struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
struct vector_type<T, 4, typename ck::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using type = d4_t;
......@@ -1143,10 +2023,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x4_;
}
......@@ -1167,10 +2048,11 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x4_;
}
......@@ -1190,12 +2072,13 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
};
template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
struct vector_type<T, 8, typename ck::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using type = d8_t;
......@@ -1215,11 +2098,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x8_;
}
......@@ -1244,11 +2128,12 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x8_;
}
......@@ -1272,13 +2157,14 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
};
template <typename T>
struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
struct vector_type<T, 16, typename ck::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>;
using d2_t = non_native_vector_base<T, 2>;
using d4_t = non_native_vector_base<T, 4>;
using d8_t = non_native_vector_base<T, 8>;
using d16_t = non_native_vector_base<T, 16>;
using type = d16_t;
......@@ -1299,12 +2185,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x16_;
}
......@@ -1333,12 +2219,12 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value,
static_assert(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value ||
is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
is_same<X, d8_t>::value || is_same<X, d16_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
if constexpr(is_same<X, d1_t>::value || is_same<X, d1_nnv_t>::value)
{
return data_.d1x16_;
}
......@@ -1366,7 +2252,7 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
};
template <typename T>
struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
struct vector_type<T, 32, typename ck::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
......@@ -1470,7 +2356,7 @@ struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
};
template <typename T>
struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
struct vector_type<T, 64, typename ck::enable_if_t<!is_native_type<T>()>>
{
using d1_t = T;
using d2_t = non_native_vector_base<T, 2>;
......@@ -1541,134 +2427,415 @@ struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"Something went wrong, please check src and dst types.");
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"Something went wrong, please check src and dst types.");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
else
{
return err;
}
}
};
using int64_t = long;
// fp64
using double2_t = typename vector_type<double, 2>::type;
using double4_t = typename vector_type<double, 4>::type;
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
// fp16
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// bfp16
using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
using bhalf8_t = typename vector_type<bhalf_t, 8>::type;
using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
using bhalf64_t = typename vector_type<bhalf_t, 64>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
// i8
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8
using f8x2_fnuz_t = typename vector_type<f8_fnuz_t, 2>::type;
using f8x4_fnuz_t = typename vector_type<f8_fnuz_t, 4>::type;
using f8x8_fnuz_t = typename vector_type<f8_fnuz_t, 8>::type;
using f8x16_fnuz_t = typename vector_type<f8_fnuz_t, 16>::type;
using f8x32_fnuz_t = typename vector_type<f8_fnuz_t, 32>::type;
using f8x64_fnuz_t = typename vector_type<f8_fnuz_t, 64>::type;
// bf8
using bf8x2_fnuz_t = typename vector_type<bf8_fnuz_t, 2>::type;
using bf8x4_fnuz_t = typename vector_type<bf8_fnuz_t, 4>::type;
using bf8x8_fnuz_t = typename vector_type<bf8_fnuz_t, 8>::type;
using bf8x16_fnuz_t = typename vector_type<bf8_fnuz_t, 16>::type;
using bf8x32_fnuz_t = typename vector_type<bf8_fnuz_t, 32>::type;
using bf8x64_fnuz_t = typename vector_type<bf8_fnuz_t, 64>::type;
// f8
using f8x2_ocp_t = typename vector_type<f8_ocp_t, 2>::type;
using f8x4_ocp_t = typename vector_type<f8_ocp_t, 4>::type;
using f8x8_ocp_t = typename vector_type<f8_ocp_t, 8>::type;
using f8x16_ocp_t = typename vector_type<f8_ocp_t, 16>::type;
using f8x32_ocp_t = typename vector_type<f8_ocp_t, 32>::type;
using f8x64_ocp_t = typename vector_type<f8_ocp_t, 64>::type;
// bf8
using bf8x2_ocp_t = typename vector_type<bf8_ocp_t, 2>::type;
using bf8x4_ocp_t = typename vector_type<bf8_ocp_t, 4>::type;
using bf8x8_ocp_t = typename vector_type<bf8_ocp_t, 8>::type;
using bf8x16_ocp_t = typename vector_type<bf8_ocp_t, 16>::type;
using bf8x32_ocp_t = typename vector_type<bf8_ocp_t, 32>::type;
using bf8x64_ocp_t = typename vector_type<bf8_ocp_t, 64>::type;
#if CK_FP8_TYPE_OCP
// f8
using f8x2_t = f8x2_ocp_t;
using f8x4_t = f8x4_ocp_t;
using f8x8_t = f8x8_ocp_t;
using f8x16_t = f8x16_ocp_t;
using f8x32_t = f8x32_ocp_t;
using f8x64_t = f8x64_ocp_t;
// bf8
using bf8x2_t = bf8x2_ocp_t;
using bf8x4_t = bf8x4_ocp_t;
using bf8x8_t = bf8x8_ocp_t;
using bf8x16_t = bf8x16_ocp_t;
using bf8x32_t = bf8x32_ocp_t;
using bf8x64_t = bf8x64_ocp_t;
#elif CK_FP8_TYPE_FNUZ
// f8
using f8x2_t = f8x2_fnuz_t;
using f8x4_t = f8x4_fnuz_t;
using f8x8_t = f8x8_fnuz_t;
using f8x16_t = f8x16_fnuz_t;
using f8x32_t = f8x32_fnuz_t;
using f8x64_t = f8x64_fnuz_t;
// bf8
using bf8x2_t = bf8x2_fnuz_t;
using bf8x4_t = bf8x4_fnuz_t;
using bf8x8_t = bf8x8_fnuz_t;
using bf8x16_t = bf8x16_fnuz_t;
using bf8x32_t = bf8x32_fnuz_t;
using bf8x64_t = bf8x64_fnuz_t;
#endif
// u8
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type;
using uint8x8_t = typename vector_type<uint8_t, 8>::type;
using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type;
// f4
using f4x2_t = typename vector_type<f4x2_pk_t, 1>::type;
using f4x4_t = typename vector_type<f4x2_pk_t, 2>::type;
using f4x8_t = typename vector_type<f4x2_pk_t, 4>::type;
using f4x16_t = typename vector_type<f4x2_pk_t, 8>::type;
using f4x32_t = typename vector_type<f4x2_pk_t, 16>::type;
using f4x64_t = typename vector_type<f4x2_pk_t, 32>::type;
// f6
using f6x16_t = typename vector_type<f6x16_pk_t, 1>::type;
using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
// bf6
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
// pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
using pk_i4x8_t = typename vector_type<pk_i4_t, 8>::type;
#ifdef CK_CODE_GEN_RTC
template <typename T>
struct NumericLimits;
template <>
struct NumericLimits<int32_t>
{
__host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; }
__host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int16_t>
{
__host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Min() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Max() noexcept { return 32767; }
__host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int8_t>
{
__host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Min() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Max() noexcept { return 127; }
__host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int8_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint32_t>
{
__host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; }
__host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint16_t>
{
__host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; }
__host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<float>
{
static constexpr unsigned int binary_min = 0x00800000;
static constexpr unsigned int binary_max = 0x7F7FFFFF;
static constexpr unsigned int binary_lowest = 0xFF7FFFFF;
static constexpr unsigned int binary_qnan = 0xFFC00001;
static constexpr unsigned int binary_inf = 0x7F8000000;
__host__ __device__ static constexpr float Min() { return bit_cast<float>(binary_min); }
__host__ __device__ static constexpr float Max() { return bit_cast<float>(binary_max); }
__host__ __device__ static constexpr float Lowest() { return bit_cast<float>(binary_lowest); }
__host__ __device__ static constexpr float QuietNaN() { return bit_cast<float>(binary_qnan); }
__host__ __device__ static constexpr float Infinity() { return bit_cast<float>(binary_inf); }
};
template <>
struct NumericLimits<half_t>
{
static constexpr unsigned short binary_min = 0x0400;
static constexpr unsigned short binary_max = 0x7BFF;
static constexpr unsigned short binary_lowest = 0xFBFF;
static constexpr unsigned short binary_qnan = 0x7FFF;
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<int4_t>
{
__host__ __device__ static constexpr int4_t Min() { return int4_t(-8); }
__host__ __device__ static constexpr int4_t Max() { return int4_t(7); }
__host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); }
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<f8_fnuz_t>
{
// negative zero nan mode with exp bias = 8
static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
else
{
return err;
}
}
__host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); }
__host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); }
__host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); }
__host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); }
};
using int64_t = long;
template <>
struct NumericLimits<bf8_fnuz_t>
{
// negative zero nan mode with exp bias = 16
static constexpr uint8_t binary_min = 0x04; // 0b00000100
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
// fp64
using double2_t = typename vector_type<double, 2>::type;
using double4_t = typename vector_type<double, 4>::type;
__host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); }
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
__host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); }
// fp16
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
__host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); }
// bfp16
using bhalf2_t = typename vector_type<bhalf_t, 2>::type;
using bhalf4_t = typename vector_type<bhalf_t, 4>::type;
using bhalf8_t = typename vector_type<bhalf_t, 8>::type;
using bhalf16_t = typename vector_type<bhalf_t, 16>::type;
using bhalf32_t = typename vector_type<bhalf_t, 32>::type;
using bhalf64_t = typename vector_type<bhalf_t, 64>::type;
__host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); }
};
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
template <>
struct NumericLimits<f8_ocp_t>
{
static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6
static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448
static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448
static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111
// i8
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
__host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast<f8_ocp_t>(binary_min); }
// f8
using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type;
__host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast<f8_ocp_t>(binary_max); }
// bf8
using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
__host__ __device__ static constexpr f8_ocp_t Lowest()
{
return bit_cast<f8_ocp_t>(binary_lowest);
}
// u8
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
using uint8x4_t = typename vector_type<uint8_t, 4>::type;
using uint8x8_t = typename vector_type<uint8_t, 8>::type;
using uint8x16_t = typename vector_type<uint8_t, 16>::type;
using uint8x32_t = typename vector_type<uint8_t, 32>::type;
using uint8x64_t = typename vector_type<uint8_t, 64>::type;
__host__ __device__ static constexpr f8_ocp_t QuietNaN()
{
return bit_cast<f8_ocp_t>(binary_qnan);
}
};
template <>
struct NumericLimits<bf8_ocp_t>
{
static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14
static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344
static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344
static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101
__host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast<bf8_ocp_t>(binary_min); }
__host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast<bf8_ocp_t>(binary_max); }
__host__ __device__ static constexpr bf8_ocp_t Lowest()
{
return bit_cast<bf8_ocp_t>(binary_lowest);
}
__host__ __device__ static constexpr bf8_ocp_t QuietNaN()
{
return bit_cast<bf8_ocp_t>(binary_qnan);
}
};
#else
template <typename T>
struct NumericLimits
{
__host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); }
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
__host__ __device__ static constexpr T QuietNaN()
{
return std::numeric_limits<T>::quiet_NaN();
}
__host__ __device__ static constexpr T Infinity() { return std::numeric_limits<T>::infinity(); }
};
......@@ -1702,7 +2869,7 @@ struct NumericLimits<int4_t>
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<f8_t>
struct NumericLimits<f8_fnuz_t>
{
// negative zero nan mode with exp bias = 8
static constexpr uint8_t binary_min = 0x08; // 0b00001000
......@@ -1715,17 +2882,17 @@ struct NumericLimits<f8_t>
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); }
__host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); }
__host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); }
__host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
__host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); }
};
template <>
struct NumericLimits<bf8_t>
struct NumericLimits<bf8_fnuz_t>
{
// negative zero nan mode with exp bias = 16
static constexpr uint8_t binary_min = 0x04; // 0b00000100
......@@ -1738,13 +2905,172 @@ struct NumericLimits<bf8_t>
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); }
__host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); }
__host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); }
__host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); }
__host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); }
};
template <>
struct NumericLimits<f8_ocp_t>
{
static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6
static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448
static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448
static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111
__host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast<f8_ocp_t>(binary_min); }
__host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast<f8_ocp_t>(binary_max); }
__host__ __device__ static constexpr f8_ocp_t Lowest()
{
return bit_cast<f8_ocp_t>(binary_lowest);
}
__host__ __device__ static constexpr f8_ocp_t QuietNaN()
{
return bit_cast<f8_ocp_t>(binary_qnan);
}
};
template <>
struct NumericLimits<bf8_ocp_t>
{
static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14
static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344
static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344
static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101
__host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast<bf8_ocp_t>(binary_min); }
__host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast<bf8_ocp_t>(binary_max); }
__host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); }
__host__ __device__ static constexpr bf8_ocp_t Lowest()
{
return bit_cast<bf8_ocp_t>(binary_lowest);
}
__host__ __device__ static constexpr bf8_ocp_t QuietNaN()
{
return bit_cast<bf8_ocp_t>(binary_qnan);
}
};
#endif
template <>
struct NumericLimits<f4_t>
{
static constexpr uint8_t binary_min_normal = 0x2; // 0b0010
static constexpr uint8_t binary_max_normal = 0x7; // 0b0111
static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111
static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001
static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001
static constexpr float data_max_normal_number = 6;
static constexpr float data_min_subnormal_number = 0.5;
__host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); }
__host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); }
__host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); }
__host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); }
__host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); }
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
__host__ __device__ static constexpr float DataMinSubnorm()
{
return data_min_subnormal_number;
}
};
template <>
struct NumericLimits<f6_t>
{
static constexpr uint8_t binary_min_normal = 0x08; // 0b001000
static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111
static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111
static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001
static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111
static constexpr float data_max_normal_number = 7.5;
static constexpr float data_min_subnormal_number = 0.125;
__host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); }
__host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); }
__host__ __device__ static constexpr f6_t Lowest()
{
return f6_t(binary_lowest_normal & 0b111111);
}
__host__ __device__ static constexpr f6_t MinSubnorm()
{
return f6_t(binary_min_subnorm & 0b111111);
}
__host__ __device__ static constexpr f6_t MaxSubnorm()
{
return f6_t(binary_max_subnorm & 0b111111);
}
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
__host__ __device__ static constexpr float DataMinSubnorm()
{
return data_min_subnormal_number;
}
};
__host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); }
template <>
struct NumericLimits<bf6_t>
{
static constexpr uint8_t binary_min_normal = 0x08; // 0b001000
static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111
static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111
static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001
static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011
static constexpr float data_max_normal_number = 28;
static constexpr float data_min_subnormal_number = 0.0625;
__host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); }
__host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); }
__host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); }
__host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); }
__host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); }
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
__host__ __device__ static constexpr float DataMinSubnorm()
{
return data_min_subnormal_number;
}
};
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
template <>
struct NumericLimits<e8m0_bexp_t>
{
static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000
static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110
static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111
static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111
static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000
static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010
static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111
static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110
__host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); }
__host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); }
__host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); }
__host__ __device__ static constexpr e8m0_bexp_t Binary_135()
{
return e8m0_bexp_t(binary_135);
}
__host__ __device__ static constexpr e8m0_bexp_t Binary_142()
{
return e8m0_bexp_t(binary_142);
}
};
template <typename T>
......@@ -1766,6 +3092,7 @@ struct NumericUtils<float>
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
static constexpr bool has_inf = true;
using bitwise_type = uint32_t;
};
......@@ -1783,33 +3110,158 @@ struct NumericUtils<half_t>
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
static constexpr bool has_inf = true;
using bitwise_type = uint16_t;
};
template <>
struct NumericUtils<f8_t>
struct NumericUtils<bhalf_t>
{
static constexpr int exp = 8;
static constexpr int mant = 7;
static constexpr int bias = 128; // negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
template <>
struct NumericUtils<f8_fnuz_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
static constexpr int bias = 8; // negative zero nan mode
// static constexpr int bias = 7; // ieee mode
static constexpr bool has_inf = false;
};
template <>
struct NumericUtils<bf8_t>
struct NumericUtils<bf8_fnuz_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
static constexpr int bias = 16; // negative zero nan mode
// static constexpr int bias = 15; // ieee mode
static constexpr bool has_inf = false;
};
template <>
struct NumericUtils<f8_ocp_t>
{
static constexpr int exp = 4;
static constexpr int mant = 3;
static constexpr int bias = 7;
};
template <>
struct NumericUtils<bhalf_t>
struct NumericUtils<bf8_ocp_t>
{
static constexpr int exp = 5;
static constexpr int mant = 2;
static constexpr int bias = 15;
};
template <>
struct NumericUtils<f4_t>
{
static constexpr int exp = 2;
static constexpr int mant = 1;
static constexpr int bias = 1;
static constexpr uint32_t sr_shift = 10;
static constexpr int unbiased_exp_min = 0;
static constexpr int unbiased_exp_max = 2;
static constexpr int biased_exp_min = 1;
static constexpr int biased_exp_max = 3;
static constexpr uint8_t positive_zero_mask = 0b0000;
static constexpr uint8_t negative_zero_mask = 0b1000;
static constexpr uint8_t one_mask = 0b0010;
static constexpr uint8_t set_sign_mask = 0b0111;
static constexpr uint8_t data_max_positive_normal_mask = 0b0111;
static constexpr uint8_t data_max_negative_normal_mask = 0b1111;
static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001;
static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001;
static constexpr bool has_inf = false;
using bitwise_type = uint8_t;
};
template <>
struct NumericUtils<f6_t>
{
static constexpr int exp = 2;
static constexpr int mant = 3;
static constexpr int bias = 1;
static constexpr uint32_t sr_shift = 12;
static constexpr int unbiased_exp_min = 0;
static constexpr int unbiased_exp_max = 2;
static constexpr int biased_exp_min = 1;
static constexpr int biased_exp_max = 3;
static constexpr uint8_t positive_zero_mask = 0b000000;
static constexpr uint8_t negative_zero_mask = 0b100000;
static constexpr uint8_t set_sign_mask = 0b011111;
static constexpr uint8_t data_max_positive_normal_mask = 0b011111;
static constexpr uint8_t data_max_negative_normal_mask = 0b111111;
static constexpr uint8_t data_max_positive_subnormal_mask = 0b000111;
static constexpr uint8_t data_max_negative_subnormal_mask = 0b100111;
static constexpr bool has_inf = false;
static constexpr bool has_nan = false;
static constexpr bool has_zero = true;
using bitwise_type = uint8_t;
};
template <>
struct NumericUtils<bf6_t>
{
static constexpr int exp = 3;
static constexpr int mant = 2;
static constexpr int bias = 3;
static constexpr uint32_t sr_shift = 11;
static constexpr int unbiased_exp_min = -2;
static constexpr int unbiased_exp_max = 4;
static constexpr int biased_exp_min = 1;
static constexpr int biased_exp_max = 7;
static constexpr uint8_t positive_zero_mask = 0b000000;
static constexpr uint8_t negative_zero_mask = 0b100000;
static constexpr uint8_t set_sign_mask = 0b011111;
static constexpr uint8_t data_max_positive_normal_mask = 0b011111;
static constexpr uint8_t data_max_negative_normal_mask = 0b111111;
static constexpr uint8_t data_max_positive_subnormal_mask = 0b000011;
static constexpr uint8_t data_max_negative_subnormal_mask = 0b100011;
static constexpr bool has_inf = false;
static constexpr bool has_nan = false;
static constexpr bool has_zero = true;
using bitwise_type = uint8_t;
};
template <>
struct NumericUtils<e8m0_bexp_t>
{
static constexpr int exp = 8;
static constexpr int mant = 7;
static constexpr int bias = 128; // negative zero nan mode
// static constexpr int bias = 127; // ieee mode
static constexpr int mant = 0;
static constexpr int bias = 127;
static constexpr int unbiased_exp_min = -127;
static constexpr int unbiased_exp_max = 127;
static constexpr int biased_exp_min = 0;
static constexpr int biased_exp_max = 254;
using bitwise_type = uint8_t;
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace ck {
namespace debug {
......
......@@ -29,6 +29,13 @@ struct DynamicBuffer
ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size}
{
......@@ -54,7 +61,8 @@ struct DynamicBuffer
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{
......@@ -81,14 +89,18 @@ struct DynamicBuffer
return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
t_per_x,
coherence>(
p_data_, i, is_valid_element, element_space_size_);
p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else
{
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
t_per_x,
coherence>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_);
p_data_,
i,
is_valid_element,
element_space_size_ / PackedSize,
invalid_element_value_);
}
}
else
......@@ -190,12 +202,13 @@ struct DynamicBuffer
dst_buf.p_data_,
dst_offset,
is_valid_element,
element_space_size_);
element_space_size_ / PackedSize);
}
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{
......@@ -224,7 +237,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
x, p_data_, i, is_valid_element, element_space_size_);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
......@@ -376,7 +389,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else
{
......@@ -415,7 +428,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_);
x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
}
else if(is_valid_element)
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/type.hpp"
namespace ck {
/**
* @brief Unsigned representation of a conventional biased Float32 exponent.
*
* bias = 127;
*
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
* E8M0_MIN = 0b00000000; => 2^-127
* E8M0_MAX = 0b11111110; => 2^127
* E8M0_NAN = 0b11111111; => NaN
*/
struct e8m0_bexp_t
{
using type = uint8_t;
type data;
constexpr static type bias = 127;
constexpr static type nan_mask = 0xFF;
__host__ __device__ constexpr e8m0_bexp_t() : data{type{}} {}
__host__ __device__ constexpr e8m0_bexp_t(type init) : data{init} {}
__host__ __device__ constexpr e8m0_bexp_t(int init) : data{static_cast<type>(init & nan_mask)}
{
}
__host__ __device__ explicit constexpr e8m0_bexp_t(float scale)
: data{static_cast<type>((bit_cast<uint32_t>(scale) & (nan_mask << 23)) >> 23)}
{
}
__host__ __device__ explicit constexpr operator float() const
{
if(data == nan_mask || data == 0)
{
uint32_t bits = data << 1;
bits |= 1;
bits <<= 22;
return bit_cast<float>(bits);
}
else
{
uint32_t bits = data << 23;
return bit_cast<float>(bits);
}
}
__host__ __device__ constexpr bool operator==(const e8m0_bexp_t& other) const
{
// strict IEEE compliance for NaN
return data == other.data && data != nan_mask;
}
__host__ __device__ constexpr bool is_nan() const { return data == nan_mask; }
};
namespace utils {
template <typename T>
__host__ __device__ inline int get_exponent_value(T x);
template <>
__host__ __device__ inline int get_exponent_value<e8m0_bexp_t>(e8m0_bexp_t x)
{
return x.data;
}
} // namespace utils
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
#ifndef CK_CODE_GEN_RTC
template <bool B, typename T = void>
using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type;
#else
template <bool B, class T = void>
struct enable_if
{
};
template <class T>
struct enable_if<true, T>
{
using type = T;
};
template <bool B, class T = void>
using enable_if_t = typename enable_if<B, T>::type;
#endif
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CODE_GEN_RTC
#pragma once
#include <cstdlib>
......@@ -183,3 +184,4 @@ void UpdateEnvVar(EnvVar, const std::string_view& val)
}
} // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
{
if constexpr(predicate)
{
return std::forward<X>(x);
return ck::forward<X>(x);
}
else
{
return std::forward<Y>(y);
return ck::forward<Y>(y);
}
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP
......@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>>
template <typename F, typename X>
__host__ __device__ constexpr auto operator()(F&& f, X&& x) const
{
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...);
return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...);
}
};
......@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const
{
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...,
std::forward<Y>(y).At(Number<Js>{})...);
return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...,
ck::forward<Y>(y).At(Number<Js>{})...);
}
};
......@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x)
{
using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x));
ck::forward<F>(f), ck::forward<X>(x));
}
// TODO: properly implement unpack that takes any number of containers
......@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
ck::forward<F>(f), ck::forward<X>(x), ck::forward<Y>(y));
}
} // namespace ck
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
return integral_constant<decltype(X % Y), X % Y>{};
}
template <bool B>
using bool_constant = integral_constant<bool, B>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/integral_constant.hpp"
namespace ck {
namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
{
using value_t = std::false_type;
using value_t = integral_constant<bool, false>;
using type = Default;
};
template <class Default, template <class...> class Op, class... Args>
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
struct detector<Default, ck::void_t<Op<Args...>>, Op, Args...>
{
using value_t = std::true_type;
using value_t = integral_constant<bool, true>;
using type = Op<Args...>;
};
} // namespace detail
......@@ -32,12 +34,12 @@ template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
template <typename T>
using is_pack2_invocable_t = decltype(std::declval<T&>().is_pack2_invocable);
using is_pack2_invocable_t = decltype(ck::declval<T&>().is_pack2_invocable);
template <typename T>
using is_pack4_invocable_t = decltype(std::declval<T&>().is_pack4_invocable);
using is_pack4_invocable_t = decltype(ck::declval<T&>().is_pack4_invocable);
template <typename T>
using is_pack8_invocable_t = decltype(std::declval<T&>().is_pack8_invocable);
using is_pack8_invocable_t = decltype(ck::declval<T&>().is_pack8_invocable);
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CODE_GEN_RTC
#include <ostream>
#endif
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
......@@ -26,6 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler()
} // namespace ck
#ifndef CK_CODE_GEN_RTC
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
{
switch(s)
......@@ -36,3 +39,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
}
return os;
}
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -9,6 +9,10 @@
#include "type.hpp"
#include "tuple.hpp"
#ifdef CK_CODE_GEN_RTC
#define INT32_MAX 2147483647
#endif
namespace ck {
// magic number division
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment