"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "d71edf6c771c3bb0ac51a52b4987843a06fb6bee"
Commit b7ab5e92 authored by Umang Yadav's avatar Umang Yadav
Browse files

merge latest develop into migraphx

parent 3c4fb1dd
......@@ -209,10 +209,10 @@ struct Bilinear
};
template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
__host__ __device__ constexpr void
operator()<int8_t, int32_t, int8_t>(int8_t& y, const int32_t& x0, const int8_t& x1) const
{
y = type_convert<std::int8_t>(x0 + ck::type_convert<std::int32_t>(x1));
y = type_convert<int8_t>(x0 + ck::type_convert<int32_t>(x1));
};
float alpha_;
......
......@@ -411,9 +411,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
MakeDsGridDescriptor_M_N(const ck::Array<index_t, NumDTensor>& MRaws,
const ck::Array<index_t, NumDTensor>& NRaws,
const ck::Array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) {
......@@ -877,7 +877,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const ck::Array<index_t, NumDTensor> StrideDs,
const index_t StrideE,
const Block2ETileMap& block_2_etile_map)
{
......
......@@ -2,21 +2,22 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
namespace ck {
namespace detail {
template <class Default, class AlwaysVoid, template <class...> class Op, class... Args>
struct detector
{
using value_t = std::false_type;
using value_t = ck::false_type;
using type = Default;
};
template <class Default, template <class...> class Op, class... Args>
struct detector<Default, std::void_t<Op<Args...>>, Op, Args...>
struct detector<Default, ck::void_t<Op<Args...>>, Op, Args...>
{
using value_t = std::true_type;
using value_t = ck::true_type;
using type = Op<Args...>;
};
} // namespace detail
......@@ -32,12 +33,12 @@ template <template <class...> class Op, class... Args>
using is_detected = typename detail::detector<nonesuch, void, Op, Args...>::value_t;
template <typename T>
using is_pack2_invocable_t = decltype(std::declval<T&>().is_pack2_invocable);
using is_pack2_invocable_t = decltype(ck::declval<T&>().is_pack2_invocable);
template <typename T>
using is_pack4_invocable_t = decltype(std::declval<T&>().is_pack4_invocable);
using is_pack4_invocable_t = decltype(ck::declval<T&>().is_pack4_invocable);
template <typename T>
using is_pack8_invocable_t = decltype(std::declval<T&>().is_pack8_invocable);
using is_pack8_invocable_t = decltype(ck::declval<T&>().is_pack8_invocable);
} // namespace ck
......@@ -184,6 +184,7 @@ inline __host__ double expm1<double>(double x)
{
return std::expm1(x);
}
#endif // __HIPCC_RTC__
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
......
......@@ -31,12 +31,24 @@ namespace ck {
}
CK_BUILTIN_TYPE_TRAIT1(is_class);
CK_BUILTIN_TYPE_TRAIT1(is_const);
CK_BUILTIN_TYPE_TRAIT1(is_pointer);
CK_BUILTIN_TYPE_TRAIT1(is_reference);
CK_BUILTIN_TYPE_TRAIT1(is_trivially_copyable);
CK_BUILTIN_TYPE_TRAIT1(is_unsigned);
CK_BUILTIN_TYPE_TRAIT2(is_base_of);
template <class T>
struct remove_const
{
typedef T type;
};
template <class T>
struct remove_const<const T>
{
typedef T type;
};
template <class T>
struct remove_cv
{
......@@ -106,19 +118,71 @@ constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
{
return static_cast<T&&>(t_);
}
template <typename... Ts>
struct make_void
{
typedef void type;
};
template <typename... Ts>
using void_t = typename make_void<Ts...>::type;
// namespace detail {
// template <class T>
// struct type_identity
// {
// using type = T;
// };
// template <class T> // Note that `cv void&` is a substitution failure
// auto try_add_lvalue_reference(int) -> type_identity<T&>;
// template <class T> // Handle T = cv void case
// auto try_add_lvalue_reference(...) -> type_identity<T>;
// template <class T>
// auto try_add_rvalue_reference(int) -> type_identity<T&&>;
// template <class T>
// auto try_add_rvalue_reference(...) -> type_identity<T>;
// } // namespace detail
// template <class T>
// struct add_lvalue_reference : decltype(detail::try_add_lvalue_reference<T>(0))
// {
// };
// template <class T>
// struct add_rvalue_reference : decltype(detail::try_add_rvalue_reference<T>(0))
// {
// };
// template <class T>
// typename add_rvalue_reference<T>::type declval();
template <class T, class U = T&&>
U private_declval(int);
template <class T>
T private_declval(long);
template <class T>
auto declval() noexcept -> decltype(private_declval<T>(0));
#else
#include <utility>
#include <type_traits>
using std::declval;
using std::forward;
using std::is_base_of;
using std::is_class;
using std::is_const;
using std::is_pointer;
using std::is_reference;
using std::is_trivially_copyable;
using std::is_unsigned;
using std::remove_const;
using std::remove_cv;
using std::remove_pointer;
using std::remove_reference;
using std::void_t;
#endif
template <typename X, typename Y>
......@@ -140,9 +204,15 @@ inline constexpr bool is_same_v = is_same<X, Y>::value;
template <typename X, typename Y>
inline constexpr bool is_base_of_v = is_base_of<X, Y>::value;
template <typename T>
inline constexpr bool is_const_v = is_const<T>::value;
template <typename T>
inline constexpr bool is_unsigned_v = is_unsigned<T>::value;
template <class T>
using remove_const_t = typename remove_const<T>::type;
template <typename T>
using remove_reference_t = typename remove_reference<T>::type;
......
......@@ -4,6 +4,8 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/random_gen.hpp"
......@@ -23,7 +25,7 @@ __host__ __device__ constexpr Y type_convert(X x)
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<ck::is_const_v<Y> || ck::is_const_v<X>, bool> = false>
ck::enable_if_t<ck::is_const_v<Y> || ck::is_const_v<X>, bool> = false>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!ck::is_reference_v<Y> && !ck::is_reference_v<X>);
......@@ -341,7 +343,7 @@ template <>
inline __host__ __device__ f8_t f8_convert_sr<f8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uint64_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
......@@ -376,7 +378,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr bool clip = true;
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uint64_t>(&x), x);
return utils::
cast_to_f8<half_t, f8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
......@@ -388,7 +390,7 @@ template <>
inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
{
constexpr int seed = 42;
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uint64_t>(&x), x);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
......@@ -424,7 +426,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic;
constexpr int seed = 42;
// as thread id is not available on host, use 0 for prn generation
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uintptr_t>(&x), x);
uint32_t rng = prand_generator<half_t, seed>(reinterpret_cast<uint64_t>(&x), x);
return utils::
cast_to_f8<half_t, bf8_t, negative_zero_nan, clip, (rm == f8_rounding_mode::stochastic)>(
x, rng);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment