Commit 4d914af3 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 223a2abe 4b798833
...@@ -6,3 +6,8 @@ add_subdirectory(01_fmha) ...@@ -6,3 +6,8 @@ add_subdirectory(01_fmha)
add_subdirectory(02_layernorm2d) add_subdirectory(02_layernorm2d)
add_subdirectory(03_gemm) add_subdirectory(03_gemm)
add_subdirectory(04_img2col) add_subdirectory(04_img2col)
add_subdirectory(05_reduce)
add_subdirectory(06_permute)
add_subdirectory(09_topk_softmax)
add_subdirectory(10_rmsnorm2d)
add_subdirectory(11_add_rmsnorm2d_rdquant)
...@@ -237,7 +237,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -237,7 +237,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
Args... args) Args... args)
{ {
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
#define MEDIAN 1 #define MEDIAN 0
if(stream_config.time_kernel_) if(stream_config.time_kernel_)
{ {
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
...@@ -275,6 +275,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -275,6 +275,14 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
#else #else
float total_time = 0; float total_time = 0;
#endif #endif
hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < nrepeat; ++i) for(int i = 0; i < nrepeat; ++i)
{ {
if constexpr(!TimePreprocess) if constexpr(!TimePreprocess)
...@@ -282,13 +290,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -282,13 +290,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
preprocess(); preprocess();
} }
hipEvent_t start, stop; // hipEvent_t start, stop;
hip_check_error(hipEventCreate(&start)); // hip_check_error(hipEventCreate(&start));
hip_check_error(hipEventCreate(&stop)); // hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize()); // hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_)); // hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// calculate preprocess time // calculate preprocess time
if constexpr(TimePreprocess) if constexpr(TimePreprocess)
{ {
...@@ -299,25 +307,34 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -299,25 +307,34 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error(hipGetLastError()); hip_check_error(hipGetLastError());
// end real kernel // end real kernel
hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); // hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop)); // hip_check_error(hipEventSynchronize(stop));
float cur_time = 0; // float cur_time = 0;
hip_check_error(hipEventElapsedTime(&cur_time, start, stop)); // hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN // #if MEDIAN
times.insert(cur_time); // times.insert(cur_time);
#else // #else
total_time += cur_time; // total_time += cur_time;
#endif // #endif
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{ {
std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n", printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
static_cast<const void*>(gemm_args.p_a_grid), static_cast<const void*>(gemm_args.p_a_grid),
static_cast<const void*>(gemm_args.p_b_grid)); static_cast<const void*>(gemm_args.p_b_grid));
} }
} }
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));
float cur_time = 0;
hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
#if MEDIAN
times.insert(cur_time);
#else
total_time += cur_time;
#endif
#if MEDIAN #if MEDIAN
auto mid = times.begin(); auto mid = times.begin();
...@@ -333,7 +350,11 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -333,7 +350,11 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return (*mid + *mid_next) / 2; return (*mid + *mid_next) / 2;
} }
#else #else
return total_time / nrepeat; // return total_time / nrepeat;
hipDeviceProp_t deviceProps;
hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
return (total_time - preprocess_offset * nrepeat) / nrepeat;
#endif #endif
} }
else else
......
...@@ -352,7 +352,7 @@ struct BlockwiseGemmWMMA ...@@ -352,7 +352,7 @@ struct BlockwiseGemmWMMA
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run( wmma_gemm.template Run<>(
a_thread_vec.template AsType<wmma_input_type_a>(), a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(), b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
...@@ -406,7 +406,7 @@ struct BlockwiseGemmWMMA ...@@ -406,7 +406,7 @@ struct BlockwiseGemmWMMA
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
wmma_gemm.template Run( wmma_gemm.template Run<>(
a_thread_vec.template AsType<wmma_input_type_a>(), a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(), b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})); c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
......
...@@ -85,9 +85,9 @@ __global__ void ...@@ -85,9 +85,9 @@ __global__ void
BsPointer p_bs_grid, BsPointer p_bs_grid,
DsPointer p_ds_grid, DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid, EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op, CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...@@ -121,6 +121,19 @@ __global__ void ...@@ -121,6 +121,19 @@ __global__ void
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; });
if constexpr(is_same_v<AElementwiseOperation, element_wise::DynamicUnaryOp>)
{
a_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(is_same_v<BElementwiseOperation, element_wise::DynamicUnaryOp>)
{
b_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(is_same_v<CDEElementwiseOperation, element_wise::DynamicUnaryOp>)
{
cde_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(isMultiA || isMultiB) if constexpr(isMultiA || isMultiB)
{ {
AsPointer p_as_grid_grp; AsPointer p_as_grid_grp;
......
...@@ -272,6 +272,26 @@ struct MultiplyMultiply ...@@ -272,6 +272,26 @@ struct MultiplyMultiply
e = ck::type_convert<ck::bhalf_t>(x0_f); e = ck::type_convert<ck::bhalf_t>(x0_f);
} }
template <>
__host__ __device__ constexpr void operator()<ck::half_t, int, ck::half_t, ck::half_t>(
ck::half_t& e, const int& c, const ck::half_t& d0, const ck::half_t& d1) const
{
const float x0_f =
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
e = ck::type_convert<ck::half_t>(x0_f);
}
template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
{
const float x0_f =
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
e = ck::type_convert<ck::bhalf_t>(x0_f);
}
}; };
struct MultiplyAddFastGelu struct MultiplyAddFastGelu
...@@ -385,7 +405,7 @@ struct ScaleAddScaleAddRelu ...@@ -385,7 +405,7 @@ struct ScaleAddScaleAddRelu
const float& d1) const const float& d1) const
{ {
const float x = c * alpha1_ + alpha2_ * d0 + d1; const float x = c * alpha1_ + alpha2_ * d0 + d1;
Relu{}.template operator()<float>(e, x); e = x > 0 ? x : 0;
} }
template <> template <>
...@@ -396,7 +416,7 @@ struct ScaleAddScaleAddRelu ...@@ -396,7 +416,7 @@ struct ScaleAddScaleAddRelu
type_convert<float>(d1); type_convert<float>(d1);
float result = 0; float result = 0;
Relu{}.template operator()<float>(result, x); result = x > 0 ? x : 0;
e = type_convert<half_t>(result); e = type_convert<half_t>(result);
} }
...@@ -409,7 +429,7 @@ struct ScaleAddScaleAddRelu ...@@ -409,7 +429,7 @@ struct ScaleAddScaleAddRelu
type_convert<float>(d1); type_convert<float>(d1);
float result = 0; float result = 0;
Relu{}.template operator()<float>(result, x); result = x > 0 ? x : 0;
e = type_convert<bhalf_t>(result); e = type_convert<bhalf_t>(result);
} }
...@@ -421,7 +441,7 @@ struct ScaleAddScaleAddRelu ...@@ -421,7 +441,7 @@ struct ScaleAddScaleAddRelu
const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1; const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1;
float result = 0; float result = 0;
Relu{}.template operator()<float>(result, x); result = x > 0 ? x : 0;
e = type_convert<int8_t>(result); e = type_convert<int8_t>(result);
} }
......
...@@ -7,11 +7,38 @@ ...@@ -7,11 +7,38 @@
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/math_v2.hpp" #include "ck/utility/math_v2.hpp"
#include "ck/utility/type_convert.hpp" #include "ck/utility/type_convert.hpp"
#include <cassert>
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace element_wise { namespace element_wise {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct UnaryOpBase
{
public:
__host__ __device__ ~UnaryOpBase() = default;
__host__ __device__ constexpr UnaryOpBase() = default;
__host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default;
__host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default;
__host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default;
__host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default;
__host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0;
__host__ __device__ virtual inline void operator()(double& y, const double& x) const = 0;
__host__ __device__ virtual inline void operator()(int32_t& y, const int32_t& x) const = 0;
__host__ __device__ virtual inline void operator()(int8_t& y, const int8_t& x) const = 0;
__host__ __device__ virtual inline void operator()(half_t& y, const half_t& x) const = 0;
__host__ __device__ virtual inline void operator()(bhalf_t& y, const bhalf_t& x) const = 0;
};
struct PassThroughPack2 struct PassThroughPack2
{ {
template <typename Y, typename X> template <typename Y, typename X>
...@@ -25,17 +52,30 @@ struct PassThroughPack2 ...@@ -25,17 +52,30 @@ struct PassThroughPack2
constexpr const static bool is_pack2_invocable = true; constexpr const static bool is_pack2_invocable = true;
}; };
struct PassThrough struct PassThrough final : public UnaryOpBase
{ {
__host__ __device__ constexpr PassThrough() = default;
__host__ __device__ constexpr PassThrough(const PassThrough&) = default;
__host__ __device__ constexpr PassThrough(PassThrough&&) = default;
__host__ __device__ PassThrough& operator=(const PassThrough&) = default;
__host__ __device__ PassThrough& operator=(PassThrough&&) = default;
__host__ __device__ ~PassThrough() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; }
__host__ __device__ inline void operator()(double& y, const double& x) const final { y = x; }
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final { y = x; }
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final { y = x; }
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final { y = x; }
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final { y = x; }
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const __host__ __device__ void operator()<float, double>(float& y, const double& x) const
{ {
...@@ -48,36 +88,12 @@ struct PassThrough ...@@ -48,36 +88,12 @@ struct PassThrough
y = type_convert<double>(x); y = type_convert<double>(x);
} }
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const __host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
{ {
y = type_convert<half_t>(x); y = type_convert<half_t>(x);
} }
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{ {
...@@ -102,12 +118,6 @@ struct PassThrough ...@@ -102,12 +118,6 @@ struct PassThrough
y = type_convert<float>(x); y = type_convert<float>(x);
} }
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const __host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const
{ {
...@@ -407,20 +417,45 @@ struct UnarySquare ...@@ -407,20 +417,45 @@ struct UnarySquare
}; };
}; };
struct UnaryAbs struct UnaryAbs final : public UnaryOpBase
{ {
template <typename T> __host__ __device__ constexpr UnaryAbs() = default;
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default;
__host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default;
__host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default;
__host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default;
__host__ __device__ ~UnaryAbs() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || y = ck::math::abs(x);
is_same<T, half_t>::value || is_same<T, int32_t>::value || }
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = ck::math::abs(x); y = ck::math::abs(x);
}; }
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
y = ck::math::abs(x);
}
template <>
__host__ __device__ void operator()(f8_t& y, const f8_t& x) const __host__ __device__ void operator()(f8_t& y, const f8_t& x) const
{ {
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x))); y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
...@@ -439,20 +474,41 @@ struct UnarySqrt ...@@ -439,20 +474,41 @@ struct UnarySqrt
}; };
}; };
struct Relu struct Relu final : public UnaryOpBase
{ {
template <typename T> __host__ __device__ constexpr Relu() = default;
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ constexpr Relu(const Relu&) = default;
__host__ __device__ constexpr Relu(Relu&&) = default;
__host__ __device__ Relu& operator=(const Relu&) = default;
__host__ __device__ Relu& operator=(Relu&&) = default;
__host__ __device__ ~Relu() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0; y = x > 0 ? x : 0;
} }
template <> __host__ __device__ inline void operator()(double& y, const double& x) const final
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const {
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{ {
float x_f32 = ck::type_convert<float>(x); float x_f32 = ck::type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0; float y_f32 = x_f32 > 0 ? x_f32 : 0;
...@@ -599,18 +655,52 @@ struct Gelu ...@@ -599,18 +655,52 @@ struct Gelu
} }
}; };
struct Sigmoid struct Sigmoid final : public UnaryOpBase
{ {
template <typename T> __host__ __device__ constexpr Sigmoid() = default;
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ constexpr Sigmoid(const Sigmoid&) = default;
__host__ __device__ constexpr Sigmoid(Sigmoid&&) = default;
__host__ __device__ Sigmoid& operator=(const Sigmoid&) = default;
__host__ __device__ Sigmoid& operator=(Sigmoid&&) = default;
__host__ __device__ ~Sigmoid() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || constexpr float one = type_convert<float>(1);
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || y = one / (one + ck::math::exp(-x));
is_same<T, int32_t>::value, }
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1); __host__ __device__ inline void operator()(double& y, const double& x) const final
y = one / (one + ck::math::exp(-x)); {
}; constexpr double one = type_convert<double>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
constexpr int32_t one = type_convert<int32_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
constexpr int8_t one = type_convert<int8_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
constexpr half_t one = type_convert<half_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
constexpr float one = type_convert<float>(1);
float x_f32 = ck::type_convert<float>(x);
float y_f32 = one / (one + ck::math::exp(x_f32));
y = ck::type_convert<bhalf_t>(y_f32);
}
}; };
struct Silu struct Silu
...@@ -626,18 +716,44 @@ struct Silu ...@@ -626,18 +716,44 @@ struct Silu
}; };
}; };
struct TanH struct TanH final : public UnaryOpBase
{ {
template <typename T> __host__ __device__ constexpr TanH() = default;
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ constexpr TanH(const TanH&) = default;
__host__ __device__ constexpr TanH(TanH&&) = default;
__host__ __device__ TanH& operator=(const TanH&) = default;
__host__ __device__ TanH& operator=(TanH&&) = default;
__host__ __device__ ~TanH() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || y = ck::math::tanh(x);
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || }
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = ck::math::tanh(x); y = ck::math::tanh(x);
}; }
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
y = ck::math::tanh(x);
}
}; };
struct ACos struct ACos
...@@ -878,138 +994,418 @@ struct Rcp ...@@ -878,138 +994,418 @@ struct Rcp
}; };
}; };
struct Swish struct Swish final : public UnaryOpBase
{ {
Swish(float beta = 1.0f) : beta_(beta) {} __host__ __device__ constexpr Swish(const Swish&) = default;
__host__ __device__ constexpr Swish(Swish&&) = default;
__host__ __device__ ~Swish() = default;
__host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {}
__host__ __device__ float get_beta() const { return beta_; }
const float beta_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<float>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<double>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<int32_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<int8_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<half_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<bhalf_t>(x / (1.f + ck::math::exp(bx)));
}
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(is_same<X, float>::value || is_same<X, double>::value || static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<X, ck::half_t>::value, is_same<X, half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value || static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, ck::half_t>::value, is_same<Y, half_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x); float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx))); y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
}; }
const float beta_;
}; };
struct SoftRelu struct SoftRelu final : public UnaryOpBase
{ {
SoftRelu(float alpha = 1.f) : alpha_(alpha){}; __host__ __device__ constexpr SoftRelu(const SoftRelu&) = default;
__host__ __device__ constexpr SoftRelu(SoftRelu&&) = default;
__host__ __device__ ~SoftRelu() = default;
template <typename T> __host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {}
__host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || float casted_alpha = type_convert<float>(alpha_);
is_same<T, half_t>::value || is_same<T, int32_t>::value || constexpr float one = type_convert<float>(1);
is_same<T, int8_t>::value, y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
"Data type is not supported by this operation!"); }
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1); __host__ __device__ inline void operator()(double& y, const double& x) const final
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; {
double casted_alpha = type_convert<double>(alpha_);
constexpr double one = type_convert<double>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
constexpr int32_t one = type_convert<int32_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
constexpr int8_t one = type_convert<int8_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
constexpr half_t one = type_convert<half_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
constexpr bhalf_t one = type_convert<bhalf_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
} }
const float alpha_;
}; };
struct Power struct Power final : public UnaryOpBase
{ {
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) __host__ __device__ constexpr Power(const Power&) = default;
: alpha_(alpha), beta_(beta), gamma_(gamma){}; __host__ __device__ constexpr Power(Power&&) = default;
__host__ __device__ ~Power() = default;
template <typename T> __host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
__host__ __device__ void operator()(T& y, const T& x) const : alpha_(alpha), beta_(beta), gamma_(gamma)
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
} }
__host__ __device__ float get_alpha() const { return alpha_; }
__host__ __device__ float get_beta() const { return beta_; }
__host__ __device__ float get_gamma() const { return gamma_; }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
const float gamma_; const float gamma_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
float casted_beta = type_convert<float>(beta_);
float casted_gamma = type_convert<float>(gamma_);
float shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
double casted_beta = type_convert<double>(beta_);
double casted_gamma = type_convert<double>(gamma_);
double shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
int32_t casted_beta = type_convert<int32_t>(beta_);
int32_t casted_gamma = type_convert<int32_t>(gamma_);
int32_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
int8_t casted_beta = type_convert<int8_t>(beta_);
int8_t casted_gamma = type_convert<int8_t>(gamma_);
int8_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
half_t casted_beta = type_convert<half_t>(beta_);
half_t casted_gamma = type_convert<half_t>(gamma_);
half_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
bhalf_t casted_beta = type_convert<bhalf_t>(beta_);
bhalf_t casted_gamma = type_convert<bhalf_t>(gamma_);
bhalf_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
}; };
struct ClippedRelu struct ClippedRelu final : public UnaryOpBase
{ {
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; __host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default;
__host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default;
__host__ __device__ ~ClippedRelu() = default;
template <typename T> __host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f)
__host__ __device__ void operator()(T& y, const T& x) const : alpha_(alpha), beta_(beta)
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
} }
__host__ __device__ float get_alpha() const { return alpha_; }
__host__ __device__ float get_beta() const { return beta_; }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
float casted_beta = type_convert<float>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
double casted_beta = type_convert<double>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
int32_t casted_beta = type_convert<int32_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
int8_t casted_beta = type_convert<int8_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
half_t casted_beta = type_convert<half_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
bhalf_t casted_beta = type_convert<bhalf_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
}; };
struct LeakyRelu struct LeakyRelu final : public UnaryOpBase
{ {
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; __host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default;
__host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default;
__host__ __device__ ~LeakyRelu() = default;
template <typename T> __host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {}
__host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()([[maybe_unused]] bhalf_t& y,
[[maybe_unused]] const bhalf_t& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
} }
const float alpha_;
}; };
struct Elu struct Elu final : public UnaryOpBase
{ {
Elu(float alpha = 1.f) : alpha_(alpha){}; __host__ __device__ constexpr Elu(const Elu&) = default;
__host__ __device__ constexpr Elu(Elu&&) = default;
__host__ __device__ ~Elu() = default;
template <typename T> __host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {}
__host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || float casted_alpha = type_convert<float>(alpha_);
is_same<T, half_t>::value || is_same<T, int32_t>::value || y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
is_same<T, int8_t>::value, }
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); __host__ __device__ inline void operator()(double& y, const double& x) const final
y = x > 0 ? x : casted_alpha * ck::math::expm1(x); {
double casted_alpha = type_convert<double>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
} }
const float alpha_;
}; };
struct Logistic struct Logistic final : public UnaryOpBase
{ {
Logistic(float alpha = 1.f) : alpha_(alpha){}; __host__ __device__ constexpr Logistic(const Logistic&) = default;
__host__ __device__ constexpr Logistic(Logistic&&) = default;
__host__ __device__ ~Logistic() = default;
template <typename T> __host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {}
__host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || float casted_alpha = type_convert<float>(alpha_);
is_same<T, half_t>::value || is_same<T, int32_t>::value || constexpr float one = type_convert<float>(1);
is_same<T, int8_t>::value, y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
"Data type is not supported by this operation!"); }
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1); __host__ __device__ inline void operator()(double& y, const double& x) const final
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); {
double casted_alpha = type_convert<double>(alpha_);
constexpr double one = type_convert<double>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
constexpr int32_t one = type_convert<int32_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
constexpr int8_t one = type_convert<int8_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
constexpr half_t one = type_convert<half_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
constexpr bhalf_t one = type_convert<bhalf_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
} }
const float alpha_;
}; };
struct ConvInvscale struct ConvInvscale
...@@ -1074,7 +1470,7 @@ struct ConvScaleRelu ...@@ -1074,7 +1470,7 @@ struct ConvScaleRelu
__host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const __host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const
{ {
float x; float x;
Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_); Relu{}(x, c * scale_in_ * scale_wei_);
e = type_convert<f8_t>(x * scale_out_); e = type_convert<f8_t>(x * scale_out_);
}; };
...@@ -1153,6 +1549,255 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N> ...@@ -1153,6 +1549,255 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
__device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); }
}; };
struct DynamicUnaryOp
{
DynamicUnaryOp& operator=(const DynamicUnaryOp& other)
{
if(this != &other)
{
unary_op_ptr_ = other.unary_op_ptr_;
unary_op_type_ = other.unary_op_type_;
}
return *this;
}
__host__ __device__ DynamicUnaryOp() = delete;
__host__ __device__ DynamicUnaryOp(const Swish& swish)
{
unary_op_type_ = UnaryOpType::Swish;
beta = swish.get_beta();
}
__host__ __device__ DynamicUnaryOp(const Swish&& swish)
{
unary_op_type_ = UnaryOpType::Swish;
beta = swish.get_beta();
}
__host__ __device__ DynamicUnaryOp(const Sigmoid&) { unary_op_type_ = UnaryOpType::Sigmoid; }
__host__ __device__ DynamicUnaryOp(const Sigmoid&&) { unary_op_type_ = UnaryOpType::Sigmoid; }
__host__ __device__ DynamicUnaryOp(const PassThrough&)
{
unary_op_type_ = UnaryOpType::PassThrough;
}
__host__ __device__ DynamicUnaryOp(const PassThrough&&)
{
unary_op_type_ = UnaryOpType::PassThrough;
}
__host__ __device__ DynamicUnaryOp(const Logistic& logistic)
{
unary_op_type_ = UnaryOpType::Logistic;
alpha = logistic.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const Logistic&& logistic)
{
unary_op_type_ = UnaryOpType::Logistic;
alpha = logistic.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const TanH&) { unary_op_type_ = UnaryOpType::TanH; }
__host__ __device__ DynamicUnaryOp(const TanH&&) { unary_op_type_ = UnaryOpType::TanH; }
__host__ __device__ DynamicUnaryOp(const Relu&) { unary_op_type_ = UnaryOpType::Relu; }
__host__ __device__ DynamicUnaryOp(const Relu&&) { unary_op_type_ = UnaryOpType::Relu; }
__host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu)
{
unary_op_type_ = UnaryOpType::SoftRelu;
alpha = softrelu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu)
{
unary_op_type_ = UnaryOpType::SoftRelu;
alpha = softrelu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const UnaryAbs&) { unary_op_type_ = UnaryOpType::UnaryAbs; }
__host__ __device__ DynamicUnaryOp(const UnaryAbs&&) { unary_op_type_ = UnaryOpType::UnaryAbs; }
__host__ __device__ DynamicUnaryOp(const Power& pow)
{
unary_op_type_ = UnaryOpType::Power;
alpha = pow.get_alpha();
beta = pow.get_beta();
gamma = pow.get_gamma();
}
__host__ __device__ DynamicUnaryOp(const Power&& pow)
{
unary_op_type_ = UnaryOpType::Power;
alpha = pow.get_alpha();
beta = pow.get_beta();
gamma = pow.get_gamma();
}
__host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu)
{
unary_op_type_ = UnaryOpType::ClippedRelu;
alpha = clippedrelu.get_alpha();
beta = clippedrelu.get_beta();
}
__host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu)
{
unary_op_type_ = UnaryOpType::ClippedRelu;
alpha = clippedrelu.get_alpha();
beta = clippedrelu.get_beta();
}
__host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu)
{
unary_op_type_ = UnaryOpType::LeakyRelu;
alpha = leakyrelu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu)
{
unary_op_type_ = UnaryOpType::LeakyRelu;
alpha = leakyrelu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const Elu& elu)
{
unary_op_type_ = UnaryOpType::Elu;
alpha = elu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const Elu&& elu)
{
unary_op_type_ = UnaryOpType::Elu;
alpha = elu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op)
: unary_op_type_(dynamic_op.unary_op_type_),
unary_op_ptr_(dynamic_op.unary_op_ptr_),
alpha(dynamic_op.alpha),
beta(dynamic_op.beta),
gamma(dynamic_op.gamma)
{
}
__host__ __device__ ~DynamicUnaryOp()
{
switch(unary_op_type_)
{
case(UnaryOpType::Swish): delete static_cast<Swish*>(unary_op_ptr_); break;
case(UnaryOpType::Sigmoid): delete static_cast<Sigmoid*>(unary_op_ptr_); break;
case(UnaryOpType::PassThrough): delete static_cast<PassThrough*>(unary_op_ptr_); break;
case(UnaryOpType::Logistic): delete static_cast<Logistic*>(unary_op_ptr_); break;
case(UnaryOpType::TanH): delete static_cast<TanH*>(unary_op_ptr_); break;
case(UnaryOpType::Relu): delete static_cast<Relu*>(unary_op_ptr_); break;
case(UnaryOpType::SoftRelu): delete static_cast<SoftRelu*>(unary_op_ptr_); break;
case(UnaryOpType::UnaryAbs): delete static_cast<UnaryAbs*>(unary_op_ptr_); break;
case(UnaryOpType::Power): delete static_cast<Power*>(unary_op_ptr_); break;
case(UnaryOpType::ClippedRelu): delete static_cast<ClippedRelu*>(unary_op_ptr_); break;
case(UnaryOpType::LeakyRelu): delete static_cast<LeakyRelu*>(unary_op_ptr_); break;
case(UnaryOpType::Elu): delete static_cast<Elu*>(unary_op_ptr_); break;
default: break;
}
}
__device__ void InitUnaryOpPtrOnDevice()
{
switch(unary_op_type_)
{
case(UnaryOpType::Swish): unary_op_ptr_ = new Swish(beta); break;
case(UnaryOpType::Sigmoid): unary_op_ptr_ = new Sigmoid; break;
case(UnaryOpType::PassThrough): unary_op_ptr_ = new PassThrough; break;
case(UnaryOpType::Logistic): unary_op_ptr_ = new Logistic(alpha); break;
case(UnaryOpType::TanH): unary_op_ptr_ = new TanH; break;
case(UnaryOpType::Relu): unary_op_ptr_ = new Relu; break;
case(UnaryOpType::SoftRelu): unary_op_ptr_ = new SoftRelu(alpha); break;
case(UnaryOpType::UnaryAbs): unary_op_ptr_ = new UnaryAbs; break;
case(UnaryOpType::Power): unary_op_ptr_ = new Power(alpha, beta, gamma); break;
case(UnaryOpType::ClippedRelu): unary_op_ptr_ = new ClippedRelu(alpha, beta); break;
case(UnaryOpType::LeakyRelu): unary_op_ptr_ = new LeakyRelu(alpha); break;
case(UnaryOpType::Elu): unary_op_ptr_ = new Elu(alpha); break;
default: unary_op_ptr_ = nullptr; break;
}
}
template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const
{
isSupported<X, Y>();
unary_op_ptr_->operator()(y, x);
}
template <typename Y, typename X>
__host__ void operator()(Y& y, const X& x) const
{
isSupported<X, Y>();
switch(unary_op_type_)
{
case(UnaryOpType::Swish): Swish{}.operator()(y, x); break;
case(UnaryOpType::Sigmoid): Sigmoid{}.operator()(y, x); break;
case(UnaryOpType::PassThrough): PassThrough{}.operator()(y, x); break;
case(UnaryOpType::Logistic): Logistic{}.operator()(y, x); break;
case(UnaryOpType::TanH): TanH{}.operator()(y, x); break;
case(UnaryOpType::Relu): Relu{}.operator()(y, x); break;
case(UnaryOpType::SoftRelu): SoftRelu{}.operator()(y, x); break;
case(UnaryOpType::UnaryAbs): UnaryAbs{}.operator()(y, x); break;
case(UnaryOpType::Power): Power{}.operator()(y, x); break;
case(UnaryOpType::ClippedRelu): ClippedRelu{}.operator()(y, x); break;
case(UnaryOpType::LeakyRelu): LeakyRelu{}.operator()(y, x); break;
case(UnaryOpType::Elu): Elu{}.operator()(y, x); break;
default: break;
}
}
template <typename X, typename Y>
__device__ __host__ constexpr void isSupported() const
{
static_assert(std::is_same<X, Y>::value, "X and Y must be of the same type");
static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<X, bhalf_t>::value || is_same<X, half_t>::value ||
is_same<X, int32_t>::value || is_same<X, int8_t>::value,
"Data type is not supported by this operation!");
}
private:
enum class UnaryOpType
{
Swish,
Sigmoid,
PassThrough,
Logistic,
TanH,
Relu,
SoftRelu,
UnaryAbs,
Power,
ClippedRelu,
LeakyRelu,
Elu
};
public:
UnaryOpType unary_op_type_;
UnaryOpBase* unary_op_ptr_ = nullptr;
float alpha;
float beta;
float gamma;
};
#pragma clang diagnostic pop
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -327,12 +327,12 @@ struct intrin_mfma_i32_16x16x32i8<16, 16> ...@@ -327,12 +327,12 @@ struct intrin_mfma_i32_16x16x32i8<16, 16>
__device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
{ {
reg_c.template AsType<int32x4_t>()(Number<0>{}) = reg_c.template AsType<int32x4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_i32_16x16x32i8(bit_cast<int64_t>(reg_a), __builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast<int64_t>(reg_a),
bit_cast<int64_t>(reg_b), bit_cast<int64_t>(reg_b),
reg_c.template AsType<int32x4_t>()[Number<0>{}], reg_c.template AsType<int32x4_t>()[Number<0>{}],
0, 0,
0, 0,
0); 0);
} }
}; };
......
...@@ -1803,4 +1803,13 @@ struct NumericUtils<bf8_t> ...@@ -1803,4 +1803,13 @@ struct NumericUtils<bf8_t>
static constexpr int bias = 16; // negative zero nan mode static constexpr int bias = 16; // negative zero nan mode
// static constexpr int bias = 15; // ieee mode // static constexpr int bias = 15; // ieee mode
}; };
template <>
struct NumericUtils<bhalf_t>
{
static constexpr int exp = 8;
static constexpr int mant = 7;
static constexpr int bias = 128; // negative zero nan mode
// static constexpr int bias = 127; // ieee mode
};
} // namespace ck } // namespace ck
...@@ -653,7 +653,7 @@ inline __device__ double sin<double>(double x) ...@@ -653,7 +653,7 @@ inline __device__ double sin<double>(double x)
template <> template <>
inline __device__ half_t sin<half_t>(half_t x) inline __device__ half_t sin<half_t>(half_t x)
{ {
return ::hsin(x); return hsin(static_cast<__half>(x));
}; };
template <typename T> template <typename T>
...@@ -785,7 +785,7 @@ inline __device__ double ceil<double>(double x) ...@@ -785,7 +785,7 @@ inline __device__ double ceil<double>(double x)
template <> template <>
inline __device__ half_t ceil<half_t>(half_t x) inline __device__ half_t ceil<half_t>(half_t x)
{ {
return ::hceil(x); return hceil(static_cast<__half>(x));
}; };
template <typename T> template <typename T>
...@@ -827,7 +827,7 @@ inline __device__ double floor<double>(double x) ...@@ -827,7 +827,7 @@ inline __device__ double floor<double>(double x)
template <> template <>
inline __device__ half_t floor<half_t>(half_t x) inline __device__ half_t floor<half_t>(half_t x)
{ {
return ::hfloor(x); return hfloor(static_cast<__half>(x));
}; };
template <typename T> template <typename T>
...@@ -849,7 +849,7 @@ inline __device__ T exp(T x) ...@@ -849,7 +849,7 @@ inline __device__ T exp(T x)
template <> template <>
inline __device__ half_t exp<half_t>(half_t x) inline __device__ half_t exp<half_t>(half_t x)
{ {
return hexp(x); return hexp(static_cast<__half>(x));
}; };
template <> template <>
...@@ -873,7 +873,7 @@ inline __device__ T log(T x) ...@@ -873,7 +873,7 @@ inline __device__ T log(T x)
template <> template <>
inline __device__ half_t log<half_t>(half_t x) inline __device__ half_t log<half_t>(half_t x)
{ {
return hlog(x); return hlog(static_cast<__half>(x));
}; };
template <> template <>
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core/algorithm/cluster_descriptor.hpp" #include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/arch.hpp"
...@@ -24,6 +25,7 @@ ...@@ -24,6 +25,7 @@
#include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/math.hpp"
...@@ -49,13 +51,17 @@ ...@@ -49,13 +51,17 @@
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/type_traits.hpp"
......
...@@ -23,6 +23,7 @@ enum struct coord_transform_enum ...@@ -23,6 +23,7 @@ enum struct coord_transform_enum
replicate, replicate,
xor_t, xor_t,
offset, offset,
indexing,
}; };
template <index_t NDimLow, index_t NDimUp> template <index_t NDimLow, index_t NDimUp>
...@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1> ...@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1>
} }
}; };
template <typename UpLength, typename IndexingAdaptor>
struct indexing : public base_transform<1, 1>
{
static constexpr index_t NDimUp = 1;
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
UpLengths up_lengths_;
IndexingAdaptor iadaptor_;
CK_TILE_HOST_DEVICE constexpr indexing() = default;
CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
const IndexingAdaptor& iadaptor)
: up_lengths_{make_tuple(up_length)}, iadaptor_{iadaptor}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::indexing;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
iadaptor_.calculate_lower_index(idx_low, idx_up);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
iadaptor_.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up);
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
IndexingAdaptor::is_known_at_compile_time();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
//******************************************************************************************************* //*******************************************************************************************************
template <typename LowLength> template <typename LowLength>
...@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le ...@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
} }
} // namespace ck_tile } // namespace ck_tile
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
namespace ck_tile {
template <typename UpLength, typename Indices>
CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength& up_lengths,
const Indices& indices)
{
// by default we use the simplest one
return indexing<UpLength, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>>{
up_lengths, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>{indices}};
}
template <typename UpLength, typename IndexingAdaptor>
CK_TILE_HOST_DEVICE constexpr auto
make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingAdaptor& iadaptor)
{
return indexing<UpLength, IndexingAdaptor>{up_lengths, iadaptor};
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// pre-defined indexing adaptor used for indexing(scatter/gather)
// this version cache the index inside thread register(which is also prefered in real senario)
// however it's user's responsibility that each thread only provide one indexing, which means
// move coordinate will not change on this dim
template <typename IndexingType>
struct indexing_adaptor_onshot_cached
{
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached() = default;
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached(const IndexingType& idx)
: cached_idx_(idx)
{
}
IndexingType cached_idx_;
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& /*idx_up*/) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = cached_idx_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& /*idx_low*/,
const UpIdx& /*idx_up*/) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<IndexingType>::value;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -81,8 +81,10 @@ struct space_filling_curve ...@@ -81,8 +81,10 @@ struct space_filling_curve
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{}); return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
} }
// Do not use this function directly!
// TODO: can refactor into generic lambda in the future
template <index_t AccessIdx1d> template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr Index get_index(number<AccessIdx1d>) static CK_TILE_HOST_DEVICE constexpr Index _get_index(number<AccessIdx1d>)
{ {
#if 0 #if 0
/* /*
...@@ -153,11 +155,11 @@ struct space_filling_curve ...@@ -153,11 +155,11 @@ struct space_filling_curve
return idx_md; return idx_md;
} }
// FIXME: rename this function // FIXME: return tuple of number<>, which is compile time only variable
template <index_t AccessIdx1d> template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>) static CK_TILE_HOST_DEVICE constexpr auto get_index(number<AccessIdx1d>)
{ {
constexpr auto idx = get_index(number<AccessIdx1d>{}); constexpr auto idx = _get_index(number<AccessIdx1d>{});
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{}); return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
} }
......
...@@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) ...@@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
} }
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
template<index_t N, typename T> struct smem_load_trait;
template<typename T> struct smem_load_trait<16, T> { using payload_t = fp32x4_t; };
template<typename T> struct smem_load_trait<8 , T> { using payload_t = fp32x2_t; };
template<typename T> struct smem_load_trait<4 , T> { using payload_t = float; };
template<typename T> struct smem_load_trait<2 , T> { using payload_t = float; };
template<typename T> struct smem_load_trait<1 , T> { using payload_t = float; };
// clang-format on
} // namespace impl
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
template <index_t>
struct smem_load;
template <>
struct smem_load<16>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t;
asm volatile("ds_read_b128 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<8>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 8);
using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t;
asm volatile("ds_read_b64 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<4>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t;
asm volatile("ds_read_b32 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<2>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
asm volatile("ds_read_u16 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
template <>
struct smem_load<1>
{
template <typename T>
CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset)
{
static_assert(sizeof(T) == 4);
using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t;
asm volatile("ds_read_u8 %0, %1 offset:%2"
: "=v"(reinterpret_cast<mbuf_t&>(value)) // ! direct write
: "v"(v_offset), "n"(i_offset)
: "memory");
}
};
// clang-format off // clang-format off
namespace impl{ namespace impl{
...@@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, ...@@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int soffset, // dst_wave_addr_offset int soffset, // dst_wave_addr_offset
int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64");
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <bool pre_nop = false> template <bool pre_nop = false>
CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem,
int32x4_t rsrc, int32x4_t rsrc,
...@@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, ...@@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t src_wave_buffer_resource, int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset, index_t src_thread_addr_offset,
index_t src_wave_addr_offset, index_t src_wave_addr_offset,
index_t src_linear_addr_offset,
index_t flag = 0, index_t flag = 0,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
...@@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, ...@@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
0, src_linear_addr_offset,
flag, flag,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst, ...@@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
src_wave_addr_offset, src_wave_addr_offset,
0, src_linear_addr_offset,
flag, flag,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, ...@@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size");
if constexpr(oob_conditional_check)
{
index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
v_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
else
{
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
smem,
sizeof(uint32_t),
src_thread_addr_offset,
src_wave_addr_offset,
src_immediate_addr_offset,
static_cast<index_t>(coherence));
}
}
template <index_t N, template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default> amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data, CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
...@@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr ...@@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
int32x4_t dst_wave_buffer_resource, int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset, index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset, index_t dst_wave_addr_offset,
index_t dst_linear_addr_offset,
index_t is_valid_element = 1) index_t is_valid_element = 1)
{ {
constexpr index_t bytes = sizeof(T) * N; constexpr index_t bytes = sizeof(T) * N;
...@@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr ...@@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0, dst_linear_addr_offset,
is_valid_element); is_valid_element);
} }
else else
...@@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr ...@@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); dst_linear_addr_offset);
} }
} }
...@@ -2014,6 +2156,7 @@ template <typename T, ...@@ -2014,6 +2156,7 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const T* p_src_wave, const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_element_space_size, index_t src_element_space_size,
index_t is_valid_element = 0, index_t is_valid_element = 0,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
...@@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, ...@@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>( amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst, dst,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
0, 0,
src_linear_addr_offset,
is_valid_element, is_valid_element,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -2041,16 +2186,19 @@ template <typename T, ...@@ -2041,16 +2186,19 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst, CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
const int32x4_t src_wave_buffer_resource, const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t is_valid_element = 0, index_t is_valid_element = 0,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>( amd_buffer_load_raw_impl<T, N, coherence, oob_conditional_check, pre_nop>(
dst, dst,
src_wave_buffer_resource, src_wave_buffer_resource,
src_thread_addr_offset, src_thread_addr_offset,
0, 0,
src_linear_addr_offset,
is_valid_element, is_valid_element,
bool_constant<pre_nop>{}); bool_constant<pre_nop>{});
} }
...@@ -2066,6 +2214,7 @@ template <typename T, ...@@ -2066,6 +2214,7 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const T* p_src_wave, const T* p_src_wave,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_element_space_size, index_t src_element_space_size,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
...@@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, ...@@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T));
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>( amd_async_buffer_load_impl<T, N, coherence>(smem,
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{}); src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
bool_constant<pre_nop>{});
} }
// This version support buffer resource as input arg // This version support buffer resource as input arg
...@@ -2086,12 +2240,42 @@ template <typename T, ...@@ -2086,12 +2240,42 @@ template <typename T,
CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
const int32x4_t src_wave_buffer_resource, const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset, index_t src_thread_element_offset,
index_t src_linear_element_offset,
bool_constant<pre_nop> = {}) bool_constant<pre_nop> = {})
{ {
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load_impl<T, N, coherence>( amd_async_buffer_load_impl<T, N, coherence>(smem,
smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant<pre_nop>{}); src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
bool_constant<pre_nop>{});
}
// This version support buffer resource as input arg
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = false>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
const int32x4_t src_wave_buffer_resource,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
bool is_valid_element,
bool_constant<oob_conditional_check> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
amd_async_buffer_load<T, N, coherence>(smem,
src_wave_buffer_resource,
src_thread_addr_offset,
0,
src_linear_addr_offset,
is_valid_element,
bool_constant<oob_conditional_check>{});
} }
// buffer_store requires: // buffer_store requires:
...@@ -2146,6 +2330,7 @@ template <typename T, ...@@ -2146,6 +2330,7 @@ template <typename T,
CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data, CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave, T* p_dst_wave,
const index_t dst_thread_element_offset, const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid, const bool dst_thread_element_valid,
const index_t dst_element_space_size) const index_t dst_element_space_size)
{ {
...@@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d ...@@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data, amd_buffer_store_raw_impl<T, N, coherence, oob_conditional_check>(src_thread_data,
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
0, 0,
dst_linear_addr_offset,
dst_thread_element_valid); dst_thread_element_valid);
} }
...@@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_ ...@@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif #endif
} }
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN void
llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc,
__attribute__((address_space(3))) uint32_t* lds_ptr,
index_t size,
index_t voffset,
index_t soffset,
index_t offset,
index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds");
template <typename T, index_t NumElemsPerThread> template <typename T, index_t NumElemsPerThread>
CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t global_offset, const index_t global_offset,
......
...@@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta) ...@@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif #endif
} }
template <typename T>
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
{
#if 0
return __shfl(v_local, src_lane);
#elif 1
if constexpr(sizeof(int32_t) > sizeof(T))
{
union packet
{
int32_t x;
T v;
};
packet p;
p.v = v_local;
packet p_remote;
p_remote.x = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(p));
return p_remote.v;
}
else if constexpr(sizeof(int32_t) == sizeof(T))
{
const int32_t v_remote_tmp =
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
}
else
{
static_assert(sizeof(T) % sizeof(int32_t) == 0, "wrong!");
constexpr index_t elm = sizeof(T) / sizeof(int32_t);
using vector_type = thread_buffer<int32_t, elm>;
auto vs = bit_cast<vector_type>(v_local);
auto vs_remote = vector_type{};
static_for<0, elm, 1>{}([&](auto i_e) {
int32_t tmp = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(vs[i_e]));
vs_remote(i_e) = tmp;
});
return bit_cast<T>(vs_remote);
}
#endif
}
} // namespace ck_tile } // namespace ck_tile
...@@ -32,13 +32,28 @@ ...@@ -32,13 +32,28 @@
#define CK_TILE_DEVICE inline __device__ #define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__ #define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__ #define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else #else
#define CK_TILE_HOST inline #define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline #define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline #define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN #define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif #endif
// implementing the "memory address space" attribute
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
#ifdef __HIPCC_
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
#else
#define CK_TILE_GENERIC_ADDR
#define CK_TILE_GLOBAL_ADDR
#define CK_TILE_LDS_ADDR
#define CK_TILE_BUF_RES_ADDR
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE #ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code #define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif #endif
...@@ -203,3 +218,8 @@ ...@@ -203,3 +218,8 @@
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA #ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1 #define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif #endif
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif
...@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>) ...@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
typename arithmetic_sequence_gen<0, N, 1>::type{}); typename arithmetic_sequence_gen<0, N, 1>::type{});
} }
namespace impl {
template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl;
template <index_t x,
index_t... xs,
index_t m,
index_t... ms,
index_t id,
index_t... ids,
index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x, xs...>,
sequence<m, ms...>,
sequence<id, ids...>,
SliceSize>
{
using old_scan =
reverse_slice_sequence_impl<sequence<xs...>, sequence<ms...>, sequence<ids...>, SliceSize>;
static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths =
typename sequence_merge<sequence<slice_length>, typename old_scan::dim_lengths>::type;
using dim_slices =
typename sequence_merge<sequence<x / slice_length>, typename old_scan::dim_slices>::type;
using remaining_slice_sizes = typename sequence_merge<
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>,
typename old_scan::remaining_slice_sizes>::type;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t _split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t _split_idx =
std::conditional_t<_split_flag, number<id>, number<0>>::value;
static constexpr index_t split_flag = _split_flag || old_scan::split_flag;
static constexpr index_t split_idx = std::
conditional_t<old_scan::split_flag, number<old_scan::split_idx>, number<_split_idx>>::value;
};
template <index_t x, index_t m, index_t id, index_t SliceSize>
struct reverse_slice_sequence_impl<sequence<x>, sequence<m>, sequence<id>, SliceSize>
{
static constexpr auto slice_size = SliceSize;
static constexpr auto slice_length =
std::conditional_t<m, number<gcd(x, slice_size)>, number<x>>::value;
using dim_lengths = sequence<slice_length>;
using dim_slices = sequence<x / slice_length>;
using remaining_slice_sizes =
std::conditional_t<m, sequence<slice_size / slice_length>, sequence<slice_size>>;
// the first idx that sliced length not equal to original length
static constexpr index_t _flag =
slice_length != x && remaining_slice_sizes{}.front().value == 1;
static constexpr index_t split_flag = std::conditional_t<m, number<_flag>, number<0>>::value;
static constexpr index_t split_idx =
std::conditional_t<split_flag, number<id>, number<0>>::value;
};
} // namespace impl
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto reverse_slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
static_assert(Seq::size() == Mask::size());
using sliced_type =
impl::reverse_slice_sequence_impl<Seq,
Mask,
typename arithmetic_sequence_gen<0, Seq::size(), 1>::type,
SliceSize>;
static_assert(sliced_type::remaining_slice_sizes::front().value == 1,
"can not evenly divide this sequence, please check");
return make_tuple(typename sliced_type::dim_lengths{},
typename sliced_type::dim_slices{},
number<sliced_type::split_idx>{});
}
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
constexpr auto r =
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());
return make_tuple(r[number<0>{}].reverse(),
r[number<1>{}].reverse(),
number<Seq::size() - r[number<2>{}] - 1>{});
}
} // namespace ck_tile } // namespace ck_tile
...@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, ...@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{}); f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
} }
namespace detail {
template <typename F, typename X, index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto embed_tuples_impl(F f, const X& x, sequence<Is...>)
{
return concat_tuple(f(x.at(number<Is>{}))...);
}
} // namespace detail
// make sure F return at least a tuple
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
// this function will return
template <typename F, typename X>
CK_TILE_HOST_DEVICE constexpr auto embed_tuples(F f, const X& x)
{
return detail::embed_tuples_impl(
f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
// By default unroll to the flatten // By default unroll to the flatten
template <index_t Depth = 0, index_t MaxDepth = -1> template <index_t Depth = 0, index_t MaxDepth = -1>
CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t) CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t)
...@@ -603,7 +623,7 @@ template <typename... Ys, ...@@ -603,7 +623,7 @@ template <typename... Ys,
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x) CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys); constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; }); static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y; return y;
...@@ -615,7 +635,7 @@ template <typename... Ys, ...@@ -615,7 +635,7 @@ template <typename... Ys,
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x) CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
{ {
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys); constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; }); static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y; return y;
...@@ -627,7 +647,7 @@ template <typename... Xs, ...@@ -627,7 +647,7 @@ template <typename... Xs,
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y) CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r; tuple<Xs...> r;
...@@ -635,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y) ...@@ -635,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] + y[i]; }, number<NSize>{});
}
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y) CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r; tuple<Xs...> r;
...@@ -649,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y) ...@@ -649,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
return r; return r;
} }
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] - y[i]; }, number<NSize>{});
}
template <typename... Xs, template <typename... Xs,
typename Y, typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> = std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false> false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y) CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
{ {
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs); constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r; tuple<Xs...> r;
...@@ -686,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a) ...@@ -686,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
return a * x; return a * x;
} }
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] * y[i]; }, number<NSize>{});
}
template <typename... Xs, typename... Ys> template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y) CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace ck_tile {
// use int8_t directly for int8 arithemetic
// here one can use ck_tile::int8_t to access original int8_t
using int8_t = int8_t;
// limits
template <class T>
struct numeric;
template <>
struct numeric<int8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr int8_t min() { return int8_t(-128); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr int8_t lowest() { return int8_t(-128); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr int8_t max() { return int8_t(127); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr int8_t epsilon()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr int8_t round_error()
{
return 1; // not used
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr int8_t infinity()
{
return 1; // not used
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr int8_t quiet_NaN()
{
return 1; // not used
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr int8_t signaling_NaN()
{
return 1; // not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr int8_t denorm_min()
{
return 1; // not used
}
CK_TILE_HOST_DEVICE static constexpr int8_t zero() { return 0; }
};
#if 0
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<int8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#endif
CK_TILE_HOST_DEVICE
constexpr float int8_to_float(const int8_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
constexpr int8_t float_to_int8(const float& x) { return static_cast<int8_t>(x); }
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -487,55 +487,12 @@ struct log2e<float> ...@@ -487,55 +487,12 @@ struct log2e<float>
template <typename T = double> template <typename T = double>
constexpr T log2e_v = log2e<T>::value; constexpr T log2e_v = log2e<T>::value;
// math
CK_TILE_HOST_DEVICE
float abs(const float& x)
{
union
{
float f32;
uint32_t u32;
} y;
y.f32 = x;
y.u32 = y.u32 & 0x7fffffff;
return y.f32;
}
CK_TILE_HOST_DEVICE
bool isnan(const float& x)
{
uint32_t xx = bit_cast<uint32_t>(x);
return (xx & 0x7fffffff) > 0x7F800000;
}
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
CK_TILE_DEVICE
float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE
double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
CK_TILE_DEVICE
float exp(float x) { return __ocml_exp_f32(x); };
CK_TILE_HOST
float exp(float x) { return std::expf(x); }
CK_TILE_DEVICE CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); }; float exp2(float x) { return exp2f(x); };
CK_TILE_HOST CK_TILE_HOST
float exp2(float x) { return std::exp2f(x); }; float exp2(float x) { return std::exp2f(x); };
CK_TILE_DEVICE
float log(float x) { return __logf(x); };
CK_TILE_HOST
float log(float x) { return std::logf(x); };
CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc) CK_TILE_DEVICE uint16_t sad_u16(uint16_t x, uint16_t y, uint16_t acc)
{ {
return __builtin_amdgcn_sad_u16(x, y, acc); return __builtin_amdgcn_sad_u16(x, y, acc);
...@@ -554,4 +511,933 @@ CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc) ...@@ -554,4 +511,933 @@ CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
return (x > y ? (x - y) : (y - x)) + acc; return (x > y ? (x - y) : (y - x)) + acc;
} }
///////////////////////////////////////////////////////////////
} // namespace ck_tile
// blow function need data type pre-defined
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
namespace ck_tile {
#if CK_TILE_WORKAROUND_SWDEV_383542
extern "C" CK_TILE_DEVICE float __ocml_native_recip_f32(float);
#endif
// math functions for the host, some are implemented by calling C++ std functions
CK_TILE_HOST float abs(float x) { return std::abs(x); };
CK_TILE_HOST double abs(double x) { return std::abs(x); };
CK_TILE_HOST int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
CK_TILE_HOST int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
CK_TILE_HOST fp16_t abs(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
return abs_x;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST int4_t abs(int4_t x)
{
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
}
#endif
CK_TILE_HOST bool isnan(float x) { return std::isnan(x); };
CK_TILE_HOST bool isnan(double x) { return std::isnan(x); };
CK_TILE_HOST bool isnan(int8_t x)
{
(void)x;
return false;
};
CK_TILE_HOST bool isnan(int32_t x)
{
(void)x;
return false;
};
CK_TILE_HOST bool isnan(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST bool isnan(int4_t x)
{
(void)x;
return false;
};
#endif
CK_TILE_HOST fp16_t sqrt(fp16_t x)
{
return static_cast<fp16_t>(std::sqrt(static_cast<float>(x)));
};
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
template <typename T>
CK_TILE_HOST T tanh(T x)
{
return type_convert<T>(std::tanhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float tanh<float>(float x)
{
return std::tanhf(x);
};
template <>
CK_TILE_HOST double tanh<double>(double x)
{
return std::tanh(x);
};
template <typename T>
CK_TILE_HOST T acos(T x)
{
return type_convert<T>(std::acosf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float acos<float>(float x)
{
return std::acosf(x);
};
template <>
CK_TILE_HOST double acos<double>(double x)
{
return std::acos(x);
};
template <typename T>
CK_TILE_HOST T neg(T x)
{
return type_convert<T>(-(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float neg<float>(float x)
{
return -x;
};
template <>
CK_TILE_HOST double neg<double>(double x)
{
return -x;
};
template <>
CK_TILE_HOST int32_t neg<int32_t>(int32_t x)
{
return -x;
};
template <>
CK_TILE_HOST int8_t neg<int8_t>(int8_t x)
{
return -x;
};
template <typename T>
CK_TILE_HOST T atan(T x)
{
return type_convert<T>(std::atanf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float atan<float>(float x)
{
return std::atanf(x);
};
template <>
CK_TILE_HOST double atan<double>(double x)
{
return std::atan(x);
};
template <typename T>
CK_TILE_HOST T sin(T x)
{
return type_convert<T>(std::sinf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float sin<float>(float x)
{
return std::sinf(x);
};
template <>
CK_TILE_HOST double sin<double>(double x)
{
return std::sin(x);
};
template <typename T>
CK_TILE_HOST T asin(T x)
{
return type_convert<T>(std::asinf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float asin<float>(float x)
{
return std::asinf(x);
};
template <>
CK_TILE_HOST double asin<double>(double x)
{
return std::asin(x);
};
template <typename T>
CK_TILE_HOST T asinh(T x)
{
return type_convert<T>(std::asinhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float asinh<float>(float x)
{
return std::asinhf(x);
};
template <>
CK_TILE_HOST double asinh<double>(double x)
{
return std::asinh(x);
};
template <typename T>
CK_TILE_HOST T cos(T x)
{
return type_convert<T>(std::cosf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float cos<float>(float x)
{
return std::cosf(x);
};
template <>
CK_TILE_HOST double cos<double>(double x)
{
return std::cos(x);
};
template <typename T>
CK_TILE_HOST T acosh(T x)
{
return type_convert<T>(std::acoshf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float acosh<float>(float x)
{
return std::acoshf(x);
};
template <>
CK_TILE_HOST double acosh<double>(double x)
{
return std::acosh(x);
};
template <typename T>
CK_TILE_HOST T tan(T x)
{
return type_convert<T>(std::tanf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float tan<float>(float x)
{
return std::tanf(x);
};
template <>
CK_TILE_HOST double tan<double>(double x)
{
return std::tan(x);
};
template <typename T>
CK_TILE_HOST T atanh(T x)
{
return type_convert<T>(std::atanhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float atanh<float>(float x)
{
return std::atanhf(x);
};
template <>
CK_TILE_HOST double atanh<double>(double x)
{
return std::atanh(x);
};
template <typename T>
CK_TILE_HOST T sinh(T x)
{
return type_convert<T>(std::sinhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float sinh<float>(float x)
{
return std::sinhf(x);
};
template <>
CK_TILE_HOST double sinh<double>(double x)
{
return std::sinh(x);
};
template <typename T>
CK_TILE_HOST T ceil(T x)
{
return type_convert<T>(std::ceilf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float ceil<float>(float x)
{
return std::ceilf(x);
};
template <>
CK_TILE_HOST double ceil<double>(double x)
{
return std::ceil(x);
};
template <typename T>
CK_TILE_HOST T cosh(T x)
{
return type_convert<T>(std::coshf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float cosh<float>(float x)
{
return std::coshf(x);
};
template <>
CK_TILE_HOST double cosh<double>(double x)
{
return std::cosh(x);
};
template <typename T>
CK_TILE_HOST T floor(T x)
{
return type_convert<T>(std::floorf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float floor<float>(float x)
{
return std::floorf(x);
};
template <>
CK_TILE_HOST double floor<double>(double x)
{
return std::floor(x);
};
template <typename T>
CK_TILE_HOST T rcp(T x)
{
return type_convert<T>(1.f / type_convert<float>(x));
};
template <typename T>
CK_TILE_HOST T exp(T x)
{
return type_convert<T>(std::expf(type_convert<float>(x)));
}
template <>
CK_TILE_HOST float exp<float>(float x)
{
return std::expf(x);
}
template <>
CK_TILE_HOST double exp<double>(double x)
{
return std::exp(x);
}
template <typename T>
CK_TILE_HOST T log(T x)
{
return type_convert<T>(std::logf(type_convert<float>(x)));
}
template <>
CK_TILE_HOST float log<float>(float x)
{
return std::logf(x);
}
template <>
CK_TILE_HOST double log<double>(double x)
{
return std::log(x);
}
template <typename T>
CK_TILE_HOST T pow(T x, T gamma)
{
return type_convert<T>(std::powf(type_convert<float>(x), type_convert<float>(gamma)));
}
template <>
CK_TILE_HOST float pow<float>(float x, float gamma)
{
return std::powf(x, gamma);
}
template <>
CK_TILE_HOST double pow<double>(double x, double gamma)
{
return std::pow(x, gamma);
}
template <typename T>
CK_TILE_HOST T expm1(T x)
{
return type_convert<T>(std::expm1f(type_convert<float>(x)));
}
template <>
CK_TILE_HOST float expm1<float>(float x)
{
return std::expm1f(x);
}
template <>
CK_TILE_HOST double expm1<double>(double x)
{
return std::expm1(x);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
CK_TILE_DEVICE float abs(float x)
{
union
{
float f32;
uint32_t u32;
} y;
y.f32 = x;
y.u32 = y.u32 & 0x7fffffff;
return y.f32;
};
CK_TILE_DEVICE double abs(double x) { return ::abs(x); };
CK_TILE_DEVICE int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
CK_TILE_DEVICE int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE int4_t abs(int4_t x)
{
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
};
#endif
CK_TILE_DEVICE fp16_t abs(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
return abs_x;
};
CK_TILE_DEVICE bool isnan(float x) { return ::isnan(x); };
CK_TILE_DEVICE bool isnan(double x) { return ::isnan(x); };
CK_TILE_DEVICE bool isnan(int8_t x)
{
(void)x;
return false;
};
CK_TILE_DEVICE bool isnan(int32_t x)
{
(void)x;
return false;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE bool isnan(int4_t x)
{
(void)x;
return false;
};
#endif
CK_TILE_DEVICE bool isnan(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
CK_TILE_DEVICE fp16_t sqrt(fp16_t x)
{
return static_cast<fp16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
CK_TILE_DEVICE float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
template <typename T>
CK_TILE_DEVICE T tanh(T x)
{
return type_convert<T>(::tanhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float tanh<float>(float x)
{
return ::tanhf(x);
};
template <>
CK_TILE_DEVICE double tanh<double>(double x)
{
return ::tanh(x);
};
template <typename T>
CK_TILE_DEVICE T acos(T x)
{
return type_convert<T>(::acosf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float acos<float>(float x)
{
return ::acosf(x);
};
template <>
CK_TILE_DEVICE double acos<double>(double x)
{
return ::acos(x);
};
template <typename T>
CK_TILE_DEVICE T neg(T x)
{
return type_convert<T>(-(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float neg<float>(float x)
{
return -x;
};
template <>
CK_TILE_DEVICE double neg<double>(double x)
{
return -x;
};
template <>
CK_TILE_DEVICE int32_t neg<int32_t>(int32_t x)
{
return -x;
};
template <>
CK_TILE_DEVICE int8_t neg<int8_t>(int8_t x)
{
return -x;
};
template <>
CK_TILE_DEVICE fp16_t neg<fp16_t>(fp16_t x)
{
return -x;
};
template <typename T>
CK_TILE_DEVICE T atan(T x)
{
return type_convert<T>(::atanf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float atan<float>(float x)
{
return ::atanf(x);
};
template <>
CK_TILE_DEVICE double atan<double>(double x)
{
return ::atan(x);
};
template <typename T>
CK_TILE_DEVICE T sin(T x)
{
return type_convert<T>(::sinf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float sin<float>(float x)
{
return ::sinf(x);
};
template <>
CK_TILE_DEVICE double sin<double>(double x)
{
return ::sin(x);
};
template <>
CK_TILE_DEVICE fp16_t sin<fp16_t>(fp16_t x)
{
return __ocml_sin_f16(x);
};
template <typename T>
CK_TILE_DEVICE T asin(T x)
{
return type_convert<T>(::asinf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float asin<float>(float x)
{
return ::asinf(x);
};
template <>
CK_TILE_DEVICE double asin<double>(double x)
{
return ::asin(x);
};
template <typename T>
CK_TILE_DEVICE T asinh(T x)
{
return type_convert<T>(::asinhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float asinh<float>(float x)
{
return ::asinhf(x);
};
template <>
CK_TILE_DEVICE double asinh<double>(double x)
{
return ::asinh(x);
};
template <typename T>
CK_TILE_DEVICE T acosh(T x)
{
return type_convert<T>(::acoshf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float acosh<float>(float x)
{
return ::acoshf(x);
};
template <>
CK_TILE_DEVICE double acosh<double>(double x)
{
return ::acosh(x);
};
template <typename T>
CK_TILE_DEVICE T tan(T x)
{
return type_convert<T>(::tanf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float tan<float>(float x)
{
return ::tanf(x);
};
template <>
CK_TILE_DEVICE double tan<double>(double x)
{
return ::tan(x);
};
template <typename T>
CK_TILE_DEVICE T atanh(T x)
{
return type_convert<T>(::atanhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float atanh<float>(float x)
{
return ::atanhf(x);
};
template <>
CK_TILE_DEVICE double atanh<double>(double x)
{
return ::atanh(x);
};
template <typename T>
CK_TILE_DEVICE T sinh(T x)
{
return type_convert<T>(::sinhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float sinh<float>(float x)
{
return ::sinhf(x);
};
template <>
CK_TILE_DEVICE double sinh<double>(double x)
{
return ::sinh(x);
};
template <typename T>
CK_TILE_DEVICE T ceil(T x)
{
return type_convert<T>(::ceilf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float ceil<float>(float x)
{
return ::ceilf(x);
};
template <>
CK_TILE_DEVICE double ceil<double>(double x)
{
return ::ceil(x);
};
template <>
CK_TILE_DEVICE fp16_t ceil<fp16_t>(fp16_t x)
{
return __ocml_ceil_f16(x);
};
template <typename T>
CK_TILE_DEVICE T cosh(T x)
{
return type_convert<T>(::coshf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float cosh<float>(float x)
{
return ::coshf(x);
};
template <>
CK_TILE_DEVICE double cosh<double>(double x)
{
return ::cosh(x);
};
template <typename T>
CK_TILE_DEVICE T floor(T x)
{
return type_convert<T>(::floorf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float floor<float>(float x)
{
return ::floorf(x);
};
template <>
CK_TILE_DEVICE double floor<double>(double x)
{
return ::floor(x);
};
template <>
CK_TILE_DEVICE fp16_t floor<fp16_t>(fp16_t x)
{
return __ocml_floor_f16(x);
};
template <typename T>
CK_TILE_DEVICE T rcp(T x)
{
#if !CK_TILE_WORKAROUND_SWDEV_383542
return __frcp_rn(x);
#else
// return __ocml_native_recip_f32(x);
return __builtin_amdgcn_rcpf(x);
#endif
};
template <typename T>
CK_TILE_DEVICE T exp(T x)
{
return type_convert<T>(__ocml_exp_f32(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE fp16_t exp<fp16_t>(fp16_t x)
{
return __ocml_exp_f16(x);
};
template <>
CK_TILE_DEVICE float exp<float>(float x)
{
return __ocml_exp_f32(x);
};
template <>
CK_TILE_DEVICE double exp<double>(double x)
{
return exp(x);
};
template <typename T>
CK_TILE_DEVICE T log(T x)
{
return type_convert<T>(__logf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE fp16_t log<fp16_t>(fp16_t x)
{
return __ocml_log_f16(x);
};
template <>
CK_TILE_DEVICE float log<float>(float x)
{
return __logf(x);
};
template <>
CK_TILE_DEVICE double log<double>(double x)
{
return log(x);
};
template <typename T>
CK_TILE_DEVICE T pow(T x, T gamma)
{
return type_convert<T>(powf(type_convert<float>(x), type_convert<float>(gamma)));
};
template <>
CK_TILE_DEVICE float pow<float>(float x, float gamma)
{
return powf(x, gamma);
};
template <>
CK_TILE_DEVICE double pow<double>(double x, double gamma)
{
return pow(x, gamma);
};
template <typename T>
CK_TILE_DEVICE T expm1(T x)
{
return type_convert<T>(expm1f(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float expm1<float>(float x)
{
return expm1f(x);
};
template <>
CK_TILE_DEVICE double expm1<double>(double x)
{
return expm1(x);
};
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment