Commit d43cd4ad authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Introduce gemm_softmax_gemm to codegen.

parent 3528a523
...@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y) ...@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
{ {
if constexpr(predicate) if constexpr(predicate)
{ {
return std::forward<X>(x); return ck::forward<X>(x);
} }
else else
{ {
return std::forward<Y>(y); return ck::forward<Y>(y);
} }
} }
......
...@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>> ...@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>>
template <typename F, typename X> template <typename F, typename X>
__host__ __device__ constexpr auto operator()(F&& f, X&& x) const __host__ __device__ constexpr auto operator()(F&& f, X&& x) const
{ {
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})...); return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...);
} }
}; };
...@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>> ...@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
template <typename F, typename X, typename Y> template <typename F, typename X, typename Y>
__host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const __host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const
{ {
return std::forward<F>(f)(std::forward<X>(x).At(Number<Is>{})..., return ck::forward<F>(f)(ck::forward<X>(x).At(Number<Is>{})...,
std::forward<Y>(y).At(Number<Js>{})...); ck::forward<Y>(y).At(Number<Js>{})...);
} }
}; };
...@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x) ...@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x)
{ {
using X_ = remove_reference_t<X>; using X_ = remove_reference_t<X>;
return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type>{}( return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x)); ck::forward<F>(f), ck::forward<X>(x));
} }
// TODO: properly implement unpack that takes any number of containers // TODO: properly implement unpack that takes any number of containers
...@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y) ...@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
using Y_ = remove_reference_t<Y>; using Y_ = remove_reference_t<Y>;
return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type, return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::Size(), 1>::type,
typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}( typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}(
std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y)); ck::forward<F>(f), ck::forward<X>(x), ck::forward<Y>(y));
} }
} // namespace ck } // namespace ck
......
...@@ -9,14 +9,14 @@ namespace detail { ...@@ -9,14 +9,14 @@ namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args> template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector struct detector
{ {
using value_t = std::false_type; using value_t = ck::false_type;
using type = Default; using type = Default;
}; };
template <class Default, template <class...> class Op, class... Args> template <class Default, template <class...> class Op, class... Args>
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...> struct detector<Default, ck::void_t<Op<Args...>>, Op, Args...>
{ {
using value_t = std::true_type; using value_t = ck::true_type;
using type = Op<Args...>; using type = Op<Args...>;
}; };
} // namespace detail } // namespace detail
...@@ -32,12 +32,12 @@ template <template <class...> class Op, class... Args> ...@@ -32,12 +32,12 @@ template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t; using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
template <typename T> template <typename T>
using is_pack2_invocable_t = decltype(std::declval<T&>().is_pack2_invocable); using is_pack2_invocable_t = decltype(ck::declval<T&>().is_pack2_invocable);
template <typename T> template <typename T>
using is_pack4_invocable_t = decltype(std::declval<T&>().is_pack4_invocable); using is_pack4_invocable_t = decltype(ck::declval<T&>().is_pack4_invocable);
template <typename T> template <typename T>
using is_pack8_invocable_t = decltype(std::declval<T&>().is_pack8_invocable); using is_pack8_invocable_t = decltype(ck::declval<T&>().is_pack8_invocable);
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ostream>
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <ostream>
#endif
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
...@@ -26,6 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler() ...@@ -26,6 +28,7 @@ constexpr LoopScheduler make_default_loop_scheduler()
} // namespace ck } // namespace ck
#ifndef __HIPCC_RTC__
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
{ {
switch(s) switch(s)
...@@ -36,3 +39,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s) ...@@ -36,3 +39,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
} }
return os; return os;
} }
#endif
...@@ -30,7 +30,7 @@ struct MagicDivision ...@@ -30,7 +30,7 @@ struct MagicDivision
// WARNING: magic division is only applicable for division inside this range. // WARNING: magic division is only applicable for division inside this range.
// You should use the return value of CalculateMagicNumbers, if division is not inside this // You should use the return value of CalculateMagicNumbers, if division is not inside this
// range. The "else" logic below is to quiet down run-time error. // range. The "else" logic below is to quiet down run-time error.
if(divisor >= 1 && divisor <= INT32_MAX) if(divisor >= 1 && divisor <= ck::NumericLimits<int32_t>::Max())
{ {
uint32_t shift = 0; uint32_t shift = 0;
for(shift = 0; shift < 32; ++shift) for(shift = 0; shift < 32; ++shift)
......
...@@ -18,6 +18,7 @@ namespace math { ...@@ -18,6 +18,7 @@ namespace math {
extern "C" __device__ float __ocml_native_recip_f32(float); extern "C" __device__ float __ocml_native_recip_f32(float);
#endif #endif
#ifndef __HIPCC_RTC__
// math functions for the host, some are implemented by calling C++ std functions // math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ float abs(float x) { return std::abs(x); };
...@@ -457,6 +458,7 @@ inline __host__ double expm1<double>(double x) ...@@ -457,6 +458,7 @@ inline __host__ double expm1<double>(double x)
{ {
return std::expm1(x); return std::expm1(x);
} }
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions // math functions for the HIP kernel, some are implemented by calling hip builtin functions
...@@ -920,5 +922,23 @@ inline __device__ double expm1<double>(double x) ...@@ -920,5 +922,23 @@ inline __device__ double expm1<double>(double x)
return expm1(x); return expm1(x);
}; };
template <typename T>
inline __device__ T cos(T x)
{
return ck::type_convert<T>(cosf(ck::type_convert<float>(x)));
};
template <>
inline __device__ float cos<float>(float x)
{
return cosf(x);
};
template <>
inline __device__ double cos<double>(double x)
{
return cos(x);
};
} // namespace math } // namespace math
} // namespace ck } // namespace ck
...@@ -7,7 +7,7 @@ namespace ck { ...@@ -7,7 +7,7 @@ namespace ck {
// Pseudo random number generator // Pseudo random number generator
// version for fp32 // version for fp32
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false> template <typename T, uint32_t seed_t, ck::enable_if_t<ck::is_same<float, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{ {
uint32_t x = *(reinterpret_cast<uint32_t*>(&val)); uint32_t x = *(reinterpret_cast<uint32_t*>(&val));
...@@ -23,7 +23,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -23,7 +23,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
} }
// version for fp16 // version for fp16
template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<half_t, T>{}, bool> = false> template <typename T, uint32_t seed_t, ck::enable_if_t<ck::is_same<half_t, T>{}, bool> = false>
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t)
{ {
uint16_t x = *(reinterpret_cast<uint16_t*>(&val)); uint16_t x = *(reinterpret_cast<uint16_t*>(&val));
...@@ -40,12 +40,18 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = ...@@ -40,12 +40,18 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
// return 0 if data is not fp16 or fp32 // return 0 if data is not fp16 or fp32
template <typename T, template <typename T,
uint32_t seed_t, uint32_t seed_t,
std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<half_t, T>{}), bool> = false> ck::enable_if_t<!(ck::is_same<float, T>{} || ck::is_same<half_t, T>{}), bool> = false>
__host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t) __host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t)
{ {
#ifdef __HIPCC_RTC__
static_cast<void>(id);
static_cast<void>(val);
static_cast<void>(seed);
#else
std::ignore = id; std::ignore = id;
std::ignore = val; std::ignore = val;
std::ignore = seed; std::ignore = seed;
#endif
return 0; return 0;
} }
......
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
#pragma once #pragma once
#ifndef __HIPCC_RTC__
#include <ostream> #include <ostream>
#endif
#include "ck/utility/integral_constant.hpp" #include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
...@@ -900,6 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type; ...@@ -900,6 +902,7 @@ using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck } // namespace ck
#ifndef __HIPCC_RTC__
template <ck::index_t... Is> template <ck::index_t... Is>
std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>) std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
{ {
...@@ -910,3 +913,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>) ...@@ -910,3 +913,4 @@ std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
os << S::At(S::Size() - ck::Number<1>{}).value << "}"; os << S::At(S::Size() - ck::Number<1>{}).value << "}";
return os; return os;
} }
#endif
...@@ -32,7 +32,7 @@ struct TupleElementKeyData ...@@ -32,7 +32,7 @@ struct TupleElementKeyData
template <typename T, template <typename T,
typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value, typename enable_if<!is_same<remove_cvref_t<T>, TupleElementKeyData>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward<T>(v)) __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(ck::forward<T>(v))
{ {
} }
...@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x) ...@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
template <typename Key, typename Data> template <typename Key, typename Data>
__host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x) __host__ __device__ constexpr Data get_tuple_element_data(const TupleElementKeyData<Key, Data>& x)
{ {
return std::forward(x.mData); return ck::forward(x.mData);
} }
template <typename Indices, typename... Xs> template <typename Indices, typename... Xs>
...@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I ...@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
!is_same<remove_cvref_t<Y>, TupleImpl>::value, !is_same<remove_cvref_t<Y>, TupleImpl>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y) __host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Y>(y))... : TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Y>(y))...
{ {
} }
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false> template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys) __host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElementKeyData<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))... : TupleElementKeyData<TupleElementKey<Is>, Xs>(ck::forward<Ys>(ys))...
{ {
static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys),
"wrong! inconsistent size"); "wrong! inconsistent size");
...@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X ...@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template <typename Y, template <typename Y,
typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value, typename enable_if<sizeof...(Xs) == 1 && !is_same<remove_cvref_t<Y>, Tuple>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y)) __host__ __device__ constexpr Tuple(Y&& y) : base(ck::forward<Y>(y))
{ {
} }
template <typename... Ys, template <typename... Ys,
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type = typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
false> false>
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...) __host__ __device__ constexpr Tuple(Ys&&... ys) : base(ck::forward<Ys>(ys)...)
{ {
} }
...@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type; ...@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type;
template <typename... Xs> template <typename... Xs>
__host__ __device__ constexpr auto make_tuple(Xs&&... xs) __host__ __device__ constexpr auto make_tuple(Xs&&... xs)
{ {
return Tuple<remove_cvref_t<Xs>...>(std::forward<Xs>(xs)...); return Tuple<remove_cvref_t<Xs>...>(ck::forward<Xs>(xs)...);
} }
// https://en.cppreference.com/w/cpp/utility/tuple/tie // https://en.cppreference.com/w/cpp/utility/tuple/tie
......
...@@ -29,7 +29,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& ...@@ -29,7 +29,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
const Tuple<Y&...>& ty) const Tuple<Y&...>& ty)
{ {
return unpack2( return unpack2(
[&](auto&&... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; }, [&](auto&&... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx, tx,
ty); ty);
} }
...@@ -38,7 +38,7 @@ template <typename... X, typename... Y> ...@@ -38,7 +38,7 @@ template <typename... X, typename... Y>
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty) __host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{ {
return unpack2( return unpack2(
[&](auto... zs) { return Tuple<decltype(zs)...>{std::forward<decltype(zs)>(zs)...}; }, [&](auto... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
tx, tx,
ty); ty);
} }
...@@ -157,6 +157,7 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple) ...@@ -157,6 +157,7 @@ __host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
} }
} }
#ifndef __HIPCC_RTC__
template <typename T> template <typename T>
using is_tuple = decltype(std::declval<T&>().IsTuple()); using is_tuple = decltype(std::declval<T&>().IsTuple());
...@@ -165,6 +166,7 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&) ...@@ -165,6 +166,7 @@ __host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
{ {
return (is_detected<is_tuple, Ts>::value || ...); return (is_detected<is_tuple, Ts>::value || ...);
} }
#endif
template <index_t depth = 0, typename T> template <index_t depth = 0, typename T>
__host__ __device__ constexpr auto TupleDepth(const T&) __host__ __device__ constexpr auto TupleDepth(const T&)
......
...@@ -8,6 +8,158 @@ ...@@ -8,6 +8,158 @@
#include "ck/utility/enable_if.hpp" #include "ck/utility/enable_if.hpp"
namespace ck { namespace ck {
#ifdef __HIPCC_RTC__
template <bool B>
using bool_constant = integral_constant<bool, B>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
struct name : bool_constant<__##name(T)> \
{ \
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
struct name : bool_constant<__##name(T, U)> \
{ \
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAITN(name) \
template <class... Ts> \
struct name : bool_constant<__##name(Ts...)> \
{ \
}
CK_BUILTIN_TYPE_TRAIT1(is_class);
CK_BUILTIN_TYPE_TRAIT1(is_pointer);
CK_BUILTIN_TYPE_TRAIT1(is_reference);
CK_BUILTIN_TYPE_TRAIT1(is_trivially_copyable);
CK_BUILTIN_TYPE_TRAIT1(is_unsigned);
CK_BUILTIN_TYPE_TRAIT2(is_base_of);
template <class T>
struct remove_cv
{
using type = T;
};
template <class T>
struct remove_cv<const T> : remove_cv<T>
{
};
template <class T>
struct remove_cv<volatile T> : remove_cv<T>
{
};
template <class T>
struct remove_reference
{
typedef T type;
};
template <class T>
struct remove_reference<T&>
{
typedef T type;
};
template <class T>
struct remove_reference<T&&>
{
typedef T type;
};
template <class T>
struct remove_pointer
{
typedef T type;
};
template <class T>
struct remove_pointer<T*>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* const>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* volatile>
{
typedef T type;
};
template <class T>
struct remove_pointer<T* const volatile>
{
typedef T type;
};
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type& t_) noexcept
{
return static_cast<T&&>(t_);
}
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
{
return static_cast<T&&>(t_);
}
// TODO
template<class T> struct is_const : false_type {};
template<class T> struct is_const<const T> : true_type {};
template< class T >
inline constexpr bool is_const_v = is_const<T>::value;
template< class T >
inline constexpr bool is_reference_v = is_reference<T>::value;
template<class T> struct remove_const { typedef T type; };
template<class T> struct remove_const<const T> { typedef T type; };
template< class T >
using remove_const_t = typename remove_const<T>::type;
template< class T >
inline constexpr bool is_class_v = is_class<T>::value;
template< class T >
inline constexpr bool is_trivially_copyable_v = is_trivially_copyable<T>::value;
template< class... >
using void_t = void;
using __hip::declval;
#else
#include <utility>
#include <type_traits>
using std::forward;
using std::is_base_of;
using std::is_class;
using std::is_pointer;
using std::is_reference;
using std::is_trivially_copyable;
using std::is_unsigned;
using std::remove_cv;
using std::remove_pointer;
using std::remove_reference;
using std::is_const_v;
using std::is_reference_v;
using std::remove_const_t;
using std::is_class_v;
using std::is_trivially_copyable_v;
using std::void_t;
using std::false_type;
using std::true_type;
using std::declval;
#endif
template <typename X, typename Y> template <typename X, typename Y>
struct is_same : public integral_constant<bool, false> struct is_same : public integral_constant<bool, false>
...@@ -23,19 +175,19 @@ template <typename X, typename Y> ...@@ -23,19 +175,19 @@ template <typename X, typename Y>
inline constexpr bool is_same_v = is_same<X, Y>::value; inline constexpr bool is_same_v = is_same<X, Y>::value;
template <typename T> template <typename T>
using remove_reference_t = typename std::remove_reference<T>::type; using remove_reference_t = typename remove_reference<T>::type;
template <typename T> template <typename T>
using remove_cv_t = typename std::remove_cv<T>::type; using remove_cv_t = typename remove_cv<T>::type;
template <typename T> template <typename T>
using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>; using remove_cvref_t = remove_cv_t<remove_reference_t<T>>;
template <typename T> template <typename T>
using remove_pointer_t = typename std::remove_pointer<T>::type; using remove_pointer_t = typename remove_pointer<T>::type;
template <typename T> template <typename T>
inline constexpr bool is_pointer_v = std::is_pointer<T>::value; inline constexpr bool is_pointer_v = is_pointer<T>::value;
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false> template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y bit_cast(const X& x) __host__ __device__ constexpr Y bit_cast(const X& x)
......
...@@ -17,10 +17,10 @@ namespace ck { ...@@ -17,10 +17,10 @@ namespace ck {
// Convert X to Y, both X and Y are non-const data types. // Convert X to Y, both X and Y are non-const data types.
template <typename Y, template <typename Y,
typename X, typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false> ck::enable_if_t<!(ck::is_const_v<Y> || ck::is_const_v<X>), bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
return static_cast<Y>(x); return static_cast<Y>(x);
} }
...@@ -28,13 +28,13 @@ __host__ __device__ constexpr Y type_convert(X x) ...@@ -28,13 +28,13 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type. // Convert X to Y, either X or Y is a const data type.
template <typename Y, template <typename Y,
typename X, typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false> ck::enable_if_t<ck::is_const_v<Y> || ck::is_const_v<X>, bool> = false>
__host__ __device__ constexpr Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>; using NonConstY = ck::remove_const_t<Y>;
using NonConstX = std::remove_const_t<X>; using NonConstX = ck::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x)); return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
} }
...@@ -104,7 +104,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_ ...@@ -104,7 +104,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert_sp(X x) __host__ __device__ constexpr Y type_convert_sp(X x)
{ {
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>); static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
return static_cast<Y>(x); return static_cast<Y>(x);
} }
...@@ -166,7 +166,7 @@ template <> ...@@ -166,7 +166,7 @@ template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x) inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<long_index_t>(&x), x);
#if defined(__gfx94__) #if defined(__gfx94__)
union union
{ {
...@@ -206,7 +206,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x) ...@@ -206,7 +206,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<long_index_t>(&x), x);
return utils:: return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng); x, rng);
...@@ -218,7 +218,7 @@ template <> ...@@ -218,7 +218,7 @@ template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x) inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<float, seed>(reinterpret_cast<long_index_t>(&x), x);
#if defined(__gfx94__) #if defined(__gfx94__)
union union
{ {
...@@ -258,7 +258,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x) ...@@ -258,7 +258,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr bool clip = true; constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 1254739; constexpr int seed = 1254739;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x); uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<long_index_t>(&x), x);
return utils:: return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>( cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng); x, rng);
...@@ -501,6 +501,7 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x) ...@@ -501,6 +501,7 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#endif #endif
} }
#ifndef __HIPCC_RTC__
template <typename Y, typename X, std::size_t NumElems> template <typename Y, typename X, std::size_t NumElems>
inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y, inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
const std::array<X, NumElems>& x) const std::array<X, NumElems>& x)
...@@ -510,6 +511,7 @@ inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y, ...@@ -510,6 +511,7 @@ inline __host__ __device__ void array_convert(std::array<Y, NumElems>& y,
y[i] = type_convert<Y>(x[i]); y[i] = type_convert<Y>(x[i]);
} }
} }
#endif
template <typename Y, typename X, index_t NumElems> template <typename Y, typename X, index_t NumElems>
inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x) inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array<X, NumElems>& x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment