Commit 253f942b authored by Umang Yadav's avatar Umang Yadav
Browse files

changes to make it compile

parent 8f9c0243
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -391,3 +394,5 @@ __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>) ...@@ -391,3 +394,5 @@ __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
} // namespace ck } // namespace ck
#endif #endif
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -5,6 +8,19 @@ ...@@ -5,6 +8,19 @@
#include "ck/utility/statically_indexed_array.hpp" #include "ck/utility/statically_indexed_array.hpp"
#ifdef __HIPCC_RTC__
/// Definitions from <cstdint>, <cmath> conflict with
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using float_t = float;
namespace std {
using byte = unsigned char;
}
#endif // __HIPCC_RTC__
namespace ck { namespace ck {
using bhalf_t = ushort; using bhalf_t = ushort;
...@@ -19,21 +35,22 @@ template <typename T, index_t N> ...@@ -19,21 +35,22 @@ template <typename T, index_t N>
struct vector_type; struct vector_type;
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to // intentionally have only declaration but no definition to cause compilation
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of // failure when trying to instantiate this template. The purpose is to catch
// vectors" // user's mistake when trying to make "vector of vectors"
template <typename T, index_t V, index_t N> template <typename T, index_t V, index_t N>
struct vector_type<T __attribute__((ext_vector_type(V))), N>; struct vector_type<T __attribute__((ext_vector_type(V))), N>;
// Caution: DO NOT REMOVE // Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to // intentionally have only declaration but no definition to cause compilation
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of // failure when trying to instantiate this template. The purpose is to catch
// vectors" // user's mistake when trying to make "vector of vectors"
template <typename T, index_t V, index_t N> template <typename T, index_t V, index_t N>
struct vector_type<vector_type<T, V>, N>; struct vector_type<vector_type<T, V>, N>;
// vector_type_maker // vector_type_maker
// This is the right way to handle "vector of vectors": making a bigger vector instead // This is the right way to handle "vector of vectors": making a bigger vector
// instead
template <typename T, index_t N> template <typename T, index_t N>
struct vector_type_maker struct vector_type_maker
{ {
...@@ -960,21 +977,233 @@ using f8x16_t = typename vector_type<f8_t, 16>::type; ...@@ -960,21 +977,233 @@ using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type; using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type; using f8x64_t = typename vector_type<f8_t, 64>::type;
template <typename T> // Convert X to Y
struct NumericLimits template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert(X x)
{ {
__host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); } static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); } return static_cast<Y>(x);
}
// convert bfp16 to fp32
template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); } return u.fp32;
}
__host__ __device__ static constexpr T QuietNaN() // convert fp32 to bfp16
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{ {
return std::numeric_limits<T>::quiet_NaN(); float fp32;
} uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16);
}
// convert bfp16 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<half_t>(x_fp32);
}
// convert fp16 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int32_t>(x_fp32);
}
// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int8 via fp32
template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int8_t>(x_fp32);
}
// convert int8 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
__host__ __device__ static constexpr T Infinity() { return std::numeric_limits<T>::infinity(); } // Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
template <typename T>
struct NumericLimits;
template <>
struct NumericLimits<int32_t>
{
__host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; }
__host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int16_t>
{
__host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Min() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Max() noexcept { return 32767; }
__host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int8_t>
{
__host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Min() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Max() noexcept { return 127; }
__host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int8_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint32_t>
{
__host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; }
__host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint16_t>
{
__host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; }
__host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint8_t>
{
__host__ __device__ static constexpr uint8_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint8_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint8_t Max() noexcept { return 255U; }
__host__ __device__ static constexpr uint8_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint8_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<float>
{
static constexpr unsigned int binary_min = 0x00800000;
static constexpr unsigned int binary_max = 0x7F7FFFFF;
static constexpr unsigned int binary_lowest = 0xFF7FFFFF;
static constexpr unsigned int binary_qnan = 0xFFC00001;
static constexpr unsigned int binary_inf = 0x7F8000000;
__host__ __device__ static constexpr float Min() { return bit_cast<float>(binary_min); }
__host__ __device__ static constexpr float Max() { return bit_cast<float>(binary_max); }
__host__ __device__ static constexpr float Lowest() { return bit_cast<float>(binary_lowest); }
__host__ __device__ static constexpr float QuietNaN() { return bit_cast<float>(binary_qnan); }
__host__ __device__ static constexpr float Infinity() { return bit_cast<float>(binary_inf); }
}; };
template <> template <>
...@@ -1024,3 +1253,5 @@ struct NumericLimits<f8_t> ...@@ -1024,3 +1253,5 @@ struct NumericLimits<f8_t>
}; };
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP #ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace ck { namespace ck {
namespace debug { namespace debug {
namespace detail { namespace detail {
template <typename T, typename Enable = void> template <typename T, typename Enable = void> struct PrintAsType;
struct PrintAsType;
template <typename T> template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type> struct PrintAsType<
{ T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
using type = float; using type = float;
__host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); } __host__ __device__ static void Print(const T &p) {
printf("%.3f ", static_cast<type>(p));
}
}; };
template <> template <> struct PrintAsType<ck::half_t, void> {
struct PrintAsType<ck::half_t, void> using type = float;
{ __host__ __device__ static void Print(const ck::half_t &p) {
using type = float; printf("%.3f ", static_cast<type>(p));
__host__ __device__ static void Print(const ck::half_t& p) }
{
printf("%.3f ", static_cast<type>(p));
}
}; };
template <typename T> template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type> struct PrintAsType<T,
{ typename std::enable_if<std::is_integral<T>::value>::type> {
using type = int; using type = int;
__host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); } __host__ __device__ static void Print(const T &p) {
printf("%d ", static_cast<type>(p));
}
}; };
} // namespace detail } // namespace detail
// Print at runtime the data in shared memory in 128 bytes per row format given shared mem pointer // Print at runtime the data in shared memory in 128 bytes per row format given
// and the number of elements. Can optionally specify strides between elements and how many bytes' // shared mem pointer and the number of elements. Can optionally specify strides
// worth of data per row. // between elements and how many bytes' worth of data per row.
// //
// Usage example: // Usage example:
// //
// debug::print_shared(a_block_buf.p_data_, index_t(a_block_desc_k0_m_k1.GetElementSpaceSize())); // debug::print_shared(a_block_buf.p_data_,
// index_t(a_block_desc_k0_m_k1.GetElementSpaceSize()));
// //
template <typename T, index_t element_stride = 1, index_t row_bytes = 128> template <typename T, index_t element_stride = 1, index_t row_bytes = 128>
__device__ void print_shared(T const* p_shared, index_t num_elements) __device__ void print_shared(T const *p_shared, index_t num_elements) {
{ constexpr index_t row_elements = row_bytes / sizeof(T);
constexpr index_t row_elements = row_bytes / sizeof(T); static_assert((element_stride >= 1 && element_stride <= row_elements),
static_assert((element_stride >= 1 && element_stride <= row_elements), "element_stride should between [1, row_elements]");
"element_stride should between [1, row_elements]");
index_t wgid = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; index_t wgid =
index_t tid = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
(threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x; index_t tid = (threadIdx.z * (blockDim.x * blockDim.y)) +
(threadIdx.y * blockDim.x) + threadIdx.x;
__syncthreads(); __syncthreads();
if(tid == 0) if (tid == 0) {
{ printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n", wgid,
printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n", row_bytes, element_stride);
wgid, for (index_t i = 0; i < num_elements; i += row_elements) {
row_bytes, printf("elem %5d: ", i);
element_stride); for (index_t j = 0; j < row_elements; j += element_stride) {
for(index_t i = 0; i < num_elements; i += row_elements) detail::PrintAsType<T>::Print(p_shared[i + j]);
{ }
printf("elem %5d: ", i);
for(index_t j = 0; j < row_elements; j += element_stride)
{
detail::PrintAsType<T>::Print(p_shared[i + j]);
}
printf("\n"); printf("\n");
}
printf("\n");
} }
printf("\n");
}
__syncthreads(); __syncthreads();
} }
} // namespace debug } // namespace debug
} // namespace ck } // namespace ck
#endif // UTILITY_DEBUG_HPP #endif // UTILITY_DEBUG_HPP
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -405,3 +408,5 @@ make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element ...@@ -405,3 +408,5 @@ make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element
} }
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifdef __HIPCC_RTC__
namespace std {
template <bool B, typename T = void>
using enable_if_t = typename enable_if<B, T>::type;
} // namespace std
#endif
namespace ck { namespace ck {
template <bool B, typename T = void> template <bool B, typename T = void> using enable_if = std::enable_if<B, T>;
using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void> template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type; using enable_if_t = typename std::enable_if<B, T>::type;
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -129,3 +132,5 @@ constexpr auto conditional_expr(X&& x, Y&& y) ...@@ -129,3 +132,5 @@ constexpr auto conditional_expr(X&& x, Y&& y)
} }
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -47,3 +50,5 @@ struct static_for ...@@ -47,3 +50,5 @@ struct static_for
}; };
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -142,3 +145,5 @@ struct ford ...@@ -142,3 +145,5 @@ struct ford
}; };
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -63,3 +66,5 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y) ...@@ -63,3 +66,5 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
} // namespace ck } // namespace ck
#endif #endif
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -121,3 +124,5 @@ __device__ float2_t atomic_max<float2_t>(float2_t* p_dst, const float2_t& x) ...@@ -121,3 +124,5 @@ __device__ float2_t atomic_max<float2_t>(float2_t* p_dst, const float2_t& x)
} }
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -26,3 +29,5 @@ __device__ index_t get_grid_size() { return gridDim.x; } ...@@ -26,3 +29,5 @@ __device__ index_t get_grid_size() { return gridDim.x; }
__device__ index_t get_block_size() { return blockDim.x; } __device__ index_t get_block_size() { return blockDim.x; }
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -20,3 +23,5 @@ struct ignore_t ...@@ -20,3 +23,5 @@ struct ignore_t
inline constexpr detail::ignore_t ignore; inline constexpr detail::ignore_t ignore;
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -234,3 +237,5 @@ inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t ...@@ -234,3 +237,5 @@ inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t
} }
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -5,47 +8,50 @@ ...@@ -5,47 +8,50 @@
namespace ck { namespace ck {
template <class T, T v> template <class T, T v> struct integral_constant {
struct integral_constant static constexpr T value = v;
{ typedef T value_type;
static constexpr T value = v; typedef integral_constant type;
typedef T value_type; __host__ __device__ constexpr operator value_type() const noexcept {
typedef integral_constant type; return value;
__host__ __device__ constexpr operator value_type() const noexcept { return value; } }
__host__ __device__ constexpr value_type operator()() const noexcept { return value; } __host__ __device__ constexpr value_type operator()() const noexcept {
return value;
}
}; };
template <typename TX, TX X, typename TY, TY Y> template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator+(integral_constant<TX, X>, integral_constant<TY, Y>) __host__ __device__ constexpr auto operator+(integral_constant<TX, X>,
{ integral_constant<TY, Y>) {
return integral_constant<decltype(X + Y), X + Y>{}; return integral_constant<decltype(X + Y), X + Y>{};
} }
template <typename TX, TX X, typename TY, TY Y> template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator-(integral_constant<TX, X>, integral_constant<TY, Y>) __host__ __device__ constexpr auto operator-(integral_constant<TX, X>,
{ integral_constant<TY, Y>) {
static_assert(Y <= X, "wrong!"); static_assert(Y <= X, "wrong!");
return integral_constant<decltype(X - Y), X - Y>{}; return integral_constant<decltype(X - Y), X - Y>{};
} }
template <typename TX, TX X, typename TY, TY Y> template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator*(integral_constant<TX, X>, integral_constant<TY, Y>) __host__ __device__ constexpr auto operator*(integral_constant<TX, X>,
{ integral_constant<TY, Y>) {
return integral_constant<decltype(X * Y), X * Y>{}; return integral_constant<decltype(X * Y), X * Y>{};
} }
template <typename TX, TX X, typename TY, TY Y> template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator/(integral_constant<TX, X>, integral_constant<TY, Y>) __host__ __device__ constexpr auto operator/(integral_constant<TX, X>,
{ integral_constant<TY, Y>) {
static_assert(Y > 0, "wrong!"); static_assert(Y > 0, "wrong!");
return integral_constant<decltype(X / Y), X / Y>{}; return integral_constant<decltype(X / Y), X / Y>{};
} }
template <typename TX, TX X, typename TY, TY Y> template <typename TX, TX X, typename TY, TY Y>
__host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_constant<TY, Y>) __host__ __device__ constexpr auto operator%(integral_constant<TX, X>,
{ integral_constant<TY, Y>) {
static_assert(Y > 0, "wrong!"); static_assert(Y > 0, "wrong!");
return integral_constant<decltype(X % Y), X % Y>{}; return integral_constant<decltype(X % Y), X % Y>{};
} }
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -54,3 +57,5 @@ struct is_known_at_compile_time<Tuple<Ts...>> ...@@ -54,3 +57,5 @@ struct is_known_at_compile_time<Tuple<Ts...>>
}; };
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -6,155 +9,144 @@ ...@@ -6,155 +9,144 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "number.hpp" #include "number.hpp"
#include "type.hpp"
#include "tuple.hpp" #include "tuple.hpp"
#include "type.hpp"
#define INT32_MAX 2147483647
namespace ck { namespace ck {
// magic number division // magic number division
// Caution: // Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce // 1. For uint32_t as dividend: magic number division implementation being
// correct result if the dividend is uint32_t and its value is within 31-bit value range. // used would produce correct result if the dividend is uint32_t and its value
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been // is within 31-bit value range.
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number // 2. For int32_t as dividendd: magic number division for int32_t dividened
// division implementation for uint32_t is then used. Therefore, dividend value need to be // has not been implemented, the int32_t dividend would be bit-wise
// non-negative. // interpreted as uint32_t and magic number division implementation for
// uint32_t is then used. Therefore, dividend value need to be non-negative.
// TODO: // TODO:
// 1. Implement magic number divison for int32_t // 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range // 2. Implement magic number divison for unit32_t with 32-bit value range
struct MagicDivision struct MagicDivision {
{ // uint32_t
// uint32_t __host__ __device__ static constexpr auto
__host__ __device__ static constexpr auto CalculateMagicNumbers(uint32_t divisor) CalculateMagicNumbers(uint32_t divisor) {
{ // WARNING: magic division is only applicable for division inside this
// WARNING: magic division is only applicable for division inside this range. // range. You should use the return value of CalculateMagicNumbers, if
// You should use the return value of CalculateMagicNumbers, if division is not inside this // division is not inside this range. The "else" logic below is to quiet
// range. The "else" logic below is to quiet down run-time error. // down run-time error.
if(divisor >= 1 && divisor <= INT32_MAX) if (divisor >= 1 && divisor <= INT32_MAX) {
{ uint32_t shift = 0;
uint32_t shift = 0; for (shift = 0; shift < 32; ++shift) {
for(shift = 0; shift < 32; ++shift) if ((1U << shift) >= divisor) {
{ break;
if((1U << shift) >= divisor)
{
break;
}
}
uint64_t one = 1;
uint64_t multiplier = ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
// assert(multiplier <= 0xffffffffUL);
return make_tuple(uint32_t(multiplier), shift);
}
else
{
return make_tuple(uint32_t(0), uint32_t(0));
} }
} }
__host__ __device__ static constexpr uint32_t CalculateMagicMultiplier(uint32_t divisor) uint64_t one = 1;
{ uint64_t multiplier =
auto tmp = CalculateMagicNumbers(divisor); ((one << 32) * ((one << shift) - divisor)) / divisor + 1;
// assert(multiplier <= 0xffffffffUL);
return tmp[Number<0>{}];
} return make_tuple(uint32_t(multiplier), shift);
} else {
__host__ __device__ static constexpr uint32_t CalculateMagicShift(uint32_t divisor) return make_tuple(uint32_t(0), uint32_t(0));
{ }
auto tmp = CalculateMagicNumbers(divisor); }
return tmp[Number<1>{}]; __host__ __device__ static constexpr uint32_t
} CalculateMagicMultiplier(uint32_t divisor) {
auto tmp = CalculateMagicNumbers(divisor);
// integral_constant<uint32_t, .>
template <uint32_t Divisor> return tmp[Number<0>{}];
__host__ __device__ static constexpr auto }
CalculateMagicNumbers(integral_constant<uint32_t, Divisor>)
{ __host__ __device__ static constexpr uint32_t
constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor}); CalculateMagicShift(uint32_t divisor) {
auto tmp = CalculateMagicNumbers(divisor);
constexpr uint32_t multiplier = tmp[Number<0>{}];
constexpr uint32_t shift = tmp[Number<1>{}]; return tmp[Number<1>{}];
}
return make_tuple(integral_constant<uint32_t, multiplier>{},
integral_constant<uint32_t, shift>{}); // integral_constant<uint32_t, .>
} template <uint32_t Divisor>
__host__ __device__ static constexpr auto
template <uint32_t Divisor> CalculateMagicNumbers(integral_constant<uint32_t, Divisor>) {
__host__ __device__ static constexpr auto constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor});
CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>)
{ constexpr uint32_t multiplier = tmp[Number<0>{}];
constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor}); constexpr uint32_t shift = tmp[Number<1>{}];
return integral_constant<uint32_t, multiplier>{}; return make_tuple(integral_constant<uint32_t, multiplier>{},
} integral_constant<uint32_t, shift>{});
}
template <uint32_t Divisor>
__host__ __device__ static constexpr auto template <uint32_t Divisor>
CalculateMagicShift(integral_constant<uint32_t, Divisor>) __host__ __device__ static constexpr auto
{ CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>) {
constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor}); constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor});
return integral_constant<uint32_t, shift>{}; return integral_constant<uint32_t, multiplier>{};
} }
// integral_constant<int32_t, .> template <uint32_t Divisor>
template <int32_t Divisor> __host__ __device__ static constexpr auto
__host__ __device__ static constexpr auto CalculateMagicShift(integral_constant<uint32_t, Divisor>) {
CalculateMagicNumbers(integral_constant<int32_t, Divisor>) constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor});
{
return CalculateMagicNumbers(integral_constant<uint32_t, Divisor>{}); return integral_constant<uint32_t, shift>{};
} }
template <int32_t Divisor> // integral_constant<int32_t, .>
__host__ __device__ static constexpr auto template <int32_t Divisor>
CalculateMagicMultiplier(integral_constant<int32_t, Divisor>) __host__ __device__ static constexpr auto
{ CalculateMagicNumbers(integral_constant<int32_t, Divisor>) {
return CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>{}); return CalculateMagicNumbers(integral_constant<uint32_t, Divisor>{});
} }
template <int32_t Divisor> template <int32_t Divisor>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
CalculateMagicShift(integral_constant<int32_t, Divisor>) CalculateMagicMultiplier(integral_constant<int32_t, Divisor>) {
{ return CalculateMagicMultiplier(integral_constant<uint32_t, Divisor>{});
return CalculateMagicShift(integral_constant<uint32_t, Divisor>{}); }
}
template <int32_t Divisor>
// magic division for uint32_t __host__ __device__ static constexpr auto
__device__ static constexpr uint32_t CalculateMagicShift(integral_constant<int32_t, Divisor>) {
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) return CalculateMagicShift(integral_constant<uint32_t, Divisor>{});
{ }
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift; // magic division for uint32_t
} __device__ static constexpr uint32_t
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) {
__host__ static constexpr uint32_t uint32_t tmp = __umulhi(dividend, multiplier);
DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) return (tmp + dividend) >> shift;
{ }
uint32_t tmp = static_cast<uint64_t>(dividend) * multiplier >> 32;
return (tmp + dividend) >> shift; __host__ static constexpr uint32_t
} DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) {
uint32_t tmp = static_cast<uint64_t>(dividend) * multiplier >> 32;
// magic division for int32_t return (tmp + dividend) >> shift;
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be }
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended // magic division for int32_t
__device__ static constexpr int32_t // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) // non-negative for result to be correct
{ // TODO: figure out how to do magic number divison for int32_t as dividended
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32); __device__ static constexpr int32_t
uint32_t tmp = __umulhi(dividend_u32, multiplier); DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) {
return (tmp + dividend_u32) >> shift; uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
} uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
__host__ static constexpr int32_t }
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{ __host__ static constexpr int32_t
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32); DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) {
uint32_t tmp = static_cast<uint64_t>(dividend_u32) * multiplier >> 32; uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
return (tmp + dividend_u32) >> shift; uint32_t tmp = static_cast<uint64_t>(dividend_u32) * multiplier >> 32;
} return (tmp + dividend_u32) >> shift;
}
}; };
struct MDiv struct MDiv
...@@ -230,3 +222,5 @@ struct MDiv2 ...@@ -230,3 +222,5 @@ struct MDiv2
}; };
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "enable_if.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "number.hpp" #include "number.hpp"
#include "type.hpp" #include "type.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
template <typename T, T s> template <typename T, T s> struct scales {
struct scales __host__ __device__ constexpr T operator()(T a) const { return s * a; }
{
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
}; };
template <typename T> template <typename T> struct plus {
struct plus __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
}; };
template <typename T> template <typename T> struct minus {
struct minus __host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
}; };
struct multiplies struct multiplies {
{ template <typename A, typename B>
template <typename A, typename B> __host__ __device__ constexpr auto operator()(const A &a, const B &b) const {
__host__ __device__ constexpr auto operator()(const A& a, const B& b) const return a * b;
{ }
return a * b;
}
}; };
template <typename T> template <typename T> struct maximize {
struct maximize __host__ __device__ constexpr T operator()(T a, T b) const {
{ return a >= b ? a : b;
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; } }
}; };
template <typename T> template <typename T> struct minimize {
struct minimize __host__ __device__ constexpr T operator()(T a, T b) const {
{ return a <= b ? a : b;
__host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; } }
}; };
template <typename T> template <typename T> struct integer_divide_ceiler {
struct integer_divide_ceiler __host__ __device__ constexpr T operator()(T a, T b) const {
{ static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
__host__ __device__ constexpr T operator()(T a, T b) const
{
static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
return (a + b - Number<1>{}) / b; return (a + b - Number<1>{}) / b;
} }
}; };
template <typename X, typename Y> template <typename X, typename Y>
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y) __host__ __device__ constexpr auto integer_divide_floor(X x, Y y) {
{ return x / y;
return x / y;
} }
template <typename X, typename Y> template <typename X, typename Y>
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) __host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) {
{ return (x + y - Number<1>{}) / y;
return (x + y - Number<1>{}) / y;
} }
template <typename X, typename Y> template <typename X, typename Y>
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y) __host__ __device__ constexpr auto integer_least_multiple(X x, Y y) {
{ return y * integer_divide_ceil(x, y);
return y * integer_divide_ceil(x, y);
} }
template <typename T> template <typename T> __host__ __device__ constexpr T max(T x) { return x; }
__host__ __device__ constexpr T max(T x)
{
return x;
}
template <typename T> template <typename T> __host__ __device__ constexpr T max(T x, T y) {
__host__ __device__ constexpr T max(T x, T y) return x > y ? x : y;
{
return x > y ? x : y;
} }
template <index_t X> template <index_t X>
__host__ __device__ constexpr index_t max(Number<X>, index_t y) __host__ __device__ constexpr index_t max(Number<X>, index_t y) {
{ return X > y ? X : y;
return X > y ? X : y;
} }
template <index_t Y> template <index_t Y>
__host__ __device__ constexpr index_t max(index_t x, Number<Y>) __host__ __device__ constexpr index_t max(index_t x, Number<Y>) {
{ return x > Y ? x : Y;
return x > Y ? x : Y;
} }
template <typename X, typename... Ys> template <typename X, typename... Ys>
__host__ __device__ constexpr auto max(X x, Ys... ys) __host__ __device__ constexpr auto max(X x, Ys... ys) {
{ static_assert(sizeof...(Ys) > 0, "not enough argument");
static_assert(sizeof...(Ys) > 0, "not enough argument");
return max(x, max(ys...)); return max(x, max(ys...));
} }
template <typename T> template <typename T> __host__ __device__ constexpr T min(T x) { return x; }
__host__ __device__ constexpr T min(T x)
{
return x;
}
template <typename T> template <typename T> __host__ __device__ constexpr T min(T x, T y) {
__host__ __device__ constexpr T min(T x, T y) return x < y ? x : y;
{
return x < y ? x : y;
} }
template <index_t X> template <index_t X>
__host__ __device__ constexpr index_t min(Number<X>, index_t y) __host__ __device__ constexpr index_t min(Number<X>, index_t y) {
{ return X < y ? X : y;
return X < y ? X : y;
} }
template <index_t Y> template <index_t Y>
__host__ __device__ constexpr index_t min(index_t x, Number<Y>) __host__ __device__ constexpr index_t min(index_t x, Number<Y>) {
{ return x < Y ? x : Y;
return x < Y ? x : Y;
} }
template <typename X, typename... Ys> template <typename X, typename... Ys>
__host__ __device__ constexpr auto min(X x, Ys... ys) __host__ __device__ constexpr auto min(X x, Ys... ys) {
{ static_assert(sizeof...(Ys) > 0, "not enough argument");
static_assert(sizeof...(Ys) > 0, "not enough argument");
return min(x, min(ys...)); return min(x, min(ys...));
} }
template <typename T> template <typename T>
__host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound) __host__ __device__ constexpr T clamp(const T &x, const T &lowerbound,
{ const T &upperbound) {
return min(max(x, lowerbound), upperbound); return min(max(x, lowerbound), upperbound);
} }
// disallow implicit type casting // disallow implicit type casting
template <typename T> template <typename T> __device__ T exp(T x);
__device__ T exp(T x);
// TODO: add f16 support using v_exp_f16 // TODO: add f16 support using v_exp_f16
template <> template <> __device__ float exp<float>(float x) { return __expf(x); }
__device__ float exp<float>(float x)
{
return __expf(x);
}
template <> template <> __device__ double exp<double>(double x) { return exp(x); }
__device__ double exp<double>(double x)
{
return exp(x);
}
static inline __host__ float exp(float x) { return ::expf(x); } // static inline __host__ float exp(float x) { return ::expf(x); }
static inline __host__ double exp(double x) { return std::exp(x); } // static inline __host__ double exp(double x) { return std::exp(x); }
// greatest common divisor, aka highest common factor // greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y) __host__ __device__ constexpr index_t gcd(index_t x, index_t y) {
{ if (x < 0) {
if(x < 0) return gcd(-x, y);
{ } else if (y < 0) {
return gcd(-x, y); return gcd(x, -y);
} } else if (x == y || x == 0) {
else if(y < 0) return y;
{ } else if (y == 0) {
return gcd(x, -y); return x;
} } else if (x > y) {
else if(x == y || x == 0) return gcd(x % y, y);
{ } else {
return y; return gcd(x, y % x);
} }
else if(y == 0)
{
return x;
}
else if(x > y)
{
return gcd(x % y, y);
}
else
{
return gcd(x, y % x);
}
} }
template <index_t X, index_t Y> template <index_t X, index_t Y>
__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>) __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>) {
{ constexpr auto r = gcd(X, Y);
constexpr auto r = gcd(X, Y);
return Number<r>{}; return Number<r>{};
} }
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false> template <typename X, typename... Ys,
__host__ __device__ constexpr auto gcd(X x, Ys... ys) typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
{ __host__ __device__ constexpr auto gcd(X x, Ys... ys) {
return gcd(x, gcd(ys...)); return gcd(x, gcd(ys...));
} }
// least common multiple // least common multiple
template <typename X, typename Y> template <typename X, typename Y>
__host__ __device__ constexpr auto lcm(X x, Y y) __host__ __device__ constexpr auto lcm(X x, Y y) {
{ return (x * y) / gcd(x, y);
return (x * y) / gcd(x, y);
} }
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false> template <typename X, typename... Ys,
__host__ __device__ constexpr auto lcm(X x, Ys... ys) typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
{ __host__ __device__ constexpr auto lcm(X x, Ys... ys) {
return lcm(x, lcm(ys...)); return lcm(x, lcm(ys...));
} }
template <typename T> template <typename T> struct equal {
struct equal __host__ __device__ constexpr bool operator()(T x, T y) const {
{ return x == y;
__host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; } }
}; };
template <typename T> template <typename T> struct less {
struct less __host__ __device__ constexpr bool operator()(T x, T y) const {
{ return x < y;
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; } }
}; };
template <index_t X> template <index_t X>
...@@ -258,3 +206,5 @@ __host__ __device__ constexpr auto next_power_of_two(Number<X> x) ...@@ -258,3 +206,5 @@ __host__ __device__ constexpr auto next_power_of_two(Number<X> x)
} // namespace math } // namespace math
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -13,177 +16,169 @@ ...@@ -13,177 +16,169 @@
namespace ck { namespace ck {
namespace math { namespace math {
// math functions for the host, some are implemented by calling C++ std functions // math functions for the host, some are implemented by calling C++ std
// functions
static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ float abs(float x) { return x < 0 ? x * -1.0 : x; };
static inline __host__ double abs(double x) { return std::abs(x); }; static inline __host__ double abs(double x) { return x < 0 ? x * -1.0 : x; };
static inline __host__ int8_t abs(int8_t x) static inline __host__ int8_t abs(int8_t x) {
{ int8_t sgn = x >> (8 - 1);
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn; return (x ^ sgn) - sgn;
}; };
static inline __host__ int32_t abs(int32_t x) static inline __host__ int32_t abs(int32_t x) {
{ int32_t sgn = x >> (32 - 1);
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn; return (x ^ sgn) - sgn;
}; };
static inline __host__ half_t abs(half_t x) static inline __host__ half_t abs(half_t x) {
{ uint16_t xx = ck::bit_cast<uint16_t>(x);
uint16_t xx = ck::bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff; uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = ck::bit_cast<half_t>(abs_xx); half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x; return abs_x;
}; };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ int4_t abs(int4_t x) static inline __host__ int4_t abs(int4_t x) {
{ int4_t sgn = x >> (4 - 1);
int4_t sgn = x >> (4 - 1); return (x ^ sgn) - sgn;
return (x ^ sgn) - sgn;
} }
#endif #endif
static inline __host__ bool isnan(float x) { return std::isnan(x); }; // TODO: to bit arithmetic to figure it out
static inline __host__ bool isnan(float x) {
(void)x;
return false;
};
static inline __host__ bool isnan(double x) { return std::isnan(x); }; static inline __host__ bool isnan(double x) {
(void)x;
return false;
};
static inline __host__ bool isnan(int8_t x) static inline __host__ bool isnan(int8_t x) {
{ (void)x;
(void)x; return false;
return false;
}; };
static inline __host__ bool isnan(int32_t x) static inline __host__ bool isnan(int32_t x) {
{ (void)x;
(void)x; return false;
return false;
}; };
static inline __host__ bool isnan(half_t x) static inline __host__ bool isnan(half_t x) {
{ uint16_t xx = ck::bit_cast<uint16_t>(x);
uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x) static inline __host__ bool isnan(int4_t x) {
{ (void)x;
(void)x; return false;
return false;
}; };
#endif #endif
static inline __host__ half_t sqrt(half_t x) // MIGRAPHX doesn't care about host compilation, just return identity values for
{ // now
return static_cast<half_t>(std::sqrt(static_cast<float>(x)));
};
static inline __host__ float sqrt(float x) { return std::sqrt(x); }; static inline __host__ half_t sqrt(half_t x) { return x; };
static inline __host__ double sqrt(double x) { return std::sqrt(x); }; static inline __host__ float sqrt(float x) { return x; };
static inline __host__ half_t tanh(half_t x) static inline __host__ double sqrt(double x) { return x; };
{
return static_cast<half_t>(std::tanh(static_cast<float>(x)));
};
static inline __host__ float tanh(float x) { return std::tanh(x); }; static inline __host__ half_t tanh(half_t x) { return x; };
static inline __host__ double tanh(double x) { return std::tanh(x); }; static inline __host__ float tanh(float x) { return x; };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions static inline __host__ double tanh(double x) { return x; };
// math functions for the HIP kernel, some are implemented by calling hip
// builtin functions
static inline __device__ float abs(float x) { return ::abs(x); }; static inline __device__ float abs(float x) { return ::abs(x); };
static inline __device__ double abs(double x) { return ::abs(x); }; static inline __device__ double abs(double x) { return ::abs(x); };
static inline __device__ int8_t abs(int8_t x) static inline __device__ int8_t abs(int8_t x) {
{ int8_t sgn = x >> (8 - 1);
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn; return (x ^ sgn) - sgn;
}; };
static inline __device__ int32_t abs(int32_t x) static inline __device__ int32_t abs(int32_t x) {
{ int32_t sgn = x >> (32 - 1);
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn; return (x ^ sgn) - sgn;
}; };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __device__ int4_t abs(int4_t x) static inline __device__ int4_t abs(int4_t x) {
{ int4_t sgn = x >> (4 - 1);
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn; return (x ^ sgn) - sgn;
}; };
#endif #endif
static inline __device__ half_t abs(half_t x) static inline __device__ half_t abs(half_t x) {
{ uint16_t xx = ck::bit_cast<uint16_t>(x);
uint16_t xx = ck::bit_cast<uint16_t>(x);
uint16_t abs_xx = xx & 0x7fff; uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = ck::bit_cast<half_t>(abs_xx); half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x; return abs_x;
}; };
static inline __device__ bool isnan(float x) { return ::isnan(x); }; static inline __device__ bool isnan(float x) { return ::isnan(x); };
static inline __device__ bool isnan(double x) { return ::isnan(x); }; static inline __device__ bool isnan(double x) { return ::isnan(x); };
static inline __device__ bool isnan(int8_t x) static inline __device__ bool isnan(int8_t x) {
{ (void)x;
(void)x; return false;
return false;
}; };
static inline __device__ bool isnan(int32_t x) static inline __device__ bool isnan(int32_t x) {
{ (void)x;
(void)x; return false;
return false;
}; };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __device__ bool isnan(int4_t x) static inline __device__ bool isnan(int4_t x) {
{ (void)x;
(void)x; return false;
return false;
}; };
#endif #endif
static inline __device__ bool isnan(half_t x) static inline __device__ bool isnan(half_t x) {
{ uint16_t xx = ck::bit_cast<uint16_t>(x);
uint16_t xx = ck::bit_cast<uint16_t>(x);
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __device__ half_t sqrt(half_t x) static inline __device__ half_t sqrt(half_t x) {
{ return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
}; };
static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); }; static inline __device__ float sqrt(float x) {
return __builtin_amdgcn_sqrtf(x);
};
static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; static inline __device__ double sqrt(double x) {
return __builtin_amdgcn_sqrt(x);
};
static inline __device__ half_t tanh(half_t x) static inline __device__ half_t tanh(half_t x) {
{ return static_cast<half_t>(::tanhf(static_cast<float>(x)));
return static_cast<half_t>(::tanhf(static_cast<float>(x)));
}; };
static inline __device__ float tanh(float x) { return ::tanhf(x); }; static inline __device__ float tanh(float x) { return ::tanhf(x); };
...@@ -192,3 +187,5 @@ static inline __device__ double tanh(double x) { return ::tanh(x); }; ...@@ -192,3 +187,5 @@ static inline __device__ double tanh(double x) { return ::tanh(x); };
} // namespace math } // namespace math
} // namespace ck } // namespace ck
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -10,3 +13,5 @@ ...@@ -10,3 +13,5 @@
#else #else
#include "statically_indexed_array_multi_index.hpp" #include "statically_indexed_array_multi_index.hpp"
#endif #endif
#pragma clang diagnostic pop
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...@@ -16,3 +19,5 @@ using LongNumber = integral_constant<long_index_t, N>; ...@@ -16,3 +19,5 @@ using LongNumber = integral_constant<long_index_t, N>;
} // namespace ck } // namespace ck
#endif #endif
#pragma clang diagnostic pop
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment