Commit 3b301468 authored by ThomasNing's avatar ThomasNing
Browse files

pre-merge with the develop branch need to fix the bug

parents 6db81a11 9c5b2f39
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <numeric> #include <numeric>
#include <sstream> #include <sstream>
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" #include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
......
...@@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout ...@@ -430,6 +430,7 @@ struct G_NDHW : public BaseTensorLayout
} // namespace convolution } // namespace convolution
#ifndef CK_CODE_GEN_RTC
template < template <
typename Layout, typename Layout,
typename std::enable_if<std::is_base_of<BaseTensorLayout, Layout>::value, bool>::type = false> typename std::enable_if<std::is_base_of<BaseTensorLayout, Layout>::value, bool>::type = false>
...@@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&) ...@@ -438,6 +439,7 @@ std::ostream& operator<<(std::ostream& os, const Layout&)
os << Layout::name; os << Layout::name;
return os; return os;
} }
#endif
} // namespace tensor_layout } // namespace tensor_layout
} // namespace ck } // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -340,8 +340,8 @@ struct Bilinear ...@@ -340,8 +340,8 @@ struct Bilinear
}; };
template <> template <>
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>( __host__ __device__ constexpr void
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const operator()<int8_t, int32_t, int8_t>(int8_t& y, const int32_t& x0, const int8_t& x1) const
{ {
y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) + y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1)); beta_ * type_convert<float>(x1));
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -533,7 +533,7 @@ struct NormalizeInInfer ...@@ -533,7 +533,7 @@ struct NormalizeInInfer
const T3& gamma, const T3& gamma,
const T4& beta) const const T4& beta) const
{ {
static_assert(std::is_same<T2, float>::value || std::is_same<T2, double>::value, static_assert(is_same<T2, float>::value || is_same<T2, double>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
using ck::type_convert; using ck::type_convert;
......
...@@ -252,7 +252,7 @@ struct PassThroughPack2 ...@@ -252,7 +252,7 @@ struct PassThroughPack2
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const __host__ __device__ constexpr void operator()(half2_t& y, const f8x2_t& x) const
{ {
auto t = type_convert<float2_t>(x); auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t); y = type_convert<half2_t>(t);
...@@ -479,7 +479,7 @@ struct PassThrough ...@@ -479,7 +479,7 @@ struct PassThrough
template <> template <>
__host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const __host__ __device__ void operator()<bf8_t, half_t>(bf8_t& y, const half_t& x) const
{ {
y = ck::type_convert<bf8_t>(x); y = type_convert<bf8_t>(x);
} }
}; };
...@@ -552,21 +552,21 @@ struct Scale ...@@ -552,21 +552,21 @@ struct Scale
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
y = ck::type_convert<Y>(ck::type_convert<float>(x) * scale_); y = type_convert<Y>(type_convert<float>(x) * scale_);
} }
template <> template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const __host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{ {
y = ck::type_convert<half_t>(scale_) * x; y = type_convert<half_t>(scale_) * x;
}; };
template <> template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const __host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{ {
const float x_tmp = ck::type_convert<float>(x); const float x_tmp = type_convert<float>(x);
const float y_tmp = scale_ * x_tmp; const float y_tmp = scale_ * x_tmp;
y = ck::type_convert<bhalf_t>(y_tmp); y = type_convert<bhalf_t>(y_tmp);
}; };
template <> template <>
...@@ -584,7 +584,7 @@ struct Scale ...@@ -584,7 +584,7 @@ struct Scale
template <> template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const __host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{ {
y = ck::type_convert<int8_t>(scale_ * ck::type_convert<float>(x)); y = type_convert<int8_t>(scale_ * type_convert<float>(x));
}; };
float scale_; float scale_;
...@@ -600,7 +600,7 @@ struct ScaleAndResetNaNToMinusInfinity ...@@ -600,7 +600,7 @@ struct ScaleAndResetNaNToMinusInfinity
template <> template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const __host__ __device__ void operator()<float, float>(float& y, const float& x) const
{ {
y = ck::math::isnan(x) ? -ck::NumericLimits<float>::Infinity() : scale_ * x; y = math::isnan(x) ? -NumericLimits<float>::Infinity() : scale_ * x;
}; };
float scale_; float scale_;
...@@ -671,12 +671,13 @@ struct UnaryAbs ...@@ -671,12 +671,13 @@ struct UnaryAbs
template <typename T> template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value || is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value, is_same<T, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::abs(x); y = math::abs(x);
}; };
template <> template <>
...@@ -694,7 +695,7 @@ struct UnarySqrt ...@@ -694,7 +695,7 @@ struct UnarySqrt
static_assert(is_same<T, float>::value || is_same<T, double>::value, static_assert(is_same<T, float>::value || is_same<T, double>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::sqrt(x); y = math::sqrt(x);
}; };
}; };
...@@ -713,9 +714,9 @@ struct Relu ...@@ -713,9 +714,9 @@ struct Relu
template <> template <>
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
{ {
float x_f32 = ck::type_convert<float>(x); float x_f32 = type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0; float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = ck::type_convert<bhalf_t>(y_f32); y = type_convert<bhalf_t>(y_f32);
} }
}; };
...@@ -731,7 +732,7 @@ struct FastGelu ...@@ -731,7 +732,7 @@ struct FastGelu
template <typename Y, typename X> template <typename Y, typename X>
__device__ void operator()(Y& y, const X& x) const; __device__ void operator()(Y& y, const X& x) const;
#ifndef CK_CODE_GEN_RTC
template <> template <>
__host__ void operator()<float, float>(float& y, const float& x) const __host__ void operator()<float, float>(float& y, const float& x) const
{ {
...@@ -742,6 +743,7 @@ struct FastGelu ...@@ -742,6 +743,7 @@ struct FastGelu
const float emu = exp(u); const float emu = exp(u);
y = x / (1.f + emu); y = x / (1.f + emu);
} }
#endif
// device code, use lower precision "__ocml_exp_f32" and "rcp" // device code, use lower precision "__ocml_exp_f32" and "rcp"
template <> template <>
...@@ -753,7 +755,7 @@ struct FastGelu ...@@ -753,7 +755,7 @@ struct FastGelu
const float u = x * (c1 * x * x + c2); const float u = x * (c1 * x * x + c2);
const float emu = __ocml_exp_f32(u); const float emu = __ocml_exp_f32(u);
y = x * ck::math::rcp(1.f + emu); y = x * math::rcp(1.f + emu);
} }
template <> template <>
...@@ -851,10 +853,9 @@ struct Gelu ...@@ -851,10 +853,9 @@ struct Gelu
} }
template <> template <>
__host__ __device__ void operator()<ck::half_t, ck::half_t>(ck::half_t& y, __host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
const ck::half_t& x) const
{ {
y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x)))); y = half_t(0.5) * x * (half_t(1) + half_t(erf(float(0.70710678118f * x))));
} }
}; };
...@@ -868,7 +869,7 @@ struct Sigmoid ...@@ -868,7 +869,7 @@ struct Sigmoid
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
y = one / (one + ck::math::exp(-x)); y = one / (one + math::exp(-x));
}; };
}; };
...@@ -877,11 +878,11 @@ struct Silu ...@@ -877,11 +878,11 @@ struct Silu
template <typename T> template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, ck::half_t> || static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, half_t> ||
is_same_v<T, int8_t> || is_same_v<T, int32_t>, is_same_v<T, int8_t> || is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck::math::exp(-x))); y = x * (one / (one + math::exp(-x)));
}; };
}; };
...@@ -895,7 +896,7 @@ struct TanH ...@@ -895,7 +896,7 @@ struct TanH
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::tanh(x); y = math::tanh(x);
}; };
}; };
...@@ -905,11 +906,11 @@ struct ACos ...@@ -905,11 +906,11 @@ struct ACos
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::acos(x); y = math::acos(x);
}; };
}; };
...@@ -919,11 +920,11 @@ struct Neg ...@@ -919,11 +920,11 @@ struct Neg
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::neg(x); y = math::neg(x);
}; };
}; };
...@@ -933,11 +934,11 @@ struct ATan ...@@ -933,11 +934,11 @@ struct ATan
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::atan(x); y = math::atan(x);
}; };
}; };
...@@ -947,11 +948,11 @@ struct Sin ...@@ -947,11 +948,11 @@ struct Sin
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::sin(x); y = math::sin(x);
}; };
}; };
...@@ -961,11 +962,11 @@ struct ASinH ...@@ -961,11 +962,11 @@ struct ASinH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::asinh(x); y = math::asinh(x);
}; };
}; };
...@@ -975,11 +976,11 @@ struct Cos ...@@ -975,11 +976,11 @@ struct Cos
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::cos(x); y = cos(x);
}; };
}; };
...@@ -989,11 +990,11 @@ struct ACosH ...@@ -989,11 +990,11 @@ struct ACosH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::acosh(x); y = math::acosh(x);
}; };
}; };
...@@ -1003,11 +1004,11 @@ struct Tan ...@@ -1003,11 +1004,11 @@ struct Tan
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::tan(x); y = math::tan(x);
}; };
}; };
...@@ -1017,11 +1018,11 @@ struct ATanH ...@@ -1017,11 +1018,11 @@ struct ATanH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::atanh(x); y = math::atanh(x);
}; };
}; };
...@@ -1031,11 +1032,11 @@ struct SinH ...@@ -1031,11 +1032,11 @@ struct SinH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::sinh(x); y = math::sinh(x);
}; };
}; };
...@@ -1045,11 +1046,11 @@ struct Ceil ...@@ -1045,11 +1046,11 @@ struct Ceil
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::ceil(x); y = math::ceil(x);
}; };
}; };
...@@ -1059,11 +1060,11 @@ struct Exp ...@@ -1059,11 +1060,11 @@ struct Exp
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::exp(x); y = math::exp(x);
}; };
}; };
...@@ -1073,11 +1074,11 @@ struct CosH ...@@ -1073,11 +1074,11 @@ struct CosH
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::cosh(x); y = math::cosh(x);
}; };
}; };
...@@ -1087,11 +1088,11 @@ struct Floor ...@@ -1087,11 +1088,11 @@ struct Floor
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::floor(x); y = math::floor(x);
}; };
}; };
...@@ -1101,11 +1102,11 @@ struct Log ...@@ -1101,11 +1102,11 @@ struct Log
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::log(x); y = math::log(x);
}; };
}; };
...@@ -1115,11 +1116,11 @@ struct ASin ...@@ -1115,11 +1116,11 @@ struct ASin
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::asin(x); y = math::asin(x);
}; };
}; };
...@@ -1129,11 +1130,11 @@ struct Rcp ...@@ -1129,11 +1130,11 @@ struct Rcp
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, ck::half_t>::value || is_same<T, int8_t>::value || is_same<T, half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value, is_same<T, int32_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck::math::rcp(x); y = math::rcp(x);
}; };
}; };
...@@ -1153,7 +1154,7 @@ struct Swish ...@@ -1153,7 +1154,7 @@ struct Swish
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x); float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx))); y = type_convert<Y>(x / (1.f + math::exp(bx)));
}; };
const float beta_; const float beta_;
...@@ -1172,7 +1173,7 @@ struct SoftRelu ...@@ -1172,7 +1173,7 @@ struct SoftRelu
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha;
} }
const float alpha_; const float alpha_;
}; };
...@@ -1193,7 +1194,7 @@ struct Power ...@@ -1193,7 +1194,7 @@ struct Power
T casted_beta = type_convert<T>(beta_); T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_); T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x; T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma); y = math::pow(shifted_scaled_x, casted_gamma);
} }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
...@@ -1213,7 +1214,7 @@ struct ClippedRelu ...@@ -1213,7 +1214,7 @@ struct ClippedRelu
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_); T casted_beta = type_convert<T>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x)); y = math::min(casted_beta, math::max(casted_alpha, x));
} }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
...@@ -1248,7 +1249,7 @@ struct Elu ...@@ -1248,7 +1249,7 @@ struct Elu
is_same<T, int8_t>::value, is_same<T, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x); y = x > 0 ? x : casted_alpha * math::expm1(x);
} }
const float alpha_; const float alpha_;
}; };
...@@ -1350,10 +1351,10 @@ struct FastNumericArrayConverter ...@@ -1350,10 +1351,10 @@ struct FastNumericArrayConverter
}; };
template <> template <>
struct FastNumericArrayConverter<uint8_t, ck::half_t, 4> struct FastNumericArrayConverter<uint8_t, half_t, 4>
{ {
using InputArray = vector_type<uint8_t, 4>; using InputArray = vector_type<uint8_t, 4>;
using OutputArray = vector_type<ck::half_t, 4>; using OutputArray = vector_type<half_t, 4>;
__device__ static OutputArray convert(InputArray const& Input) __device__ static OutputArray convert(InputArray const& Input)
{ {
...@@ -1383,13 +1384,13 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4> ...@@ -1383,13 +1384,13 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, 4>
}; };
template <index_t N> template <index_t N>
struct FastNumericArrayConverter<uint8_t, ck::half_t, N> struct FastNumericArrayConverter<uint8_t, half_t, N>
{ {
static constexpr int VEC_WIDTH = 4; static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using InputArray = vector_type<uint8_t, N>; using InputArray = vector_type<uint8_t, N>;
using OutputArray = vector_type<ck::half_t, N>; using OutputArray = vector_type<half_t, N>;
__device__ static OutputArray convert(InputArray const& Input) __device__ static OutputArray convert(InputArray const& Input)
{ {
...@@ -1398,7 +1399,7 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N> ...@@ -1398,7 +1399,7 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
OutputArray Output; OutputArray Output;
using Vec_InputArray = vector_type<uint8_t, 4>; using Vec_InputArray = vector_type<uint8_t, 4>;
using Vec_OutputArray = vector_type<ck::half_t, 4>; using Vec_OutputArray = vector_type<half_t, 4>;
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output); Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input); Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp"
#ifndef CK_CODE_GEN_RTC
#include <limits> #include <limits>
#include <stdlib.h> #include <stdlib.h>
#endif
namespace ck { namespace ck {
...@@ -978,8 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit ...@@ -978,8 +981,7 @@ struct BlockToCTileMap_3DGrid_KSplit
// Create 3D grid // Create 3D grid
const auto M0 = math::integer_divide_ceil(M, MPerBlock); const auto M0 = math::integer_divide_ceil(M, MPerBlock);
const auto N0 = math::integer_divide_ceil(N, NPerBlock); const auto N0 = math::integer_divide_ceil(N, NPerBlock);
return make_tuple(N0, M0, k_split);
return std::make_tuple(N0, M0, k_split);
} }
template <typename TopIdx> template <typename TopIdx>
...@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK ...@@ -1103,7 +1105,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t dp_for_sk_iters = k_iters_per_tile.get(); uint32_t dp_for_sk_iters = k_iters_per_tile.get();
uint32_t best_sk_score = uint32_t best_sk_score =
std::numeric_limits<int>::max(); // we need to find the smallest sk iters NumericLimits<int32_t>::Max(); // we need to find the smallest sk iters
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
tentative_sk_blocks++) tentative_sk_blocks++)
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -423,10 +423,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -423,10 +423,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
} }
template <typename AsLayout, GemmSpecialization GemmSpec> template <typename AsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto __host__ __device__ static auto MakeAsGridDescriptor_M_K(
MakeAsGridDescriptor_M_K(const std::array<index_t, NumATensor>& MRaws, #ifdef CK_CODE_GEN_RTC
const std::array<index_t, NumATensor>& KRaws, const ck::Array<index_t, NumATensor>& MRaws,
const std::array<index_t, NumATensor>& AsStride) const ck::Array<index_t, NumATensor>& KRaws,
const ck::Array<index_t, NumATensor>& AsStride
#else
const std::array<index_t, NumATensor>& MRaws,
const std::array<index_t, NumATensor>& KRaws,
const std::array<index_t, NumATensor>& AsStride
#endif
)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -462,10 +469,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -462,10 +469,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
} }
template <typename BsLayout, GemmSpecialization GemmSpec> template <typename BsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto __host__ __device__ static auto MakeBsGridDescriptor_N_K(
MakeBsGridDescriptor_N_K(const std::array<index_t, NumBTensor>& NRaws, #ifdef CK_CODE_GEN_RTC
const std::array<index_t, NumBTensor>& KRaws, const ck::Array<index_t, NumBTensor>& NRaws,
const std::array<index_t, NumBTensor>& BsStride) const ck::Array<index_t, NumBTensor>& KRaws,
const ck::Array<index_t, NumBTensor>& BsStride
#else
const std::array<index_t, NumBTensor>& NRaws,
const std::array<index_t, NumBTensor>& KRaws,
const std::array<index_t, NumBTensor>& BsStride
#endif
)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -500,10 +514,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -500,10 +514,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
} }
template <typename DsLayout, GemmSpecialization GemmSpec> template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto __host__ __device__ static auto MakeDsGridDescriptor_M_N(
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws, #ifdef CK_CODE_GEN_RTC
const std::array<index_t, NumDTensor>& NRaws, const ck::Array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& DsStride) const ck::Array<index_t, NumDTensor>& NRaws,
const ck::Array<index_t, NumDTensor>& DsStride
#else
const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride
#endif
)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -969,9 +990,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -969,9 +990,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
const index_t M, const index_t M,
const index_t N, const index_t N,
const index_t K, const index_t K,
#ifdef CK_CODE_GEN_RTC
const ck::Array<index_t, NumATensor> StrideAs,
const ck::Array<index_t, NumBTensor> StrideBs,
const ck::Array<index_t, NumDTensor> StrideDs,
#else
const std::array<index_t, NumATensor> StrideAs, const std::array<index_t, NumATensor> StrideAs,
const std::array<index_t, NumBTensor> StrideBs, const std::array<index_t, NumBTensor> StrideBs,
const std::array<index_t, NumDTensor> StrideDs, const std::array<index_t, NumDTensor> StrideDs,
#endif
const index_t StrideE, const index_t StrideE,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -473,11 +473,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -473,11 +473,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
} }
#ifdef CK_CODE_GEN_RTC
template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const ck::Array<index_t, NumDTensor>& MRaws,
const ck::Array<index_t, NumDTensor>& NRaws,
const ck::Array<index_t, NumDTensor>& DsStride)
#else
template <typename DsLayout, GemmSpecialization GemmSpec> template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto __host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws, MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws, const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride) const std::array<index_t, NumDTensor>& DsStride)
#endif
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -941,7 +949,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -941,7 +949,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const index_t K, const index_t K,
const index_t StrideA, const index_t StrideA,
const index_t StrideB, const index_t StrideB,
#ifdef CK_CODE_GEN_RTC
const ck::Array<index_t, NumDTensor> StrideDs,
#else
const std::array<index_t, NumDTensor> StrideDs, const std::array<index_t, NumDTensor> StrideDs,
#endif
const index_t StrideE, const index_t StrideE,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#ifndef CK_CODE_GEN_RTC
#include <iostream> #include <iostream>
#include <ostream> #include <ostream>
#endif
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp"
...@@ -53,12 +54,15 @@ constexpr auto GridwiseGemmPipeline_Selector() ...@@ -53,12 +54,15 @@ constexpr auto GridwiseGemmPipeline_Selector()
} }
else else
{ {
#ifndef CK_CODE_GEN_RTC
std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl;
#endif
} }
} }
} // namespace ck } // namespace ck
#ifndef CK_CODE_GEN_RTC
inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
{ {
switch(p) switch(p)
...@@ -71,3 +75,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) ...@@ -71,3 +75,4 @@ inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p)
} }
return os; return os;
} }
#endif
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -212,7 +212,7 @@ template <typename SrcData, ...@@ -212,7 +212,7 @@ template <typename SrcData,
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2 struct ThreadwiseTensorSliceTransfer_v2
{ {
static_assert((InvalidElementAsNaN && !std::is_integral<DstData>::value) || static_assert((InvalidElementAsNaN && !ck::is_integral<DstData>::value) ||
(!InvalidElementAsNaN), (!InvalidElementAsNaN),
"Filling invalid element as NaN is only for floating point types"); "Filling invalid element as NaN is only for floating point types");
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
...@@ -148,8 +147,8 @@ struct TransformConvFwdToGemm ...@@ -148,8 +147,8 @@ struct TransformConvFwdToGemm
template <typename ConvDimsType, template <typename ConvDimsType,
typename ConvSpatialDimsType, typename ConvSpatialDimsType,
index_t NDim = NDimSpatial, index_t NDim = NDimSpatial,
typename std::enable_if<NDim == 1, bool>::type = false> typename ck::enable_if<NDim == 1, bool>::type = false>
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides, const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& b_g_k_c_xs_lengths, const ConvDimsType& b_g_k_c_xs_lengths,
...@@ -201,11 +200,15 @@ struct TransformConvFwdToGemm ...@@ -201,11 +200,15 @@ struct TransformConvFwdToGemm
InRightPadW_{input_right_pads[I0]}, InRightPadW_{input_right_pads[I0]},
ZYX_{X_} ZYX_{X_}
{ {
#ifdef CK_CODE_GEN_RTC
static_assert(is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
#else
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> || static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>); is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> || static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>); is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
#endif
if constexpr(SplitN) if constexpr(SplitN)
{ {
N_ = GetSplitedNSize( N_ = GetSplitedNSize(
...@@ -219,8 +222,8 @@ struct TransformConvFwdToGemm ...@@ -219,8 +222,8 @@ struct TransformConvFwdToGemm
template <typename ConvDimsType, template <typename ConvDimsType,
typename ConvSpatialDimsType, typename ConvSpatialDimsType,
index_t NDim = NDimSpatial, index_t NDim = NDimSpatial,
typename std::enable_if<NDim == 2, bool>::type = false> typename ck::enable_if<NDim == 2, bool>::type = false>
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides, const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& b_g_k_c_xs_lengths, const ConvDimsType& b_g_k_c_xs_lengths,
...@@ -272,11 +275,15 @@ struct TransformConvFwdToGemm ...@@ -272,11 +275,15 @@ struct TransformConvFwdToGemm
InRightPadW_{input_right_pads[I1]}, InRightPadW_{input_right_pads[I1]},
ZYX_{Y_ * X_} ZYX_{Y_ * X_}
{ {
#ifdef CK_CODE_GEN_RTC
static_assert(is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
#else
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> || static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>); is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> || static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>); is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
#endif
if constexpr(SplitN) if constexpr(SplitN)
{ {
N_ = GetSplitedNSize( N_ = GetSplitedNSize(
...@@ -290,8 +297,8 @@ struct TransformConvFwdToGemm ...@@ -290,8 +297,8 @@ struct TransformConvFwdToGemm
template <typename ConvDimsType, template <typename ConvDimsType,
typename ConvSpatialDimsType, typename ConvSpatialDimsType,
index_t NDim = NDimSpatial, index_t NDim = NDimSpatial,
typename std::enable_if<NDim == 3, bool>::type = false> typename ck::enable_if<NDim == 3, bool>::type = false>
__host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& a_g_n_c_wis_strides, const ConvDimsType& a_g_n_c_wis_strides,
const ConvDimsType& b_g_k_c_xs_lengths, const ConvDimsType& b_g_k_c_xs_lengths,
...@@ -343,11 +350,15 @@ struct TransformConvFwdToGemm ...@@ -343,11 +350,15 @@ struct TransformConvFwdToGemm
InRightPadW_{input_right_pads[I2]}, InRightPadW_{input_right_pads[I2]},
ZYX_{Z_ * Y_ * X_} ZYX_{Z_ * Y_ * X_}
{ {
#ifdef CK_CODE_GEN_RTC
static_assert(is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
#else
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> || static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>); is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> || static_assert(is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>); is_same_v<ConvDimsType, ck::Array<IndexType, NDimSpatial + I3>>);
#endif
if constexpr(SplitN) if constexpr(SplitN)
{ {
N_ = GetSplitedNSize( N_ = GetSplitedNSize(
...@@ -478,11 +489,11 @@ struct TransformConvFwdToGemm ...@@ -478,11 +489,11 @@ struct TransformConvFwdToGemm
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties // properties
template <typename ALayout, template <typename ALayout,
typename std::enable_if<NDimSpatial == 1 && typename ck::enable_if<NDimSpatial == 1 &&
(is_same_v<ALayout, tensor_layout::convolution::G_NW_C> || (is_same_v<ALayout, tensor_layout::convolution::G_NW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NWGC> || is_same_v<ALayout, tensor_layout::convolution::NWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNWC>), is_same_v<ALayout, tensor_layout::convolution::GNWC>),
bool>::type = false> bool>::type = false>
__host__ __device__ auto MakeADescriptor_M_K() const __host__ __device__ auto MakeADescriptor_M_K() const
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
...@@ -691,11 +702,11 @@ struct TransformConvFwdToGemm ...@@ -691,11 +702,11 @@ struct TransformConvFwdToGemm
} }
template <typename ALayout, template <typename ALayout,
typename std::enable_if< typename ck::enable_if<NDimSpatial == 2 &&
NDimSpatial == 2 && (is_same_v<ALayout, tensor_layout::convolution::G_NHW_C> || (is_same_v<ALayout, tensor_layout::convolution::G_NHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGC> || is_same_v<ALayout, tensor_layout::convolution::NHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNHWC>), is_same_v<ALayout, tensor_layout::convolution::GNHWC>),
bool>::type = false> bool>::type = false>
__host__ __device__ auto MakeADescriptor_M_K() const __host__ __device__ auto MakeADescriptor_M_K() const
{ {
...@@ -932,7 +943,7 @@ struct TransformConvFwdToGemm ...@@ -932,7 +943,7 @@ struct TransformConvFwdToGemm
} }
template <typename ALayout, template <typename ALayout,
typename std::enable_if< typename ck::enable_if<
NDimSpatial == 3 && (is_same_v<ALayout, tensor_layout::convolution::G_NDHW_C> || NDimSpatial == 3 && (is_same_v<ALayout, tensor_layout::convolution::G_NDHW_C> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGC> || is_same_v<ALayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>), is_same_v<ALayout, tensor_layout::convolution::GNDHWC>),
...@@ -1242,19 +1253,19 @@ struct TransformConvFwdToGemm ...@@ -1242,19 +1253,19 @@ struct TransformConvFwdToGemm
} }
template <typename BLayout, template <typename BLayout,
typename std::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKXC> || typename ck::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKYXC> || is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>, is_same_v<BLayout, tensor_layout::convolution::GKZYXC>,
bool>::type = false> bool>::type = false>
__host__ __device__ auto MakeBDescriptor_N_K() const __host__ __device__ auto MakeBDescriptor_N_K() const
{ {
if constexpr(ConvForwardSpecialization == if constexpr(ConvForwardSpecialization ==
device::ConvolutionForwardSpecialization::Filter3x3) device::ConvolutionForwardSpecialization::Filter3x3)
{ {
using FilterSizeNumType = using FilterSizeNumType =
std::conditional_t<NDimSpatial == 1, ck::conditional_t<NDimSpatial == 1,
Number<3>, Number<3>,
std::conditional_t<NDimSpatial == 2, Number<9>, Number<27>>>; ck::conditional_t<NDimSpatial == 2, Number<9>, Number<27>>>;
if constexpr(NumGroupsToMerge == 1) if constexpr(NumGroupsToMerge == 1)
{ {
...@@ -1297,13 +1308,13 @@ struct TransformConvFwdToGemm ...@@ -1297,13 +1308,13 @@ struct TransformConvFwdToGemm
template < template <
typename BLayout, typename BLayout,
typename std::enable_if<is_same_v<BLayout, tensor_layout::convolution::G_K_X_C> || typename ck::enable_if<is_same_v<BLayout, tensor_layout::convolution::G_K_X_C> ||
is_same_v<BLayout, tensor_layout::convolution::G_K_YX_C> || is_same_v<BLayout, tensor_layout::convolution::G_K_YX_C> ||
is_same_v<BLayout, tensor_layout::convolution::G_K_ZYX_C> || is_same_v<BLayout, tensor_layout::convolution::G_K_ZYX_C> ||
is_same_v<BLayout, tensor_layout::convolution::KXGC> || is_same_v<BLayout, tensor_layout::convolution::KXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KYXGC> || is_same_v<BLayout, tensor_layout::convolution::KYXGC> ||
is_same_v<BLayout, tensor_layout::convolution::KZYXGC>, is_same_v<BLayout, tensor_layout::convolution::KZYXGC>,
bool>::type = false> bool>::type = false>
__host__ __device__ auto MakeBDescriptor_N_K() const __host__ __device__ auto MakeBDescriptor_N_K() const
{ {
const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( const auto wei_k_yx_c_desc = make_naive_tensor_descriptor(
...@@ -1318,36 +1329,36 @@ struct TransformConvFwdToGemm ...@@ -1318,36 +1329,36 @@ struct TransformConvFwdToGemm
return wei_gemmn_gemmk_desc; return wei_gemmn_gemmk_desc;
} }
template <typename CLayout, template <
index_t NDimSp = NDimSpatial, typename CLayout,
index_t NDimSp = NDimSpatial,
typename std::enable_if<NDimSp == 1 && typename ck::enable_if<NDimSp == 1 && (is_same_v<CLayout, tensor_layout::convolution::G_K>),
(is_same_v<CLayout, tensor_layout::convolution::G_K>), bool>::type = false>
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const __host__ __device__ auto MakeCDescriptor_M_N() const
{ {
return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_), return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_)); make_tuple(I0, KStrideTensorC_));
} }
template <typename CLayout, template <
index_t NDimSp = NDimSpatial, typename CLayout,
index_t NDimSp = NDimSpatial,
typename std::enable_if<NDimSp == 2 && typename ck::enable_if<NDimSp == 2 && (is_same_v<CLayout, tensor_layout::convolution::G_K>),
(is_same_v<CLayout, tensor_layout::convolution::G_K>), bool>::type = false>
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const __host__ __device__ auto MakeCDescriptor_M_N() const
{ {
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
make_tuple(I0, KStrideTensorC_)); make_tuple(I0, KStrideTensorC_));
} }
template <typename CLayout, template <
index_t NDimSp = NDimSpatial, typename CLayout,
index_t NDimSp = NDimSpatial,
typename std::enable_if<NDimSp == 3 && typename ck::enable_if<NDimSp == 3 && (is_same_v<CLayout, tensor_layout::convolution::G_K>),
(is_same_v<CLayout, tensor_layout::convolution::G_K>), bool>::type = false>
bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const __host__ __device__ auto MakeCDescriptor_M_N() const
{ {
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
...@@ -1355,12 +1366,12 @@ struct TransformConvFwdToGemm ...@@ -1355,12 +1366,12 @@ struct TransformConvFwdToGemm
} }
template <typename CLayout, template <typename CLayout,
index_t NDimSp = NDimSpatial, index_t NDimSp = NDimSpatial,
typename std::enable_if<NDimSp == 1 && typename ck::enable_if<NDimSp == 1 &&
(is_same_v<CLayout, tensor_layout::convolution::G_NW_K> || (is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NWGK> || is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
is_same_v<CLayout, tensor_layout::convolution::GNWK>), is_same_v<CLayout, tensor_layout::convolution::GNWK>),
bool>::type = false> bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const __host__ __device__ auto MakeCDescriptor_M_N() const
{ {
const IndexType NDoHoWo = N_ * Wo_; const IndexType NDoHoWo = N_ * Wo_;
...@@ -1410,11 +1421,11 @@ struct TransformConvFwdToGemm ...@@ -1410,11 +1421,11 @@ struct TransformConvFwdToGemm
template <typename CLayout, template <typename CLayout,
index_t NDimSp = NDimSpatial, index_t NDimSp = NDimSpatial,
typename std::enable_if< typename ck::enable_if<NDimSp == 2 &&
NDimSp == 2 && (is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> || (is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> || is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::GNHWK>), is_same_v<CLayout, tensor_layout::convolution::GNHWK>),
bool>::type = false> bool>::type = false>
__host__ __device__ auto MakeCDescriptor_M_N() const __host__ __device__ auto MakeCDescriptor_M_N() const
{ {
const IndexType NDoHoWo = N_ * Ho_ * Wo_; const IndexType NDoHoWo = N_ * Ho_ * Wo_;
...@@ -1467,7 +1478,7 @@ struct TransformConvFwdToGemm ...@@ -1467,7 +1478,7 @@ struct TransformConvFwdToGemm
template <typename CLayout, template <typename CLayout,
index_t NDimSp = NDimSpatial, index_t NDimSp = NDimSpatial,
typename std::enable_if< typename ck::enable_if<
NDimSp == 3 && (is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> || NDimSp == 3 && (is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK> || is_same_v<CLayout, tensor_layout::convolution::NDHWGK> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>), is_same_v<CLayout, tensor_layout::convolution::GNDHWK>),
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
...@@ -1021,15 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, ...@@ -1021,15 +1021,24 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes); static_assert(bytes_per_thread == dword_bytes);
#ifndef CK_CODE_GEN_RTC
const uint32_t* global_ptr = const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr)); reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
#else
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<size_t>(global_base_ptr));
#endif
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM #if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
T* lds_ptr = lds_base_ptr + lds_offset; T* lds_ptr = lds_base_ptr + lds_offset;
#ifndef CK_CODE_GEN_RTC
auto const lds_ptr_sgpr = auto const lds_ptr_sgpr =
__builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr))); __builtin_amdgcn_readfirstlane((reinterpret_cast<uintptr_t>(lds_ptr)));
#else
auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast<size_t>(lds_ptr)));
#endif
asm volatile("s_mov_b32 m0, %0; \n\t" asm volatile("s_mov_b32 m0, %0; \n\t"
"buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr),
"v"(global_offset_bytes), "v"(global_offset_bytes),
...@@ -1038,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, ...@@ -1038,8 +1047,13 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
#else #else
// LDS pointer must be attributed with the LDS address space. // LDS pointer must be attributed with the LDS address space.
__attribute__((address_space(3))) uint32_t* lds_ptr = __attribute__((address_space(3))) uint32_t* lds_ptr =
#ifndef CK_CODE_GEN_RTC
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset)); reinterpret_cast<uintptr_t>(lds_base_ptr + lds_offset));
#else
reinterpret_cast<__attribute__((address_space(3))) uint32_t*>(
reinterpret_cast<size_t>(lds_base_ptr + lds_offset));
#endif
llvm_amdgcn_raw_buffer_load_lds( llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck/ck.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/random_gen.hpp" #include "ck/utility/random_gen.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
...@@ -424,9 +426,9 @@ __host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a) ...@@ -424,9 +426,9 @@ __host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a)
} }
template <typename T, template <typename T,
std::enable_if_t<std::is_same_v<T, bf8_ocp_t> || std::is_same_v<T, f8_ocp_t> || ck::enable_if_t<is_same_v<T, bf8_ocp_t> || is_same_v<T, f8_ocp_t> ||
std::is_same_v<T, bf8_fnuz_t> || std::is_same_v<T, f8_fnuz_t>, is_same_v<T, bf8_fnuz_t> || is_same_v<T, f8_fnuz_t>,
bool> = true> bool> = true>
__host__ __device__ static inline constexpr bool fp8_is_inf(T) __host__ __device__ static inline constexpr bool fp8_is_inf(T)
{ {
return false; return false;
...@@ -823,7 +825,11 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) ...@@ -823,7 +825,11 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
if constexpr(stochastic_rounding) if constexpr(stochastic_rounding)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f); #ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
#endif
} }
return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>( return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
f, rng); f, rng);
...@@ -839,7 +845,11 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) ...@@ -839,7 +845,11 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f)
if constexpr(stochastic_rounding) if constexpr(stochastic_rounding)
{ {
constexpr int seed = 1254739; constexpr int seed = 1254739;
#ifndef CK_CODE_GEN_RTC
rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f); rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&f), f);
#else
rng = prand_generator<float, seed>(reinterpret_cast<size_t>(&f), f);
#endif
} }
if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ) if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ)
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -7,10 +7,12 @@ ...@@ -7,10 +7,12 @@
#include "ck/utility/functional2.hpp" #include "ck/utility/functional2.hpp"
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#ifndef CK_CODE_GEN_RTC
#include <array> #include <array>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>
#endif
namespace ck { namespace ck {
namespace detail { namespace detail {
...@@ -37,7 +39,7 @@ struct get_carrier<3> ...@@ -37,7 +39,7 @@ struct get_carrier<3>
{ {
using value_type = uint32_t; using value_type = uint32_t;
std::array<std::byte, 3> bytes; Array<ck::byte, 3> bytes;
static_assert(sizeof(bytes) <= sizeof(value_type)); static_assert(sizeof(bytes) <= sizeof(value_type));
// replacement of host std::copy_n() // replacement of host std::copy_n()
...@@ -61,22 +63,22 @@ struct get_carrier<3> ...@@ -61,22 +63,22 @@ struct get_carrier<3>
// method to trigger template substitution failure // method to trigger template substitution failure
__device__ carrier(const carrier& other) noexcept __device__ carrier(const carrier& other) noexcept
{ {
copy_n(other.bytes.begin(), bytes.size(), bytes.begin()); copy_n(other.bytes.begin(), bytes.Size(), bytes.begin());
} }
public: public:
__device__ carrier& operator=(value_type value) noexcept __device__ carrier& operator=(value_type value) noexcept
{ {
copy_n(reinterpret_cast<const std::byte*>(&value), bytes.size(), bytes.begin()); copy_n(reinterpret_cast<const ck::byte*>(&value), bytes.Size(), bytes.begin());
return *this; return *this;
} }
__device__ operator value_type() const noexcept __device__ operator value_type() const noexcept
{ {
std::byte result[sizeof(value_type)]; ck::byte result[sizeof(value_type)];
copy_n(bytes.begin(), bytes.size(), result); copy_n(bytes.begin(), bytes.Size(), result);
return *reinterpret_cast<const value_type*>(result); return *reinterpret_cast<const value_type*>(result);
} }
...@@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value) ...@@ -109,8 +111,8 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
{ {
constexpr unsigned object_size = sizeof(int64_t); constexpr unsigned object_size = sizeof(int64_t);
constexpr unsigned second_part_offset = object_size / 2; constexpr unsigned second_part_offset = object_size / 2;
auto* const from_obj = reinterpret_cast<const std::byte*>(&value); auto* const from_obj = reinterpret_cast<const ck::byte*>(&value);
alignas(int64_t) std::byte to_obj[object_size]; alignas(int64_t) ck::byte to_obj[object_size];
using Sgpr = uint32_t; using Sgpr = uint32_t;
...@@ -122,17 +124,16 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value) ...@@ -122,17 +124,16 @@ __device__ inline int64_t amd_wave_read_first_lane(int64_t value)
return *reinterpret_cast<int64_t*>(to_obj); return *reinterpret_cast<int64_t*>(to_obj);
} }
template < template <typename Object,
typename Object, typename = ck::enable_if_t<ck::is_class_v<Object> && ck::is_trivially_copyable_v<Object>>>
typename = std::enable_if_t<std::is_class_v<Object> && std::is_trivially_copyable_v<Object>>>
__device__ auto amd_wave_read_first_lane(const Object& obj) __device__ auto amd_wave_read_first_lane(const Object& obj)
{ {
using Size = unsigned; using Size = unsigned;
constexpr Size SgprSize = 4; constexpr Size SgprSize = 4;
constexpr Size ObjectSize = sizeof(Object); constexpr Size ObjectSize = sizeof(Object);
auto* const from_obj = reinterpret_cast<const std::byte*>(&obj); auto* const from_obj = reinterpret_cast<const ck::byte*>(&obj);
alignas(Object) std::byte to_obj[ObjectSize]; alignas(Object) ck::byte to_obj[ObjectSize];
constexpr Size RemainedSize = ObjectSize % SgprSize; constexpr Size RemainedSize = ObjectSize % SgprSize;
constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_ARRAY_HPP #ifndef CK_ARRAY_HPP
#define CK_ARRAY_HPP #define CK_ARRAY_HPP
...@@ -38,6 +38,8 @@ struct Array ...@@ -38,6 +38,8 @@ struct Array
} }
__host__ __device__ constexpr const TData* begin() const { return &mData[0]; } __host__ __device__ constexpr const TData* begin() const { return &mData[0]; }
__host__ __device__ constexpr const TData* end() const { return &mData[NSize]; } __host__ __device__ constexpr const TData* end() const { return &mData[NSize]; }
__host__ __device__ constexpr TData* begin() { return &mData[0]; }
__host__ __device__ constexpr TData* end() { return &mData[NSize]; }
}; };
// empty Array // empty Array
...@@ -54,7 +56,7 @@ template <typename X, typename... Xs> ...@@ -54,7 +56,7 @@ template <typename X, typename... Xs>
__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs)
{ {
using data_type = remove_cvref_t<X>; using data_type = remove_cvref_t<X>;
return Array<data_type, sizeof...(Xs) + 1>{std::forward<X>(x), std::forward<Xs>(xs)...}; return Array<data_type, sizeof...(Xs) + 1>{ck::forward<X>(x), ck::forward<Xs>(xs)...};
} }
// make empty array // make empty array
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_CONTAINER_HELPER_HPP #ifndef CK_CONTAINER_HELPER_HPP
#define CK_CONTAINER_HELPER_HPP #define CK_CONTAINER_HELPER_HPP
...@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY> ...@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay) __host__ __device__ constexpr auto container_concat(const Array<T, NX>& ax, const Array<T, NY>& ay)
{ {
return unpack2( return unpack2(
[&](auto&&... zs) { return make_array(std::forward<decltype(zs)>(zs)...); }, ax, ay); [&](auto&&... zs) { return make_array(ck::forward<decltype(zs)>(zs)...); }, ax, ay);
} }
template <typename... X, typename... Y> template <typename... X, typename... Y>
__host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty) __host__ __device__ constexpr auto container_concat(const Tuple<X...>& tx, const Tuple<Y...>& ty)
{ {
return unpack2( return unpack2(
[&](auto&&... zs) { return make_tuple(std::forward<decltype(zs)>(zs)...); }, tx, ty); [&](auto&&... zs) { return make_tuple(ck::forward<decltype(zs)>(zs)...); }, tx, ty);
} }
template <typename Container> template <typename Container>
......
...@@ -5,9 +5,21 @@ ...@@ -5,9 +5,21 @@
#include "ck/utility/amd_ck_fp8.hpp" #include "ck/utility/amd_ck_fp8.hpp"
#include "ck/utility/statically_indexed_array.hpp" #include "ck/utility/statically_indexed_array.hpp"
#ifdef CK_CODE_GEN_RTC
using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using float_t = float;
#endif
namespace ck { namespace ck {
#ifdef CK_CODE_GEN_RTC
using byte = unsigned char;
#else
using std::byte;
#endif
using bhalf_t = ushort; using bhalf_t = ushort;
using half_t = _Float16; using half_t = _Float16;
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
...@@ -217,7 +229,7 @@ struct scalar_type<bool> ...@@ -217,7 +229,7 @@ struct scalar_type<bool>
}; };
template <typename T> template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 1, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using type = d1_t; using type = d1_t;
...@@ -253,7 +265,7 @@ struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>> ...@@ -253,7 +265,7 @@ struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
__device__ int static err = 0; __device__ int static err = 0;
template <typename T> template <typename T>
struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 2, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -313,7 +325,7 @@ struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>> ...@@ -313,7 +325,7 @@ struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 3, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 3, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -383,7 +395,7 @@ struct vector_type<T, 3, typename std::enable_if_t<is_native_type<T>()>> ...@@ -383,7 +395,7 @@ struct vector_type<T, 3, typename std::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 4, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -453,7 +465,7 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>> ...@@ -453,7 +465,7 @@ struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 5, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 5, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d4_t __attribute__((ext_vector_type(4))); typedef T d4_t __attribute__((ext_vector_type(4)));
...@@ -523,7 +535,7 @@ struct vector_type<T, 5, typename std::enable_if_t<is_native_type<T>()>> ...@@ -523,7 +535,7 @@ struct vector_type<T, 5, typename std::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 7, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 7, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -605,7 +617,7 @@ struct vector_type<T, 7, typename std::enable_if_t<is_native_type<T>()>> ...@@ -605,7 +617,7 @@ struct vector_type<T, 7, typename std::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 8, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -687,7 +699,7 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>> ...@@ -687,7 +699,7 @@ struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 13, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 13, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d4_t __attribute__((ext_vector_type(4))); typedef T d4_t __attribute__((ext_vector_type(4)));
...@@ -769,7 +781,7 @@ struct vector_type<T, 13, typename std::enable_if_t<is_native_type<T>()>> ...@@ -769,7 +781,7 @@ struct vector_type<T, 13, typename std::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 16, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -863,7 +875,7 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>> ...@@ -863,7 +875,7 @@ struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 32, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -967,7 +979,7 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>> ...@@ -967,7 +979,7 @@ struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 64, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -1083,7 +1095,7 @@ struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>> ...@@ -1083,7 +1095,7 @@ struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 128, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -1209,7 +1221,7 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>> ...@@ -1209,7 +1221,7 @@ struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>> struct vector_type<T, 256, typename ck::enable_if_t<is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d2_t __attribute__((ext_vector_type(2)));
...@@ -1374,7 +1386,7 @@ template <typename T, index_t N> ...@@ -1374,7 +1386,7 @@ template <typename T, index_t N>
struct non_native_vector_base< struct non_native_vector_base<
T, T,
N, N,
std::enable_if_t<sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8>> ck::enable_if_t<sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8>>
{ {
using data_t = typename nnvb_data_t_selector<T>::type; // select data_t based on the size of T using data_t = typename nnvb_data_t_selector<T>::type; // select data_t based on the size of T
static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch");
...@@ -1499,7 +1511,7 @@ struct scalar_type<non_native_vector_base<pk_i4_t, N>> ...@@ -1499,7 +1511,7 @@ struct scalar_type<non_native_vector_base<pk_i4_t, N>>
// non-native vector_type implementation // non-native vector_type implementation
template <typename T> template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 1, typename ck::enable_if_t<!is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>; using d1_nnv_t = non_native_vector_base<T, 1>;
...@@ -1550,7 +1562,7 @@ struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1550,7 +1562,7 @@ struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 2, typename ck::enable_if_t<!is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>; using d1_nnv_t = non_native_vector_base<T, 1>;
...@@ -1613,7 +1625,7 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1613,7 +1625,7 @@ struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 4, typename ck::enable_if_t<!is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>; using d1_nnv_t = non_native_vector_base<T, 1>;
...@@ -1686,7 +1698,7 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1686,7 +1698,7 @@ struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 8, typename ck::enable_if_t<!is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>; using d1_nnv_t = non_native_vector_base<T, 1>;
...@@ -1771,7 +1783,7 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1771,7 +1783,7 @@ struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 16, typename ck::enable_if_t<!is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using d1_nnv_t = non_native_vector_base<T, 1>; using d1_nnv_t = non_native_vector_base<T, 1>;
...@@ -1866,7 +1878,7 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1866,7 +1878,7 @@ struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 32, typename ck::enable_if_t<!is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using d2_t = non_native_vector_base<T, 2>; using d2_t = non_native_vector_base<T, 2>;
...@@ -1970,7 +1982,7 @@ struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>> ...@@ -1970,7 +1982,7 @@ struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
}; };
template <typename T> template <typename T>
struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 64, typename ck::enable_if_t<!is_native_type<T>()>>
{ {
using d1_t = T; using d1_t = T;
using d2_t = non_native_vector_base<T, 2>; using d2_t = non_native_vector_base<T, 2>;
...@@ -2210,20 +2222,230 @@ using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type; ...@@ -2210,20 +2222,230 @@ using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type; using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
using pk_i4x8_t = typename vector_type<pk_i4_t, 8>::type; using pk_i4x8_t = typename vector_type<pk_i4_t, 8>::type;
#ifdef CK_CODE_GEN_RTC
template <typename T>
struct NumericLimits;
template <>
struct NumericLimits<int32_t>
{
__host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; }
__host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; }
__host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int16_t>
{
__host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Min() noexcept { return -32768; }
__host__ __device__ static constexpr int16_t Max() noexcept { return 32767; }
__host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<int8_t>
{
__host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Min() noexcept { return -128; }
__host__ __device__ static constexpr int8_t Max() noexcept { return 127; }
__host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr int8_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint32_t>
{
__host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; }
__host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint32_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<uint16_t>
{
__host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Min() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; }
__host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; }
__host__ __device__ static constexpr uint16_t QuietNaN() { return 0; }
};
template <>
struct NumericLimits<float>
{
static constexpr unsigned int binary_min = 0x00800000;
static constexpr unsigned int binary_max = 0x7F7FFFFF;
static constexpr unsigned int binary_lowest = 0xFF7FFFFF;
static constexpr unsigned int binary_qnan = 0xFFC00001;
static constexpr unsigned int binary_inf = 0x7F8000000;
__host__ __device__ static constexpr float Min() { return bit_cast<float>(binary_min); }
__host__ __device__ static constexpr float Max() { return bit_cast<float>(binary_max); }
__host__ __device__ static constexpr float Lowest() { return bit_cast<float>(binary_lowest); }
__host__ __device__ static constexpr float QuietNaN() { return bit_cast<float>(binary_qnan); }
__host__ __device__ static constexpr float Infinity() { return bit_cast<float>(binary_inf); }
};
template <>
struct NumericLimits<half_t>
{
static constexpr unsigned short binary_min = 0x0400;
static constexpr unsigned short binary_max = 0x7BFF;
static constexpr unsigned short binary_lowest = 0xFBFF;
static constexpr unsigned short binary_qnan = 0x7FFF;
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<int4_t>
{
__host__ __device__ static constexpr int4_t Min() { return int4_t(-8); }
__host__ __device__ static constexpr int4_t Max() { return int4_t(7); }
__host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); }
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<f8_fnuz_t>
{
// negative zero nan mode with exp bias = 8
static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); }
__host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); }
__host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); }
__host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); }
};
template <>
struct NumericLimits<bf8_fnuz_t>
{
// negative zero nan mode with exp bias = 16
static constexpr uint8_t binary_min = 0x04; // 0b00000100
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); }
__host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); }
__host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); }
__host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); }
};
template <>
struct NumericLimits<f8_ocp_t>
{
static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6
static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448
static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448
static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111
__host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast<f8_ocp_t>(binary_min); }
__host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast<f8_ocp_t>(binary_max); }
__host__ __device__ static constexpr f8_ocp_t Lowest()
{
return bit_cast<f8_ocp_t>(binary_lowest);
}
__host__ __device__ static constexpr f8_ocp_t QuietNaN()
{
return bit_cast<f8_ocp_t>(binary_qnan);
}
};
template <>
struct NumericLimits<bf8_ocp_t>
{
static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14
static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344
static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344
static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101
__host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast<bf8_ocp_t>(binary_min); }
__host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast<bf8_ocp_t>(binary_max); }
__host__ __device__ static constexpr bf8_ocp_t Lowest()
{
return bit_cast<bf8_ocp_t>(binary_lowest);
}
__host__ __device__ static constexpr bf8_ocp_t QuietNaN()
{
return bit_cast<bf8_ocp_t>(binary_qnan);
}
};
#else
template <typename T> template <typename T>
struct NumericLimits struct NumericLimits
{ {
__host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); } __host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); }
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); } __host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); } __host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
__host__ __device__ static constexpr T QuietNaN() __host__ __device__ static constexpr T QuietNaN()
{ {
return std::numeric_limits<T>::quiet_NaN(); return std::numeric_limits<T>::quiet_NaN();
} }
__host__ __device__ static constexpr T Infinity() { return std::numeric_limits<T>::infinity(); } __host__ __device__ static constexpr T Infinity() { return std::numeric_limits<T>::infinity(); }
}; };
...@@ -2347,6 +2569,7 @@ struct NumericLimits<bf8_ocp_t> ...@@ -2347,6 +2569,7 @@ struct NumericLimits<bf8_ocp_t>
return bit_cast<bf8_ocp_t>(binary_qnan); return bit_cast<bf8_ocp_t>(binary_qnan);
} }
}; };
#endif
template <typename T> template <typename T>
struct NumericUtils struct NumericUtils
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP #ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace ck { namespace ck {
namespace debug { namespace debug {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
namespace ck { namespace ck {
#ifndef CK_CODE_GEN_RTC
template <bool B, typename T = void> template <bool B, typename T = void>
using enable_if = std::enable_if<B, T>; using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void> template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type; using enable_if_t = typename std::enable_if<B, T>::type;
#else
template <bool B, class T = void>
struct enable_if
{
};
template <class T>
struct enable_if<true, T>
{
using type = T;
};
template <bool B, class T = void>
using enable_if_t = typename enable_if<B, T>::type;
#endif
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment