"sgl-kernel/csrc/vscode:/vscode.git/clone" did not exist on "7a4309cc8a56e7a2cffba82a5189b51fd5776259"
Commit 6c7ba556 authored by Umang Yadav's avatar Umang Yadav
Browse files

add type traits for hiprtc

parent d7173bc6
......@@ -38,6 +38,7 @@ tags
# build-in-source directory
build*
.cache/*
# emacs temporary/backup files
.\#*
......
......@@ -6,10 +6,8 @@
#include "ck/ck.hpp"
#include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp"
#include <cstddef>
#include <cstdint>
#include <type_traits>
#include "ck/utility/type.hpp"
#include "ck/utility/data_type.hpp"
namespace ck {
namespace detail {
......@@ -47,15 +45,15 @@ __device__ inline int32_t amd_wave_read_first_lane(int32_t value)
template <
typename Object,
typename = std::enable_if_t<std::is_class_v<Object> && std::is_trivially_copyable_v<Object>>>
typename = std::enable_if_t<std::is_class<Object>::value && std::is_trivially_copyable<Object>::value>>
__device__ auto amd_wave_read_first_lane(const Object& obj)
{
using Size = unsigned;
constexpr Size SgprSize = 4;
constexpr Size ObjectSize = sizeof(Object);
auto* const from_obj = reinterpret_cast<const std::byte*>(&obj);
alignas(Object) std::byte to_obj[ObjectSize];
auto* const from_obj = reinterpret_cast<const byte*>(&obj);
alignas(Object) byte to_obj[ObjectSize];
constexpr Size RemainedSize = ObjectSize % SgprSize;
constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize;
......
......@@ -3,8 +3,30 @@
#pragma once
#include "ck/utility/number.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#include "ck/utility/type.hpp"
#ifdef __HIPCC_RTC__
#ifdef WORKAROUND_ISSUE_HIPRTC_TRUE_TYPE
/// Definitions from <cstdint>, <cmath> conflict with
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using int32_t = signed int;
using uint32_t = unsigned int;
using int64_t = signed long long;
using uint64_t = unsigned long long;
using byte = unsigned char;
#include <limits> // std::numeric_limits
#else
#include <cstdint> // int8_t, int16_t
#include <cstddef>
#include <cmath> // float_t
#endif
#endif // __HIPCC_RTC__
namespace ck {
using bhalf_t = ushort;
......
......@@ -3,7 +3,7 @@
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace ck {
namespace debug {
......
......@@ -2,6 +2,20 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef __HIPCC_RTC__
namespace std {
template <bool B, class T = void>
struct enable_if
{
};
template <class T>
struct enable_if<true, T>
{
using type = T;
};
}
#endif
namespace ck {
......
......@@ -47,5 +47,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
static_assert(Y > 0, "wrong!");
return integral_constant<decltype(X % Y), X % Y>{};
}
template <bool B>
using bool_constant = integral_constant<bool, B>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
} // namespace ck
......@@ -7,6 +7,169 @@
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/enable_if.hpp"
#ifdef __HIPCC_RTC__
#ifdef WORKAROUND_ISSUE_HIPRTC_TRUE_TYPE
/// We need <type_traits> for std::remove_reference and std::remove_cv.
/// But <type_traits> also defines std::true_type, per Standard.
/// However the latter definition conflicts with
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h,
/// which defines std::true_type as well (which is wrong).
namespace std {
template<class T> struct remove_pointer { typedef T type; };
template<class T> struct remove_pointer<T*> { typedef T type; };
template<class T> struct remove_pointer<T* const> { typedef T type; };
template<class T> struct remove_pointer<T* volatile> { typedef T type; };
template<class T> struct remove_pointer<T* const volatile> { typedef T type; };
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
struct name : bool_constant<__##name(T)> \
{ \
}
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
struct name : bool_constant<__##name(T, U)> \
{ \
}
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAITN(name) \
template <class... Ts> \
struct name : bool_constant<__##name(Ts...)> \
{ \
}
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_arithmetic);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_destructible);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_nothrow_destructible);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_pointer);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_scalar);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_signed);
// MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_void);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_abstract);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_aggregate);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_array);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_class);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_compound);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_const);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_empty);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_enum);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_final);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_floating_point);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_function);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_fundamental);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_integral);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_literal_type);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_lvalue_reference);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_function_pointer);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_object_pointer);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_member_pointer);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_object);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_pod);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_polymorphic);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_reference);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_rvalue_reference);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_standard_layout);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivial);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivially_copyable);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_trivially_destructible);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_union);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_unsigned);
MIGRAPHX_BUILTIN_TYPE_TRAIT1(is_volatile);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_assignable);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_base_of);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_convertible);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_nothrow_assignable);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_same);
MIGRAPHX_BUILTIN_TYPE_TRAIT2(is_trivially_assignable);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_nothrow_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN(is_trivially_constructible);
template <class T>
struct remove_reference
{
typedef T type;
};
template <class T>
struct remove_reference<T&>
{
typedef T type;
};
template <class T>
struct remove_reference<T&&>
{
typedef T type;
};
template <class T>
using remove_reference_t = typename remove_reference<T>::type;
template <class T>
struct remove_const
{
typedef T type;
};
template <class T>
struct remove_const<const T>
{
typedef T type;
};
template <class T>
struct remove_volatile
{
typedef T type;
};
template <class T>
struct remove_volatile<volatile T>
{
typedef T type;
};
template <class T>
struct remove_cv
{
typedef typename remove_volatile<typename remove_const<T>::type>::type type;
};
template <class T>
struct is_pointer_helper : std::false_type
{
};
template <class T>
struct is_pointer_helper<T*> : std::true_type
{
};
template <class T>
struct is_pointer : is_pointer_helper<typename std::remove_cv<T>::type>
{
};
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type& t_) noexcept
{
return static_cast<T&&>(t_);
}
template <typename T>
constexpr T&& forward(typename remove_reference<T>::type&& t_) noexcept
{
return static_cast<T&&>(t_);
}
} // namespace std
#else
#include <type_traits> // std::remove_reference, std::remove_cv, is_pointer
#endif
#endif // __HIPCC_RTC__
namespace ck {
template <typename X, typename Y>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment