Commit 7971bb5b authored by carlushuang's avatar carlushuang
Browse files

add test for scatter/gather

parent d311c953
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core/algorithm/cluster_descriptor.hpp" #include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/arch.hpp"
......
...@@ -23,6 +23,7 @@ enum struct coord_transform_enum ...@@ -23,6 +23,7 @@ enum struct coord_transform_enum
replicate, replicate,
xor_t, xor_t,
offset, offset,
indexing,
}; };
template <index_t NDimLow, index_t NDimUp> template <index_t NDimLow, index_t NDimUp>
...@@ -1549,6 +1550,184 @@ struct offset : public base_transform<1, 1> ...@@ -1549,6 +1550,184 @@ struct offset : public base_transform<1, 1>
} }
}; };
#if 0
template <typename UpLength,
typename Index>
struct indexing : public base_transform<1, 1>
{
static constexpr index_t NDimUp = 1;
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
using Indices = decltype(make_tuple(Index{}));
UpLengths up_lengths_;
Indices indices_;
CK_TILE_HOST_DEVICE constexpr indexing() = default;
CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
const Index& index)
: up_lengths_{make_tuple(up_length)}, indices_{make_tuple(indices)}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::indexing;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& /*idx_up*/) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = indices_[number<0>{}];
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& /*idx_diff_up*/,
LowIdx& /*idx_low*/,
const UpIdx& /*idx_up*/) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_diff_low(number<0>{}) = 0;
//static_for<0, NDimUp, 1>{}(
// [&](auto i) { idx_diff_low(number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
// idx_low += idx_up;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<Indices>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("indices_: ");
print(indices_);
printf("}");
}
};
#endif
template <typename UpLength, typename IndexingAdaptor>
struct indexing : public base_transform<1, 1>
{
static constexpr index_t NDimUp = 1;
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
UpLengths up_lengths_;
IndexingAdaptor iadaptor_;
CK_TILE_HOST_DEVICE constexpr indexing() = default;
CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
const IndexingAdaptor& iadaptor)
: up_lengths_{make_tuple(up_length)}, iadaptor_{iadaptor}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::indexing;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
iadaptor_.calculate_lower_index(idx_low, idx_up);
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& idx_up) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
iadaptor_.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up);
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
IndexingAdaptor::is_known_at_compile_time();
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
printf("}");
}
};
//******************************************************************************************************* //*******************************************************************************************************
template <typename LowLength> template <typename LowLength>
...@@ -1670,3 +1849,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le ...@@ -1670,3 +1849,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
} }
} // namespace ck_tile } // namespace ck_tile
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
namespace ck_tile {
template <typename UpLength, typename Indices>
CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength& up_lengths,
const Indices& indices)
{
// by default we use the simplest one
return indexing<UpLength, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>>{
up_lengths, indexing_adaptor_onshot_cached<remove_cvref_t<Indices>>{indices}};
}
template <typename UpLength, typename IndexingAdaptor>
CK_TILE_HOST_DEVICE constexpr auto
make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingAdaptor& iadaptor)
{
return indexing<UpLength, IndexingAdaptor>{up_lengths, iadaptor};
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
// pre-defined indexing adaptor used for indexing(scatter/gather)
// this version cache the index inside thread register(which is also prefered in real senario)
// however it's user's responsibility that each thread only provide one indexing, which means
// move coordinate will not change on this dim
template <typename IndexingType>
struct indexing_adaptor_onshot_cached
{
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached() = default;
CK_TILE_HOST_DEVICE constexpr indexing_adaptor_onshot_cached(const IndexingType& idx)
: cached_idx_(idx)
{
}
IndexingType cached_idx_;
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& /*idx_up*/) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = cached_idx_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& /*idx_low*/,
const UpIdx& /*idx_up*/) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<IndexingType>::value;
}
};
} // namespace ck_tile
...@@ -31,12 +31,16 @@ ...@@ -31,12 +31,16 @@
#define CK_TILE_HOST inline __host__ #define CK_TILE_HOST inline __host__
#define CK_TILE_DEVICE inline __device__ #define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__ #define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_HOST_EXTERN __host__
#define CK_TILE_DEVICE_EXTERN __device__ #define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else #else
#define CK_TILE_HOST inline #define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline #define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline #define CK_TILE_HOST_DEVICE inline
#define CK_TILE_HOST_EXTERN
#define CK_TILE_DEVICE_EXTERN #define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif #endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE #ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
...@@ -191,3 +195,8 @@ ...@@ -191,3 +195,8 @@
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA #ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1 #define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif #endif
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif
...@@ -41,9 +41,8 @@ struct scales ...@@ -41,9 +41,8 @@ struct scales
Scale lhs_; Scale lhs_;
}; };
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template <typename Scale> template <typename Scale>
__host__ __device__ scales(Scale)->scales<Scale>; CK_TILE_HOST_DEVICE_EXTERN scales(Scale)->scales<Scale>;
template <typename Left = void, typename Right = Left> template <typename Left = void, typename Right = Left>
struct plus struct plus
...@@ -66,8 +65,7 @@ struct plus<void, void> ...@@ -66,8 +65,7 @@ struct plus<void, void>
} }
}; };
/// FIXME: create macro to replace '__host__ __device__' and nothing more CK_TILE_HOST_DEVICE_EXTERN plus()->plus<void, void>;
__host__ __device__ plus()->plus<void, void>;
template <typename Left = void, typename Right = Left> template <typename Left = void, typename Right = Left>
struct minus struct minus
...@@ -90,8 +88,7 @@ struct minus<void, void> ...@@ -90,8 +88,7 @@ struct minus<void, void>
} }
}; };
/// FIXME: create macro to replace '__host__ __device__' and nothing more CK_TILE_HOST_DEVICE_EXTERN minus()->minus<void, void>;
__host__ __device__ minus()->minus<void, void>;
template <typename Left = void, typename Right = Left> template <typename Left = void, typename Right = Left>
struct multiplies struct multiplies
...@@ -114,8 +111,7 @@ struct multiplies<void, void> ...@@ -114,8 +111,7 @@ struct multiplies<void, void>
} }
}; };
/// FIXME: create macro to replace '__host__ __device__' and nothing more CK_TILE_HOST_DEVICE_EXTERN multiplies()->multiplies<void, void>;
__host__ __device__ multiplies()->multiplies<void, void>;
template <typename T> template <typename T>
struct maximize struct maximize
...@@ -345,8 +341,7 @@ struct equal<void, void> ...@@ -345,8 +341,7 @@ struct equal<void, void>
} }
}; };
/// FIXME: create macro to replace '__host__ __device__' and nothing more CK_TILE_HOST_DEVICE_EXTERN equal()->equal<void, void>;
__host__ __device__ equal()->equal<void, void>;
template <> template <>
struct equal<float, float> struct equal<float, float>
...@@ -387,8 +382,7 @@ struct less<void, void> ...@@ -387,8 +382,7 @@ struct less<void, void>
} }
}; };
/// FIXME: create macro to replace '__host__ __device__' and nothing more CK_TILE_HOST_DEVICE_EXTERN less()->less<void, void>;
__host__ __device__ less()->less<void, void>;
template <typename Left = void, typename Right = Left> template <typename Left = void, typename Right = Left>
struct less_equal struct less_equal
...@@ -411,8 +405,7 @@ struct less_equal<void, void> ...@@ -411,8 +405,7 @@ struct less_equal<void, void>
} }
}; };
/// FIXME: create macro to replace '__host__ __device__' and nothing more CK_TILE_HOST_DEVICE_EXTERN less_equal()->less_equal<void, void>;
__host__ __device__ less_equal()->less_equal<void, void>;
template <> template <>
struct less_equal<float, float> struct less_equal<float, float>
...@@ -488,19 +481,19 @@ template <typename T = double> ...@@ -488,19 +481,19 @@ template <typename T = double>
constexpr T log2e_v = log2e<T>::value; constexpr T log2e_v = log2e<T>::value;
// math // math
CK_TILE_HOST_DEVICE // CK_TILE_HOST_DEVICE
float abs(const float& x) // float abs(const float& x)
{ // {
union // union
{ // {
float f32; // float f32;
uint32_t u32; // uint32_t u32;
} y; // } y;
y.f32 = x; // y.f32 = x;
y.u32 = y.u32 & 0x7fffffff; // y.u32 = y.u32 & 0x7fffffff;
return y.f32; // return y.f32;
} // }
#if 0
CK_TILE_HOST_DEVICE CK_TILE_HOST_DEVICE
bool isnan(const float& x) bool isnan(const float& x)
{ {
...@@ -523,18 +516,20 @@ float exp(float x) { return __ocml_exp_f32(x); }; ...@@ -523,18 +516,20 @@ float exp(float x) { return __ocml_exp_f32(x); };
CK_TILE_HOST CK_TILE_HOST
float exp(float x) { return std::expf(x); } float exp(float x) { return std::expf(x); }
#endif
CK_TILE_DEVICE CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); }; float exp2(float x) { return exp2f(x); };
CK_TILE_HOST CK_TILE_HOST
float exp2(float x) { return std::exp2f(x); }; float exp2(float x) { return std::exp2f(x); };
#if 0
CK_TILE_DEVICE CK_TILE_DEVICE
float log(float x) { return __logf(x); }; float log(float x) { return __logf(x); };
CK_TILE_HOST CK_TILE_HOST
float log(float x) { return std::logf(x); }; float log(float x) { return std::logf(x); };
#endif
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
{ {
...@@ -547,4 +542,932 @@ CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) ...@@ -547,4 +542,932 @@ CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
return (x > y ? (x - y) : (y - x)) + acc; return (x > y ? (x - y) : (y - x)) + acc;
} }
///////////////////////////////////////////////////////////////
} // namespace ck_tile
// blow function need data type pre-defined
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
namespace ck_tile {
#if CK_TILE_WORKAROUND_SWDEV_383542
extern "C" CK_TILE_DEVICE float __ocml_native_recip_f32(float);
#endif
// math functions for the host, some are implemented by calling C++ std functions
CK_TILE_HOST float abs(float x) { return std::abs(x); };
CK_TILE_HOST double abs(double x) { return std::abs(x); };
CK_TILE_HOST int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
CK_TILE_HOST int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
CK_TILE_HOST fp16_t abs(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
return abs_x;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST int4_t abs(int4_t x)
{
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
}
#endif
CK_TILE_HOST bool isnan(float x) { return std::isnan(x); };
CK_TILE_HOST bool isnan(double x) { return std::isnan(x); };
CK_TILE_HOST bool isnan(int8_t x)
{
(void)x;
return false;
};
CK_TILE_HOST bool isnan(int32_t x)
{
(void)x;
return false;
};
CK_TILE_HOST bool isnan(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST bool isnan(int4_t x)
{
(void)x;
return false;
};
#endif
CK_TILE_HOST fp16_t sqrt(fp16_t x)
{
return static_cast<fp16_t>(std::sqrt(static_cast<float>(x)));
};
CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); };
CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); };
template <typename T>
CK_TILE_HOST T tanh(T x)
{
return type_convert<T>(std::tanhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float tanh<float>(float x)
{
return std::tanhf(x);
};
template <>
CK_TILE_HOST double tanh<double>(double x)
{
return std::tanh(x);
};
template <typename T>
CK_TILE_HOST T acos(T x)
{
return type_convert<T>(std::acosf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float acos<float>(float x)
{
return std::acosf(x);
};
template <>
CK_TILE_HOST double acos<double>(double x)
{
return std::acos(x);
};
template <typename T>
CK_TILE_HOST T neg(T x)
{
return type_convert<T>(-(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float neg<float>(float x)
{
return -x;
};
template <>
CK_TILE_HOST double neg<double>(double x)
{
return -x;
};
template <>
CK_TILE_HOST int32_t neg<int32_t>(int32_t x)
{
return -x;
};
template <>
CK_TILE_HOST int8_t neg<int8_t>(int8_t x)
{
return -x;
};
template <typename T>
CK_TILE_HOST T atan(T x)
{
return type_convert<T>(std::atanf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float atan<float>(float x)
{
return std::atanf(x);
};
template <>
CK_TILE_HOST double atan<double>(double x)
{
return std::atan(x);
};
template <typename T>
CK_TILE_HOST T sin(T x)
{
return type_convert<T>(std::sinf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float sin<float>(float x)
{
return std::sinf(x);
};
template <>
CK_TILE_HOST double sin<double>(double x)
{
return std::sin(x);
};
template <typename T>
CK_TILE_HOST T asin(T x)
{
return type_convert<T>(std::asinf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float asin<float>(float x)
{
return std::asinf(x);
};
template <>
CK_TILE_HOST double asin<double>(double x)
{
return std::asin(x);
};
template <typename T>
CK_TILE_HOST T asinh(T x)
{
return type_convert<T>(std::asinhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float asinh<float>(float x)
{
return std::asinhf(x);
};
template <>
CK_TILE_HOST double asinh<double>(double x)
{
return std::asinh(x);
};
template <typename T>
CK_TILE_HOST T cos(T x)
{
return type_convert<T>(std::cosf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float cos<float>(float x)
{
return std::cosf(x);
};
template <>
CK_TILE_HOST double cos<double>(double x)
{
return std::cos(x);
};
template <typename T>
CK_TILE_HOST T acosh(T x)
{
return type_convert<T>(std::acoshf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float acosh<float>(float x)
{
return std::acoshf(x);
};
template <>
CK_TILE_HOST double acosh<double>(double x)
{
return std::acosh(x);
};
template <typename T>
CK_TILE_HOST T tan(T x)
{
return type_convert<T>(std::tanf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float tan<float>(float x)
{
return std::tanf(x);
};
template <>
CK_TILE_HOST double tan<double>(double x)
{
return std::tan(x);
};
template <typename T>
CK_TILE_HOST T atanh(T x)
{
return type_convert<T>(std::atanhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float atanh<float>(float x)
{
return std::atanhf(x);
};
template <>
CK_TILE_HOST double atanh<double>(double x)
{
return std::atanh(x);
};
template <typename T>
CK_TILE_HOST T sinh(T x)
{
return type_convert<T>(std::sinhf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float sinh<float>(float x)
{
return std::sinhf(x);
};
template <>
CK_TILE_HOST double sinh<double>(double x)
{
return std::sinh(x);
};
template <typename T>
CK_TILE_HOST T ceil(T x)
{
return type_convert<T>(std::ceilf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float ceil<float>(float x)
{
return std::ceilf(x);
};
template <>
CK_TILE_HOST double ceil<double>(double x)
{
return std::ceil(x);
};
template <typename T>
CK_TILE_HOST T cosh(T x)
{
return type_convert<T>(std::coshf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float cosh<float>(float x)
{
return std::coshf(x);
};
template <>
CK_TILE_HOST double cosh<double>(double x)
{
return std::cosh(x);
};
template <typename T>
CK_TILE_HOST T floor(T x)
{
return type_convert<T>(std::floorf(type_convert<float>(x)));
};
template <>
CK_TILE_HOST float floor<float>(float x)
{
return std::floorf(x);
};
template <>
CK_TILE_HOST double floor<double>(double x)
{
return std::floor(x);
};
template <typename T>
CK_TILE_HOST T rcp(T x)
{
return type_convert<T>(1.f / type_convert<float>(x));
};
template <typename T>
CK_TILE_HOST T exp(T x)
{
return type_convert<T>(std::expf(type_convert<float>(x)));
}
template <>
CK_TILE_HOST float exp<float>(float x)
{
return std::expf(x);
}
template <>
CK_TILE_HOST double exp<double>(double x)
{
return std::exp(x);
}
template <typename T>
CK_TILE_HOST T log(T x)
{
return type_convert<T>(std::logf(type_convert<float>(x)));
}
template <>
CK_TILE_HOST float log<float>(float x)
{
return std::logf(x);
}
template <>
CK_TILE_HOST double log<double>(double x)
{
return std::log(x);
}
template <typename T>
CK_TILE_HOST T pow(T x, T gamma)
{
return type_convert<T>(std::powf(type_convert<float>(x), type_convert<float>(gamma)));
}
template <>
CK_TILE_HOST float pow<float>(float x, float gamma)
{
return std::powf(x, gamma);
}
template <>
CK_TILE_HOST double pow<double>(double x, double gamma)
{
return std::pow(x, gamma);
}
template <typename T>
CK_TILE_HOST T expm1(T x)
{
return type_convert<T>(std::expm1f(type_convert<float>(x)));
}
template <>
CK_TILE_HOST float expm1<float>(float x)
{
return std::expm1f(x);
}
template <>
CK_TILE_HOST double expm1<double>(double x)
{
return std::expm1(x);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
CK_TILE_DEVICE float abs(float x)
{
union
{
float f32;
uint32_t u32;
} y;
y.f32 = x;
y.u32 = y.u32 & 0x7fffffff;
return y.f32;
};
CK_TILE_DEVICE double abs(double x) { return ::abs(x); };
CK_TILE_DEVICE int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
CK_TILE_DEVICE int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE int4_t abs(int4_t x)
{
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
};
#endif
CK_TILE_DEVICE fp16_t abs(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
fp16_t abs_x = bit_cast<fp16_t>(abs_xx);
return abs_x;
};
CK_TILE_DEVICE bool isnan(float x) { return ::isnan(x); };
CK_TILE_DEVICE bool isnan(double x) { return ::isnan(x); };
CK_TILE_DEVICE bool isnan(int8_t x)
{
(void)x;
return false;
};
CK_TILE_DEVICE bool isnan(int32_t x)
{
(void)x;
return false;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE bool isnan(int4_t x)
{
(void)x;
return false;
};
#endif
CK_TILE_DEVICE bool isnan(fp16_t x)
{
uint16_t xx = bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
};
CK_TILE_DEVICE fp16_t sqrt(fp16_t x)
{
return static_cast<fp16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
CK_TILE_DEVICE float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
CK_TILE_DEVICE double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
template <typename T>
CK_TILE_DEVICE T tanh(T x)
{
return type_convert<T>(::tanhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float tanh<float>(float x)
{
return ::tanhf(x);
};
template <>
CK_TILE_DEVICE double tanh<double>(double x)
{
return ::tanh(x);
};
template <typename T>
CK_TILE_DEVICE T acos(T x)
{
return type_convert<T>(::acosf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float acos<float>(float x)
{
return ::acosf(x);
};
template <>
CK_TILE_DEVICE double acos<double>(double x)
{
return ::acos(x);
};
template <typename T>
CK_TILE_DEVICE T neg(T x)
{
return type_convert<T>(-(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float neg<float>(float x)
{
return -x;
};
template <>
CK_TILE_DEVICE double neg<double>(double x)
{
return -x;
};
template <>
CK_TILE_DEVICE int32_t neg<int32_t>(int32_t x)
{
return -x;
};
template <>
CK_TILE_DEVICE int8_t neg<int8_t>(int8_t x)
{
return -x;
};
template <>
CK_TILE_DEVICE fp16_t neg<fp16_t>(fp16_t x)
{
return __hneg(x);
};
template <typename T>
CK_TILE_DEVICE T atan(T x)
{
return type_convert<T>(::atanf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float atan<float>(float x)
{
return ::atanf(x);
};
template <>
CK_TILE_DEVICE double atan<double>(double x)
{
return ::atan(x);
};
template <typename T>
CK_TILE_DEVICE T sin(T x)
{
return type_convert<T>(::sinf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float sin<float>(float x)
{
return ::sinf(x);
};
template <>
CK_TILE_DEVICE double sin<double>(double x)
{
return ::sin(x);
};
template <>
CK_TILE_DEVICE fp16_t sin<fp16_t>(fp16_t x)
{
return ::hsin(x);
};
template <typename T>
CK_TILE_DEVICE T asin(T x)
{
return type_convert<T>(::asinf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float asin<float>(float x)
{
return ::asinf(x);
};
template <>
CK_TILE_DEVICE double asin<double>(double x)
{
return ::asin(x);
};
template <typename T>
CK_TILE_DEVICE T asinh(T x)
{
return type_convert<T>(::asinhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float asinh<float>(float x)
{
return ::asinhf(x);
};
template <>
CK_TILE_DEVICE double asinh<double>(double x)
{
return ::asinh(x);
};
template <typename T>
CK_TILE_DEVICE T acosh(T x)
{
return type_convert<T>(::acoshf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float acosh<float>(float x)
{
return ::acoshf(x);
};
template <>
CK_TILE_DEVICE double acosh<double>(double x)
{
return ::acosh(x);
};
template <typename T>
CK_TILE_DEVICE T tan(T x)
{
return type_convert<T>(::tanf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float tan<float>(float x)
{
return ::tanf(x);
};
template <>
CK_TILE_DEVICE double tan<double>(double x)
{
return ::tan(x);
};
template <typename T>
CK_TILE_DEVICE T atanh(T x)
{
return type_convert<T>(::atanhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float atanh<float>(float x)
{
return ::atanhf(x);
};
template <>
CK_TILE_DEVICE double atanh<double>(double x)
{
return ::atanh(x);
};
template <typename T>
CK_TILE_DEVICE T sinh(T x)
{
return type_convert<T>(::sinhf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float sinh<float>(float x)
{
return ::sinhf(x);
};
template <>
CK_TILE_DEVICE double sinh<double>(double x)
{
return ::sinh(x);
};
template <typename T>
CK_TILE_DEVICE T ceil(T x)
{
return type_convert<T>(::ceilf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float ceil<float>(float x)
{
return ::ceilf(x);
};
template <>
CK_TILE_DEVICE double ceil<double>(double x)
{
return ::ceil(x);
};
template <>
CK_TILE_DEVICE fp16_t ceil<fp16_t>(fp16_t x)
{
return ::hceil(x);
};
template <typename T>
CK_TILE_DEVICE T cosh(T x)
{
return type_convert<T>(::coshf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float cosh<float>(float x)
{
return ::coshf(x);
};
template <>
CK_TILE_DEVICE double cosh<double>(double x)
{
return ::cosh(x);
};
template <typename T>
CK_TILE_DEVICE T floor(T x)
{
return type_convert<T>(::floorf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float floor<float>(float x)
{
return ::floorf(x);
};
template <>
CK_TILE_DEVICE double floor<double>(double x)
{
return ::floor(x);
};
template <>
CK_TILE_DEVICE fp16_t floor<fp16_t>(fp16_t x)
{
return ::hfloor(x);
};
template <typename T>
CK_TILE_DEVICE T rcp(T x)
{
#if !CK_TILE_WORKAROUND_SWDEV_383542
return __frcp_rn(x);
#else
return __ocml_native_recip_f32(x);
#endif
};
template <typename T>
CK_TILE_DEVICE T exp(T x)
{
return type_convert<T>(__ocml_exp_f32(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE fp16_t exp<fp16_t>(fp16_t x)
{
return hexp(x);
};
template <>
CK_TILE_DEVICE float exp<float>(float x)
{
return __ocml_exp_f32(x);
};
template <>
CK_TILE_DEVICE double exp<double>(double x)
{
return exp(x);
};
template <typename T>
CK_TILE_DEVICE T log(T x)
{
return type_convert<T>(__logf(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE fp16_t log<fp16_t>(fp16_t x)
{
return hlog(x);
};
template <>
CK_TILE_DEVICE float log<float>(float x)
{
return __logf(x);
};
template <>
CK_TILE_DEVICE double log<double>(double x)
{
return log(x);
};
template <typename T>
CK_TILE_DEVICE T pow(T x, T gamma)
{
return type_convert<T>(powf(type_convert<float>(x), type_convert<float>(gamma)));
};
template <>
CK_TILE_DEVICE float pow<float>(float x, float gamma)
{
return powf(x, gamma);
};
template <>
CK_TILE_DEVICE double pow<double>(double x, double gamma)
{
return pow(x, gamma);
};
template <typename T>
CK_TILE_DEVICE T expm1(T x)
{
return type_convert<T>(expm1f(type_convert<float>(x)));
};
template <>
CK_TILE_DEVICE float expm1<float>(float x)
{
return expm1f(x);
};
template <>
CK_TILE_DEVICE double expm1<double>(double x)
{
return expm1(x);
};
} // namespace ck_tile } // namespace ck_tile
...@@ -17,6 +17,14 @@ ...@@ -17,6 +17,14 @@
namespace ck_tile { namespace ck_tile {
namespace detail {
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
{
return Distribution::_get_partition_index();
}
} // namespace detail
// distributed span // distributed span
template <index_t... PartialHsLengths> template <index_t... PartialHsLengths>
struct tile_distributed_span struct tile_distributed_span
...@@ -83,6 +91,21 @@ struct tile_distribution ...@@ -83,6 +91,21 @@ struct tile_distribution
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; } CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; } CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
CK_TILE_HOST_DEVICE static auto _get_partition_index()
{
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
if constexpr(NDimP == 1)
{
return array<index_t, 1>{get_lane_id()};
}
else if constexpr(NDimP == 2)
{
return array<index_t, 2>{get_warp_id(), get_lane_id()};
}
}
CK_TILE_HOST_DEVICE static constexpr auto get_lengths() CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
{ {
#if 0 #if 0
...@@ -149,6 +172,16 @@ struct tile_distribution ...@@ -149,6 +172,16 @@ struct tile_distribution
} }
#endif #endif
template <typename PartitionIndex = decltype(_get_partition_index())>
CK_TILE_HOST_DEVICE auto
calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const
{
const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
const auto window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
return window_adaptor_thread_coord_tmp.get_bottom_index();
}
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans() CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
{ {
constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_; constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
...@@ -500,22 +533,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr ...@@ -500,22 +533,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
namespace detail { namespace detail {
template <typename Distribution>
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
{
// only support warp-tile and block-tile
static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!");
if constexpr(Distribution::NDimP == 1)
{
return array<index_t, 1>{get_lane_id()};
}
else if constexpr(Distribution::NDimP == 2)
{
return array<index_t, 2>{get_warp_id(), get_lane_id()};
}
}
template <typename, typename, typename, index_t> template <typename, typename, typename, index_t>
struct reverse_slice_sequence_impl; struct reverse_slice_sequence_impl;
......
...@@ -41,6 +41,7 @@ struct tile_window_with_static_distribution ...@@ -41,6 +41,7 @@ struct tile_window_with_static_distribution
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
static_assert(NumCoord == 1);
// TODO: check WindowLengths and StaticTileDistribution are consistent // TODO: check WindowLengths and StaticTileDistribution are consistent
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
namespace element_wise {
#if 0
struct PassThroughPack2
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
constexpr const static bool is_pack2_invocable = true;
};
#endif
struct PassThrough
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
const ck_tile::bf16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
const ck_tile::fp16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
{
y = type_convert<int32_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
{
y = type_convert<float>(x);
}
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
{
y = type_convert<int4_t>(x);
}
#endif
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
const ck_tile::fp8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
const ck_tile::bf8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
}
};
#if 0
struct UnaryConvert
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = type_convert<Y>(x);
}
};
struct ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(ck_tile::is_same<Y, ck_tile::bf16_t>::value, "Data type is not supported by this operation!");
// check X datatype
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, ck_tile::fp16_t>::value,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
}
};
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(ck_tile::is_same<Y, ck_tile::fp8_t>::value || ck_tile::is_same<Y, ck_tile::bf8_t>::value,
"Data type is not supported by this operation!");
// check X datatype
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, ck_tile::fp16_t>::value,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
}
};
struct ConvertF8RNE
{
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(ck_tile::is_same<Y, ck_tile::fp8_t>::value || ck_tile::is_same<Y, ck_tile::bf8_t>::value,
"Data type is not supported by this operation!");
// check X datatype
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, ck_tile::fp16_t>::value,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
}
};
#endif
struct Scale
{
CK_TILE_HOST_DEVICE Scale(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = ck_tile::type_convert<Y>(ck_tile::type_convert<float>(x) * scale_);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::fp16_t>(scale_) * x;
};
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
const float x_tmp = ck_tile::type_convert<float>(x);
const float y_tmp = scale_ * x_tmp;
y = ck_tile::type_convert<ck_tile::bf16_t>(y_tmp);
};
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = scale_ * x;
};
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
{
y = scale_ * x;
};
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = ck_tile::type_convert<int8_t>(scale_ * ck_tile::type_convert<float>(x));
};
float scale_;
};
struct ScaleAndResetNaNToMinusInfinity
{
CK_TILE_HOST_DEVICE ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = ck_tile::isnan(x) ? -ck_tile::NumericLimits<float>::Infinity() : scale_ * x;
};
float scale_;
};
struct UnaryDivide
{
CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = x / type_convert<T>(divider_);
};
int32_t divider_ = 1;
};
struct UnarySquare
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(is_same_v<T, float> || is_same_v<T, ck_tile::fp16_t> ||
is_same_v<T, double> || is_same_v<T, int32_t> || is_same_v<T, int8_t>
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| is_same_v<T, int4_t>
#endif
,
"Data type is not supported by this operation!");
y = x * x;
};
};
struct UnaryAbs
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::abs(x);
};
};
struct UnarySqrt
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value,
"Data type is not supported by this operation!");
y = ck_tile::sqrt(x);
};
};
struct Relu
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0;
}
template <>
CK_TILE_HOST_DEVICE void operator()(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
float x_f32 = ck_tile::type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = ck_tile::type_convert<ck_tile::bf16_t>(y_f32);
}
};
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct FastGelu
{
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
{
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = __ocml_exp_f32(u);
y = x * ck_tile::rcp(1.f + emu);
}
template <>
CK_TILE_HOST void operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y,
const ck_tile::fp16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y,
const ck_tile::fp16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y,
const ck_tile::bf16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y,
const ck_tile::bf16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::bf16_t>(y_f);
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::fp16_t(0.5) * x *
(ck_tile::fp16_t(1) + ck_tile::fp16_t(erf(float(0.70710678118f * x))));
}
};
struct Sigmoid
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = one / (one + ck_tile::exp(-x));
};
};
struct Silu
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(is_same_v<T, float> || is_same_v<T, double> ||
is_same_v<T, ck_tile::fp16_t> || is_same_v<T, int8_t> ||
is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x)));
};
};
struct TanH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::tanh(x);
};
};
struct ACos
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::acos(x);
};
};
struct Neg
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::neg(x);
};
};
struct ATan
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::atan(x);
};
};
struct Sin
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::sin(x);
};
};
struct ASinH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::asinh(x);
};
};
struct Cos
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::cos(x);
};
};
struct ACosH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::acosh(x);
};
};
struct Tan
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::tan(x);
};
};
struct ATanH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::atanh(x);
};
};
struct SinH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::sinh(x);
};
};
struct Ceil
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::ceil(x);
};
};
struct Exp
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::exp(x);
};
};
struct CosH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::cosh(x);
};
};
struct Floor
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::floor(x);
};
};
struct Log
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::log(x);
};
};
struct ASin
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::asin(x);
};
};
struct Rcp
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
y = ck_tile::rcp(x);
};
};
struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, double>::value ||
ck_tile::is_same<X, ck_tile::fp16_t>::value,
"Data type is not supported by this operation!");
static_assert(ck_tile::is_same<Y, float>::value || ck_tile::is_same<Y, double>::value ||
ck_tile::is_same<Y, ck_tile::fp16_t>::value,
"Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck_tile::exp(bx)));
};
const float beta_;
};
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = ck_tile::log(one + ck_tile::exp(x * casted_alpha)) / casted_alpha;
}
const float alpha_;
};
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck_tile::pow(shifted_scaled_x, casted_gamma);
}
const float alpha_;
const float beta_;
const float gamma_;
};
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck_tile::min(casted_beta, ck_tile::max(casted_alpha, x));
}
const float alpha_;
const float beta_;
};
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
const float alpha_;
};
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck_tile::expm1(x);
}
const float alpha_;
};
struct Logistic
{
Logistic(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = casted_alpha / (one + ck_tile::exp(-x) * casted_alpha);
}
const float alpha_;
};
struct ConvInvscale
{
CK_TILE_HOST_DEVICE
ConvInvscale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
e = type_convert<ck_tile::fp8_t>(c / scale_in_ / scale_wei_ / scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
struct ConvScale
{
CK_TILE_HOST_DEVICE
ConvScale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
e = type_convert<ck_tile::fp8_t>(c * scale_in_ * scale_wei_ * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
struct ConvScaleRelu
{
CK_TILE_HOST_DEVICE
ConvScaleRelu(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
float x;
Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_);
e = type_convert<ck_tile::fp8_t>(x * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
// support fastconvert of int8 to fp16
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
struct FastNumericArrayConverter
{
};
template <>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4>
{
using InputArray = vector_type<uint8_t, 4>;
using OutputArray = vector_type<ck_tile::fp16_t, 4>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
OutputArray Output;
uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
static constexpr uint32_t byte_selector_01 = 0x05010500;
static constexpr uint32_t byte_selector_23 = 0x05030502;
static constexpr uint32_t fp16_adder = 0x64646464;
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[0])
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[1])
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
template <index_t N>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
{
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using InputArray = vector_type<uint8_t, N>;
using OutputArray = vector_type<ck_tile::fp16_t, N>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4> converter;
OutputArray Output;
using Vec_InputArray = vector_type<uint8_t, 4>;
using Vec_OutputArray = vector_type<ck_tile::fp16_t, 4>;
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
static_for<0, N / VEC_WIDTH, 1>{}(
[&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); });
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
} // namespace element_wise
} // namespace ck_tile
...@@ -217,3 +217,5 @@ if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_ ...@@ -217,3 +217,5 @@ if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_
add_subdirectory(smfmac_op) add_subdirectory(smfmac_op)
endif() endif()
add_subdirectory(position_embedding) add_subdirectory(position_embedding)
add_subdirectory(scatter_gather)
add_test_executable(test_scatter_gather scatter_gather.cpp)
# target_compile_options(test_scatter_gather PRIVATE -v --save-temps -Wno-gnu-line-marker)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#ifndef TEST_SCATTER_GATHER_VERBOSE
#define TEST_SCATTER_GATHER_VERBOSE 0
#endif
#define HIP_CALL(call) \
do \
{ \
hipError_t err = call; \
if(err != hipSuccess) \
{ \
printf("[hiperror](%d) fail to call %s", static_cast<int>(err), #call); \
exit(0); \
} \
} while(0)
template <ck_tile::index_t ROW_TILE_SIZE = 8,
ck_tile::index_t COL_TILE_SIZE = 32 * 8,
ck_tile::index_t BLOCK_SIZE = 256,
ck_tile::index_t ALIGNMENT = 8,
typename INDEX_BUF_TYPE = ck_tile::index_t,
typename DATA_TYPE = ck_tile::fp16_t>
__global__ void row_scatter_gather(const INDEX_BUF_TYPE* src_row_idx_ptr,
const INDEX_BUF_TYPE* dst_row_idx_ptr,
const DATA_TYPE* src_ptr,
DATA_TYPE* dst_ptr,
ck_tile::index_t n_row_total,
ck_tile::index_t /*n_row_select*/,
ck_tile::index_t n_cols)
{
using namespace ck_tile;
// some constexpr vars
constexpr index_t vec = ALIGNMENT;
static_assert(COL_TILE_SIZE % vec == 0);
constexpr index_t col_lanes = COL_TILE_SIZE / vec;
constexpr index_t warp_size = ck_tile::get_warp_size();
static_assert(warp_size % col_lanes == 0);
constexpr index_t row_lanes = warp_size / col_lanes;
constexpr index_t num_warps = BLOCK_SIZE / warp_size;
static_assert(ROW_TILE_SIZE % (num_warps * row_lanes) == 0);
constexpr index_t row_repeat = ROW_TILE_SIZE / (num_warps * row_lanes);
static_assert(
row_repeat == 1,
"currently indexing not support(and would be not performant) if row_repeat has more");
// tile partitioner
index_t tile_col_idx = 0;
index_t tile_row_idx = blockIdx.x * ROW_TILE_SIZE;
// create our tild distribution, which tell us the location of different threads
constexpr auto src_dist = make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<row_repeat, num_warps, row_lanes>, sequence<col_lanes, vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
const auto coord = src_dist.calculate_index();
const auto row_coord = coord[number<0>{}] + tile_row_idx;
// load the current row index from the indexing buffer. we do not use ck_tile utility here
INDEX_BUF_TYPE src_row_id = src_row_idx_ptr[row_coord];
INDEX_BUF_TYPE dst_row_id = dst_row_idx_ptr[row_coord];
// printf("-- tid:%d, src_row_id:%d, dst_row_id:%d\n", static_cast<int>(threadIdx.x),
// static_cast<int>(src_row_id), static_cast<int>(dst_row_id));
const auto src_view =
make_naive_tensor_view<address_space_enum::global>(src_ptr,
make_tuple(n_row_total, n_cols),
make_tuple(n_cols, 1),
number<vec>{}, // alignement
number<1>{});
const auto src_gather_view = transform_tensor_view(
src_view,
make_tuple(make_indexing_transform(
n_row_total,
src_row_id), // here we replace row_idx which is loaded from another buffer
make_pass_through_transform(n_cols)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto src_tile = make_tile_window(src_gather_view,
make_tuple(number<ROW_TILE_SIZE>{}, number<COL_TILE_SIZE>{}),
{tile_row_idx, tile_col_idx},
src_dist);
const auto dst_view =
make_naive_tensor_view<address_space_enum::global>(dst_ptr,
make_tuple(n_row_total, n_cols),
make_tuple(n_cols, 1),
number<vec>{},
number<1>{});
const auto dst_scatter_view = transform_tensor_view(
dst_view,
make_tuple(make_indexing_transform(
n_row_total,
dst_row_id), // here we replace row_idx which is loaded from another buffer
make_pass_through_transform(n_cols)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto dst_tile = make_tile_window(dst_scatter_view,
make_tuple(number<ROW_TILE_SIZE>{}, number<COL_TILE_SIZE>{}),
{tile_row_idx, tile_col_idx},
src_dist /*reuse distribution*/);
// we finished descriptor construction and index calculation, now start load/store
for(auto i = 0; i < n_cols; i += COL_TILE_SIZE)
{
// note that scatter/gather are just the same API when doing load store as normal memory
// operation
auto data = load_tile(src_tile);
store_tile(dst_tile, data);
move_tile_window(src_tile, {0, COL_TILE_SIZE});
move_tile_window(dst_tile, {0, COL_TILE_SIZE});
}
}
union pixel
{
struct __attribute__((packed))
{
unsigned int r : 6;
unsigned int c : 10;
};
ushort data;
};
struct unique_linear_rand
{
unique_linear_rand(int capacity_) : capacity(capacity_) {}
std::unordered_set<int> set;
int gen()
{
if(static_cast<int>(set.size()) >= capacity)
{
printf("overflow, but will give you an number as well\n");
return std::rand() % capacity;
}
while(1)
{
int r = std::rand() % capacity;
if(set.count(r) == 1)
{
continue;
}
set.insert(r);
return r;
}
}
int capacity;
};
int main()
{
int row_total = 64;
int row_select = 8 * 2;
int col = 256 * 2;
using fp16_t = ck_tile::fp16_t;
constexpr int row_tile = 8;
constexpr int col_tile = 256;
fp16_t* src = reinterpret_cast<fp16_t*>(malloc(row_total * col * sizeof(fp16_t)));
for(int i_r = 0; i_r < row_total; i_r++)
{
for(int i_c = 0; i_c < col; i_c++)
{
int i = i_r * col + i_c;
pixel p;
p.r = i_r;
p.c = i_c;
ushort d = p.data;
src[i] = ck_tile::bit_cast<fp16_t>(d); // for simplicity, just cast
}
}
fp16_t* dst = reinterpret_cast<fp16_t*>(malloc(row_total * col * sizeof(fp16_t)));
int* src_idx = reinterpret_cast<int*>(malloc(row_select * sizeof(int)));
int* dst_idx = reinterpret_cast<int*>(malloc(row_select * sizeof(int)));
// std::srand(std::time(std::nullptr));
// std::srand(11935);
std::srand(std::time(nullptr));
auto src_gen = unique_linear_rand(row_total);
auto dst_gen = unique_linear_rand(row_total); // dst index must be unique. src is fine
for(int i_r = 0; i_r < row_select; i_r++)
{
src_idx[i_r] = src_gen.gen();
dst_idx[i_r] = dst_gen.gen();
}
void* dev_src;
void* dev_dst;
void* dev_src_idx;
void* dev_dst_idx;
HIP_CALL(hipMalloc(&dev_src, row_total * col * sizeof(fp16_t)));
HIP_CALL(hipMalloc(&dev_dst, row_total * col * sizeof(fp16_t)));
HIP_CALL(hipMalloc(&dev_src_idx, row_select * sizeof(int)));
HIP_CALL(hipMalloc(&dev_dst_idx, row_select * sizeof(int)));
HIP_CALL(hipMemcpy(dev_src, src, row_total * col * sizeof(fp16_t), hipMemcpyHostToDevice));
HIP_CALL(hipMemcpy(dev_src_idx, src_idx, row_select * sizeof(int), hipMemcpyHostToDevice));
HIP_CALL(hipMemcpy(dev_dst_idx, dst_idx, row_select * sizeof(int), hipMemcpyHostToDevice));
constexpr int bdim = 256;
int gdim = (row_select + row_tile - 1) / row_tile;
row_scatter_gather<row_tile, col_tile><<<gdim, bdim>>>(reinterpret_cast<int*>(dev_src_idx),
reinterpret_cast<int*>(dev_dst_idx),
reinterpret_cast<fp16_t*>(dev_src),
reinterpret_cast<fp16_t*>(dev_dst),
row_total,
row_select,
col);
HIP_CALL(hipMemcpy(dst, dev_dst, row_total * col * sizeof(fp16_t), hipMemcpyDeviceToHost));
#if TEST_SCATTER_GATHER_VERBOSE
printf("select row:");
for(int i_r = 0; i_r < row_select; i_r++)
{
printf("%d->%d->%d ", i_r, src_idx[i_r], dst_idx[i_r]);
}
printf("\n");
#endif
int err_cnt = 0;
for(int i_r = 0; i_r < row_select; i_r++)
{
for(int i_c = 0; i_c < col; i_c++)
{
int i = dst_idx[i_r] * col + i_c;
pixel p = ck_tile::bit_cast<pixel>(dst[i]);
bool is_ok = p.r == src_idx[i_r] && p.c == i_c;
if(!is_ok)
{
if(i_c == 0)
printf("(%d)pixel: %dx%d -> %d\n", i_r, p.r, p.c, dst_idx[i_r]);
err_cnt++;
}
}
}
#if TEST_SCATTER_GATHER_VERBOSE
printf("err:%d\n", err_cnt);
#endif
free(src);
free(dst);
free(src_idx);
free(dst_idx);
return err_cnt == 0 ? 0 : -1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment