Unverified Commit 9f8ab221 authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into add_int8_wmma_example_instance

parents 755ace59 b4fc4d0b
...@@ -9,15 +9,9 @@ namespace ck { ...@@ -9,15 +9,9 @@ namespace ck {
using bhalf_t = ushort; using bhalf_t = ushort;
using half_t = _Float16; using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 using int4_t = _BitInt(4);
using int4_t = _BitInt(4); using f8_t = _BitInt(8);
#endif using bf8_t = unsigned _BitInt(8);
#if defined CK_ENABLE_FP8
using f8_t = _BitInt(8);
#endif
#if defined CK_ENABLE_BF8
using bf8_t = unsigned _BitInt(8);
#endif
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
...@@ -148,23 +142,19 @@ struct scalar_type<int4_t> ...@@ -148,23 +142,19 @@ struct scalar_type<int4_t>
}; };
#endif #endif
#if defined CK_ENABLE_FP8
template <> template <>
struct scalar_type<f8_t> struct scalar_type<f8_t>
{ {
using type = f8_t; using type = f8_t;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
#endif
#if defined CK_ENABLE_BF8
template <> template <>
struct scalar_type<bf8_t> struct scalar_type<bf8_t>
{ {
using type = bf8_t; using type = bf8_t;
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
#endif
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
...@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type; ...@@ -968,24 +958,20 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type; using int8x64_t = typename vector_type<int8_t, 64>::type;
// f8 // f8
#if defined CK_ENABLE_FP8
using f8x2_t = typename vector_type<f8_t, 2>::type; using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type; using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type; using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type; using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type; using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type; using f8x64_t = typename vector_type<f8_t, 64>::type;
#endif
// bf8 // bf8
#if defined CK_ENABLE_BF8
using bf8x2_t = typename vector_type<bf8_t, 2>::type; using bf8x2_t = typename vector_type<bf8_t, 2>::type;
using bf8x4_t = typename vector_type<bf8_t, 4>::type; using bf8x4_t = typename vector_type<bf8_t, 4>::type;
using bf8x8_t = typename vector_type<bf8_t, 8>::type; using bf8x8_t = typename vector_type<bf8_t, 8>::type;
using bf8x16_t = typename vector_type<bf8_t, 16>::type; using bf8x16_t = typename vector_type<bf8_t, 16>::type;
using bf8x32_t = typename vector_type<bf8_t, 32>::type; using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using bf8x64_t = typename vector_type<bf8_t, 64>::type; using bf8x64_t = typename vector_type<bf8_t, 64>::type;
#endif
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
...@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t> ...@@ -1033,7 +1019,6 @@ struct NumericLimits<int4_t>
}; };
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template <> template <>
struct NumericLimits<f8_t> struct NumericLimits<f8_t>
{ {
...@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t> ...@@ -1056,9 +1041,7 @@ struct NumericLimits<f8_t>
__host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); } __host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); }
}; };
#endif
#if defined CK_ENABLE_BF8
template <> template <>
struct NumericLimits<bf8_t> struct NumericLimits<bf8_t>
{ {
...@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t> ...@@ -1081,7 +1064,6 @@ struct NumericLimits<bf8_t>
__host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); } __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); }
}; };
#endif
template <typename T> template <typename T>
struct NumericUtils struct NumericUtils
...@@ -1120,22 +1102,18 @@ struct NumericUtils<half_t> ...@@ -1120,22 +1102,18 @@ struct NumericUtils<half_t>
using bitwise_type = uint16_t; using bitwise_type = uint16_t;
}; };
#if defined CK_ENABLE_FP8
template <> template <>
struct NumericUtils<f8_t> struct NumericUtils<f8_t>
{ {
static constexpr int exp = 4; static constexpr int exp = 4;
static constexpr int mant = 3; static constexpr int mant = 3;
}; };
#endif
#if defined CK_ENABLE_BF8
template <> template <>
struct NumericUtils<bf8_t> struct NumericUtils<bf8_t>
{ {
static constexpr int exp = 5; static constexpr int exp = 5;
static constexpr int mant = 2; static constexpr int mant = 2;
}; };
#endif //
} // namespace ck } // namespace ck
...@@ -140,10 +140,36 @@ struct DynamicBuffer ...@@ -140,10 +140,36 @@ struct DynamicBuffer
} }
else if constexpr(Op == InMemoryDataOperationEnum::Add) else if constexpr(Op == InMemoryDataOperationEnum::Add)
{ {
auto tmp = this->template Get<X>(i, is_valid_element); auto tmp = this->template Get<X>(i, is_valid_element);
this->template Set<X>(i, is_valid_element, x + tmp); using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
// tmp += x; // handle bfloat addition
// this->template Set<X>(i, is_valid_element, tmp); if constexpr(is_same_v<scalar_t, bhalf_t>)
{
if constexpr(is_scalar_type<X>::value)
{
// Scalar type
auto result =
type_convert<X>(type_convert<float>(x) + type_convert<float>(tmp));
this->template Set<X>(i, is_valid_element, result);
}
else
{
// Vector type
constexpr auto vector_size = scalar_type<remove_cvref_t<X>>::vector_size;
const vector_type<scalar_t, vector_size> a_vector{tmp};
const vector_type<scalar_t, vector_size> b_vector{x};
static_for<0, vector_size, 1>{}([&](auto idx) {
auto result = type_convert<scalar_t>(
type_convert<float>(a_vector.template AsType<scalar_t>()[idx]) +
type_convert<float>(b_vector.template AsType<scalar_t>()[idx]));
this->template Set<scalar_t>(i + idx, is_valid_element, result);
});
}
}
else
{
this->template Set<X>(i, is_valid_element, x + tmp);
}
} }
} }
......
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available // these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace ck { namespace ck {
// fp8 rounding modes // fp8 rounding modes
...@@ -244,5 +242,3 @@ __host__ __device__ Y cast_from_f8(X x) ...@@ -244,5 +242,3 @@ __host__ __device__ Y cast_from_f8(X x)
} }
} // namespace ck::utils } // namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
{
using value_t = std::false_type;
using type = Default;
};
template <class Default, template <class...> class Op, class... Args>
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
{
using value_t = std::true_type;
using type = Op<Args...>;
};
} // namespace detail
struct nonesuch
{
~nonesuch() = delete;
nonesuch(nonesuch const&) = delete;
void operator=(nonesuch const&) = delete;
};
template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
template <typename T>
using is_pack2_invocable_t = decltype(std::declval<T&>().is_pack2_invocable);
template <typename T>
using is_pack4_invocable_t = decltype(std::declval<T&>().is_pack4_invocable);
template <typename T>
using is_pack8_invocable_t = decltype(std::declval<T&>().is_pack8_invocable);
} // namespace ck
...@@ -150,28 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& ...@@ -150,28 +150,6 @@ __host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T&
return min(max(x, lowerbound), upperbound); return min(max(x, lowerbound), upperbound);
} }
// disallow implicit type casting
template <typename T>
__device__ T exp(T x);
// TODO: add f16 support using v_exp_f16
template <>
__device__ float exp<float>(float x)
{
return __expf(x);
}
template <>
__device__ double exp<double>(double x)
{
return exp(x);
}
static inline __host__ float exp(float x) { return std::expf(x); }
static inline __host__ double exp(double x) { return std::exp(x); }
// greatest common divisor, aka highest common factor // greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y) __host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{ {
......
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
...@@ -92,14 +93,96 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); }; ...@@ -92,14 +93,96 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); }; static inline __host__ double sqrt(double x) { return std::sqrt(x); };
static inline __host__ half_t tanh(half_t x) template <typename T>
inline __host__ T tanh(T x)
{ {
return static_cast<half_t>(std::tanh(static_cast<float>(x))); return ck::type_convert<T>(std::tanhf(ck::type_convert<float>(x)));
}; };
static inline __host__ float tanh(float x) { return std::tanh(x); }; template <>
inline __host__ float tanh<float>(float x)
{
return std::tanhf(x);
};
template <>
inline __host__ double tanh<double>(double x)
{
return std::tanh(x);
};
template <typename T>
inline __host__ T exp(T x)
{
return ck::type_convert<T>(std::expf(ck::type_convert<float>(x)));
}
template <>
inline __host__ float exp<float>(float x)
{
return std::expf(x);
}
static inline __host__ double tanh(double x) { return std::tanh(x); }; template <>
inline __host__ double exp<double>(double x)
{
return std::exp(x);
}
template <typename T>
inline __host__ T log(T x)
{
return ck::type_convert<T>(std::logf(ck::type_convert<float>(x)));
}
template <>
inline __host__ float log<float>(float x)
{
return std::logf(x);
}
template <>
inline __host__ double log<double>(double x)
{
return std::log(x);
}
template <typename T>
inline __host__ T pow(T x, T gamma)
{
return ck::type_convert<T>(
std::powf(ck::type_convert<float>(x), ck::type_convert<float>(gamma)));
}
template <>
inline __host__ float pow<float>(float x, float gamma)
{
return std::powf(x, gamma);
}
template <>
inline __host__ double pow<double>(double x, double gamma)
{
return std::pow(x, gamma);
}
template <typename T>
inline __host__ T expm1(T x)
{
return ck::type_convert<T>(std::expm1f(ck::type_convert<float>(x)));
}
template <>
inline __host__ float expm1<float>(float x)
{
return std::expm1f(x);
}
template <>
inline __host__ double expm1<double>(double x)
{
return std::expm1(x);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions // math functions for the HIP kernel, some are implemented by calling hip builtin functions
...@@ -181,14 +264,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); ...@@ -181,14 +264,107 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x);
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
static inline __device__ half_t tanh(half_t x) template <typename T>
inline __device__ T tanh(T x)
{
return ck::type_convert<T>(::tanhf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float tanh<float>(float x)
{ {
return static_cast<half_t>(::tanhf(static_cast<float>(x))); return ::tanhf(x);
}; };
static inline __device__ float tanh(float x) { return ::tanhf(x); }; template <>
inline __device__ double tanh<double>(double x)
{
return ::tanh(x);
};
template <typename T>
inline __device__ T exp(T x)
{
return ck::type_convert<T>(__expf(ck::type_convert<float>(x)));
};
template <>
inline __device__ half_t exp<half_t>(half_t x)
{
return hexp(x);
};
template <>
inline __device__ float exp<float>(float x)
{
return __expf(x);
};
static inline __device__ double tanh(double x) { return ::tanh(x); }; template <>
inline __device__ double exp<double>(double x)
{
return exp(x);
};
template <typename T>
inline __device__ T log(T x)
{
return ck::type_convert<T>(__logf(ck::type_convert<float>(x)));
};
template <>
inline __device__ half_t log<half_t>(half_t x)
{
return hlog(x);
};
template <>
inline __device__ float log<float>(float x)
{
return __logf(x);
};
template <>
inline __device__ double log<double>(double x)
{
return log(x);
};
template <typename T>
inline __device__ T pow(T x, T gamma)
{
return ck::type_convert<T>(powf(ck::type_convert<float>(x), ck::type_convert<float>(gamma)));
};
template <>
inline __device__ float pow<float>(float x, float gamma)
{
return powf(x, gamma);
};
template <>
inline __device__ double pow<double>(double x, double gamma)
{
return pow(x, gamma);
};
template <typename T>
inline __device__ T expm1(T x)
{
return ck::type_convert<T>(expm1f(ck::type_convert<float>(x)));
};
template <>
inline __device__ float expm1<float>(float x)
{
return expm1f(x);
};
template <>
inline __device__ double expm1<double>(double x)
{
return expm1(x);
};
} // namespace math } // namespace math
} // namespace ck } // namespace ck
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP #define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "ck/utility/math_v2.hpp"
namespace ck { namespace ck {
......
...@@ -177,6 +177,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -177,6 +177,8 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
} }
__host__ __device__ static constexpr bool IsStaticBuffer() { return true; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; }
__host__ __device__ static constexpr bool IsTuple() { return true; }
}; };
template <> template <>
......
...@@ -9,8 +9,10 @@ ...@@ -9,8 +9,10 @@
namespace ck { namespace ck {
// Convert X to Y // Convert X to Y, both X and Y are non-const data types.
template <typename Y, typename X> template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
...@@ -18,6 +20,19 @@ __host__ __device__ constexpr Y type_convert(X x) ...@@ -18,6 +20,19 @@ __host__ __device__ constexpr Y type_convert(X x)
return static_cast<Y>(x); return static_cast<Y>(x);
} }
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>;
using NonConstX = std::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
}
// convert bfp16 to fp32 // convert bfp16 to fp32
template <> template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
...@@ -80,7 +95,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -80,7 +95,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
return type_convert<bhalf_t>(x_fp32); return type_convert<bhalf_t>(x_fp32);
} }
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 // convert fp32 to fp8
template <> template <>
inline __host__ __device__ f8_t type_convert<f8_t, float>(float x) inline __host__ __device__ f8_t type_convert<f8_t, float>(float x)
...@@ -131,7 +145,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) ...@@ -131,7 +145,7 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion // convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x)); return type_convert<f8_t>(type_convert<float>(x));
#else #elif 0
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
...@@ -139,6 +153,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x) ...@@ -139,6 +153,8 @@ inline __host__ __device__ f8_t type_convert<f8_t, half_t>(half_t x)
return utils:: return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng); x, rng);
#else
return type_convert<f8_t>(type_convert<float>(x));
#endif #endif
} }
...@@ -149,14 +165,14 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) ...@@ -149,14 +165,14 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); return type_convert<half_t>(type_convert<float>(x));
#else #elif 0
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x); return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
#else
return type_convert<half_t>(type_convert<float>(x));
#endif #endif
} }
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 // convert fp32 to bf8
template <> template <>
inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x) inline __host__ __device__ bf8_t type_convert<bf8_t, float>(float x)
...@@ -206,8 +222,8 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x) ...@@ -206,8 +222,8 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion // convert to float and use native converion
return type_convert<f8_t>(type_convert<float>(x)); return type_convert<bf8_t>(type_convert<float>(x));
#else #elif 0
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr f8_rounding_mode rm = f8_rounding_mode::standard;
...@@ -215,6 +231,8 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x) ...@@ -215,6 +231,8 @@ inline __host__ __device__ bf8_t type_convert<bf8_t, half_t>(half_t x)
return utils:: return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng); x, rng);
#else
return type_convert<bf8_t>(type_convert<float>(x));
#endif #endif
} }
...@@ -225,12 +243,13 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x) ...@@ -225,12 +243,13 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); return type_convert<half_t>(type_convert<float>(x));
#else #elif 0
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x); return utils::cast_from_f8<bf8_t, half_t, negative_zero_nan>(x);
#else
return type_convert<half_t>(type_convert<float>(x));
#endif #endif
} }
#endif
// Declare a template function for bf16 conversion using RTN // Declare a template function for bf16 conversion using RTN
template <typename Y, typename X> template <typename Y, typename X>
...@@ -293,7 +312,6 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h ...@@ -293,7 +312,6 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(h
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y f8_convert_sr(X x); __host__ __device__ constexpr Y f8_convert_sr(X x);
#if defined CK_ENABLE_FP8
// convert fp32 to fp8 with stochastic rounding // convert fp32 to fp8 with stochastic rounding
template <> template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
...@@ -329,7 +347,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -329,7 +347,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x)); return f8_convert_sr<f8_t>(type_convert<float>(x));
#else #elif 0
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
...@@ -338,11 +356,11 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -338,11 +356,11 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
return utils:: return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng); x, rng);
#else
return f8_convert_sr<f8_t>(type_convert<float>(x));
#endif #endif
} }
#endif
#if defined CK_ENABLE_BF8
// convert fp32 to bf8 with stochastic rounding // convert fp32 to bf8 with stochastic rounding
template <> template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
...@@ -378,7 +396,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x) ...@@ -378,7 +396,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// convert to float and use native converion // convert to float and use native converion
return f8_convert_sr<f8_t>(type_convert<float>(x)); return f8_convert_sr<f8_t>(type_convert<float>(x));
#else #elif 0
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
...@@ -388,8 +406,9 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x) ...@@ -388,8 +406,9 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
return utils:: return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng); x, rng);
#else
return f8_convert_sr<bf8_t>(type_convert<float>(x));
#endif #endif
} }
#endif
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <type_traits>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
/**
* \brief Reference implementation for column to image.
*
* Input tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* Output tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
*
* \tparam NDimSpatial Number of spatial dimensions.
* \tparam ImageLayout Image Layout.
* \tparam InDataType Input Data Type.
* \tparam OutDataType Output Data Type.
*/
template <ck::index_t NDimSpatial,
typename ImageLayout,
typename InDataType,
typename OutDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceColumnToImage : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
public:
Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
: input_{input},
output_{output},
conv_strides_{conv_filter_strides},
conv_dilations_{conv_filter_dilations},
in_left_pads_{input_left_pads},
in_right_pads_{input_right_pads},
filter_spatial_lengths_{filter_spatial_lengths}
{
initOutputSpatialLengths();
}
const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_;
std::vector<index_t> conv_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_;
private:
void initOutputSpatialLengths()
{
constexpr auto input_offset_to_spatial = 3;
for(ck::index_t i = 0; i < NDimSpatial; ++i)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back(
(output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
in_right_pads_[i] - x_eff) /
conv_strides_[i] +
1);
}
}
};
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceColumnToImage::Argument;
float Run(const Argument& arg)
{
if(!(arg.output_.GetNumOfDimension() == NDimSpatial + 3 &&
arg.input_.GetNumOfDimension() == 2))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
const index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1)
{
const index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto n) {
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Wo + wo;
index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c)
{
if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
{
float v_in = ck::type_convert<float>(arg.input_(row, column));
float v_out = ck::type_convert<float>(arg.output_(0, n, c, wi));
arg.output_(0, n, c, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
}
}
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NDimSpatial == 2)
{
const index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto n) {
for(index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{
auto hi =
static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c)
{
if(hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.output_.GetLengths()[3] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.output_.GetLengths()[4])
{
float v_in =
ck::type_convert<float>(arg.input_(row, column));
float v_out = ck::type_convert<float>(
arg.output_(0, n, c, hi, wi));
arg.output_(0, n, c, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
}
}
}
}
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
return 0;
}
else if constexpr(NDimSpatial == 3)
{
const index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o)
{
for(index_t ho = 0; ho < Ho; ++ho)
{
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{
auto di =
static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{
auto hi =
static_cast<ck::long_index_t>(ho *
arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y *
arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
{
auto wi =
static_cast<ck::long_index_t>(
wo * arg.conv_strides_[2]) +
static_cast<ck::long_index_t>(
x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
ck::type_convert<std::size_t>(di) <
arg.output_.GetLengths()[3] &&
hi >= 0 &&
ck::type_convert<std::size_t>(hi) <
arg.output_.GetLengths()[4] &&
wi >= 0 &&
ck::type_convert<std::size_t>(wi) <
arg.output_.GetLengths()[5])
{
float v_in = ck::type_convert<float>(
arg.input_(row, column));
float v_out = ck::type_convert<float>(
arg.output_(0, n, c, di, hi, wi));
arg.output_(0, n, c, di, hi, wi) =
ck::type_convert<OutDataType>(v_in + v_out);
}
column++;
}
}
}
}
}
}
}
};
make_ParallelTensorFunctor(func, N)(std::thread::hardware_concurrency());
return 0;
}
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /*stream_config*/ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
using namespace tensor_layout::convolution;
if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
std::is_same_v<ImageLayout, GNDHWC>))
{
return false;
}
if constexpr(!(NDimSpatial >= 1 && NDimSpatial <= 3))
{
return false;
}
return true;
}
bool IsSupportedArgument(const Argument& arg)
{
const ck::index_t G = arg.output_.GetLengths()[0];
const ck::index_t N = arg.output_.GetLengths()[1];
const ck::index_t C = arg.output_.GetLengths()[2];
const index_t NDoHoWo =
N * ck::accumulate_n<index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX =
C * ck::accumulate_n<index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(NDoHoWo) &&
arg.input_.GetLengths()[1] == static_cast<std::size_t>(CZYX)))
{
return false;
}
if(G != 1)
{
return false;
}
return true;
}
bool IsSupportedArgument(const device::BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads)
{
return Argument{input,
output,
filter_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceColumnToImage"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
...@@ -25,6 +25,8 @@ template <ck::index_t NDimSpatial, ...@@ -25,6 +25,8 @@ template <ck::index_t NDimSpatial,
typename InElementwiseOperation, typename InElementwiseOperation,
typename WeiElementwiseOperation, typename WeiElementwiseOperation,
typename OutElementwiseOperation, typename OutElementwiseOperation,
typename ComputeTypeA = OutDataType,
typename ComputeTypeB = InDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false> typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceConvBwdWeight : public device::BaseOperator struct ReferenceConvBwdWeight : public device::BaseOperator
{ {
...@@ -98,8 +100,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -98,8 +100,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
{ {
float v_out; ComputeTypeA v_out;
float v_in; ComputeTypeB v_in;
arg.out_element_op_( arg.out_element_op_(
v_out, ck::type_convert<float>(arg.output_(g, n, k, wo))); v_out, ck::type_convert<float>(arg.output_(g, n, k, wo)));
...@@ -107,7 +109,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -107,7 +109,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, wi)));
v_acc += v_out * v_in; v_acc += type_convert<float>(v_out) * type_convert<float>(v_in);
} }
} }
} }
...@@ -158,8 +160,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -158,8 +160,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
wi >= 0 && wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[4])
{ {
float v_out; ComputeTypeA v_out;
float v_in; ComputeTypeB v_in;
arg.out_element_op_( arg.out_element_op_(
v_out, v_out,
...@@ -168,7 +170,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -168,7 +170,7 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
arg.in_element_op_( arg.in_element_op_(
v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi))); v_in, ck::type_convert<float>(arg.input_(g, n, c, hi, wi)));
v_acc += v_out * v_in; v_acc += type_convert<float>(v_out) * type_convert<float>(v_in);
} }
} }
} }
...@@ -226,8 +228,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -226,8 +228,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
ck::type_convert<std::size_t>(wi) < ck::type_convert<std::size_t>(wi) <
arg.input_.GetLengths()[5]) arg.input_.GetLengths()[5])
{ {
float v_out; ComputeTypeA v_out;
float v_in; ComputeTypeB v_in;
arg.out_element_op_(v_out, arg.out_element_op_(v_out,
ck::type_convert<float>( ck::type_convert<float>(
...@@ -237,7 +239,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -237,7 +239,8 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
ck::type_convert<float>( ck::type_convert<float>(
arg.input_(g, n, c, di, hi, wi))); arg.input_(g, n, c, di, hi, wi)));
v_acc += v_out * v_in; v_acc +=
type_convert<float>(v_out) * type_convert<float>(v_in);
} }
} }
} }
......
...@@ -128,11 +128,9 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -128,11 +128,9 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
} }
float v_out; OutDataType v_out;
arg.out_element_op_(v_out, ck::type_convert<OutDataType>(v_acc));
arg.out_element_op_(v_out, v_acc); arg.output_(g, n, k, wo) = v_out;
arg.output_(g, n, k, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(func, make_ParallelTensorFunctor(func,
...@@ -184,11 +182,9 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -184,11 +182,9 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
} }
float v_out; OutDataType v_out;
arg.out_element_op_(v_out, ck::type_convert<OutDataType>(v_acc));
arg.out_element_op_(v_out, v_acc); arg.output_(g, n, k, ho, wo) = v_out;
arg.output_(g, n, k, ho, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(func, make_ParallelTensorFunctor(func,
...@@ -253,11 +249,9 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -253,11 +249,9 @@ struct ReferenceConvFwd : public device::BaseOperator
} }
} }
float v_out; OutDataType v_out;
arg.out_element_op_(v_out, ck::type_convert<OutDataType>(v_acc));
arg.out_element_op_(v_out, v_acc); arg.output_(g, n, k, d_o, ho, wo) = v_out;
arg.output_(g, n, k, d_o, ho, wo) = ck::type_convert<OutDataType>(v_out);
}; };
make_ParallelTensorFunctor(func, make_ParallelTensorFunctor(func,
......
...@@ -21,7 +21,8 @@ template <typename ADataType, ...@@ -21,7 +21,8 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename ComputType = ADataType> typename ComputeTypeA = ADataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceGemm : public device::BaseOperator struct ReferenceGemm : public device::BaseOperator
{ {
// Argument // Argument
...@@ -65,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -65,8 +66,8 @@ struct ReferenceGemm : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
ComputType v_a; ComputeTypeA v_a;
ComputType v_b; ComputeTypeB v_b;
// use PassThrough instead of ConvertBF16RTN for reference calculation // use PassThrough instead of ConvertBF16RTN for reference calculation
if constexpr(is_same_v<AElementwiseOperation, if constexpr(is_same_v<AElementwiseOperation,
......
...@@ -20,8 +20,9 @@ template <typename XDataType, ...@@ -20,8 +20,9 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename SaveMeanInvStdDataType,
typename AccElementwiseOperation> typename ComputeDataType,
typename YElementwiseOperation>
struct ReferenceGroupnorm : public device::BaseOperator struct ReferenceGroupnorm : public device::BaseOperator
{ {
// x = [N, H, W, G, C] // x = [N, H, W, G, C]
...@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -35,14 +36,18 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma, const Tensor<GammaDataType>& gamma,
const Tensor<BetaDataType>& beta, const Tensor<BetaDataType>& beta,
Tensor<YDataType>& y, Tensor<YDataType>& y,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean,
Tensor<SaveMeanInvStdDataType>& save_inv_std,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
AccDataType epsilon) ComputeDataType epsilon)
: x_(x), : x_(x),
gamma_(gamma), gamma_(gamma),
beta_(beta), beta_(beta),
y_(y), y_(y),
acc_elementwise_op_(acc_elementwise_op), save_mean_(save_mean),
save_inv_std_(save_inv_std),
y_elementwise_op_(y_elementwise_op),
lengths_(lengths), lengths_(lengths),
epsilon_(epsilon) epsilon_(epsilon)
{ {
...@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -52,9 +57,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<XDataType> gamma_; const Tensor<XDataType> gamma_;
const Tensor<XDataType> beta_; const Tensor<XDataType> beta_;
Tensor<YDataType>& y_; Tensor<YDataType>& y_;
AccElementwiseOperation acc_elementwise_op_; Tensor<SaveMeanInvStdDataType>& save_mean_;
Tensor<SaveMeanInvStdDataType>& save_inv_std_;
YElementwiseOperation y_elementwise_op_;
std::vector<index_t> lengths_; std::vector<index_t> lengths_;
AccDataType epsilon_; ComputeDataType epsilon_;
}; };
// Invoker // Invoker
...@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -68,8 +75,8 @@ struct ReferenceGroupnorm : public device::BaseOperator
int G = arg.lengths_[3]; int G = arg.lengths_[3];
int C = arg.lengths_[4]; int C = arg.lengths_[4];
Tensor<AccDataType> mean({N, G}); Tensor<ComputeDataType> mean({N, G});
Tensor<AccDataType> var({N, G}); Tensor<ComputeDataType> var({N, G});
// Compute mean & var in [H, W, C] by Welford Algorithm // Compute mean & var in [H, W, C] by Welford Algorithm
// TODO - parallel for each HWC // TODO - parallel for each HWC
...@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -78,9 +85,9 @@ struct ReferenceGroupnorm : public device::BaseOperator
{ {
for(int g = 0; g < G; ++g) for(int g = 0; g < G; ++g)
{ {
AccDataType mean_val = type_convert<AccDataType>(0.0f); ComputeDataType mean_val = type_convert<ComputeDataType>(0.0f);
AccDataType var_val = type_convert<AccDataType>(0.0f); ComputeDataType var_val = type_convert<ComputeDataType>(0.0f);
int32_t curr_count = 0; int32_t curr_count = 0;
for(int h = 0; h < H; ++h) for(int h = 0; h < H; ++h)
{ {
...@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -89,10 +96,11 @@ struct ReferenceGroupnorm : public device::BaseOperator
for(int c = 0; c < C; ++c) for(int c = 0; c < C; ++c)
{ {
curr_count++; curr_count++;
AccDataType x = type_convert<AccDataType>(arg.x_(n, h, w, g, c)); ComputeDataType x =
AccDataType delta = x - mean_val; type_convert<ComputeDataType>(arg.x_(n, h, w, g, c));
ComputeDataType delta = x - mean_val;
mean_val += delta / curr_count; mean_val += delta / curr_count;
AccDataType delta2 = x - mean_val; ComputeDataType delta2 = x - mean_val;
var_val += delta * delta2; var_val += delta * delta2;
} }
} }
...@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -100,6 +108,12 @@ struct ReferenceGroupnorm : public device::BaseOperator
mean(n, g) = mean_val; mean(n, g) = mean_val;
var(n, g) = var_val / curr_count; var(n, g) = var_val / curr_count;
arg.save_mean_(n, g) = ck::type_convert<SaveMeanInvStdDataType>(mean(n, g));
ComputeDataType divisor =
static_cast<ComputeDataType>(1) / ck::math::sqrt(var(n, g) + arg.epsilon_);
arg.save_inv_std_(n, g) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
} }
} }
...@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -114,15 +128,19 @@ struct ReferenceGroupnorm : public device::BaseOperator
{ {
for(int c = 0; c < C; ++c) for(int c = 0; c < C; ++c)
{ {
AccDataType x = type_convert<AccDataType>(arg.x_(n, h, w, g, c)); ComputeDataType x =
AccDataType gamma = type_convert<AccDataType>(arg.gamma_(g, c)); type_convert<ComputeDataType>(arg.x_(n, h, w, g, c));
AccDataType beta = type_convert<AccDataType>(arg.beta_(g, c)); ComputeDataType gamma =
AccDataType mean_val = type_convert<AccDataType>(mean(n, g)); type_convert<ComputeDataType>(arg.gamma_(g, c));
AccDataType var_val = type_convert<AccDataType>(var(n, g)); ComputeDataType beta =
AccDataType y = gamma * (x - mean_val) / type_convert<ComputeDataType>(arg.beta_(g, c));
ck::math::sqrt(arg.epsilon_ + var_val) + ComputeDataType mean_val =
beta; type_convert<ComputeDataType>(mean(n, g));
arg.acc_elementwise_op_(y, y); ComputeDataType var_val = type_convert<ComputeDataType>(var(n, g));
ComputeDataType y = gamma * (x - mean_val) /
ck::math::sqrt(arg.epsilon_ + var_val) +
beta;
arg.y_elementwise_op_(y, y);
arg.y_(n, h, w, g, c) = type_convert<YDataType>(y); arg.y_(n, h, w, g, c) = type_convert<YDataType>(y);
} }
} }
...@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator ...@@ -159,11 +177,14 @@ struct ReferenceGroupnorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma, const Tensor<GammaDataType>& gamma,
const Tensor<BetaDataType>& beta, const Tensor<BetaDataType>& beta,
Tensor<YDataType>& y, Tensor<YDataType>& y,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean,
Tensor<SaveMeanInvStdDataType>& save_inv_std,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
AccDataType epsilon) ComputeDataType epsilon)
{ {
return Argument{x, gamma, beta, y, acc_elementwise_op, lengths, epsilon}; return Argument{
x, gamma, beta, y, save_mean, save_inv_std, y_elementwise_op, lengths, epsilon};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -18,16 +18,18 @@ namespace host { ...@@ -18,16 +18,18 @@ namespace host {
/** /**
* \brief Reference implementation for image to column. * \brief Reference implementation for image to column.
* *
* Tensor descriptor has [G, N, C, Di, Hi, Wi] data layout. * Input tensor descriptor has [G, N, C, Di, Hi, Wi] data layout.
* G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C]. * G must be equal to 1. Memory layout is [G, N, Di, Hi, Wi, C].
* Output tensor descriptor has [N * Do * Ho * Wo, Z * Y * X * C] data layout.
* Memory layout is the same.
* *
* \tparam NDimSpatial Number of spatial dimensions. * \tparam NDimSpatial Number of spatial dimensions.
* \tparam InputLayout Input Layout. * \tparam ImageLayout Image Layout.
* \tparam InDataType Input Data Type. * \tparam InDataType Input Data Type.
* \tparam OutDataType Output Data Type. * \tparam OutDataType Output Data Type.
*/ */
template <ck::index_t NDimSpatial, template <ck::index_t NDimSpatial,
typename InputLayout, typename ImageLayout,
typename InDataType, typename InDataType,
typename OutDataType, typename OutDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false> typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
...@@ -240,8 +242,8 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -240,8 +242,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
{ {
using namespace tensor_layout::convolution; using namespace tensor_layout::convolution;
if constexpr(!(std::is_same_v<InputLayout, GNWC> || std::is_same_v<InputLayout, GNHWC> || if constexpr(!(std::is_same_v<ImageLayout, GNWC> || std::is_same_v<ImageLayout, GNHWC> ||
std::is_same_v<InputLayout, GNDHWC>)) std::is_same_v<ImageLayout, GNDHWC>))
{ {
return false; return false;
} }
......
...@@ -20,8 +20,9 @@ template <typename XDataType, ...@@ -20,8 +20,9 @@ template <typename XDataType,
typename GammaDataType, typename GammaDataType,
typename BetaDataType, typename BetaDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename SaveMeanInvStdDataType,
typename AccElementwiseOperation, typename ComputeDataType,
typename YElementwiseOperation,
index_t Rank, index_t Rank,
index_t NumReduceDim> index_t NumReduceDim>
struct ReferenceLayernorm : public device::BaseOperator struct ReferenceLayernorm : public device::BaseOperator
...@@ -36,15 +37,19 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -36,15 +37,19 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n, const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n, Tensor<YDataType>& y_m_n,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean_m,
Tensor<SaveMeanInvStdDataType>& save_inv_std_m,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon) ComputeDataType epsilon)
: x_m_n_(x_m_n), : x_m_n_(x_m_n),
gamma_n_(gamma_n), gamma_n_(gamma_n),
beta_n_(beta_n), beta_n_(beta_n),
y_m_n_(y_m_n), y_m_n_(y_m_n),
acc_elementwise_op_(acc_elementwise_op), save_mean_m_(save_mean_m),
save_inv_std_m_(save_inv_std_m),
y_elementwise_op_(y_elementwise_op),
lengths_(lengths), lengths_(lengths),
reduceDims_(reduceDims), reduceDims_(reduceDims),
epsilon_(epsilon) epsilon_(epsilon)
...@@ -55,10 +60,12 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -55,10 +60,12 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<XDataType> gamma_n_; const Tensor<XDataType> gamma_n_;
const Tensor<XDataType> beta_n_; const Tensor<XDataType> beta_n_;
Tensor<YDataType>& y_m_n_; Tensor<YDataType>& y_m_n_;
AccElementwiseOperation acc_elementwise_op_; Tensor<SaveMeanInvStdDataType>& save_mean_m_;
Tensor<SaveMeanInvStdDataType>& save_inv_std_m_;
YElementwiseOperation y_elementwise_op_;
std::vector<index_t> lengths_; std::vector<index_t> lengths_;
std::vector<index_t> reduceDims_; std::vector<index_t> reduceDims_;
AccDataType epsilon_; ComputeDataType epsilon_;
}; };
// Invoker // Invoker
...@@ -69,8 +76,8 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -69,8 +76,8 @@ struct ReferenceLayernorm : public device::BaseOperator
int M = arg.lengths_[0]; int M = arg.lengths_[0];
int N = arg.lengths_[1]; int N = arg.lengths_[1];
Tensor<AccDataType> mean({M}); Tensor<ComputeDataType> mean({M});
Tensor<AccDataType> var({M}); Tensor<ComputeDataType> var({M});
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
...@@ -79,7 +86,7 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -79,7 +86,7 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
mean(m) += x_val; mean(m) += x_val;
var(m) += x_val * x_val; var(m) += x_val * x_val;
} }
...@@ -90,17 +97,21 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -90,17 +97,21 @@ struct ReferenceLayernorm : public device::BaseOperator
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
AccDataType divisor = ComputeDataType divisor =
static_cast<AccDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_); static_cast<ComputeDataType>(1) / ck::math::sqrt(var(m) + arg.epsilon_);
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
auto x_val = ck::type_convert<AccDataType>(arg.x_m_n_(m, n)); auto x_val = ck::type_convert<ComputeDataType>(arg.x_m_n_(m, n));
auto y_val = (x_val - mean(m)) * divisor; auto gamma_val = ck::type_convert<ComputeDataType>(arg.gamma_n_(n));
y_val = (y_val * arg.gamma_n_(n)) + arg.beta_n_(n); auto beta_val = ck::type_convert<ComputeDataType>(arg.beta_n_(n));
arg.acc_elementwise_op_(y_val, y_val); auto y_val = (x_val - mean(m)) * divisor;
y_val = (y_val * gamma_val) + beta_val;
arg.y_elementwise_op_(y_val, y_val);
arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val); arg.y_m_n_(m, n) = ck::type_convert<YDataType>(y_val);
} }
arg.save_mean_m_(m) = ck::type_convert<SaveMeanInvStdDataType>(mean(m));
arg.save_inv_std_m_(m) = ck::type_convert<SaveMeanInvStdDataType>(divisor);
} }
return 0; return 0;
...@@ -140,13 +151,23 @@ struct ReferenceLayernorm : public device::BaseOperator ...@@ -140,13 +151,23 @@ struct ReferenceLayernorm : public device::BaseOperator
const Tensor<GammaDataType>& gamma_n, const Tensor<GammaDataType>& gamma_n,
const Tensor<BetaDataType>& beta_n, const Tensor<BetaDataType>& beta_n,
Tensor<YDataType>& y_m_n, Tensor<YDataType>& y_m_n,
AccElementwiseOperation acc_elementwise_op, Tensor<SaveMeanInvStdDataType>& save_mean_m,
Tensor<SaveMeanInvStdDataType>& save_inv_std_m,
YElementwiseOperation y_elementwise_op,
const std::vector<index_t> lengths, const std::vector<index_t> lengths,
const std::vector<index_t> reduceDims, const std::vector<index_t> reduceDims,
AccDataType epsilon) ComputeDataType epsilon)
{ {
return Argument{ return Argument{x_m_n,
x_m_n, gamma_n, beta_n, y_m_n, acc_elementwise_op, lengths, reduceDims, epsilon}; gamma_n,
beta_n,
y_m_n,
save_mean_m,
save_inv_std_m,
y_elementwise_op,
lengths,
reduceDims,
epsilon};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
......
...@@ -20,12 +20,8 @@ using F16 = ck::half_t; ...@@ -20,12 +20,8 @@ using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
#if defined CK_ENABLE_FP8 using F8 = ck::f8_t;
using F8 = ck::f8_t; using BF8 = ck::bf8_t;
#endif
#if defined CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using Empty_Tuple = ck::Tuple<>; using Empty_Tuple = ck::Tuple<>;
......
...@@ -16,26 +16,26 @@ namespace tensor_operation { ...@@ -16,26 +16,26 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// FP16 #ifdef CK_ENABLE_FP16
void add_device_batchnorm_backward_rank_4_3_f16_instances( void add_device_batchnorm_backward_rank_4_3_f16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&); DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
#endif
// FP32 #ifdef CK_ENABLE_FP32
void add_device_batchnorm_backward_rank_4_3_f32_instances( void add_device_batchnorm_backward_rank_4_3_f32_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&); DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
#endif
// BF16 #ifdef CK_ENABLE_BF16
void add_device_batchnorm_backward_rank_4_3_bf16_instances( void add_device_batchnorm_backward_rank_4_3_bf16_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&); DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
#endif
// FP64 #ifdef CK_ENABLE_FP64
void add_device_batchnorm_backward_rank_4_3_f64_instances( void add_device_batchnorm_backward_rank_4_3_f64_instances(
std::vector<std::unique_ptr< std::vector<std::unique_ptr<
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&); DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
#endif
template <typename XDataType, template <typename XDataType,
typename DxDataType, typename DxDataType,
typename DyDataType, typename DyDataType,
...@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory< ...@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> && if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> && is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> && is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
...@@ -83,37 +83,43 @@ struct DeviceOperationInstanceFactory< ...@@ -83,37 +83,43 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs); add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> && #endif
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> && #ifdef CK_ENABLE_FP32
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> && if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
is_same_v<MeanVarDataType, F32>) is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs); add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> && #endif
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> && #ifdef CK_ENABLE_BF16
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> && if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
is_same_v<MeanVarDataType, F32>) is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs); add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> && #endif
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> && #ifdef CK_ENABLE_FP64
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> && if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
is_same_v<MeanVarDataType, F64>) is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs); add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -16,26 +16,26 @@ namespace tensor_operation { ...@@ -16,26 +16,26 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// FP16 #ifdef CK_ENABLE_FP16
void add_device_batchnorm_forward_rank_4_3_f16_instances( void add_device_batchnorm_forward_rank_4_3_f16_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&); std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&);
#endif
// FP32 #ifdef CK_ENABLE_FP32
void add_device_batchnorm_forward_rank_4_3_f32_instances( void add_device_batchnorm_forward_rank_4_3_f32_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&); std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
#endif
// BF16 #ifdef CK_ENABLE_BF16
void add_device_batchnorm_forward_rank_4_3_bf16_instances( void add_device_batchnorm_forward_rank_4_3_bf16_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&); std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
#endif
// FP64 #ifdef CK_ENABLE_FP64
void add_device_batchnorm_forward_rank_4_3_f64_instances( void add_device_batchnorm_forward_rank_4_3_f64_instances(
std::vector< std::vector<
std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&); std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
#endif
template <typename XDataType, template <typename XDataType,
typename YDataType, typename YDataType,
typename AccDataType, typename AccDataType,
...@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory< ...@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> && if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F16> && is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F16> &&
is_same_v<BiasDataType, F16> && is_same_v<MeanVarDataType, F32>) is_same_v<BiasDataType, F16> && is_same_v<MeanVarDataType, F32>)
...@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory< ...@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<
add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs); add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> && #endif
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> && #ifdef CK_ENABLE_FP32
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>) if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs); add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> && #endif
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> && #ifdef CK_ENABLE_BF16
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>) if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs); add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> && #endif
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> && #ifdef CK_ENABLE_FP64
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>) if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
{ {
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>) if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
{ {
add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs); add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
...@@ -16,38 +16,38 @@ namespace tensor_operation { ...@@ -16,38 +16,38 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
// FP16 #ifdef CK_ENABLE_FP16
void add_device_batchnorm_infer_rank_4_f16_instances( void add_device_batchnorm_infer_rank_4_f16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise< std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F16, F32, F32, F16, F16>, ck::Tuple<F16, F32, F32, F16, F16>,
ck::Tuple<F16>, ck::Tuple<F16>,
ck::tensor_operation::element_wise::NormalizeInInfer, ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&); 4>>>&);
#endif
// FP32 #ifdef CK_ENABLE_FP32
void add_device_batchnorm_infer_rank_4_f32_instances( void add_device_batchnorm_infer_rank_4_f32_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise< std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F32, F32, F32, F32, F32>, ck::Tuple<F32, F32, F32, F32, F32>,
ck::Tuple<F32>, ck::Tuple<F32>,
ck::tensor_operation::element_wise::NormalizeInInfer, ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&); 4>>>&);
#endif
// BF16 #ifdef CK_ENABLE_BF16
void add_device_batchnorm_infer_rank_4_bf16_instances( void add_device_batchnorm_infer_rank_4_bf16_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise< std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<BF16, F32, F32, BF16, BF16>, ck::Tuple<BF16, F32, F32, BF16, BF16>,
ck::Tuple<BF16>, ck::Tuple<BF16>,
ck::tensor_operation::element_wise::NormalizeInInfer, ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&); 4>>>&);
#endif
// FP64 #ifdef CK_ENABLE_FP64
void add_device_batchnorm_infer_rank_4_f64_instances( void add_device_batchnorm_infer_rank_4_f64_instances(
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise< std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
ck::Tuple<F64, F64, F64, F64, F64>, ck::Tuple<F64, F64, F64, F64, F64>,
ck::Tuple<F64>, ck::Tuple<F64>,
ck::tensor_operation::element_wise::NormalizeInInfer, ck::tensor_operation::element_wise::NormalizeInInfer,
4>>>&); 4>>>&);
#endif
template <typename XDataType, template <typename XDataType,
typename YDataType, typename YDataType,
typename ScaleDataType, typename ScaleDataType,
...@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen ...@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
static auto GetInstances() static auto GetInstances()
{ {
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> && if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
is_same_v<ScaleDataType, F16> && is_same_v<BiasDataType, F16> && is_same_v<ScaleDataType, F16> && is_same_v<BiasDataType, F16> &&
is_same_v<MeanVarDataType, F32>) is_same_v<MeanVarDataType, F32>)
...@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen ...@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
add_device_batchnorm_infer_rank_4_f16_instances(op_ptrs); add_device_batchnorm_infer_rank_4_f16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> && #endif
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> && #ifdef CK_ENABLE_FP32
is_same_v<MeanVarDataType, F32>) if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4) if constexpr(Rank == 4)
{ {
add_device_batchnorm_infer_rank_4_f32_instances(op_ptrs); add_device_batchnorm_infer_rank_4_f32_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> && #endif
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> && #ifdef CK_ENABLE_BF16
is_same_v<MeanVarDataType, F32>) if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
is_same_v<MeanVarDataType, F32>)
{ {
if constexpr(Rank == 4) if constexpr(Rank == 4)
{ {
add_device_batchnorm_infer_rank_4_bf16_instances(op_ptrs); add_device_batchnorm_infer_rank_4_bf16_instances(op_ptrs);
} }
} }
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> && #endif
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> && #ifdef CK_ENABLE_FP64
is_same_v<MeanVarDataType, F64>) if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
is_same_v<MeanVarDataType, F64>)
{ {
if constexpr(Rank == 4) if constexpr(Rank == 4)
{ {
add_device_batchnorm_infer_rank_4_f64_instances(op_ptrs); add_device_batchnorm_infer_rank_4_f64_instances(op_ptrs);
} }
} }
#endif
return op_ptrs; return op_ptrs;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment