"vscode:/vscode.git/clone" did not exist on "469a1965c249aa424fb36b1d9d2a6eaa480d7b95"
Commit 3c4fb1dd authored by Umang Yadav's avatar Umang Yadav
Browse files

Merge remote-tracking branch 'origin/develop' into migx_merge

parents 57cdd70b e8cddfdc
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
...@@ -93,15 +94,97 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); }; ...@@ -93,15 +94,97 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); }; static inline __host__ double sqrt(double x) { return std::sqrt(x); };
static inline __host__ half_t tanh(half_t x) template <typename T>
inline __host__ T tanh(T x)
{ {
return static_cast<half_t>(std::tanh(static_cast<float>(x))); return ck::type_convert<T>(std::tanhf(ck::type_convert<float>(x)));
}; };
static inline __host__ float tanh(float x) { return std::tanh(x); }; template <>
inline __host__ float tanh<float>(float x)
{
return std::tanhf(x);
};
template <>
inline __host__ double tanh<double>(double x)
{
return std::tanh(x);
};
template <typename T>
inline __host__ T exp(T x)
{
return ck::type_convert<T>(std::expf(ck::type_convert<float>(x)));
}
template <>
inline __host__ float exp<float>(float x)
{
return std::expf(x);
}
template <>
inline __host__ double exp<double>(double x)
{
return std::exp(x);
}
template <typename T>
inline __host__ T log(T x)
{
return ck::type_convert<T>(std::logf(ck::type_convert<float>(x)));
}
template <>
inline __host__ float log<float>(float x)
{
return std::logf(x);
}
template <>
inline __host__ double log<double>(double x)
{
return std::log(x);
}
template <typename T>
inline __host__ T pow(T x, T gamma)
{
return ck::type_convert<T>(
std::powf(ck::type_convert<float>(x), ck::type_convert<float>(gamma)));
}
template <>
inline __host__ float pow<float>(float x, float gamma)
{
return std::powf(x, gamma);
}
template <>
inline __host__ double pow<double>(double x, double gamma)
{
return std::pow(x, gamma);
}
template <typename T>
inline __host__ T expm1(T x)
{
return ck::type_convert<T>(std::expm1f(ck::type_convert<float>(x)));
}
template <>
inline __host__ float expm1<float>(float x)
{
return std::expm1f(x);
}
template <>
inline __host__ double expm1<double>(double x)
{
return std::expm1(x);
}
static inline __host__ double tanh(double x) { return std::tanh(x); };
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions // math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); }; static inline __device__ float abs(float x) { return ::abs(x); };
...@@ -182,14 +265,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); ...@@ -182,14 +265,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
static inline __device__ half_t tanh(half_t x) template <typename T>
inline __device__ T tanh(T x)
{
return ck::type_convert<T>(::tanhf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float tanh<float>(float x)
{
return ::tanhf(x);
};
template <>
inline __device__ double tanh<double>(double x)
{
return ::tanh(x);
};
template <typename T>
inline __device__ T exp(T x)
{
return ck::type_convert<T>(__expf(ck::type_convert<float>(x)));
};
template <>
inline __device__ half_t exp<half_t>(half_t x)
{
return hexp(x);
};
template <>
inline __device__ float exp<float>(float x)
{
return __expf(x);
};
template <>
inline __device__ double exp<double>(double x)
{
return exp(x);
};
template <typename T>
inline __device__ T log(T x)
{
return ck::type_convert<T>(__logf(ck::type_convert<float>(x)));
};
template <>
inline __device__ half_t log<half_t>(half_t x)
{ {
return static_cast<half_t>(::tanhf(static_cast<float>(x))); return hlog(x);
}; };
static inline __device__ float tanh(float x) { return ::tanhf(x); }; template <>
inline __device__ float log<float>(float x)
{
return __logf(x);
};
template <>
inline __device__ double log<double>(double x)
{
return log(x);
};
template <typename T>
inline __device__ T pow(T x, T gamma)
{
return ck::type_convert<T>(powf(ck::type_convert<float>(x), ck::type_convert<float>(gamma)));
};
template <>
inline __device__ float pow<float>(float x, float gamma)
{
return powf(x, gamma);
};
static inline __device__ double tanh(double x) { return ::tanh(x); }; template <>
inline __device__ double pow<double>(double x, double gamma)
{
return pow(x, gamma);
};
template <typename T>
inline __device__ T expm1(T x)
{
return ck::type_convert<T>(expm1f(ck::type_convert<float>(x)));
};
template <>
inline __device__ float expm1<float>(float x)
{
return expm1f(x);
};
template <>
inline __device__ double expm1<double>(double x)
{
return expm1(x);
};
} // namespace math } // namespace math
} // namespace ck } // namespace ck
...@@ -116,7 +116,15 @@ struct Max ...@@ -116,7 +116,15 @@ struct Max
template <typename T> template <typename T>
__host__ __device__ static constexpr T GetIdentityValue() __host__ __device__ static constexpr T GetIdentityValue()
{ {
return NumericLimits<T>::Lowest(); if constexpr(is_same_v<T, bhalf_t>)
{
float val = NumericLimits<float>::Lowest();
return type_convert<bhalf_t>(val);
}
else
{
return NumericLimits<T>::Lowest();
}
}; };
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
...@@ -138,6 +146,15 @@ struct Max ...@@ -138,6 +146,15 @@ struct Max
a = b; a = b;
} }
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
a = b;
}
template <typename T> template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
...@@ -152,6 +169,18 @@ struct Max ...@@ -152,6 +169,18 @@ struct Max
changed = true; changed = true;
} }
} }
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ < b_)
{
a = b;
changed = true;
}
}
}; };
struct Min struct Min
...@@ -159,6 +188,15 @@ struct Min ...@@ -159,6 +188,15 @@ struct Min
template <typename T> template <typename T>
__host__ __device__ static constexpr T GetIdentityValue() __host__ __device__ static constexpr T GetIdentityValue()
{ {
if constexpr(is_same_v<T, bhalf_t>)
{
float val = NumericLimits<float>::Max();
return type_convert<bhalf_t>(val);
}
else
{
return NumericLimits<T>::Max();
}
return NumericLimits<T>::Max(); return NumericLimits<T>::Max();
}; };
...@@ -181,6 +219,15 @@ struct Min ...@@ -181,6 +219,15 @@ struct Min
a = b; a = b;
} }
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
a = b;
}
template <typename T> template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
...@@ -195,6 +242,18 @@ struct Min ...@@ -195,6 +242,18 @@ struct Min
changed = true; changed = true;
} }
} }
__host__ __device__ inline constexpr void operator()(bhalf_t& a, bhalf_t b, bool& changed) const
{
float a_ = type_convert<float>(a);
float b_ = type_convert<float>(b);
if(a_ > b_)
{
a = b;
changed = true;
}
}
}; };
struct AMax struct AMax
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ck/utility/math_v2.hpp"
namespace ck { namespace ck {
......
...@@ -177,6 +177,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -177,6 +177,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
} }
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsTuple() { return true; }
}; };
template <> template <>
......
...@@ -9,8 +9,10 @@ ...@@ -9,8 +9,10 @@
namespace ck { namespace ck {
// Convert X to Y // Convert X to Y, both X and Y are non-const data types.
template <typename Y, typename X> template <typename Y,
typename X,
ck::enable_if_t<!(ck::is_const_v<Y> || ck::is_const_v<X>), bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
...@@ -18,6 +20,19 @@ __host__ __device__ constexpr Y type_convert(X x) ...@@ -18,6 +20,19 @@ __host__ __device__ constexpr Y type_convert(X x)
return static_cast<Y>(x); return static_cast<Y>(x);
} }
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<ck::is_const_v<Y> || ck::is_const_v<X>, bool> = false>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
using NonConstY = ck::remove_const_t<Y>;
using NonConstX = ck::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
}
// convert bfp16 to fp32 // convert bfp16 to fp32
template <> template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
...@@ -84,40 +99,180 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -84,40 +99,180 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float max_fp8 = 240.0f;
x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x);
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( return utils::
x, rng); cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
rng);
#endif
} }
// convert fp8 to fp32 // convert fp8 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<float, negative_zero_nan>(x); return utils::cast_from_f8<f8_t, float, negative_zero_nan>(x);
#endif
}
template <>
inline __host__ __device__ float2_t type_convert<float2_t, f8x2_t>(f8x2_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
const auto i16val = bit_cast<uint16_t>(x);
return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 0);
#else
constexpr bool negative_zero_nan = true;
const auto f8x2_v = vector_type<f8_t, 2>(x);
vector_type<float, 2> f32x2_v;
f32x2_v.template AsType<float>()(Number<0>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<0>{}]);
f32x2_v.template AsType<float>()(Number<1>{}) =
utils::cast_from_f8<f8_t, float, negative_zero_nan>(
f8x2_v.template AsType<f8_t>()[Number<1>{}]);
return f32x2_v.template AsType<float2_t>()[Number<0>{}];
#endif
}
template <>
inline __host__ __device__ half2_t type_convert<half2_t, float2_t>(float2_t x)
{
const vector_type<float, 2> f32x2_v(x);
const auto y = __builtin_amdgcn_cvt_pkrtz(f32x2_v.template AsType<float>()[Number<0>{}],
f32x2_v.template AsType<float>()[Number<1>{}]);
return bit_cast<half2_t>(y);
} }
// convert fp16 to fp8 // convert fp16 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0; constexpr uint32_t rng = 0;
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( return utils::
x, rng); cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
} }
// convert fp8 to fp16 // convert fp8 to fp16
template <> template <>
inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
#endif
}
// convert fp32 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
val.i32val = ival;
return val.i8val[0];
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert bf8 to fp32
template <>
inline __host__ __device__ float type_convert<float, bf8_t>(bf8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
// asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
return fval;
#else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<half_t, negative_zero_nan>(x); return utils::cast_from_f8<bf8_t, float, negative_zero_nan>(x);
#endif
}
// convert fp16 to bf8
template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return type_convert<bf8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
constexpr uint32_t rng = 0;
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert bf8 to fp16
template <>
inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
#endif
} }
// Declare a template function for bf16 conversion using RTN // Declare a template function for bf16 conversion using RTN
...@@ -185,28 +340,95 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); ...@@ -185,28 +340,95 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{ {
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42; return utils::
// as thread id is not available on host, use 0 for prn generation cast_to_f8<float, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(x,
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&x), x); rng);
return utils::cast_to_f8<float, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( #endif
x, rng);
} }
// convert fp16 to fp8 with stochastic rounding // convert fp16 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp32 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
float fval;
uint32_t i32val;
uint8_t i8val[4]; // not endian independent
} val;
val.fval = x;
uint32_t ival = 0;
ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
val.i32val = ival;
return val.i8val[0]; // little endian
#else
constexpr bool negative_zero_nan = true;
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
return utils::
cast_to_f8<float, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
}
// convert fp16 to bf8 with stochastic rounding
template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x));
#else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42; constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation // as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<size_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
return utils::cast_to_f8<half_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( return utils::
x, rng); cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
#endif
} }
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <type_traits>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
/**
* \brief Reference implementation for column to image.
*
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout.
* \tparam InDataType Input Data Type.
* \tparam OutDataType Output Data Type.
*/
template <ck::index_t NDimSpatial,
typename ImageLayout,
typename InDataType,
typename OutDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceColumnToImage : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
public:
Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
: input_{input},
output_{output},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
filter_spatial_lengths_{filter_spatial_lengths}
{
initOutputSpatialLengths();
}
const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_;
private:
void initOutputSpatialLengths()
{
constexpr auto input_offset_to_spatial = 3;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back(
(output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
in_right_pads_[i] - x_eff) /
conv_strides_[i] +
1);
}
}
};
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceColumnToImage::Argument;
float Run(const Argument& arg)
{
if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.input_.GetNumOfDimension() == 3))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
const index_t G = arg.output_.GetLengths()[0];
const index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n) {
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Wo + wo;
index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c)
{
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
{
float v_in =
ck::type_convert<float>(arg.input_(g, row, column));
float v_out = ck::type_convert<float>(arg.output_(g, n, c, wi));
arg.output_(g, n, c, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
}
}
}
};
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NDimSpatial == 2)
{
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto g, auto n) {
for(index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{
auto hi =
static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c)
{
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.output_.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.output_.GetLengths()[4])
{
float v_in =
ck::type_convert<float>(arg.input_(g, row, column));
float v_out = ck::type_convert<float>(
arg.output_(g, n, c, hi, wi));
arg.output_(g, n, c, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
}
}
}
}
}
};
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NDimSpatial == 3)
{
const index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto g, auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o)
{
for(index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{
auto di =
static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{
auto hi =
static_cast<ck::long_index_t>(ho *
arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y *
arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(
wo * arg.conv_strides_[2]) +
static_cast<ck::long_index_t>(
x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.output_.GetLengths()[3] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.output_.GetLengths()[4] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.output_.GetLengths()[5])
{
float v_in = ck::type_convert<float>(
arg.input_(g, row, column));
float v_out = ck::type_convert<float>(
arg.output_(g, n, c, di, hi, wi));
arg.output_(g, n, c, di, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
}
}
}
}
}
}
}
};
make_ParallelTensorFunctor(func, G, N)(std::thread::hardware_concurrency());
return 0;
}
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
using namespace tensor_layout::convolution;
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
std::is_same_v<ImageLayout, GNDHWC>))
{
return false;
}
if constexpr(!(NDimSpatial >= 1 && NDimSpatial <= 3))
{
return false;
}
return true;
}
bool IsSupportedArgument(const Argument& arg)
{
const ck::index_t G = arg.output_.GetLengths()[0];
const ck::index_t N = arg.output_.GetLengths()[1];
const ck::index_t C = arg.output_.GetLengths()[2];
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX =
C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
arg.input_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
arg.input_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
{
return false;
}
if(G != 1)
{
return false;
}
return true;
}
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
return Argument{input,
output,
filter_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceColumnToImage"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -23,6 +23,7 @@ template <ck::index_t NumDimM, ...@@ -23,6 +23,7 @@ template <ck::index_t NumDimM,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType, typename AccDataType,
typename ComputeDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false> ck::enable_if_t<NumDimM == 2 && NumDimN == 2 && NumDimK == 2, bool> = false>
...@@ -69,19 +70,24 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base ...@@ -69,19 +70,24 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base
{ {
for(ck::index_t k1 = 0; k1 < K1; ++k1) for(ck::index_t k1 = 0; k1 < K1; ++k1)
{ {
// Simulate the possible casting when ComputeDataType is different than the
// A/B data types
ComputeDataType v_a_compute_input =
ck::type_convert<ComputeDataType>(arg.a_ms_ks_(m0, m1, k0, k1));
ComputeDataType v_b_compute_input =
ck::type_convert<ComputeDataType>(arg.b_ns_ks_(n0, n1, k0, k1));
AccDataType v_a; AccDataType v_a;
AccDataType v_b; AccDataType v_b;
arg.a_element_op_( arg.a_element_op_(v_a, ck::type_convert<AccDataType>(v_a_compute_input));
v_a, ck::type_convert<const AccDataType>(arg.a_ms_ks_(m0, m1, k0, k1))); arg.b_element_op_(v_b, ck::type_convert<AccDataType>(v_b_compute_input));
arg.b_element_op_(
v_b, ck::type_convert<const AccDataType>(arg.b_ns_ks_(n0, n1, k0, k1)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
} }
arg.c_ms_ns_(m0, m1, n0, n1) = v_acc; arg.c_ms_ns_(m0, m1, n0, n1) = ck::type_convert<CDataType>(v_acc);
}; };
make_ParallelTensorFunctor(f_ms_ns, make_ParallelTensorFunctor(f_ms_ns,
......
...@@ -25,6 +25,8 @@ template <ck::index_t NDimSpatial, ...@@ -25,6 +25,8 @@ template <ck::index_t NDimSpatial,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
typename ComputeTypeA = OutDataType,
typename ComputeTypeB = InDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false> typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
...@@ -98,8 +100,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -98,8 +100,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_out; ComputeTypeA v_out;
float v_in; ComputeTypeB v_in;
arg.out_element_op_( arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo))); v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
...@@ -107,7 +109,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -107,7 +109,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
v_acc += v_out * v_in; v_acc += type_convert<float>(v_out) * type_convert<float>(v_in);
} }
} }
} }
...@@ -158,8 +160,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -158,8 +160,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{ {
float v_out; ComputeTypeA v_out;
float v_in; ComputeTypeB v_in;
arg.out_element_op_( arg.out_element_op_(
v_out, v_out,
...@@ -168,7 +170,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -168,7 +170,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
v_acc += v_out * v_in; v_acc += type_convert<float>(v_out) * type_convert<float>(v_in);
} }
} }
} }
...@@ -226,8 +228,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -226,8 +228,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5]) arg.input_.GetLengths()[5])
{ {
float v_out; ComputeTypeA v_out;
float v_in; ComputeTypeB v_in;
arg.out_element_op_(v_out, arg.out_element_op_(v_out,
ck::type_convert<float>( ck::type_convert<float>(
...@@ -237,7 +239,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -237,7 +239,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
ck::type_convert<float>( ck::type_convert<float>(
arg.input_(g, n, c, di, hi, wi))); arg.input_(g, n, c, di, hi, wi)));
v_acc += v_out * v_in; v_acc +=
type_convert<float>(v_out) * type_convert<float>(v_in);
} }
} }
} }
......
...@@ -3,12 +3,23 @@ ...@@ -3,12 +3,23 @@
#pragma once #pragma once
#include <iostream> #include <cmath>
#include <cstdlib>
#include <numeric>
#include <type_traits> #include <type_traits>
#include <sstream> #include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
...@@ -22,6 +33,7 @@ namespace host { ...@@ -22,6 +33,7 @@ namespace host {
// Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout // Supports both GNCHW/NGCHW as well as GNHWC/NHWGC physical layout
// as long as dimensions in tensor descriptor is in GNCHW order // as long as dimensions in tensor descriptor is in GNCHW order
// //
// @tparam NDimSpatial Number of spatial dimensions.
// @tparam InDataType Input tensor data type. // @tparam InDataType Input tensor data type.
// @tparam WeiDataType Weights tensor data type. // @tparam WeiDataType Weights tensor data type.
// @tparam OutDataType Output tensor data type. // @tparam OutDataType Output tensor data type.
...@@ -29,7 +41,9 @@ namespace host { ...@@ -29,7 +41,9 @@ namespace host {
// operation. // operation.
// @tparam WeiElementwiseOperation Functor for weights tensor elementwise // @tparam WeiElementwiseOperation Functor for weights tensor elementwise
// operation. // operation.
// @tparam NDimSpatial Number of spatial dimensions. // @tparam NumAElementwiseTensor Number of A elementwise tensors.
// @tparam NumBElementwiseTensor Number of B elementwise tensors.
// @tparam NumDElementwiseTensor Number of D elementwise tensors.
// //
// input descriptor in [G, N, C, Do, Ho, Wo] order // input descriptor in [G, N, C, Do, Ho, Wo] order
// weight descriptor in [G, K, C, Z, Y, X] order // weight descriptor in [G, K, C, Z, Y, X] order
...@@ -42,25 +56,35 @@ template <ck::index_t NDimSpatial, ...@@ -42,25 +56,35 @@ template <ck::index_t NDimSpatial,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
ck::index_t NumAElementwiseTensor = 0,
ck::index_t NumBElementwiseTensor = 0,
ck::index_t NumDElementwiseTensor = 0,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false> typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvFwd : public device::BaseOperator struct ReferenceConvFwd : public device::BaseOperator
{ {
// Argument // Argument
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
{ {
Argument(const Tensor<InDataType>& input, Argument(
const Tensor<WeiDataType>& weight, const Tensor<InDataType>& input,
Tensor<OutDataType>& output, const Tensor<WeiDataType>& weight,
std::vector<ck::index_t> conv_filter_strides, Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_left_pads,
InElementwiseOperation in_element_op, std::vector<ck::index_t> input_right_pads,
WeiElementwiseOperation wei_element_op, InElementwiseOperation in_element_op,
OutElementwiseOperation out_element_op) WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors,
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors,
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors)
: input_{input}, : input_{input},
weight_{weight}, weight_{weight},
output_{output}, output_{output},
elementwise_a_tensors_{elementwise_a_tensors},
elementwise_b_tensors_{elementwise_b_tensors},
elementwise_d_tensors_{elementwise_d_tensors},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations}, conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads}, in_left_pads_{input_left_pads},
...@@ -75,6 +99,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -75,6 +99,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const Tensor<WeiDataType>& weight_; const Tensor<WeiDataType>& weight_;
Tensor<OutDataType>& output_; Tensor<OutDataType>& output_;
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors_;
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_; std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<index_t> in_left_pads_;
...@@ -114,25 +142,43 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -114,25 +142,43 @@ struct ReferenceConvFwd : public device::BaseOperator
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_in; InDataType v_in;
float v_wei; WeiDataType v_wei;
arg.in_element_op_( ExecuteElementwiseOp(arg.in_element_op_,
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi))); arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
arg.wei_element_op_( v_in,
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, x))); arg.input_(g, n, c, wi),
g,
v_acc += v_in * v_wei; n,
c,
wi);
ExecuteElementwiseOp(arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
arg.weight_(g, k, c, x),
g,
k,
c,
x);
v_acc +=
ck::type_convert<float>(v_in) * ck::type_convert<float>(v_wei);
} }
} }
} }
OutDataType v_acc_converted = ck::type_convert<OutDataType>(v_acc);
float v_out; OutDataType& v_out = arg.output_(g, n, k, wo);
ExecuteElementwiseOp(arg.out_element_op_,
arg.out_element_op_(v_out, v_acc); arg.elementwise_d_tensors_,
Number<NumDElementwiseTensor>{},
arg.output_(g, n, k, wo) = ck::type_convert<OutDataType>(v_out); v_out,
v_acc_converted,
g,
n,
k,
wo);
}; };
make_ParallelTensorFunctor(func, make_ParallelTensorFunctor(func,
...@@ -169,26 +215,47 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -169,26 +215,47 @@ struct ReferenceConvFwd : public device::BaseOperator
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{ {
float v_in; InDataType v_in;
float v_wei; WeiDataType v_wei;
arg.in_element_op_( ExecuteElementwiseOp(arg.in_element_op_,
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi))); arg.elementwise_a_tensors_,
Number<NumAElementwiseTensor>{},
arg.wei_element_op_( v_in,
v_wei, ck::type_convert<float>(arg.weight_(g, k, c, y, x))); arg.input_(g, n, c, hi, wi),
g,
v_acc += v_in * v_wei; n,
c,
hi,
wi);
ExecuteElementwiseOp(arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
arg.weight_(g, k, c, y, x),
g,
k,
c,
y,
x);
v_acc += ck::type_convert<float>(v_in) *
ck::type_convert<float>(v_wei);
} }
} }
} }
} }
OutDataType v_acc_converted = ck::type_convert<OutDataType>(v_acc);
float v_out; OutDataType& v_out = arg.output_(g, n, k, ho, wo);
ExecuteElementwiseOp(arg.out_element_op_,
arg.out_element_op_(v_out, v_acc); arg.elementwise_d_tensors_,
Number<NumDElementwiseTensor>{},
arg.output_(g, n, k, ho, wo) = ck::type_convert<OutDataType>(v_out); v_out,
v_acc_converted,
g,
n,
k,
ho,
wo);
}; };
make_ParallelTensorFunctor(func, make_ParallelTensorFunctor(func,
...@@ -235,29 +302,51 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -235,29 +302,51 @@ struct ReferenceConvFwd : public device::BaseOperator
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5]) arg.input_.GetLengths()[5])
{ {
float v_in; InDataType v_in;
float v_wei; WeiDataType v_wei;
arg.in_element_op_(v_in, ExecuteElementwiseOp(arg.in_element_op_,
ck::type_convert<float>( arg.elementwise_a_tensors_,
arg.input_(g, n, c, di, hi, wi))); Number<NumAElementwiseTensor>{},
v_in,
arg.wei_element_op_( arg.input_(g, n, c, di, hi, wi),
v_wei, g,
ck::type_convert<float>(arg.weight_(g, k, c, z, y, x))); n,
c,
v_acc += v_in * v_wei; di,
hi,
wi);
ExecuteElementwiseOp(arg.wei_element_op_,
arg.elementwise_b_tensors_,
Number<NumBElementwiseTensor>{},
v_wei,
arg.weight_(g, k, c, z, y, x),
g,
k,
c,
z,
y,
x);
v_acc += ck::type_convert<float>(v_in) *
ck::type_convert<float>(v_wei);
} }
} }
} }
} }
} }
OutDataType v_acc_converted = ck::type_convert<OutDataType>(v_acc);
float v_out; OutDataType& v_out = arg.output_(g, n, k, d_o, ho, wo);
ExecuteElementwiseOp(arg.out_element_op_,
arg.out_element_op_(v_out, v_acc); arg.elementwise_d_tensors_,
Number<NumDElementwiseTensor>{},
arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out); v_out,
v_acc_converted,
g,
n,
k,
d_o,
ho,
wo);
}; };
make_ParallelTensorFunctor(func, make_ParallelTensorFunctor(func,
...@@ -280,6 +369,36 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -280,6 +369,36 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
}; };
template <typename... Args,
typename ElementwiseOp,
typename ElementwiseTensor,
typename NumTensor,
typename T>
static void ExecuteElementwiseOp(ElementwiseOp& elementwise_op,
ElementwiseTensor& elementwise_tensors,
NumTensor,
T& y,
const T& x,
Args... dims)
{
if constexpr(NumTensor::value == 0)
{
elementwise_op(y, x);
}
else if constexpr(NumTensor::value == 1)
{
elementwise_op(y, x, elementwise_tensors[0](dims...));
}
else if constexpr(NumTensor::value == 2)
{
elementwise_op(y, x, elementwise_tensors[0](dims...), elementwise_tensors[1](dims...));
}
else
{
throw std::runtime_error("ElementOp not supported in reference.");
}
}
static constexpr bool IsValidCompilationParameter() static constexpr bool IsValidCompilationParameter()
{ {
// TODO: properly implement this check // TODO: properly implement this check
...@@ -291,16 +410,20 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -291,16 +410,20 @@ struct ReferenceConvFwd : public device::BaseOperator
return NDimSpatial >= 1 && NDimSpatial <= 3; return NDimSpatial >= 1 && NDimSpatial <= 3;
} }
static auto MakeArgument(const Tensor<InDataType>& input, static auto MakeArgument(
const Tensor<WeiDataType>& weight, const Tensor<InDataType>& input,
Tensor<OutDataType>& output, const Tensor<WeiDataType>& weight,
std::vector<ck::index_t> conv_filter_strides, Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_left_pads,
InElementwiseOperation in_element_op, std::vector<ck::index_t> input_right_pads,
WeiElementwiseOperation wei_element_op, InElementwiseOperation in_element_op,
OutElementwiseOperation out_element_op) WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const std::array<Tensor<InDataType>, NumAElementwiseTensor>& elementwise_a_tensors = {},
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors = {},
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors = {})
{ {
return Argument{input, return Argument{input,
weight, weight,
...@@ -311,7 +434,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -311,7 +434,10 @@ struct ReferenceConvFwd : public device::BaseOperator
input_right_pads, input_right_pads,
in_element_op, in_element_op,
wei_element_op, wei_element_op,
out_element_op}; out_element_op,
elementwise_a_tensors,
elementwise_b_tensors,
elementwise_d_tensors};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -20,7 +20,9 @@ template <typename ADataType, ...@@ -20,7 +20,9 @@ template <typename ADataType,
typename AccDataType, typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation,
typename ComputeTypeA = ADataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceGemm : public device::BaseOperator struct ReferenceGemm : public device::BaseOperator
{ {
// Argument // Argument
...@@ -64,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -64,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
ADataType v_a; ComputeTypeA v_a;
BDataType v_b; ComputeTypeB v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation // use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation, if constexpr(is_same_v<AElementwiseOperation,
...@@ -92,11 +94,11 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -92,11 +94,11 @@ struct ReferenceGemm : public device::BaseOperator
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b); ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
} }
AccDataType v_c; CDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
arg.c_m_n_(m, n) = ck::type_convert<CDataType>(v_c); arg.c_m_n_(m, n) = v_c;
}; };
make_ParallelTensorFunctor( make_ParallelTensorFunctor(
......
...@@ -20,8 +20,9 @@ template <typename XDataType, ...@@ -20,8 +20,9 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename SaveMeanInvStdDataType,
typename AccElementwiseOperation> typename ComputeDataType,
typename YElementwiseOperation>
struct ReferenceGroupnorm : public device::BaseOperator struct ReferenceGroupnorm : public device::BaseOperator
{ {
// x = [N, H, W, G, C] // x = [N, H, W, G, C]
...@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma, const Tensor<GammaDataType>& gamma,
const Tensor<BetaDataType>& beta, const Tensor<BetaDataType>& beta,
Tensor<YDataType>& y, Tensor<YDataType>& y,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean,
Tensor<SaveMeanInvStdDataType>& save_inv_std,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
AccDataType epsilon) ComputeDataType epsilon)
: x_(x), : x_(x),
gamma_(gamma), gamma_(gamma),
beta_(beta), beta_(beta),
y_(y), y_(y),
acc_elementwise_op_(acc_elementwise_op), save_mean_(save_mean),
save_inv_std_(save_inv_std),
y_elementwise_op_(y_elementwise_op),
lengths_(lengths), lengths_(lengths),
epsilon_(epsilon) epsilon_(epsilon)
{ {
...@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<XDataType> gamma_; const Tensor<XDataType> gamma_;
const Tensor<XDataType> beta_; const Tensor<XDataType> beta_;
Tensor<YDataType>& y_; Tensor<YDataType>& y_;
AccElementwiseOperation acc_elementwise_op_; Tensor<SaveMeanInvStdDataType>& save_mean_;
Tensor<SaveMeanInvStdDataType>& save_inv_std_;
YElementwiseOperation y_elementwise_op_;
std::vector<index_t> lengths_; std::vector<index_t> lengths_;
AccDataType epsilon_; ComputeDataType epsilon_;
}; };
// Invoker // Invoker
...@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator
int G = arg.lengths_[3]; int G = arg.lengths_[3];
int C = arg.lengths_[4]; int C = arg.lengths_[4];
Tensor<AccDataType> mean({N, G}); Tensor<ComputeDataType> mean({N, G});
Tensor<AccDataType> var({N, G}); Tensor<ComputeDataType> var({N, G});
// Compute mean & var in [H, W, C] by Welford Algorithm // Compute mean & var in [H, W, C] by Welford Algorithm
// TODO - parallel for each HWC // TODO - parallel for each HWC
...@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator
{ {
for(int g = 0; g < G; ++g) for(int g = 0; g < G; ++g)
{ {
AccDataType mean_val = type_convert<AccDataType>(0.0f); ComputeDataType mean_val = type_convert<ComputeDataType>(0.0f);
AccDataType var_val = type_convert<AccDataType>(0.0f); ComputeDataType var_val = type_convert<ComputeDataType>(0.0f);
int32_t curr_count = 0; int32_t curr_count = 0;
for(int h = 0; h < H; ++h) for(int h = 0; h < H; ++h)
{ {
...@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
for(int c = 0; c < C; ++c) for(int c = 0; c < C; ++c)
{ {
curr_count++; curr_count++;
AccDataType x = type_convert<AccDataType>(arg.x_(n, h, w, g, c)); ComputeDataType x =
AccDataType delta = x - mean_val; type_convert<ComputeDataType>(arg.x_(n, h, w, g, c));
ComputeDataType delta = x - mean_val;
mean_val += delta / curr_count; mean_val += delta / curr_count;
AccDataType delta2 = x - mean_val; ComputeDataType delta2 = x - mean_val;
var_val += delta * delta2; var_val += delta * delta2;
} }
} }
...@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator
mean(n, g) = mean_val; mean(n, g) = mean_val;
var(n, g) = var_val / curr_count; var(n, g) = var_val / curr_count;
arg.save_mean_(n, g) = ck::type_convert<SaveMeanInvStdDataType>(mean(n, g));
ComputeDataType divisor =
static_cast<ComputeDataType>(1) / ck::math::sqrt(var(n, g) + arg.epsilon_);
arg.save_inv_std_(n, g) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
} }
} }
...@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator
{ {
for(int c = 0; c < C; ++c) for(int c = 0; c < C; ++c)
{ {
AccDataType x = type_convert<AccDataType>(arg.x_(n, h, w, g, c)); ComputeDataType x =
AccDataType gamma = type_convert<AccDataType>(arg.gamma_(g, c)); type_convert<ComputeDataType>(arg.x_(n, h, w, g, c));
AccDataType beta = type_convert<AccDataType>(arg.beta_(g, c)); ComputeDataType gamma =
AccDataType mean_val = type_convert<AccDataType>(mean(n, g)); type_convert<ComputeDataType>(arg.gamma_(g, c));
AccDataType var_val = type_convert<AccDataType>(var(n, g)); ComputeDataType beta =
AccDataType y = gamma * (x - mean_val) / type_convert<ComputeDataType>(arg.beta_(g, c));
ck::math::sqrt(arg.epsilon_ + var_val) + ComputeDataType mean_val =
beta; type_convert<ComputeDataType>(mean(n, g));
arg.acc_elementwise_op_(y, y); ComputeDataType var_val = type_convert<ComputeDataType>(var(n, g));
ComputeDataType y = gamma * (x - mean_val) /
ck::math::sqrt(arg.epsilon_ + var_val) +
beta;
arg.y_elementwise_op_(y, y);
arg.y_(n, h, w, g, c) = type_convert<YDataType>(y); arg.y_(n, h, w, g, c) = type_convert<YDataType>(y);
} }
} }
...@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma, const Tensor<GammaDataType>& gamma,
const Tensor<BetaDataType>& beta, const Tensor<BetaDataType>& beta,
Tensor<YDataType>& y, Tensor<YDataType>& y,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean,
Tensor<SaveMeanInvStdDataType>& save_inv_std,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
AccDataType epsilon) ComputeDataType epsilon)
{ {
return Argument{x, gamma, beta, y, acc_elementwise_op, lengths, epsilon}; return Argument{
x, gamma, beta, y, save_mean, save_inv_std, y_elementwise_op, lengths, epsilon};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DGammaDataType,
typename DBetaDataType,
typename DXDataType,
typename ComputeDataType>
struct ReferenceGroupnormBwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<DYDataType>& dy_nhwgc,
const Tensor<XDataType>& x_nhwgc,
const Tensor<GammaDataType>& gamma_gc,
const Tensor<MeanInvStdDataType>& mean_ng,
const Tensor<MeanInvStdDataType>& inv_std_ng,
Tensor<DGammaDataType>& dgamma_gc,
Tensor<DBetaDataType>& dbeta_gc,
Tensor<DXDataType>& dx_nhwgc,
const std::vector<index_t> lengths)
: dy_nhwgc_(dy_nhwgc),
x_nhwgc_(x_nhwgc),
gamma_gc_(gamma_gc),
mean_ng_(mean_ng),
inv_std_ng_(inv_std_ng),
dgamma_gc_(dgamma_gc),
dbeta_gc_(dbeta_gc),
dx_nhwgc_(dx_nhwgc),
lengths_(lengths)
{
}
const Tensor<DYDataType>& dy_nhwgc_;
const Tensor<XDataType>& x_nhwgc_;
const Tensor<GammaDataType>& gamma_gc_;
const Tensor<MeanInvStdDataType>& mean_ng_;
const Tensor<MeanInvStdDataType>& inv_std_ng_;
Tensor<DGammaDataType>& dgamma_gc_;
Tensor<DBetaDataType>& dbeta_gc_;
Tensor<DXDataType>& dx_nhwgc_;
std::vector<index_t> lengths_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
int N = arg.lengths_[0];
int H = arg.lengths_[1];
int W = arg.lengths_[2];
int G = arg.lengths_[3];
int C = arg.lengths_[4];
// Calculate dgamma and dbeta
for(int g = 0; g < G; ++g)
for(int c = 0; c < C; ++c)
{
ComputeDataType dgamma = 0;
ComputeDataType dbeta = 0;
for(int n = 0; n < N; ++n)
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
{
ComputeDataType dy =
ck::type_convert<ComputeDataType>(arg.dy_nhwgc_(n, h, w, g, c));
ComputeDataType x =
ck::type_convert<ComputeDataType>(arg.x_nhwgc_(n, h, w, g, c));
ComputeDataType mean =
ck::type_convert<ComputeDataType>(arg.mean_ng_(n, g));
ComputeDataType rstd =
ck::type_convert<ComputeDataType>(arg.inv_std_ng_(n, g));
dgamma += dy * rstd * (x - mean);
dbeta += dy;
}
arg.dgamma_gc_(g, c) = ck::type_convert<DGammaDataType>(dgamma);
arg.dbeta_gc_(g, c) = ck::type_convert<DBetaDataType>(dbeta);
}
// Calculate dx
int reduce_size = H * W * C;
for(int n = 0; n < N; ++n)
for(int g = 0; g < G; ++g)
{
ComputeDataType ds = 0;
ComputeDataType db = 0;
ComputeDataType mean = ck::type_convert<ComputeDataType>(arg.mean_ng_(n, g));
ComputeDataType rstd = ck::type_convert<ComputeDataType>(arg.inv_std_ng_(n, g));
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
for(int c = 0; c < C; ++c)
{
ComputeDataType dy =
ck::type_convert<ComputeDataType>(arg.dy_nhwgc_(n, h, w, g, c));
ComputeDataType x =
ck::type_convert<ComputeDataType>(arg.x_nhwgc_(n, h, w, g, c));
ComputeDataType gamma =
ck::type_convert<ComputeDataType>(arg.gamma_gc_(g, c));
ds += dy * gamma * x;
db += dy * gamma;
}
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
for(int c = 0; c < C; ++c)
{
ComputeDataType dy =
ck::type_convert<ComputeDataType>(arg.dy_nhwgc_(n, h, w, g, c));
ComputeDataType x =
ck::type_convert<ComputeDataType>(arg.x_nhwgc_(n, h, w, g, c));
ComputeDataType gamma =
ck::type_convert<ComputeDataType>(arg.gamma_gc_(g, c));
ComputeDataType b =
(db * mean - ds) * rstd * rstd * rstd / reduce_size;
ComputeDataType c1 = -b * mean - db * rstd / reduce_size;
arg.dx_nhwgc_(n, h, w, g, c) =
ck::type_convert<DXDataType>(dy * gamma * rstd + b * x + c1);
}
}
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<DYDataType>& dy_nhwgc,
const Tensor<XDataType>& x_nhwgc,
const Tensor<GammaDataType>& gamma_gc,
const Tensor<MeanInvStdDataType>& mean_ng,
const Tensor<MeanInvStdDataType>& inv_std_ng,
Tensor<DGammaDataType>& dgamma_gc,
Tensor<DBetaDataType>& dbeta_gc,
Tensor<DXDataType>& dx_nhwgc,
const std::vector<index_t> lengths)
{
return Argument{dy_nhwgc,
x_nhwgc,
gamma_gc,
mean_ng,
inv_std_ng,
dgamma_gc,
dbeta_gc,
dx_nhwgc,
lengths};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceGroupnormBwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <type_traits>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
/**
* \brief Reference implementation for image to column.
*
* Input tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* Output tensor descriptor has [G * N * Do * Ho * Wo, Z * Y * X * C] data layout.
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout.
* \tparam InDataType Input Data Type.
* \tparam OutDataType Output Data Type.
*/
template <ck::index_t NDimSpatial,
typename ImageLayout,
typename InDataType,
typename OutDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceImageToColumn : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
public:
Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
: input_{input},
output_{output},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
filter_spatial_lengths_{filter_spatial_lengths}
{
initOutputSpatialLengths();
}
const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_;
private:
void initOutputSpatialLengths()
{
constexpr auto input_offset_to_spatial = 3;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back(
(input_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
in_right_pads_[i] - x_eff) /
conv_strides_[i] +
1);
}
}
};
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceImageToColumn::Argument;
float Run(const Argument& arg)
{
if(!(arg.input_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.output_.GetNumOfDimension() == 3))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
const index_t G = arg.input_.GetLengths()[0];
const index_t N = arg.input_.GetLengths()[1];
const index_t C = arg.input_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n, auto wo) {
index_t row = n * Wo + wo;
index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c)
{
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{
InDataType v_in = arg.input_(g, n, c, wi);
arg.output_(g, row, column) = ck::type_convert<OutDataType>(v_in);
}
column++;
}
}
};
make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NDimSpatial == 2)
{
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto g, auto n, auto ho, auto wo) {
index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c)
{
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) < arg.input_.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{
InDataType v_in = arg.input_(g, n, c, hi, wi);
arg.output_(g, row, column) =
ck::type_convert<OutDataType>(v_in);
}
column++;
}
}
}
};
make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NDimSpatial == 3)
{
const index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{
auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.input_.GetLengths()[3] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.input_.GetLengths()[4] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5])
{
InDataType v_in = arg.input_(g, n, c, di, hi, wi);
arg.output_(g, row, column) =
ck::type_convert<OutDataType>(v_in);
}
column++;
}
}
}
}
};
make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(
std::thread::hardware_concurrency());
return 0;
}
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
using namespace tensor_layout::convolution;
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
std::is_same_v<ImageLayout, GNDHWC>))
{
return false;
}
if constexpr(!(NDimSpatial >= 1 && NDimSpatial <= 3))
{
return false;
}
return true;
}
bool IsSupportedArgument(const Argument& arg)
{
const ck::index_t G = arg.input_.GetLengths()[0];
const ck::index_t N = arg.input_.GetLengths()[1];
const ck::index_t C = arg.input_.GetLengths()[2];
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX =
C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(G) &&
arg.output_.GetLengths()[1] == static_cast<std::size_t>(NDoHoWo) &&
arg.output_.GetLengths()[2] == static_cast<std::size_t>(CZYX)))
{
return false;
}
if(G != 1)
{
return false;
}
return true;
}
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
return Argument{input,
output,
filter_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceImageToColumn"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -20,14 +20,16 @@ template <typename XDataType, ...@@ -20,14 +20,16 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename SaveMeanInvStdDataType,
typename AccElementwiseOperation, typename ComputeDataType,
typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct ReferenceLayernorm : public device::BaseOperator struct ReferenceLayernorm : public device::BaseOperator
{ {
// TODO - support generic layernorm // TODO - support generic layernorm
static_assert((Rank == 2 && NumReduceDim == 1), "Only support 2D version so far"); static_assert((Rank == 2 && NumReduceDim == 1) || (Rank == 4 && NumReduceDim == 3),
"Only support 2D & 4D version so far");
// Argument // Argument
struct Argument : public device::BaseArgument struct Argument : public device::BaseArgument
...@@ -36,15 +38,19 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -36,15 +38,19 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n, const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n, Tensor<YDataType>& y_m_n,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean_m,
Tensor<SaveMeanInvStdDataType>& save_inv_std_m,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon) ComputeDataType epsilon)
: x_m_n_(x_m_n), : x_m_n_(x_m_n),
gamma_n_(gamma_n), gamma_n_(gamma_n),
beta_n_(beta_n), beta_n_(beta_n),
y_m_n_(y_m_n), y_m_n_(y_m_n),
acc_elementwise_op_(acc_elementwise_op), save_mean_m_(save_mean_m),
save_inv_std_m_(save_inv_std_m),
y_elementwise_op_(y_elementwise_op),
lengths_(lengths), lengths_(lengths),
reduceDims_(reduceDims), reduceDims_(reduceDims),
epsilon_(epsilon) epsilon_(epsilon)
...@@ -55,22 +61,24 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -55,22 +61,24 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<XDataType> gamma_n_; const Tensor<XDataType> gamma_n_;
const Tensor<XDataType> beta_n_; const Tensor<XDataType> beta_n_;
Tensor<YDataType>& y_m_n_; Tensor<YDataType>& y_m_n_;
AccElementwiseOperation acc_elementwise_op_; Tensor<SaveMeanInvStdDataType>& save_mean_m_;
Tensor<SaveMeanInvStdDataType>& save_inv_std_m_;
YElementwiseOperation y_elementwise_op_;
std::vector<index_t> lengths_; std::vector<index_t> lengths_;
std::vector<index_t> reduceDims_; std::vector<index_t> reduceDims_;
AccDataType epsilon_; ComputeDataType epsilon_;
}; };
// Invoker // Invoker
struct Invoker : public device::BaseInvoker struct Invoker : public device::BaseInvoker
{ {
float Run(const Argument& arg) float Run2D(const Argument& arg)
{ {
int M = arg.lengths_[0]; int M = arg.lengths_[0];
int N = arg.lengths_[1]; int N = arg.lengths_[1];
Tensor<AccDataType> mean({M}); Tensor<ComputeDataType> mean({M});
Tensor<AccDataType> var({M}); Tensor<ComputeDataType> var({M});
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
...@@ -79,7 +87,7 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -79,7 +87,7 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
mean(m) += x_val; mean(m) += x_val;
var(m) += x_val * x_val; var(m) += x_val * x_val;
} }
...@@ -90,22 +98,91 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -90,22 +98,91 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
AccDataType divisor = ComputeDataType divisor =
static_cast<AccDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_); static_cast<ComputeDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) * divisor; auto gamma_val = ck::type_convert<ComputeDataType>(arg.gamma_n_(n));
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n); auto beta_val = ck::type_convert<ComputeDataType>(arg.beta_n_(n));
arg.acc_elementwise_op_(y_val, y_val); auto y_val = (x_val - mean(m)) * divisor;
y_val = (y_val * gamma_val) + beta_val;
arg.y_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val); arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
} }
arg.save_mean_m_(m) = ck::type_convert<SaveMeanInvStdDataType>(mean(m));
arg.save_inv_std_m_(m) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
} }
return 0; return 0;
} }
float Run4D(const Argument& arg)
{
int N = arg.lengths_[0];
int H = arg.lengths_[1];
int W = arg.lengths_[2];
int C = arg.lengths_[3];
Tensor<ComputeDataType> mean({N});
Tensor<ComputeDataType> var({N});
int reduce_length = H * W * C;
for(int n = 0; n < N; ++n)
{
mean(n) = 0;
var(n) = 0;
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
for(int c = 0; c < C; ++c)
{
auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(n, h, w, c));
mean(n) += x_val;
var(n) += x_val * x_val;
}
mean(n) = mean(n) / reduce_length;
var(n) = (var(n) / reduce_length) - (mean(n) * mean(n));
}
for(int n = 0; n < N; ++n)
{
ComputeDataType divisor =
static_cast<ComputeDataType>(1) / ck::math::sqrt(var(n) + arg.epsilon_);
for(int h = 0; h < H; ++h)
for(int w = 0; w < W; ++w)
for(int c = 0; c < C; ++c)
{
auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(n, h, w, c));
auto gamma_val =
ck::type_convert<ComputeDataType>(arg.gamma_n_(h, w, c));
auto beta_val = ck::type_convert<ComputeDataType>(arg.beta_n_(h, w, c));
auto y_val = (x_val - mean(n)) * divisor;
y_val = (y_val * gamma_val) + beta_val;
arg.y_elementwise_op_(y_val, y_val);
arg.y_m_n_(n, h, w, c) = ck::type_convert<YDataType>(y_val);
}
arg.save_mean_m_(n) = ck::type_convert<SaveMeanInvStdDataType>(mean(n));
arg.save_inv_std_m_(n) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
}
return 0;
}
float Run(const Argument& arg)
{
if(arg.lengths_.size() == 2)
return Run2D(arg);
else if(arg.lengths_.size() == 4)
return Run4D(arg);
return 0;
}
float Run(const device::BaseArgument* p_arg, float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override const StreamConfig& /* stream_config */ = StreamConfig{}) override
{ {
...@@ -123,30 +200,39 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -123,30 +200,39 @@ struct ReferenceLayernorm : public device::BaseOperator
{ {
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg); const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
// TODO - support generic layernorm if(p_arg_->lengths_.size() == 2 && p_arg_->reduceDims_.size() == 1 &&
if(p_arg_->lengths_.size() != 2) p_arg_->reduceDims_[0] == 1)
return false; return true;
if(p_arg_->reduceDims_.size() != 1)
return false;
if(p_arg_->reduceDims_[0] != 1) else if(p_arg_->lengths_.size() == 4 && p_arg_->reduceDims_.size() == 3 &&
return false; p_arg_->reduceDims_[0] == 1 && p_arg_->reduceDims_[1] == 2 &&
p_arg_->reduceDims_[2] == 3)
return true;
return true; return false;
} }
static auto MakeArgument(const Tensor<XDataType>& x_m_n, static auto MakeArgument(const Tensor<XDataType>& x_m_n,
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n, const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n, Tensor<YDataType>& y_m_n,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean_m,
Tensor<SaveMeanInvStdDataType>& save_inv_std_m,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon) ComputeDataType epsilon)
{ {
return Argument{ return Argument{x_m_n,
x_m_n, gamma_n, beta_n, y_m_n, acc_elementwise_op, lengths, reduceDims, epsilon}; gamma_n,
beta_n,
y_m_n,
save_mean_m,
save_inv_std_m,
y_elementwise_op,
lengths,
reduceDims,
epsilon};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DGammaDataType,
typename DBetaDataType,
typename DXDataType,
typename ComputeDataType>
struct ReferenceLayernormBwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<DYDataType>& dy_m_n,
const Tensor<XDataType>& x_m_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<MeanInvStdDataType>& mean_m,
const Tensor<MeanInvStdDataType>& inv_std_m,
Tensor<DGammaDataType>& dgamma_n,
Tensor<DBetaDataType>& dbeta_n,
Tensor<DXDataType>& dx_m_n,
const std::vector<index_t> lengths)
: dy_m_n_(dy_m_n),
x_m_n_(x_m_n),
gamma_n_(gamma_n),
mean_m_(mean_m),
inv_std_m_(inv_std_m),
dgamma_n_(dgamma_n),
dbeta_n_(dbeta_n),
dx_m_n_(dx_m_n),
lengths_(lengths)
{
}
const Tensor<DYDataType>& dy_m_n_;
const Tensor<XDataType>& x_m_n_;
const Tensor<GammaDataType>& gamma_n_;
const Tensor<MeanInvStdDataType>& mean_m_;
const Tensor<MeanInvStdDataType>& inv_std_m_;
Tensor<DGammaDataType>& dgamma_n_;
Tensor<DBetaDataType>& dbeta_n_;
Tensor<DXDataType>& dx_m_n_;
std::vector<index_t> lengths_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
float Run(const Argument& arg)
{
int M = arg.lengths_[0];
int N = arg.lengths_[1];
// Calculate dgamma and dbeta
for(int n = 0; n < N; ++n)
{
ComputeDataType dgamma = 0;
ComputeDataType dbeta = 0;
for(int m = 0; m < M; ++m)
{
ComputeDataType dy = ck::type_convert<ComputeDataType>(arg.dy_m_n_(m, n));
ComputeDataType x = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
ComputeDataType mean = ck::type_convert<ComputeDataType>(arg.mean_m_(m));
ComputeDataType rstd = ck::type_convert<ComputeDataType>(arg.inv_std_m_(m));
dgamma += dy * rstd * (x - mean);
dbeta += dy;
}
arg.dgamma_n_(n) = ck::type_convert<DGammaDataType>(dgamma);
arg.dbeta_n_(n) = ck::type_convert<DBetaDataType>(dbeta);
}
// Calculate dx
for(int m = 0; m < M; ++m)
{
ComputeDataType ds = 0;
ComputeDataType db = 0;
ComputeDataType mean = ck::type_convert<ComputeDataType>(arg.mean_m_(m));
ComputeDataType rstd = ck::type_convert<ComputeDataType>(arg.inv_std_m_(m));
for(int n = 0; n < N; ++n)
{
ComputeDataType dy = ck::type_convert<ComputeDataType>(arg.dy_m_n_(m, n));
ComputeDataType x = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
ComputeDataType gamma = ck::type_convert<ComputeDataType>(arg.gamma_n_(n));
ds += dy * gamma * x;
db += dy * gamma;
}
for(int n = 0; n < N; ++n)
{
ComputeDataType dy = ck::type_convert<ComputeDataType>(arg.dy_m_n_(m, n));
ComputeDataType x = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
ComputeDataType gamma = ck::type_convert<ComputeDataType>(arg.gamma_n_(n));
ComputeDataType b = (db * mean - ds) * rstd * rstd * rstd / N;
ComputeDataType c = -b * mean - db * rstd / N;
arg.dx_m_n_(m, n) = ck::type_convert<DXDataType>(dy * gamma * rstd + b * x + c);
}
}
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<DYDataType>& dy_m_n,
const Tensor<XDataType>& x_m_n,
const Tensor<GammaDataType>& gamma_n,
const Tensor<MeanInvStdDataType>& mean_m,
const Tensor<MeanInvStdDataType>& inv_std_m,
Tensor<DGammaDataType>& dgamma_n,
Tensor<DBetaDataType>& dbeta_n,
Tensor<DXDataType>& dx_m_n,
const std::vector<index_t> lengths)
{
return Argument{
dy_m_n, x_m_n, gamma_n, mean_m, inv_std_m, dgamma_n, dbeta_n, dx_m_n, lengths};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceLayernormBwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -53,7 +53,16 @@ struct ReferenceMaxPoolBwd : public device::BaseOperator ...@@ -53,7 +53,16 @@ struct ReferenceMaxPoolBwd : public device::BaseOperator
{ {
int index = arg.indices_.mData[i]; int index = arg.indices_.mData[i];
if(index >= 0 && index < din_length) if(index >= 0 && index < din_length)
buf[index] += ck::type_convert<ConputeDataType>(arg.dout_.mData[i]); {
if constexpr(is_same_v<ConputeDataType, bhalf_t>)
{
float buf_val = ck::type_convert<float>(buf[index]);
buf_val += ck::type_convert<float>(arg.dout_.mData[i]);
buf[index] = ck::type_convert<ConputeDataType>(buf_val);
}
else
buf[index] += ck::type_convert<ConputeDataType>(arg.dout_.mData[i]);
}
} }
for(int i = 0; i < din_length; ++i) for(int i = 0; i < din_length; ++i)
......
...@@ -256,10 +256,12 @@ struct ReferencePoolingFwd : public device::BaseOperator ...@@ -256,10 +256,12 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y) for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{ {
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0]; ck::index_t hi = ho * arg.window_strides_[0] +
y * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x) for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{ {
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1]; ck::index_t wi = wo * arg.window_strides_[1] +
x * arg.window_dilations_[1] - arg.in_left_pads_[1];
if(hi >= 0 && if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) && hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 && wi >= 0 &&
......
...@@ -17,13 +17,16 @@ namespace instance { ...@@ -17,13 +17,16 @@ namespace instance {
using F64 = double; using F64 = double;
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using F8 = ck::f8_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using Empty_Tuple = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
using BF16_Tuple = ck::Tuple<BF16>;
using F16_Tuple = ck::Tuple<F16>; using F16_Tuple = ck::Tuple<F16>;
using F16_F16_Tuple = ck::Tuple<F16, F16>; using F16_F16_Tuple = ck::Tuple<F16, F16>;
...@@ -31,6 +34,9 @@ using F64_Tuple = ck::Tuple<F64>; ...@@ -31,6 +34,9 @@ using F64_Tuple = ck::Tuple<F64>;
using F32_Tuple = ck::Tuple<F32>; using F32_Tuple = ck::Tuple<F32>;
using I32_Tuple = ck::Tuple<I32>; using I32_Tuple = ck::Tuple<I32>;
using I32_F32_Tuple = ck::Tuple<I32, F32>; using I32_F32_Tuple = ck::Tuple<I32, F32>;
using I8_Tuple = ck::Tuple<I8>;
using F32_F32_Tuple = ck::Tuple<F32, F32>;
// GEMM layout // GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -95,9 +101,11 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; ...@@ -95,9 +101,11 @@ using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
using FastGelu = ck::tensor_operation::element_wise::FastGelu; using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu; using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
using Add = ck::tensor_operation::element_wise::Add;
template <typename Activation> template <typename Activation>
using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>; using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<Activation>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_avgpool_bwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, F16, F16, NDHWC, NDHWC>>>&);
#endif
#ifdef CK_ENABLE_BF16
void add_device_avgpool_bwd_ndhwc_bf16_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, BF16, BF16, NDHWC, NDHWC>>>&);
#endif
#ifdef CK_ENABLE_FP32
void add_device_avgpool_bwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DeviceAvgPoolBwd<3, F32, F32, NDHWC, NDHWC>>>&);
#endif
template <typename DOutDataType, typename DInDataType, typename InLayout, typename OutLayout>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::
DeviceAvgPoolBwd<3, DOutDataType, DInDataType, InLayout, OutLayout>>
{
using DeviceOp = DeviceAvgPoolBwd<3, DOutDataType, DInDataType, InLayout, OutLayout>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<InLayout, NDHWC> && is_same_v<OutLayout, NDHWC>)
{
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<DOutDataType, F16> && is_same_v<DInDataType, F16>)
add_device_avgpool_bwd_ndhwc_f16_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<DOutDataType, BF16> && is_same_v<DInDataType, BF16>)
add_device_avgpool_bwd_ndhwc_bf16_instances(op_ptrs);
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<DOutDataType, F32> && is_same_v<DInDataType, F32>)
add_device_avgpool_bwd_ndhwc_f32_instances(op_ptrs);
#endif
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -16,26 +16,26 @@ namespace tensor_operation { ...@@ -16,26 +16,26 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// FP16 #ifdef CK_ENABLE_FP16
void add_device_batchnorm_backward_rank_4_3_f16_instances( void add_device_batchnorm_backward_rank_4_3_f16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&); DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
#endif
// FP32 #ifdef CK_ENABLE_FP32
void add_device_batchnorm_backward_rank_4_3_f32_instances( void add_device_batchnorm_backward_rank_4_3_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&); DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
#endif
// BF16 #ifdef CK_ENABLE_BF16
void add_device_batchnorm_backward_rank_4_3_bf16_instances( void add_device_batchnorm_backward_rank_4_3_bf16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&); DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
#endif
// FP64 #ifdef CK_ENABLE_FP64
void add_device_batchnorm_backward_rank_4_3_f64_instances( void add_device_batchnorm_backward_rank_4_3_f64_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&); DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
#endif
template <typename XDataType, template <typename XDataType,
typename DxDataType, typename DxDataType,
typename DyDataType, typename DyDataType,
...@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory< ...@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> && if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> && is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> && is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
...@@ -83,37 +83,43 @@ struct DeviceOperationInstanceFactory< ...@@ -83,37 +83,43 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs); add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> && #endif
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> && #ifdef CK_ENABLE_FP32
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> && if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
is_same_v<MeanVarDataType, F32>) is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs); add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> && #endif
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> && #ifdef CK_ENABLE_BF16
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> && if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
is_same_v<MeanVarDataType, F32>) is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs); add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> && #endif
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> && #ifdef CK_ENABLE_FP64
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> && if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
is_same_v<MeanVarDataType, F64>) is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs); add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment