Commit 253f942b authored by Umang Yadav's avatar Umang Yadav
Browse files

changes to make it compile

parent 8f9c0243
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -391,3 +394,5 @@ __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
} // namespace ck
#endif
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -5,6 +8,19 @@
#include "ck/utility/statically_indexed_array.hpp"
#ifdef __HIPCC_RTC__
/// Definitions from <cstdint>, <cmath> conflict with
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using float_t = float;
namespace std {
using byte = unsigned char;
}
#endif // __HIPCC_RTC__
namespace ck {
using bhalf_t = ushort;
......@@ -19,21 +35,22 @@ template <typename T, index_t N>
struct vector_type;
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
// vectors"
// intentionally have only declaration but no definition to cause compilation
// failure when trying to instantiate this template. The purpose is to catch
// user's mistake when trying to make "vector of vectors"
template <typename T, index_t V, index_t N>
struct vector_type<T __attribute__((ext_vector_type(V))), N>;
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
// vectors"
// intentionally have only declaration but no definition to cause compilation
// failure when trying to instantiate this template. The purpose is to catch
// user's mistake when trying to make "vector of vectors"
template <typename T, index_t V, index_t N>
struct vector_type<vector_type<T, V>, N>;
// vector_type_maker
// This is the right way to handle "vector of vectors": making a bigger vector instead
// This is the right way to handle "vector of vectors": making a bigger vector
// instead
template <typename T, index_t N>
struct vector_type_maker
{
......@@ -960,21 +977,233 @@ using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type;
template <typename T>
struct NumericLimits
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert(X x)
{
__host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); }
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
return static_cast<Y>(x);
}
// convert bfp16 to fp32
template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
return u.fp32;
}
__host__ __device__ static constexpr T QuietNaN()
// convert fp32 to bfp16
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{
return std::numeric_limits<T>::quiet_NaN();
}
float fp32;
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16);
}
// convert bfp16 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<half_t>(x_fp32);
}
// convert fp16 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int32_t>(x_fp32);
}
// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int8 via fp32
template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int8_t>(x_fp32);
}
// convert int8 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
__host__ __device__ static constexpr T Infinity() { return std::numeric_limits<T>::infinity(); }
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
template <typename T>
struct NumericLimits;
template <>
struct NumericLimits<int32_t>
{
__host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; }
__host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int16_t>
{
__host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Min() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Max() noexcept { return 32767; }
__host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int8_t>
{
__host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Min() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Max() noexcept { return 127; }
__host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int8_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint32_t>
{
__host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; }
__host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint16_t>
{
__host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; }
__host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint8_t>
{
__host__ __device__ static constexpr uint8_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint8_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint8_t Max() noexcept { return 255U; }
__host__ __device__ static constexpr uint8_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint8_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<float>
{
static constexpr unsigned int binary_min = 0x00800000;
static constexpr unsigned int binary_max = 0x7F7FFFFF;
static constexpr unsigned int binary_lowest = 0xFF7FFFFF;
static constexpr unsigned int binary_qnan = 0xFFC00001;
static constexpr unsigned int binary_inf = 0x7F8000000;
__host__ __device__ static constexpr float Min() { return bit_cast<float>(binary_min); }
__host__ __device__ static constexpr float Max() { return bit_cast<float>(binary_max); }
__host__ __device__ static constexpr float Lowest() { return bit_cast<float>(binary_lowest); }
__host__ __device__ static constexpr float QuietNaN() { return bit_cast<float>(binary_qnan); }
__host__ __device__ static constexpr float Infinity() { return bit_cast<float>(binary_inf); }
};
template <>
......@@ -1024,3 +1253,5 @@ struct NumericLimits<f8_t>
};
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace ck {
namespace debug {
namespace detail {
template <typename T, typename Enable = void>
struct PrintAsType;
template <typename T, typename Enable = void> struct PrintAsType;
template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type>
{
using type = float;
__host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); }
struct PrintAsType<
T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using type = float;
__host__ __device__ static void Print(const T &p) {
printf("%.3f ", static_cast<type>(p));
}
};
template <>
struct PrintAsType<ck::half_t, void>
{
using type = float;
__host__ __device__ static void Print(const ck::half_t& p)
{
printf("%.3f ", static_cast<type>(p));
}
template <> struct PrintAsType<ck::half_t, void> {
using type = float;
__host__ __device__ static void Print(const ck::half_t &p) {
printf("%.3f ", static_cast<type>(p));
}
};
template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type>
{
using type = int;
__host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); }
struct PrintAsType<T,
typename std::enable_if<std::is_integral<T>::value>::type> {
using type = int;
__host__ __device__ static void Print(const T &p) {
printf("%d ", static_cast<type>(p));
}
};
} // namespace detail
// Print at runtime the data in shared memory in 128 bytes per row format given shared mem pointer
// and the number of elements. Can optionally specify strides between elements and how many bytes'
// worth of data per row.
// Print at runtime the data in shared memory in 128 bytes per row format given
// shared mem pointer and the number of elements. Can optionally specify strides
// between elements and how many bytes' worth of data per row.
//
// Usage example:
//
// debug::print_shared(a_block_buf.p_data_, index_t(a_block_desc_k0_m_k1.GetElementSpaceSize()));
// debug::print_shared(a_block_buf.p_data_,
// index_t(a_block_desc_k0_m_k1.GetElementSpaceSize()));
//
template <typename T, index_t element_stride = 1, index_t row_bytes = 128>
__device__ void print_shared(T const* p_shared, index_t num_elements)
{
constexpr index_t row_elements = row_bytes / sizeof(T);
static_assert((element_stride >= 1 && element_stride <= row_elements),
"element_stride should between [1, row_elements]");
__device__ void print_shared(T const *p_shared, index_t num_elements) {
constexpr index_t row_elements = row_bytes / sizeof(T);
static_assert((element_stride >= 1 && element_stride <= row_elements),
"element_stride should between [1, row_elements]");
index_t wgid = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
index_t tid =
(threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x;
index_t wgid =
blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
index_t tid = (threadIdx.z * (blockDim.x * blockDim.y)) +
(threadIdx.y * blockDim.x) + threadIdx.x;
__syncthreads();
__syncthreads();
if(tid == 0)
{
printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n",
wgid,
row_bytes,
element_stride);
for(index_t i = 0; i < num_elements; i += row_elements)
{
printf("elem %5d: ", i);
for(index_t j = 0; j < row_elements; j += element_stride)
{
detail::PrintAsType<T>::Print(p_shared[i + j]);
}
if (tid == 0) {
printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n", wgid,
row_bytes, element_stride);
for (index_t i = 0; i < num_elements; i += row_elements) {
printf("elem %5d: ", i);
for (index_t j = 0; j < row_elements; j += element_stride) {
detail::PrintAsType<T>::Print(p_shared[i + j]);
}
printf("\n");
}
printf("\n");
printf("\n");
}
printf("\n");
}
__syncthreads();
__syncthreads();
}
} // namespace debug
} // namespace ck
#endif // UTILITY_DEBUG_HPP
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -405,3 +408,5 @@ make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element
}
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifdef __HIPCC_RTC__
namespace std {
template <bool B, typename T = void>
using enable_if_t = typename enable_if<B, T>::type;
} // namespace std
#endif
namespace ck {
template <bool B, typename T = void>
using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void> using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type;
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -129,3 +132,5 @@ constexpr auto conditional_expr(X&& x, Y&& y)
}
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -47,3 +50,5 @@ struct static_for
};
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -142,3 +145,5 @@ struct ford
};
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -63,3 +66,5 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
} // namespace ck
#endif
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -121,3 +124,5 @@ __device__ float2_t atomic_max<float2_t>(float2_t* p_dst, const float2_t& x)
}
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -26,3 +29,5 @@ __device__ index_t get_grid_size() { return gridDim.x; }
__device__ index_t get_block_size() { return blockDim.x; }
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -20,3 +23,5 @@ struct ignore_t
inline constexpr detail::ignore_t ignore;
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -234,3 +237,5 @@ inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t
}
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -5,47 +8,50 @@
namespace ck {
template <class T, T v>
struct integral_constant
{
static constexpr T value = v;
typedef T value_type;
typedef integral_constant type;
__host__ __device__ constexpr operator value_type() const noexcept { return value; }
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
template <class T, T v> struct integral_constant {
static constexpr T value = v;
typedef T value_type;
typedef integral_constant type;
__host__ __device__ constexpr operator value_type() const noexcept {
return value;
}
__host__ __device__ constexpr value_type operator()() const noexcept {
return value;
}
};
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator+(integral_constant<TX, X>, integral_constant<TY, Y>)
{
return integral_constant<decltype(X + Y), X + Y>{};
__host__ __device__ constexpr auto operator+(integral_constant<TX, X>,
integral_constant<TY, Y>) {
return integral_constant<decltype(X + Y), X + Y>{};
}
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator-(integral_constant<TX, X>, integral_constant<TY, Y>)
{
static_assert(Y <= X, "wrong!");
return integral_constant<decltype(X - Y), X - Y>{};
__host__ __device__ constexpr auto operator-(integral_constant<TX, X>,
integral_constant<TY, Y>) {
static_assert(Y <= X, "wrong!");
return integral_constant<decltype(X - Y), X - Y>{};
}
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator*(integral_constant<TX, X>, integral_constant<TY, Y>)
{
return integral_constant<decltype(X * Y), X * Y>{};
__host__ __device__ constexpr auto operator*(integral_constant<TX, X>,
integral_constant<TY, Y>) {
return integral_constant<decltype(X * Y), X * Y>{};
}
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator/(integral_constant<TX, X>, integral_constant<TY, Y>)
{
static_assert(Y > 0, "wrong!");
return integral_constant<decltype(X / Y), X / Y>{};
__host__ __device__ constexpr auto operator/(integral_constant<TX, X>,
integral_constant<TY, Y>) {
static_assert(Y > 0, "wrong!");
return integral_constant<decltype(X / Y), X / Y>{};
}
template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_constant<TY, Y>)
{
static_assert(Y > 0, "wrong!");
return integral_constant<decltype(X % Y), X % Y>{};
__host__ __device__ constexpr auto operator%(integral_constant<TX, X>,
integral_constant<TY, Y>) {
static_assert(Y > 0, "wrong!");
return integral_constant<decltype(X % Y), X % Y>{};
}
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -54,3 +57,5 @@ struct is_known_at_compile_time<Tuple<Ts...>>
};
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -6,155 +9,144 @@
#include "ck/ck.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "tuple.hpp"
#include "type.hpp"
#define INT32_MAX 2147483647
namespace ck {
// magic number division
// Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
// division implementation for uint32_t is then used. Therefore, dividend value need to be
// non-negative.
// 1. For uint32_t as dividend: magic number division implementation being
// used would produce correct result if the dividend is uint32_t and its value
// is within 31-bit value range.
// 2. For int32_t as dividendd: magic number division for int32_t dividened
// has not been implemented, the int32_t dividend would be bit-wise
// interpreted as uint32_t and magic number division implementation for
// uint32_t is then used. Therefore, dividend value need to be non-negative.
// TODO:
// 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range
struct MagicDivision
{
// uint32_t
__host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor)
{
// WARNING: magic division is only applicable for division inside this range.
// You should use the return value of CalculateMagicNumbers, if division is not inside this
// range. The "else" logic below is to quiet down run-time error.
if(divisor >= 1 && divisor <= INT32_MAX)
{
uint32_t shift = 0;
for(shift = 0; shift < 32; ++shift)
{
if((1U << shift) >= divisor)
{
break;
}
}
uint64_t one = 1;
uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
// assert(multiplier <= 0xffffffffUL);
return make_tuple(uint32_t(multiplier), shift);
}
else
{
return make_tuple(uint32_t(0), uint32_t(0));
struct MagicDivision {
// uint32_t
__host__ __device__ static constexpr auto
CalculateMagicNumbers(uint32_t divisor) {
// WARNING: magic division is only applicable for division inside this
// range. You should use the return value of CalculateMagicNumbers, if
// division is not inside this range. The "else" logic below is to quiet
// down run-time error.
if (divisor >= 1 && divisor <= INT32_MAX) {
uint32_t shift = 0;
for (shift = 0; shift < 32; ++shift) {
if ((1U << shift) >= divisor) {
break;
}
}
__host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor)
{
auto tmp = CalculateMagicNumbers(divisor);
return tmp[Number<0>{}];
}
__host__ __device__ static constexpr uint32_t CalculateMagicShift(uint32_t divisor)
{
auto tmp = CalculateMagicNumbers(divisor);
return tmp[Number<1>{}];
}
// integral_constant<uint32_t, .>
template <uint32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicNumbers(integral_constant<uint32_t, Divisor>)
{
constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[Number<0>{}];
constexpr uint32_t shift = tmp[Number<1>{}];
return make_tuple(integral_constant<uint32_t, multiplier>{},
integral_constant<uint32_t, shift>{});
}
template <uint32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>)
{
constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor});
return integral_constant<uint32_t, multiplier>{};
}
template <uint32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicShift(integral_constant<uint32_t, Divisor>)
{
constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor});
return integral_constant<uint32_t, shift>{};
}
// integral_constant<int32_t, .>
template <int32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicNumbers(integral_constant<int32_t, Divisor>)
{
return CalculateMagicNumbers(integral_constant<uint32_t, Divisor>{});
}
template <int32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicMultiplier(integral_constant<int32_t, Divisor>)
{
return CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>{});
}
template <int32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicShift(integral_constant<int32_t, Divisor>)
{
return CalculateMagicShift(integral_constant<uint32_t, Divisor>{});
}
// magic division for uint32_t
__device__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
__host__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = static_cast<uint64_t>(dividend) * multiplier >> 32;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
__device__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
__host__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = static_cast<uint64_t>(dividend_u32) * multiplier >> 32;
return (tmp + dividend_u32) >> shift;
}
}
uint64_t one = 1;
uint64_t multiplier =
((one << 32) * ((one << shift) - divisor)) / divisor + 1;
// assert(multiplier <= 0xffffffffUL);
return make_tuple(uint32_t(multiplier), shift);
} else {
return make_tuple(uint32_t(0), uint32_t(0));
}
}
__host__ __device__ static constexpr uint32_t
CalculateMagicMultiplier(uint32_t divisor) {
auto tmp = CalculateMagicNumbers(divisor);
return tmp[Number<0>{}];
}
__host__ __device__ static constexpr uint32_t
CalculateMagicShift(uint32_t divisor) {
auto tmp = CalculateMagicNumbers(divisor);
return tmp[Number<1>{}];
}
// integral_constant<uint32_t, .>
template <uint32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicNumbers(integral_constant<uint32_t, Divisor>) {
constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor});
constexpr uint32_t multiplier = tmp[Number<0>{}];
constexpr uint32_t shift = tmp[Number<1>{}];
return make_tuple(integral_constant<uint32_t, multiplier>{},
integral_constant<uint32_t, shift>{});
}
template <uint32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>) {
constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor});
return integral_constant<uint32_t, multiplier>{};
}
template <uint32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicShift(integral_constant<uint32_t, Divisor>) {
constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor});
return integral_constant<uint32_t, shift>{};
}
// integral_constant<int32_t, .>
template <int32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicNumbers(integral_constant<int32_t, Divisor>) {
return CalculateMagicNumbers(integral_constant<uint32_t, Divisor>{});
}
template <int32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicMultiplier(integral_constant<int32_t, Divisor>) {
return CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>{});
}
template <int32_t Divisor>
__host__ __device__ static constexpr auto
CalculateMagicShift(integral_constant<int32_t, Divisor>) {
return CalculateMagicShift(integral_constant<uint32_t, Divisor>{});
}
// magic division for uint32_t
__device__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) {
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
__host__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) {
uint32_t tmp = static_cast<uint64_t>(dividend) * multiplier >> 32;
return (tmp + dividend) >> shift;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
__device__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) {
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
__host__ static constexpr int32_t
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) {
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = static_cast<uint64_t>(dividend_u32) * multiplier >> 32;
return (tmp + dividend_u32) >> shift;
}
};
struct MDiv
......@@ -230,3 +222,5 @@ struct MDiv2
};
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "enable_if.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "enable_if.hpp"
namespace ck {
namespace math {
template <typename T, T s>
struct scales
{
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
template <typename T, T s> struct scales {
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
};
template <typename T>
struct plus
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
template <typename T> struct plus {
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct minus
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
template <typename T> struct minus {
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
};
struct multiplies
{
template <typename A, typename B>
__host__ __device__ constexpr auto operator()(const A& a, const B& b) const
{
return a * b;
}
struct multiplies {
template <typename A, typename B>
__host__ __device__ constexpr auto operator()(const A &a, const B &b) const {
return a * b;
}
};
template <typename T>
struct maximize
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
template <typename T> struct maximize {
__host__ __device__ constexpr T operator()(T a, T b) const {
return a >= b ? a : b;
}
};
template <typename T>
struct minimize
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
template <typename T> struct minimize {
__host__ __device__ constexpr T operator()(T a, T b) const {
return a <= b ? a : b;
}
};
template <typename T>
struct integer_divide_ceiler
{
__host__ __device__ constexpr T operator()(T a, T b) const
{
static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
template <typename T> struct integer_divide_ceiler {
__host__ __device__ constexpr T operator()(T a, T b) const {
static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
return (a + b - Number<1>{}) / b;
}
return (a + b - Number<1>{}) / b;
}
};
template <typename X, typename Y>
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
{
return x / y;
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y) {
return x / y;
}
template <typename X, typename Y>
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
{
return (x + y - Number<1>{}) / y;
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) {
return (x + y - Number<1>{}) / y;
}
template <typename X, typename Y>
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
{
return y * integer_divide_ceil(x, y);
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y) {
return y * integer_divide_ceil(x, y);
}
template <typename T>
__host__ __device__ constexpr T max(T x)
{
return x;
}
template <typename T> __host__ __device__ constexpr T max(T x) { return x; }
template <typename T>
__host__ __device__ constexpr T max(T x, T y)
{
return x > y ? x : y;
template <typename T> __host__ __device__ constexpr T max(T x, T y) {
return x > y ? x : y;
}
template <index_t X>
__host__ __device__ constexpr index_t max(Number<X>, index_t y)
{
return X > y ? X : y;
__host__ __device__ constexpr index_t max(Number<X>, index_t y) {
return X > y ? X : y;
}
template <index_t Y>
__host__ __device__ constexpr index_t max(index_t x, Number<Y>)
{
return x > Y ? x : Y;
__host__ __device__ constexpr index_t max(index_t x, Number<Y>) {
return x > Y ? x : Y;
}
template <typename X, typename... Ys>
__host__ __device__ constexpr auto max(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
__host__ __device__ constexpr auto max(X x, Ys... ys) {
static_assert(sizeof...(Ys) > 0, "not enough argument");
return max(x, max(ys...));
return max(x, max(ys...));
}
template <typename T>
__host__ __device__ constexpr T min(T x)
{
return x;
}
template <typename T> __host__ __device__ constexpr T min(T x) { return x; }
template <typename T>
__host__ __device__ constexpr T min(T x, T y)
{
return x < y ? x : y;
template <typename T> __host__ __device__ constexpr T min(T x, T y) {
return x < y ? x : y;
}
template <index_t X>
__host__ __device__ constexpr index_t min(Number<X>, index_t y)
{
return X < y ? X : y;
__host__ __device__ constexpr index_t min(Number<X>, index_t y) {
return X < y ? X : y;
}
template <index_t Y>
__host__ __device__ constexpr index_t min(index_t x, Number<Y>)
{
return x < Y ? x : Y;
__host__ __device__ constexpr index_t min(index_t x, Number<Y>) {
return x < Y ? x : Y;
}
template <typename X, typename... Ys>
__host__ __device__ constexpr auto min(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
__host__ __device__ constexpr auto min(X x, Ys... ys) {
static_assert(sizeof...(Ys) > 0, "not enough argument");
return min(x, min(ys...));
return min(x, min(ys...));
}
template <typename T>
__host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
{
return min(max(x, lowerbound), upperbound);
__host__ __device__ constexpr T clamp(const T &x, const T &lowerbound,
const T &upperbound) {
return min(max(x, lowerbound), upperbound);
}
// disallow implicit type casting
template <typename T>
__device__ T exp(T x);
template <typename T> __device__ T exp(T x);
// TODO: add f16 support using v_exp_f16
template <>
__device__ float exp<float>(float x)
{
return __expf(x);
}
template <> __device__ float exp<float>(float x) { return __expf(x); }
template <>
__device__ double exp<double>(double x)
{
return exp(x);
}
template <> __device__ double exp<double>(double x) { return exp(x); }
static inline __host__ float exp(float x) { return ::expf(x); }
// static inline __host__ float exp(float x) { return ::expf(x); }
static inline __host__ double exp(double x) { return std::exp(x); }
// static inline __host__ double exp(double x) { return std::exp(x); }
// greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{
if(x < 0)
{
return gcd(-x, y);
}
else if(y < 0)
{
return gcd(x, -y);
}
else if(x == y || x == 0)
{
return y;
}
else if(y == 0)
{
return x;
}
else if(x > y)
{
return gcd(x % y, y);
}
else
{
return gcd(x, y % x);
}
__host__ __device__ constexpr index_t gcd(index_t x, index_t y) {
if (x < 0) {
return gcd(-x, y);
} else if (y < 0) {
return gcd(x, -y);
} else if (x == y || x == 0) {
return y;
} else if (y == 0) {
return x;
} else if (x > y) {
return gcd(x % y, y);
} else {
return gcd(x, y % x);
}
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
{
constexpr auto r = gcd(X, Y);
__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>) {
constexpr auto r = gcd(X, Y);
return Number<r>{};
return Number<r>{};
}
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
{
return gcd(x, gcd(ys...));
template <typename X, typename... Ys,
typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto gcd(X x, Ys... ys) {
return gcd(x, gcd(ys...));
}
// least common multiple
template <typename X, typename Y>
__host__ __device__ constexpr auto lcm(X x, Y y)
{
return (x * y) / gcd(x, y);
__host__ __device__ constexpr auto lcm(X x, Y y) {
return (x * y) / gcd(x, y);
}
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
{
return lcm(x, lcm(ys...));
template <typename X, typename... Ys,
typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto lcm(X x, Ys... ys) {
return lcm(x, lcm(ys...));
}
template <typename T>
struct equal
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
template <typename T> struct equal {
__host__ __device__ constexpr bool operator()(T x, T y) const {
return x == y;
}
};
template <typename T>
struct less
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
template <typename T> struct less {
__host__ __device__ constexpr bool operator()(T x, T y) const {
return x < y;
}
};
template <index_t X>
......@@ -258,3 +206,5 @@ __host__ __device__ constexpr auto next_power_of_two(Number<X> x)
} // namespace math
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -13,177 +16,169 @@
namespace ck {
namespace math {
// math functions for the host, some are implemented by calling C++ std functions
// math functions for the host, some are implemented by calling C++ std
// functions
static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ float abs(float x) { return x < 0 ? x * -1.0 : x; };
static inline __host__ double abs(double x) { return std::abs(x); };
static inline __host__ double abs(double x) { return x < 0 ? x * -1.0 : x; };
static inline __host__ int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
static inline __host__ int8_t abs(int8_t x) {
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
return (x ^ sgn) - sgn;
};
static inline __host__ int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
static inline __host__ int32_t abs(int32_t x) {
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
return (x ^ sgn) - sgn;
};
static inline __host__ half_t abs(half_t x)
{
uint16_t xx = ck::bit_cast<uint16_t>(x);
static inline __host__ half_t abs(half_t x) {
uint16_t xx = ck::bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = ck::bit_cast<half_t>(abs_xx);
half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x;
return abs_x;
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ int4_t abs(int4_t x)
{
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
static inline __host__ int4_t abs(int4_t x) {
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
}
#endif
static inline __host__ bool isnan(float x) { return std::isnan(x); };
// TODO: to bit arithmetic to figure it out
static inline __host__ bool isnan(float x) {
(void)x;
return false;
};
static inline __host__ bool isnan(double x) { return std::isnan(x); };
static inline __host__ bool isnan(double x) {
(void)x;
return false;
};
static inline __host__ bool isnan(int8_t x)
{
(void)x;
return false;
static inline __host__ bool isnan(int8_t x) {
(void)x;
return false;
};
static inline __host__ bool isnan(int32_t x)
{
(void)x;
return false;
static inline __host__ bool isnan(int32_t x) {
(void)x;
return false;
};
static inline __host__ bool isnan(half_t x)
{
uint16_t xx = ck::bit_cast<uint16_t>(x);
static inline __host__ bool isnan(half_t x) {
uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
return (xx & 0x7FFF) > 0x7C00;
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x)
{
(void)x;
return false;
static inline __host__ bool isnan(int4_t x) {
(void)x;
return false;
};
#endif
static inline __host__ half_t sqrt(half_t x)
{
return static_cast<half_t>(std::sqrt(static_cast<float>(x)));
};
// MIGRAPHX doesn't care about host compilation, just return identity values for
// now
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ half_t sqrt(half_t x) { return x; };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
static inline __host__ float sqrt(float x) { return x; };
static inline __host__ half_t tanh(half_t x)
{
return static_cast<half_t>(std::tanh(static_cast<float>(x)));
};
static inline __host__ double sqrt(double x) { return x; };
static inline __host__ float tanh(float x) { return std::tanh(x); };
static inline __host__ half_t tanh(half_t x) { return x; };
static inline __host__ double tanh(double x) { return std::tanh(x); };
static inline __host__ float tanh(float x) { return x; };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __host__ double tanh(double x) { return x; };
// math functions for the HIP kernel, some are implemented by calling hip
// builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
static inline __device__ double abs(double x) { return ::abs(x); };
static inline __device__ int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
static inline __device__ int8_t abs(int8_t x) {
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
return (x ^ sgn) - sgn;
};
static inline __device__ int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
static inline __device__ int32_t abs(int32_t x) {
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
return (x ^ sgn) - sgn;
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __device__ int4_t abs(int4_t x)
{
int4_t sgn = x >> (4 - 1);
static inline __device__ int4_t abs(int4_t x) {
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
return (x ^ sgn) - sgn;
};
#endif
static inline __device__ half_t abs(half_t x)
{
uint16_t xx = ck::bit_cast<uint16_t>(x);
static inline __device__ half_t abs(half_t x) {
uint16_t xx = ck::bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff;
uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = ck::bit_cast<half_t>(abs_xx);
half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x;
return abs_x;
};
static inline __device__ bool isnan(float x) { return ::isnan(x); };
static inline __device__ bool isnan(double x) { return ::isnan(x); };
static inline __device__ bool isnan(int8_t x)
{
(void)x;
return false;
static inline __device__ bool isnan(int8_t x) {
(void)x;
return false;
};
static inline __device__ bool isnan(int32_t x)
{
(void)x;
return false;
static inline __device__ bool isnan(int32_t x) {
(void)x;
return false;
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __device__ bool isnan(int4_t x)
{
(void)x;
return false;
static inline __device__ bool isnan(int4_t x) {
(void)x;
return false;
};
#endif
static inline __device__ bool isnan(half_t x)
{
uint16_t xx = ck::bit_cast<uint16_t>(x);
static inline __device__ bool isnan(half_t x) {
uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00;
return (xx & 0x7FFF) > 0x7C00;
};
static inline __device__ half_t sqrt(half_t x)
{
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
static inline __device__ half_t sqrt(half_t x) {
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
};
static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); };
static inline __device__ float sqrt(float x) {
return __builtin_amdgcn_sqrtf(x);
};
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); };
static inline __device__ double sqrt(double x) {
return __builtin_amdgcn_sqrt(x);
};
static inline __device__ half_t tanh(half_t x)
{
return static_cast<half_t>(::tanhf(static_cast<float>(x)));
static inline __device__ half_t tanh(half_t x) {
return static_cast<half_t>(::tanhf(static_cast<float>(x)));
};
static inline __device__ float tanh(float x) { return ::tanhf(x); };
......@@ -192,3 +187,5 @@ static inline __device__ double tanh(double x) { return ::tanh(x); };
} // namespace math
} // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -10,3 +13,5 @@
#else
#include "statically_indexed_array_multi_index.hpp"
#endif
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
......@@ -16,3 +19,5 @@ using LongNumber = integral_constant<long_index_t, N>;
} // namespace ck
#endif
#pragma clang diagnostic pop
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment