Commit 5d015452 authored by Chaitanya Inumella's avatar Chaitanya Inumella
Browse files

Rebased the hipTENSOR development branch with the contraction branch

parents b7fa6bb1 ed3feb4d
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CONTAINER_ELEMENT_PICKER_HPP #ifndef CK_CONTAINER_ELEMENT_PICKER_HPP
#define CK_CONTAINER_ELEMENT_PICKER_HPP #define CK_CONTAINER_ELEMENT_PICKER_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CONTAINER_HELPER_HPP #ifndef CK_CONTAINER_HELPER_HPP
#define CK_CONTAINER_HELPER_HPP #define CK_CONTAINER_HELPER_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "statically_indexed_array.hpp"
#include "ck/utility/statically_indexed_array.hpp"
namespace ck { namespace ck {
...@@ -928,14 +932,14 @@ using int8x64_t = typename vector_type<int8_t, 64>::type; ...@@ -928,14 +932,14 @@ using int8x64_t = typename vector_type<int8_t, 64>::type;
// Convert X to Y // Convert X to Y
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ Y type_convert(X x) __host__ __device__ constexpr Y type_convert(X x)
{ {
return static_cast<Y>(x); return static_cast<Y>(x);
} }
// convert bfp16 to fp32 // convert bfp16 to fp32
template <> template <>
inline __host__ __device__ float type_convert<float, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
{ {
union union
{ {
...@@ -948,7 +952,7 @@ inline __host__ __device__ float type_convert<float, bhalf_t>(bhalf_t x) ...@@ -948,7 +952,7 @@ inline __host__ __device__ float type_convert<float, bhalf_t>(bhalf_t x)
// convert fp32 to bfp16 // convert fp32 to bfp16
template <> template <>
inline __host__ __device__ bhalf_t type_convert<bhalf_t, float>(float x) inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{ {
union union
{ {
...@@ -1000,6 +1004,11 @@ struct NumericLimits ...@@ -1000,6 +1004,11 @@ struct NumericLimits
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); } __host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); } __host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
__host__ __device__ static constexpr T QuietNaN()
{
return std::numeric_limits<T>::quiet_NaN();
}
}; };
template <> template <>
...@@ -1008,12 +1017,15 @@ struct NumericLimits<half_t> ...@@ -1008,12 +1017,15 @@ struct NumericLimits<half_t>
static constexpr unsigned short binary_min = 0x0400; static constexpr unsigned short binary_min = 0x0400;
static constexpr unsigned short binary_max = 0x7BFF; static constexpr unsigned short binary_max = 0x7BFF;
static constexpr unsigned short binary_lowest = 0xFBFF; static constexpr unsigned short binary_lowest = 0xFBFF;
static constexpr unsigned short binary_qnan = 0x7FFF;
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); } __host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); } __host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); } __host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
}; };
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP #ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP
...@@ -9,21 +12,27 @@ template <typename T, typename Enable = void> ...@@ -9,21 +12,27 @@ template <typename T, typename Enable = void>
struct PrintAsType; struct PrintAsType;
template <typename T> template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::value> struct PrintAsType<T, typename std::enable_if<std::is_floating_point<T>::value>::type>
{ {
using type = float; using type = float;
__host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); }
}; };
template <> template <>
struct PrintAsType<ck::half_t, void> struct PrintAsType<ck::half_t, void>
{ {
using type = float; using type = float;
__host__ __device__ static void Print(const ck::half_t& p)
{
printf("%.3f ", static_cast<type>(p));
}
}; };
template <typename T> template <typename T>
struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value> struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::type>
{ {
using type = int; using type = int;
__host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); }
}; };
} // namespace detail } // namespace detail
...@@ -38,7 +47,6 @@ struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value ...@@ -38,7 +47,6 @@ struct PrintAsType<T, typename std::enable_if<std::is_integral<T>::value>::value
template <typename T, index_t element_stride = 1, index_t row_bytes = 128> template <typename T, index_t element_stride = 1, index_t row_bytes = 128>
__device__ void print_shared(T const* p_shared, index_t num_elements) __device__ void print_shared(T const* p_shared, index_t num_elements)
{ {
using PrintType = typename detail::PrintAsType<T>::type;
constexpr index_t row_elements = row_bytes / sizeof(T); constexpr index_t row_elements = row_bytes / sizeof(T);
static_assert((element_stride >= 1 && element_stride <= row_elements), static_assert((element_stride >= 1 && element_stride <= row_elements),
"element_stride should between [1, row_elements]"); "element_stride should between [1, row_elements]");
...@@ -60,7 +68,7 @@ __device__ void print_shared(T const* p_shared, index_t num_elements) ...@@ -60,7 +68,7 @@ __device__ void print_shared(T const* p_shared, index_t num_elements)
printf("elem %5d: ", i); printf("elem %5d: ", i);
for(index_t j = 0; j < row_elements; j += element_stride) for(index_t j = 0; j < row_elements; j += element_stride)
{ {
printf("%.0f ", static_cast<PrintType>(p_shared[i + j])); detail::PrintAsType<T>::Print(p_shared[i + j]);
} }
printf("\n"); printf("\n");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "config.hpp"
#include "ck/ck.hpp"
#include "enable_if.hpp" #include "enable_if.hpp"
#include "c_style_pointer_cast.hpp" #include "c_style_pointer_cast.hpp"
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
......
#ifndef CK_ENABLE_IF_HPP // SPDX-License-Identifier: MIT
#define CK_ENABLE_IF_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck { namespace ck {
...@@ -10,4 +12,3 @@ template <bool B, typename T = void> ...@@ -10,4 +12,3 @@ template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type; using enable_if_t = typename std::enable_if<B, T>::type;
} // namespace ck } // namespace ck
#endif
#ifndef CK_FUNCTIONAL_HPP // SPDX-License-Identifier: MIT
#define CK_FUNCTIONAL_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "integral_constant.hpp" #pragma once
#include "type.hpp"
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp"
namespace ck { namespace ck {
...@@ -113,4 +115,3 @@ template <bool predicate, class X, class Y> ...@@ -113,4 +115,3 @@ template <bool predicate, class X, class Y>
using conditional_t = typename conditional<predicate, X, Y>::type; using conditional_t = typename conditional<predicate, X, Y>::type;
} // namespace ck } // namespace ck
#endif
#ifndef CK_FUNCTIONAL2_HPP // SPDX-License-Identifier: MIT
#define CK_FUNCTIONAL2_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "functional.hpp" #pragma once
#include "sequence.hpp"
#include "ck/utility/functional.hpp"
#include "ck/utility/sequence.hpp"
namespace ck { namespace ck {
...@@ -45,4 +47,3 @@ struct static_for ...@@ -45,4 +47,3 @@ struct static_for
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_FUNCTIONAL3_HPP // SPDX-License-Identifier: MIT
#define CK_FUNCTIONAL3_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "functional.hpp" #pragma once
#include "functional2.hpp"
#include "sequence.hpp" #include "ck/ck.hpp"
#include "multi_index.hpp" #include "ck/utility/functional.hpp"
#include "ck/utility/functional2.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/multi_index.hpp"
namespace ck { namespace ck {
...@@ -139,4 +142,3 @@ struct ford ...@@ -139,4 +142,3 @@ struct ford
}; };
} // namespace ck } // namespace ck
#endif
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_FUNCTIONAL4_HPP #ifndef CK_FUNCTIONAL4_HPP
#define CK_FUNCTIONAL4_HPP #define CK_FUNCTIONAL4_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "config.hpp"
#include "ck/ck.hpp"
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_IGNORE_HPP #ifndef CK_IGNORE_HPP
#define CK_IGNORE_HPP #define CK_IGNORE_HPP
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
......
#ifndef CK_INTEGRAL_CONSTANT_HPP // SPDX-License-Identifier: MIT
#define CK_INTEGRAL_CONSTANT_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck { namespace ck {
...@@ -47,4 +49,3 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_ ...@@ -47,4 +49,3 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
} }
} // namespace ck } // namespace ck
#endif
#ifndef IS_KNOWN_AT_COMPILE_TIME_HPP // SPDX-License-Identifier: MIT
#define IS_KNOWN_AT_COMPILE_TIME_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp" #pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "sequence.hpp" #include "sequence.hpp"
#include "tuple.hpp" #include "tuple.hpp"
...@@ -52,4 +54,3 @@ struct is_known_at_compile_time<Tuple<Ts...>> ...@@ -52,4 +54,3 @@ struct is_known_at_compile_time<Tuple<Ts...>>
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_MAGIC_DIVISION_HPP // SPDX-License-Identifier: MIT
#define CK_MAGIC_DIVISION_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp" #pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "number.hpp" #include "number.hpp"
#include "type.hpp" #include "type.hpp"
...@@ -156,5 +158,3 @@ struct MagicDivision ...@@ -156,5 +158,3 @@ struct MagicDivision
}; };
} // namespace ck } // namespace ck
#endif
#ifndef CK_MATH_HPP // SPDX-License-Identifier: MIT
#define CK_MATH_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp" #pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "number.hpp" #include "number.hpp"
#include "type.hpp" #include "type.hpp"
...@@ -142,6 +144,24 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) ...@@ -142,6 +144,24 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
return min(x, min(ys...)); return min(x, min(ys...));
} }
// disallow implicit type casting
template <typename T>
__device__ T exp(T x);
// TODO: add f16 support using v_exp_f16
template <>
__device__ float exp<float>(float x)
{
return __expf(x);
}
template <>
__device__ double exp<double>(double x)
{
return exp(x);
}
// greatest common divisor, aka highest common factor // greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y) __host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{ {
...@@ -212,5 +232,3 @@ struct less ...@@ -212,5 +232,3 @@ struct less
} // namespace math } // namespace math
} // namespace ck } // namespace ck
#endif
#ifndef CK_MATH_V2_HPP // SPDX-License-Identifier: MIT
#define CK_MATH_V2_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cmath> #include <cmath>
#include "data_type.hpp"
#include "half.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
// math functions for the host, some are implemented by calling C++ std functions
static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ float abs(float x) { return std::abs(x); };
static inline __host__ double abs(double x) { return std::abs(x); }; static inline __host__ double abs(double x) { return std::abs(x); };
...@@ -28,26 +33,26 @@ static inline __host__ int32_t abs(int32_t x) ...@@ -28,26 +33,26 @@ static inline __host__ int32_t abs(int32_t x)
static inline __host__ half_t abs(half_t x) static inline __host__ half_t abs(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
half_float::half abs_xx = half_float::abs(xx); uint16_t abs_xx = xx & 0x7fff;
half_t abs_x = *reinterpret_cast<half_t*>(&abs_xx); half_t abs_x = ck::bit_cast<half_t>(abs_xx);
return abs_x; return abs_x;
}; };
static inline __host__ float isnan(float x) { return std::isnan(x); }; static inline __host__ bool isnan(float x) { return std::isnan(x); };
static inline __host__ double isnan(double x) { return std::isnan(x); }; static inline __host__ bool isnan(double x) { return std::isnan(x); };
static inline __host__ int8_t isnan(int8_t x) static inline __host__ bool isnan(int8_t x)
{ {
(void)x; (void)x;
return false; return false;
}; };
static inline __host__ int32_t isnan(int32_t x) static inline __host__ bool isnan(int32_t x)
{ {
(void)x; (void)x;
return false; return false;
...@@ -55,12 +60,58 @@ static inline __host__ int32_t isnan(int32_t x) ...@@ -55,12 +60,58 @@ static inline __host__ int32_t isnan(int32_t x)
static inline __host__ bool isnan(half_t x) static inline __host__ bool isnan(half_t x)
{ {
half_float::half xx = *reinterpret_cast<half_float::half*>(&x); uint16_t xx = ck::bit_cast<uint16_t>(x);
return half_float::isnan(xx); return (xx & 0x7FFF) > 0x7C00;
}; };
static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); };
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static inline __device__ float abs(float x) { return ::abs(x); };
static inline __device__ double abs(double x) { return ::abs(x); };
static inline __device__ int8_t abs(int8_t x)
{
int8_t sgn = x >> (8 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ int32_t abs(int32_t x)
{
int32_t sgn = x >> (32 - 1);
return (x ^ sgn) - sgn;
};
static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
static inline __device__ bool isnan(float x) { return ::isnan(x); };
static inline __device__ bool isnan(double x) { return ::isnan(x); };
static inline __device__ bool isnan(int8_t x)
{
(void)x;
return false;
};
static inline __device__ bool isnan(int32_t x)
{
(void)x;
return false;
};
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
static inline __device__ double sqrt(double x) { return ::sqrt(x); };
} // namespace math } // namespace math
} // namespace ck } // namespace ck
#endif
#ifndef CK_MULTI_INDEX_HPP // SPDX-License-Identifier: MIT
#define CK_MULTI_INDEX_HPP // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "common_header.hpp" #include "common_header.hpp"
...@@ -8,5 +10,3 @@ ...@@ -8,5 +10,3 @@
#else #else
#include "statically_indexed_array_multi_index.hpp" #include "statically_indexed_array_multi_index.hpp"
#endif #endif
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment