Commit cd4d4629 authored by danyao12's avatar danyao12
Browse files

Merge branch 'develop' into ck_tile/fa_bwd_v3

parents 21d12bb7 888317e6
...@@ -54,7 +54,8 @@ struct DynamicBuffer ...@@ -54,7 +54,8 @@ struct DynamicBuffer
template <typename X, template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type, typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value, typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{ {
...@@ -195,7 +196,8 @@ struct DynamicBuffer ...@@ -195,7 +196,8 @@ struct DynamicBuffer
template <typename X, template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type, typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value, typename scalar_type<remove_cvref_t<T>>::type>::value ||
!is_native_type<X>(),
bool>::type = false> bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{ {
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#pragma once #pragma once
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck { namespace ck {
......
...@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x) ...@@ -80,7 +80,7 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); }; static inline __host__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x) static inline __host__ bool isnan(int4_t x)
...@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x) ...@@ -531,7 +531,7 @@ static inline __device__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); }; static inline __device__ bool isnan(f8_t x) { return ck::fp8_is_nan(x); };
static inline __device__ half_t sqrt(half_t x) static inline __device__ half_t sqrt(half_t x)
{ {
...@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x) ...@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
template <> template <>
inline __device__ half_t neg<half_t>(half_t x) inline __device__ half_t neg<half_t>(half_t x)
{ {
return __hneg(x); return __hneg(static_cast<__half>(x));
}; };
template <typename T> template <typename T>
...@@ -653,7 +653,7 @@ inline __device__ double sin<double>(double x) ...@@ -653,7 +653,7 @@ inline __device__ double sin<double>(double x)
template <> template <>
inline __device__ half_t sin<half_t>(half_t x) inline __device__ half_t sin<half_t>(half_t x)
{ {
return ::hsin(x); return hsin(static_cast<__half>(x));
}; };
template <typename T> template <typename T>
...@@ -785,7 +785,7 @@ inline __device__ double ceil<double>(double x) ...@@ -785,7 +785,7 @@ inline __device__ double ceil<double>(double x)
template <> template <>
inline __device__ half_t ceil<half_t>(half_t x) inline __device__ half_t ceil<half_t>(half_t x)
{ {
return ::hceil(x); return hceil(static_cast<__half>(x));
}; };
template <typename T> template <typename T>
...@@ -827,7 +827,7 @@ inline __device__ double floor<double>(double x) ...@@ -827,7 +827,7 @@ inline __device__ double floor<double>(double x)
template <> template <>
inline __device__ half_t floor<half_t>(half_t x) inline __device__ half_t floor<half_t>(half_t x)
{ {
return ::hfloor(x); return hfloor(static_cast<__half>(x));
}; };
template <typename T> template <typename T>
...@@ -849,7 +849,7 @@ inline __device__ T exp(T x) ...@@ -849,7 +849,7 @@ inline __device__ T exp(T x)
template <> template <>
inline __device__ half_t exp<half_t>(half_t x) inline __device__ half_t exp<half_t>(half_t x)
{ {
return hexp(x); return hexp(static_cast<__half>(x));
}; };
template <> template <>
...@@ -873,7 +873,7 @@ inline __device__ T log(T x) ...@@ -873,7 +873,7 @@ inline __device__ T log(T x)
template <> template <>
inline __device__ half_t log<half_t>(half_t x) inline __device__ half_t log<half_t>(half_t x)
{ {
return hlog(x); return hlog(static_cast<__half>(x));
}; };
template <> template <>
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/ck.hpp"
namespace ck { namespace ck {
// Pseudo random number generator // Pseudo random number generator
...@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -23,7 +25,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
} }
// version for fp16 // version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false> template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<_Float16, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{ {
uint16_t x = *(reinterpret_cast<uint16_t*>(&val)); uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
...@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -38,9 +40,10 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
} }
// return 0 if data is not fp16 or fp32 // return 0 if data is not fp16 or fp32
template <typename T, template <
uint32_t seed_t, typename T,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false> uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{ {
std::ignore = id; std::ignore = id;
......
...@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector ...@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
template <typename X, template <typename X,
index_t I, index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false> typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
bool>::type = false>
__host__ __device__ constexpr auto GetAsType(Number<I> i) const __host__ __device__ constexpr auto GetAsType(Number<I> i) const
{ {
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
...@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector ...@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
template <typename X, template <typename X,
index_t I, index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false> typename enable_if<has_same_scalar_type<S, X>::value || !is_native_type<S>(),
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Number<I> i, X x) __host__ __device__ constexpr void SetAsType(Number<I> i, X x)
{ {
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/utility/array.hpp" #include "ck/utility/array.hpp"
namespace ck { namespace ck {
// Define the common macro for gfx94x models // Define the common macro for MI300 models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__ #define __gfx94__
#endif #endif
...@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -100,6 +100,18 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
template <>
inline __host__ __device__ constexpr f8_ocp_t type_convert<f8_ocp_t, int>(int x)
{
return f8_ocp_t{type_convert<f8_ocp_t::data_type>(x)};
}
template <>
inline __host__ __device__ constexpr bf8_ocp_t type_convert<bf8_ocp_t, int>(int x)
{
return bf8_ocp_t{type_convert<bf8_ocp_t::data_type>(x)};
}
// Convert X to Y // Convert X to Y
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x) __host__ __device__ constexpr Y type_convert_sp(X x)
...@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); ...@@ -163,7 +175,7 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
// convert fp32 to fp8 with stochastic rounding // convert fp32 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
...@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) ...@@ -189,33 +201,35 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils:: return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
rng); x, rng);
#endif #endif
} }
// convert fp16 to fp8 with stochastic rounding // convert fp16 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) inline __host__ __device__ f8_fnuz_t f8_convert_sr<f8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x)); return f8_convert_sr<f8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( f8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp32 to bf8 with stochastic rounding // convert fp32 to bf8 with stochastic rounding
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
...@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) ...@@ -240,28 +254,32 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils:: return utils::cast_to_f8<float,
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp16 to bf8 with stochastic rounding // convert fp16 to bf8 with stochastic rounding
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_fnuz_t f8_convert_sr<bf8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_sr<bf8_t>(type_convert<float>(x)); return f8_convert_sr<bf8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
...@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); ...@@ -271,7 +289,7 @@ __host__ __device__ constexpr Y f8_convert_rne(X x);
// convert fp32 to fp8 with rounding to nearest even // convert fp32 to fp8 with rounding to nearest even
template <> template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x) inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, float>(float x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
union union
...@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x) ...@@ -296,32 +314,34 @@ inline __host__ __device__ f8_t f8_convert_rne<f8_t, float>(float x)
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::
cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x, cast_to_f8<float, f8_fnuz_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
rng); x, rng);
#endif #endif
} }
// convert fp16 to fp8 with rounding to nearest even // convert fp16 to fp8 with rounding to nearest even
template <> template <>
inline __host__ __device__ f8_t f8_convert_rne<f8_t, half_t>(half_t x) inline __host__ __device__ f8_fnuz_t f8_convert_rne<f8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_rne<f8_t>(type_convert<float>(x)); return f8_convert_rne<f8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( f8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp32 to bf8 with rounding to nearest even // convert fp32 to bf8 with rounding to nearest even
template <> template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x) inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, float>(float x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
union union
...@@ -345,44 +365,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x) ...@@ -345,44 +365,59 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::cast_to_f8<float,
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif #endif
} }
// convert fp16 to bf8 with rounding to nearest even // convert fp16 to bf8 with rounding to nearest even
template <> template <>
inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_fnuz_t f8_convert_rne<bf8_fnuz_t, half_t>(half_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_rne<bf8_t>(type_convert<float>(x)); return f8_convert_rne<bf8_fnuz_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils:: return utils::cast_to_f8<half_t,
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( bf8_fnuz_t,
x, rng); negative_zero_nan,
clip,
(rm == f8_rounding_mode::stochastic)>(x, rng);
#endif
}
// convert fp32 to fp8
template <>
inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_fnuz_t>(x);
#else
return f8_convert_rne<f8_fnuz_t>(x);
#endif #endif
} }
// convert fp32 to fp8 // convert fp32 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, float>(float x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x); return f8_convert_sr<f8_ocp_t>(x);
#else #else
return f8_convert_rne<f8_t>(x); return f8_convert_rne<f8_ocp_t>(x);
#endif #endif
} }
// convert fp8 to fp32 // convert fp8 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) inline __host__ __device__ float type_convert<float, f8_fnuz_t>(f8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
float fval; float fval;
...@@ -392,30 +427,57 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) ...@@ -392,30 +427,57 @@ inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
return fval; return fval;
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x); return utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(x);
#endif #endif
} }
template <> template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x) inline __host__ __device__ float2_t type_convert<float2_t, f8x2_fnuz_t>(f8x2_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
const auto i16val = bit_cast<uint16_t>(x); const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0); return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x); const auto f8x2_v = vector_type<f8_fnuz_t, 2>(x);
vector_type<float, 2> f32x2_v; vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) = f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>( utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]); f8x2_v.template AsType<f8_fnuz_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) = f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>( utils::cast_from_f8<f8_fnuz_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]); f8x2_v.template AsType<f8_fnuz_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}]; return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif #endif
} }
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_t x)
{
#if CK_OCP_FP8_CVT_FAST_PATH
return fp8_impl::cast_to_f32x2_from_f8x2<f8_ocp_t::default_interpret>(
x.AsType<fp8_impl::fp8x2_storage_t>()[Number<0>{}]);
#else
return float2_t{fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<0>{}]),
fp8_impl::cast_from_f8<float, f8_ocp_t::wm, f8_ocp_t::we, false>(
x.AsType<fp8_storage_t>()[Number<1>{}])};
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, pk_i4_t>(pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f32 = ck::type_convert<float>(x_l);
auto h_f32 = ck::type_convert<float>(x_h);
return {l_f32, h_f32};
}
template <> template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x) inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{ {
...@@ -428,42 +490,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x) ...@@ -428,42 +490,64 @@ inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
// convert fp16 to fp8 // convert fp16 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) inline __host__ __device__ f8_fnuz_t type_convert<f8_fnuz_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_fnuz_t>(x);
#else
return f8_convert_rne<f8_fnuz_t>(x);
#endif
}
// convert fp16 to fp8
template <>
inline __host__ __device__ f8_ocp_t type_convert<f8_ocp_t, half_t>(half_t x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<f8_t>(x); return f8_convert_sr<f8_ocp_t>(x);
#else #else
return f8_convert_rne<f8_t>(x); return f8_convert_rne<f8_ocp_t>(x);
#endif #endif
} }
// convert fp8 to fp16 // convert fp8 to fp16
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) inline __host__ __device__ half_t type_convert<half_t, f8_fnuz_t>(f8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); return type_convert<half_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x); return utils::cast_from_f8<f8_fnuz_t, half_t, negative_zero_nan>(x);
#endif #endif
} }
// convert fp32 to bf8 // convert fp32 to bf8
template <> template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x) inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, float>(float x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x); return f8_convert_sr<bf8_fnuz_t>(x);
#else #else
return f8_convert_rne<bf8_t>(x); return f8_convert_rne<bf8_fnuz_t>(x);
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, float>(float x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_ocp_t>(x);
#else
return f8_convert_rne<bf8_ocp_t>(x);
#endif #endif
} }
// convert bf8 to fp32 // convert bf8 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x) inline __host__ __device__ float type_convert<float, bf8_fnuz_t>(bf8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
float fval; float fval;
...@@ -473,31 +557,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x) ...@@ -473,31 +557,42 @@ inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
return fval; return fval;
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x); return utils::cast_from_f8<bf8_fnuz_t, float, negative_zero_nan>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_fnuz_t type_convert<bf8_fnuz_t, half_t>(half_t x)
{
#if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_fnuz_t>(x);
#else
return f8_convert_rne<bf8_fnuz_t>(x);
#endif #endif
} }
// convert fp16 to bf8 // convert fp16 to bf8
template <> template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x) inline __host__ __device__ bf8_ocp_t type_convert<bf8_ocp_t, half_t>(half_t x)
{ {
#if CK_USE_SR_F8_CONVERSION #if CK_USE_SR_F8_CONVERSION
return f8_convert_sr<bf8_t>(x); return f8_convert_sr<bf8_ocp_t>(x);
#else #else
return f8_convert_rne<bf8_t>(x); return f8_convert_rne<bf8_ocp_t>(x);
#endif #endif
} }
// convert bf8 to fp16 // convert bf8 to fp16
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x) inline __host__ __device__ half_t type_convert<half_t, bf8_fnuz_t>(bf8_fnuz_t x)
{ {
#if defined(__gfx94__) #if defined(__gfx94__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); return type_convert<half_t>(type_convert<float>(x));
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x); return utils::cast_from_f8<bf8_fnuz_t, half_t, negative_zero_nan>(x);
#endif #endif
} }
......
# ck_tile [Back to the main page](../../README.md)
# Composable Kernel Tile
## concept ## concept
`ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator `ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator
- tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time. - tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time.
...@@ -44,5 +45,8 @@ our implementation of different device operators. ...@@ -44,5 +45,8 @@ our implementation of different device operators.
**[ops/epilogue]** **[ops/epilogue]**
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues. epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
**[ref]**
reference implementation of cpu or gpu. This folder is supposed to include a specific header on demand.
## examples ## examples
currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder. currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder.
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core/algorithm/cluster_descriptor.hpp" #include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/arch.hpp"
...@@ -24,6 +25,7 @@ ...@@ -24,6 +25,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/math.hpp"
...@@ -49,13 +51,20 @@ ...@@ -49,13 +51,20 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/type_traits.hpp"
......
...@@ -23,6 +23,7 @@ enum struct coord_transform_enum ...@@ -23,6 +23,7 @@ enum struct coord_transform_enum
replicate, replicate,
xor_t, xor_t,
offset, offset,
indexing,
}; };
template <index_t NDimLow, index_t NDimUp> template <index_t NDimLow, index_t NDimUp>
...@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1> ...@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1>
} }
}; };
template <typename UpLength, typename IndexingAdaptor>
struct indexing : public base_transform<1, 1>
{
static constexpr index_t NDimUp = 1;
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
UpLengths up_lengths_;
IndexingAdaptor iadaptor_;
CK_TILE_HOST_DEVICE constexpr indexing() = default;
CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
const IndexingAdaptor& iadaptor)
: up_lengths_{make_tuple(up_length)}, iadaptor_{iadaptor}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::indexing;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
iadaptor_.calculate_lower_index(idx_low, idx_up);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
iadaptor_.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up);
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
IndexingAdaptor::is_known_at_compile_time();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
//******************************************************************************************************* //*******************************************************************************************************
template <typename LowLength> template <typename LowLength>
...@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le ...@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
} }
} // namespace ck_tile } // namespace ck_tile
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
namespace ck_tile {
template <typename UpLength, typename Indices>
CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength& up_lengths,
const Indices& indices)
{
// by default we use the simplest one
return indexing<UpLength, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>>{
up_lengths, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>{indices}};
}
template <typename UpLength, typename IndexingAdaptor>
CK_TILE_HOST_DEVICE constexpr auto
make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingAdaptor& iadaptor)
{
return indexing<UpLength, IndexingAdaptor>{up_lengths, iadaptor};
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// pre-defined indexing adaptor used for indexing(scatter/gather)
// this version cache the index inside thread register(which is also prefered in real senario)
// however it's user's responsibility that each thread only provide one indexing, which means
// move coordinate will not change on this dim
template <typename IndexingType>
struct indexing_adaptor_onshot_cached
{
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached() = default;
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached(const IndexingType& idx)
: cached_idx_(idx)
{
}
IndexingType cached_idx_;
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& /*idx_up*/) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = cached_idx_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& /*idx_low*/,
const UpIdx& /*idx_up*/) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<IndexingType>::value;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -81,8 +81,10 @@ struct space_filling_curve ...@@ -81,8 +81,10 @@ struct space_filling_curve
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{}); return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
} }
// Do not use this function directly!
// TODO: can refactor into generic lambda in the future
template <index_t AccessIdx1d> template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr Index get_index(number<AccessIdx1d>) static CK_TILE_HOST_DEVICE constexpr Index _get_index(number<AccessIdx1d>)
{ {
#if 0 #if 0
/* /*
...@@ -153,11 +155,11 @@ struct space_filling_curve ...@@ -153,11 +155,11 @@ struct space_filling_curve
return idx_md; return idx_md;
} }
// FIXME: rename this function // FIXME: return tuple of number<>, which is compile time only variable
template <index_t AccessIdx1d> template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>) static CK_TILE_HOST_DEVICE constexpr auto get_index(number<AccessIdx1d>)
{ {
constexpr auto idx = get_index(number<AccessIdx1d>{}); constexpr auto idx = _get_index(number<AccessIdx1d>{});
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{}); return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
} }
......
...@@ -621,6 +621,158 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) ...@@ -621,6 +621,158 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add_if;
template <bool pre_nop>
struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(v_offset),
"v"(bit_cast<mbuf_t>(value)),
"s"(res.xy),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add;
template <bool pre_nop>
struct buffer_atomic_add<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag = 1*/)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
:
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
: "memory");
}
};
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
template<index_t N, typename T> struct smem_load_trait;
template<typename T> struct smem_load_trait<16, T> { using payload_t = fp32x4_t; };
template<typename T> struct smem_load_trait<8 , T> { using payload_t = fp32x2_t; };
template<typename T> struct smem_load_trait<4 , T> { using payload_t = float; };
template<typename T> struct smem_load_trait<2 , T> { using payload_t = float; };
template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
// clang-format on
} // namespace impl
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
template <index_t>
struct smem_load;
template <>
struct smem_load<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t;
asm volatile("ds_read_b128 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 8);
using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t;
asm volatile("ds_read_b64 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t;
asm volatile("ds_read_b32 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
asm volatile("ds_read_u16 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
asm volatile("ds_read_u8 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
// clang-format off // clang-format off
namespace impl{ namespace impl{
...@@ -717,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0) ...@@ -717,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// buffer load i8 // buffer load i8
CK_TILE_DEVICE_EXTERN int8_t CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
...@@ -976,6 +1133,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, ...@@ -976,6 +1133,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <bool pre_nop = false> template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -1136,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe ...@@ -1136,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
static_assert( static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
...@@ -1313,6 +1480,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, ...@@ -1313,6 +1480,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource, int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
index_t src_wave_addr_offset, index_t src_wave_addr_offset,
index_t src_linear_addr_offset,
index_t flag = 0, index_t flag = 0,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
...@@ -1327,7 +1495,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, ...@@ -1327,7 +1495,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
0, src_linear_addr_offset,
flag, flag,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -1337,7 +1505,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, ...@@ -1337,7 +1505,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
0, src_linear_addr_offset,
flag, flag,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -1365,6 +1533,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, ...@@ -1365,6 +1533,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
}
template <index_t N, template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default> amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data, CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
...@@ -1685,6 +1890,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr ...@@ -1685,6 +1890,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
int32x4_t dst_wave_buffer_resource, int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset, index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset, index_t dst_wave_addr_offset,
index_t dst_linear_addr_offset,
index_t is_valid_element = 1) index_t is_valid_element = 1)
{ {
constexpr index_t bytes = sizeof(T) * N; constexpr index_t bytes = sizeof(T) * N;
...@@ -1698,7 +1904,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr ...@@ -1698,7 +1904,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0, dst_linear_addr_offset,
is_valid_element); is_valid_element);
} }
else else
...@@ -1707,7 +1913,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr ...@@ -1707,7 +1913,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); dst_linear_addr_offset);
} }
} }
...@@ -2014,6 +2220,7 @@ template <typename T, ...@@ -2014,6 +2220,7 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave, const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_element_space_size, index_t src_element_space_size,
index_t is_valid_element = 0, index_t is_valid_element = 0,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
...@@ -2022,12 +2229,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, ...@@ -2022,12 +2229,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>( amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst, dst,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
0, 0,
src_linear_addr_offset,
is_valid_element, is_valid_element,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -2041,16 +2250,19 @@ template <typename T, ...@@ -2041,16 +2250,19 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const int32x4_t src_wave_buffer_resource, const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t is_valid_element = 0, index_t is_valid_element = 0,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>( amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst, dst,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
0, 0,
src_linear_addr_offset,
is_valid_element, is_valid_element,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -2066,6 +2278,7 @@ template <typename T, ...@@ -2066,6 +2278,7 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const T* p_src_wave, const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_element_space_size, index_t src_element_space_size,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
...@@ -2073,9 +2286,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, ...@@ -2073,9 +2286,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>( amd_async_buffer_load_impl<T, N, coherence>(smem,
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{}); src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
bool_constant<pre_nop>{});
} }
// This version support buffer resource as input arg // This version support buffer resource as input arg
...@@ -2086,12 +2304,42 @@ template <typename T, ...@@ -2086,12 +2304,42 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const int32x4_t src_wave_buffer_resource, const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>( amd_async_buffer_load_impl<T, N, coherence>(smem,
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{}); src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load<T, N, coherence>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
} }
// buffer_store requires: // buffer_store requires:
...@@ -2146,6 +2394,7 @@ template <typename T, ...@@ -2146,6 +2394,7 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data, CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave, T* p_dst_wave,
const index_t dst_thread_element_offset, const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid, const bool dst_thread_element_valid,
const index_t dst_element_space_size) const index_t dst_element_space_size)
{ {
...@@ -2153,11 +2402,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d ...@@ -2153,11 +2402,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data, amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
0, 0,
dst_linear_addr_offset,
dst_thread_element_valid); dst_thread_element_valid);
} }
...@@ -2191,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_ ...@@ -2191,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
#endif #endif
} }
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size,
bool_constant<pre_nop> = {})
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
if constexpr(oob_conditional_check)
{
buffer_atomic_add_if<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
dst_thread_element_valid);
}
else
{
buffer_atomic_add<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
1);
}
}
// buffer_atomic_max requires: // buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory // 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer. // 2) p_dst_wave must be a wavewise pointer.
...@@ -2221,16 +2511,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_ ...@@ -2221,16 +2511,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif #endif
} }
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <typename T, index_t NumElemsPerThread> template <typename T, index_t NumElemsPerThread>
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset, const index_t global_offset,
......
...@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds() ...@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
#endif #endif
} }
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
asm volatile("s_wait_loadcnt %0 \n"
"s_barrier_signal -1 \n"
"s_barrier_wait -1"
:
: "n"(cnt)
: "memory");
#else
asm volatile("s_waitcnt vmcnt(%0) \n"
"s_barrier"
:
: "n"(cnt)
: "memory");
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load() CK_TILE_DEVICE void block_sync_lds_direct_load()
{ {
asm volatile("\ asm volatile("\
......
...@@ -59,4 +59,71 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta) ...@@ -59,4 +59,71 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif #endif
} }
template <typename T>
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
{
#if 0
return __shfl(v_local, src_lane);
#elif 1
if constexpr(sizeof(int32_t) > sizeof(T))
{
union packet
{
int32_t x;
T v;
};
packet p;
p.v = v_local;
packet p_remote;
p_remote.x = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(p));
return p_remote.v;
}
else if constexpr(sizeof(int32_t) == sizeof(T))
{
const int32_t v_remote_tmp =
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
}
else
{
static_assert(sizeof(T) % sizeof(int32_t) == 0, "wrong!");
constexpr index_t elm = sizeof(T) / sizeof(int32_t);
using vector_type = thread_buffer<int32_t, elm>;
auto vs = bit_cast<vector_type>(v_local);
auto vs_remote = vector_type{};
static_for<0, elm, 1>{}([&](auto i_e) {
int32_t tmp = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(vs[i_e]));
vs_remote(i_e) = tmp;
});
return bit_cast<T>(vs_remote);
}
#endif
}
template <typename T>
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
{
static_assert(sizeof(T) == 4);
// per-thread v_flag store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
: [s_exec_flag] "=s"(exec_flag)
: [v_flag] "v"(v_flag));
return exec_flag;
}
template <typename X, typename Y>
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
{
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
// per-thread cmp store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
: [s_exec_flag] "=s"(exec_flag)
: [v_x] "v"(x), [v_y] "v"(y));
return exec_flag;
}
} // namespace ck_tile } // namespace ck_tile
...@@ -11,13 +11,15 @@ ...@@ -11,13 +11,15 @@
#define __gfx94__ #define __gfx94__
#endif #endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
defined(__gfx10_3_generic__)
#define __gfx103__ #define __gfx103__
#endif #endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__ #define __gfx11__
#endif #endif
#if defined(__gfx1200__) || defined(__gfx1201__) #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__ #define __gfx12__
#endif #endif
...@@ -32,13 +34,28 @@ ...@@ -32,13 +34,28 @@
#define CK_TILE_DEVICE inline __device__ #define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__ #define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__ #define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else #else
#define CK_TILE_HOST inline #define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline #define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline #define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN #define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif #endif
// implementing the "memory address space" attribute
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
#ifdef __HIPCC_
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
#else
#define CK_TILE_GENERIC_ADDR
#define CK_TILE_GLOBAL_ADDR
#define CK_TILE_LDS_ADDR
#define CK_TILE_BUF_RES_ADDR
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE #ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code #define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif #endif
...@@ -47,6 +64,7 @@ ...@@ -47,6 +64,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3 #define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#define CK_TILE_FLOAT_TO_BFLOAT16_RTA_ASM 4
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT #ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE #define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
...@@ -203,3 +221,12 @@ ...@@ -203,3 +221,12 @@
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA #ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1 #define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif #endif
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
...@@ -30,7 +30,7 @@ struct meta_data_buffer ...@@ -30,7 +30,7 @@ struct meta_data_buffer
{ {
constexpr index_t size = sizeof(T); constexpr index_t size = sizeof(T);
auto tmp = bit_cast<array<std::byte, size>>(data); auto tmp = ck_tile::bit_cast<array<std::byte, size>>(data);
for(int i = 0; i < size; i++) for(int i = 0; i < size; i++)
{ {
...@@ -66,7 +66,7 @@ struct meta_data_buffer ...@@ -66,7 +66,7 @@ struct meta_data_buffer
pos++; pos++;
} }
data = bit_cast<T>(tmp); data = ck_tile::bit_cast<T>(tmp);
} }
return data; return data;
...@@ -86,7 +86,7 @@ struct meta_data_buffer ...@@ -86,7 +86,7 @@ struct meta_data_buffer
pos++; pos++;
} }
auto data = bit_cast<T>(tmp); auto data = ck_tile::bit_cast<T>(tmp);
return data; return data;
} }
......
...@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>) ...@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{}); typename arithmetic_sequence_gen<0, N, 1>::type{});
} }
namespace impl {
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
template <index_t x,
index_t... xs,
index_t m,
index_t... ms,
index_t id,
index_t... ids,
index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x, xs...>,
sequence<m, ms...>,
sequence<id, ids...>,
SliceSize>
{
using old_scan =
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths =
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
using dim_slices =
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
using remaining_slice_sizes = typename sequence_merge<
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
typename old_scan::remaining_slice_sizes>::type;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, number<id>, number<0>>::value;
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
static constexpr index_t split_idx = std::
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
};
template <index_t x, index_t m, index_t id, index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
{
static constexpr auto slice_size = SliceSize;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths = sequence<slice_length>;
using dim_slices = sequence<x / slice_length>;
using remaining_slice_sizes =
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
};
} // namespace impl
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto reverse_slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
static_assert(Seq::size() == Mask::size());
using sliced_type =
impl::reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},
number<sliced_type::split_idx>{});
}
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
constexpr auto r =
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());
return make_tuple(r[number<0>{}].reverse(),
r[number<1>{}].reverse(),
number<Seq::size() - r[number<2>{}] - 1>{});
}
} // namespace ck_tile } // namespace ck_tile
...@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, ...@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{}); f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
} }
namespace detail {
template <typename F, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto embed_tuples_impl(F f, const X& x, sequence<Is...>)
{
return concat_tuple(f(x.at(number<Is>{}))...);
}
} // namespace detail
// make sure F return at least a tuple
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
// this function will return
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto embed_tuples(F f, const X& x)
{
return detail::embed_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
// By default unroll to the flatten // By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1> template <index_t Depth = 0, index_t MaxDepth = -1>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t) CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
...@@ -603,7 +623,7 @@ template <typename... Ys, ...@@ -603,7 +623,7 @@ template <typename... Ys,
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x) CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys); constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; }); static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y; return y;
...@@ -615,7 +635,7 @@ template <typename... Ys, ...@@ -615,7 +635,7 @@ template <typename... Ys,
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x) CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys); constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; }); static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y; return y;
...@@ -627,7 +647,7 @@ template <typename... Xs, ...@@ -627,7 +647,7 @@ template <typename... Xs,
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y) CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r; tuple<Xs...> r;
...@@ -635,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y) ...@@ -635,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] + y[i]; }, number<NSize>{});
}
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y) CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r; tuple<Xs...> r;
...@@ -649,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y) ...@@ -649,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] - y[i]; }, number<NSize>{});
}
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y) CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r; tuple<Xs...> r;
...@@ -686,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a) ...@@ -686,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
return a * x; return a * x;
} }
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] * y[i]; }, number<NSize>{});
}
template <typename... Xs, typename... Ys> template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y) CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
{ {
......
...@@ -18,6 +18,7 @@ enum class bf16_rounding_mode ...@@ -18,6 +18,7 @@ enum class bf16_rounding_mode
truncate_with_nan, truncate_with_nan,
truncate, truncate,
standard_asm, standard_asm,
rta_asm, // round to nearest away
}; };
template <bf16_rounding_mode rounding = template <bf16_rounding_mode rounding =
...@@ -180,6 +181,39 @@ uint16_t float_to_bf16_rtn_asm(float f) ...@@ -180,6 +181,39 @@ uint16_t float_to_bf16_rtn_asm(float f)
return uint16_t(u.int32); return uint16_t(u.int32);
} }
// TODO: do we need this on host?
CK_TILE_HOST
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rta_asm(float f)
{
union
{
float fp32;
struct
{
uint16_t lo;
uint16_t hi;
};
} u = {f};
const uint32_t low_nan = 0x7fff;
const uint32_t hi_nan = 0x7fff0000;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
: [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
: [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
// Note: in above code snipet, we use hi 16 bit
return u.hi;
}
// Truncate instead of rounding, preserving SNaN // Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f) constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
...@@ -213,6 +247,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round ...@@ -213,6 +247,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
return float_to_bf16_rtn_asm(f); return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan) else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f); return float_to_bf16_truc_nan_raw(f);
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
return float_to_bf16_rta_asm(f);
else else
return float_to_bf16_truc_raw(f); return float_to_bf16_truc_raw(f);
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace ck_tile {
// use int8_t directly for int8 arithemetic
// here one can use ck_tile::int8_t to access original int8_t
using int8_t = int8_t;
// limits
template <class T>
struct numeric;
template <>
struct numeric<int8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr int8_t min() { return int8_t(-128); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr int8_t lowest() { return int8_t(-128); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr int8_t max() { return int8_t(127); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr int8_t epsilon()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr int8_t round_error()
{
return 1; // not used
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr int8_t infinity()
{
return 1; // not used
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr int8_t quiet_NaN()
{
return 1; // not used
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr int8_t signaling_NaN()
{
return 1; // not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr int8_t denorm_min()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr int8_t zero() { return 0; }
};
#if 0
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<int8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#endif
CK_TILE_HOST_DEVICE
constexpr float int8_to_float(const int8_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
constexpr int8_t float_to_int8(const float& x) { return static_cast<int8_t>(x); }
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment