Commit dec32dc6 authored by ThomasNing's avatar ThomasNing
Browse files

Finish the feature and merge with develop on the computeV2

parents 71352c44 c5fff071
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "batched_transpose_example.hpp"
#if 0
template <typename T>
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
{
auto len = x.get_lengths();
assert(len.size() == 4);
std::cout << "[";
for(size_t i = 0; i < len[0]; i++)
{
std::cout << i << ": [";
for(size_t j = 0; j < len[1]; j++)
{
std::cout << j << ": [";
for(size_t k = 0; k < len[2]; k++)
{
std::cout << k << ": [";
for(size_t v = 0; v < len[3]; v++)
{
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
{
auto m =
ck_tile::type_convert<float>(x(std::vector<std::size_t>{i, j, k, v}));
std::cout << m;
if(v != len[3] - 1)
std::cout << ",";
}
else
{
std::cout << x(std::vector<std::size_t>{i, j, k, v}) << " ";
}
}
std::cout << "]" << std::endl;
}
std::cout << "]" << std::endl;
}
std::cout << std::endl;
}
std::cout << "--------------------" << std::endl;
}
#endif
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "whether do CPU validation or not")
.insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
.insert("N", "2", "input batch size. ")
.insert("C", "16", "input channel size.")
.insert("H", "1", "input height size.")
.insert("W", "16", "input width size. ")
.insert("layout_in", "NCHW", "input tensor data layout - NCHW by default")
.insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "t to 1 will print kernel name");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename Type>
bool run_batched_transpose(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string prec = args.get_str("pr");
int N = args.get_int("N");
int C = args.get_int("C");
int H = args.get_int("H");
int W = args.get_int("W");
std::string layout_in = args.get_str("layout_in");
std::string layout_out = args.get_str("layout_out");
int seed = args.get_int("seed");
int dim_in[4], dim_out[4];
int stride_dim_in[4], stride_dim_out[4];
bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC";
bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW";
assert(nchw2nhwc != nhwc2nchw);
(void)nhwc2nchw;
dim_in[0] = N;
dim_in[1] = nchw2nhwc ? C : H;
dim_in[2] = nchw2nhwc ? H : W;
dim_in[3] = nchw2nhwc ? W : C;
dim_out[0] = N;
dim_out[1] = nchw2nhwc ? H : C;
dim_out[2] = nchw2nhwc ? W : H;
dim_out[3] = nchw2nhwc ? C : W;
stride_dim_in[0] = C * H * W;
stride_dim_in[1] = nchw2nhwc ? H * W : C * W;
stride_dim_in[2] = nchw2nhwc ? W : C;
stride_dim_in[3] = 1;
stride_dim_out[0] = C * H * W;
stride_dim_out[1] = nchw2nhwc ? C * W : H * W;
stride_dim_out[2] = nchw2nhwc ? C : W;
stride_dim_out[3] = 1;
if(seed < 0)
{
seed = std::time(nullptr);
}
ck_tile::HostTensor<Type> x_host(
{dim_in[0], dim_in[1], dim_in[2], dim_in[3]},
{stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]});
ck_tile::HostTensor<Type> y_host(
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
ck_tile::FillUniformDistribution<Type>{-.5f, .5f}(x_host);
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes());
x_dev.ToDevice(x_host.data());
auto trait = batched_transpose_trait{prec, layout_in};
uint32_t height = nchw2nhwc ? C : H * W;
uint32_t width = nchw2nhwc ? H * W : C;
batched_transpose_kargs karg = [&]() {
batched_transpose_kargs a_;
a_.p_input = x_dev.GetDeviceBuffer();
a_.p_output = y_dev.GetDeviceBuffer();
a_.batch = N;
a_.height = height;
a_.width = width;
return a_;
}();
ck_tile::stream_config sc{nullptr, true};
auto ms = batched_transpose(trait, karg, sc);
std::size_t num_operations = N * C * H * (W - 1);
std::size_t num_bytes = N * C * H * W * sizeof(Type);
float ave_time = ms * 1E-3;
float gb_per_sec = num_bytes / ms * 1.E-6;
float tflops = static_cast<float>(num_operations) / ms * 1.E-6;
std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H
<< ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out
<< " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops"
<< gb_per_sec << " GB/s, " << std::endl;
printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n",
prec.c_str(),
N,
C,
H,
W,
layout_in.c_str(),
ms);
if(ms < 0)
printf("not supported\n");
fflush(stdout);
if(ms < 0)
{
return false;
}
y_dev.FromDevice(y_host.data());
bool rtn = true;
if(validate)
{
// this host buffer will not copy to GPU, so no need use stride
ck_tile::HostTensor<Type> y_ref(
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
ck_tile::reference_batched_transpose<Type>(x_host, y_ref, layout_in, layout_out);
auto [rtol, atol] = get_elimit<Type>("");
rtn &= ck_tile::check_err(
y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol);
}
printf("valid:%s\n", rtn ? "y" : "n");
fflush(stdout);
return rtn;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string prec = args.get_str("pr");
bool r = true;
if(prec.compare("fp32") == 0)
{
r &= run_batched_transpose<float>(args);
}
else if(prec.compare("fp16") == 0)
{
r &= run_batched_transpose<ck_tile::fp16_t>(args);
}
else if(prec.compare("bf16") == 0)
{
r &= run_batched_transpose<ck_tile::bf16_t>(args);
}
else if(prec.compare("int8") == 0)
{
r &= run_batched_transpose<ck_tile::int8_t>(args);
}
return r ? 0 : -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/ops/batched_transpose.hpp"
#include <vector>
#include <string>
#pragma once
struct batched_transpose_trait
{
std::string type;
std::string layout;
};
struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs
{
};
float batched_transpose(batched_transpose_trait t,
batched_transpose_kargs a,
ck_tile::stream_config s);
#!/bin/sh
EXE=./build/bin/tile_example_batched_transpose
for pr in "fp32" "fp16" "int8" ; do
$EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW'
$EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC'
$EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC'
done
...@@ -17,3 +17,4 @@ add_subdirectory(14_moe_smoothquant) ...@@ -17,3 +17,4 @@ add_subdirectory(14_moe_smoothquant)
add_subdirectory(15_fused_moe) add_subdirectory(15_fused_moe)
add_subdirectory(16_batched_gemm) add_subdirectory(16_batched_gemm)
add_subdirectory(17_grouped_gemm) add_subdirectory(17_grouped_gemm)
add_subdirectory(35_batched_transpose)
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -17,7 +17,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -17,7 +17,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// to do: add various levels of logging with CK_LOG_LEVEL // to do: add various levels of logging with CK_LOG_LEVEL
#ifndef CK_TIME_KERNEL
#define CK_TIME_KERNEL 1 #define CK_TIME_KERNEL 1
#endif
// constant address space for kernel parameter // constant address space for kernel parameter
// https://llvm.org/docs/AMDGPUUsage.html#address-spaces // https://llvm.org/docs/AMDGPUUsage.html#address-spaces
...@@ -155,6 +157,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -155,6 +157,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// LDS direct loads using inline assembly // LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0 #define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0
// set rounding to nearest even as default for bf16 conversions
#define CK_USE_RNE_BF16_CONVERSION 1
// set rounding to nearest even as default for f8 conversions // set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0 #define CK_USE_SR_F8_CONVERSION 0
...@@ -230,13 +235,18 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -230,13 +235,18 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// workaround: compiler issue on gfx908 // workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1 #define CK_WORKAROUND_SWDEV_388832 1
// denorm test fix, required to work around dissue // denorm test fix, necessary for gfx90a
#ifndef CK_WORKAROUND_DENORM_FIX #ifndef CK_GFX90A_DENORM_WORKAROUND
#define CK_WORKAROUND_DENORM_FIX 0 #define CK_GFX90A_DENORM_WORKAROUND 0
#endif // CK_GFX90A_DENORM_WORKAROUND
// Enable only for gfx90a
#if defined(__gfx90a__)
#if CK_GFX90A_DENORM_WORKAROUND
#define CK_GFX90A_DENORM_WORKAROUND 1
#endif // CK_GFX90A_DENORM_WORKAROUND is set to 1
#else #else
// enable only for gfx90a #define CK_GFX90A_DENORM_WORKAROUND 0
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__) #endif // gfx90a
#endif // CK_WORKAROUND_DENORM_FIX
// set flag to 1 to build deprecated instances // set flag to 1 to build deprecated instances
#define CK_BUILD_DEPRECATED 1 #define CK_BUILD_DEPRECATED 1
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -121,19 +121,6 @@ __global__ void ...@@ -121,19 +121,6 @@ __global__ void
static_for<0, NumDTensor, 1>{}( static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; });
if constexpr(is_same_v<AElementwiseOperation, element_wise::DynamicUnaryOp>)
{
a_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(is_same_v<BElementwiseOperation, element_wise::DynamicUnaryOp>)
{
b_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(is_same_v<CDEElementwiseOperation, element_wise::DynamicUnaryOp>)
{
cde_element_op.InitUnaryOpPtrOnDevice();
}
if constexpr(isMultiA || isMultiB) if constexpr(isMultiA || isMultiB)
{ {
AsPointer p_as_grid_grp; AsPointer p_as_grid_grp;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -247,32 +247,6 @@ struct DequantPack8 ...@@ -247,32 +247,6 @@ struct DequantPack8
constexpr const static bool is_pack8_invocable = true; constexpr const static bool is_pack8_invocable = true;
}; };
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnon-virtual-dtor"
struct UnaryOpBase
{
public:
__host__ __device__ ~UnaryOpBase() = default;
__host__ __device__ constexpr UnaryOpBase() = default;
__host__ __device__ constexpr UnaryOpBase(const UnaryOpBase&) = default;
__host__ __device__ constexpr UnaryOpBase(UnaryOpBase&&) = default;
__host__ __device__ UnaryOpBase& operator=(const UnaryOpBase&) = default;
__host__ __device__ UnaryOpBase& operator=(UnaryOpBase&&) = default;
__host__ __device__ virtual inline void operator()(float& y, const float& x) const = 0;
__host__ __device__ virtual inline void operator()(double& y, const double& x) const = 0;
__host__ __device__ virtual inline void operator()(int32_t& y, const int32_t& x) const = 0;
__host__ __device__ virtual inline void operator()(int8_t& y, const int8_t& x) const = 0;
__host__ __device__ virtual inline void operator()(half_t& y, const half_t& x) const = 0;
__host__ __device__ virtual inline void operator()(bhalf_t& y, const bhalf_t& x) const = 0;
};
struct PassThroughPack2 struct PassThroughPack2
{ {
template <typename Y, typename X> template <typename Y, typename X>
...@@ -304,27 +278,8 @@ struct PassThroughPack2 ...@@ -304,27 +278,8 @@ struct PassThroughPack2
constexpr const static bool is_pack2_invocable = true; constexpr const static bool is_pack2_invocable = true;
}; };
struct PassThrough final : public UnaryOpBase struct PassThrough
{ {
__host__ __device__ constexpr PassThrough() = default;
__host__ __device__ constexpr PassThrough(const PassThrough&) = default;
__host__ __device__ constexpr PassThrough(PassThrough&&) = default;
__host__ __device__ PassThrough& operator=(const PassThrough&) = default;
__host__ __device__ PassThrough& operator=(PassThrough&&) = default;
__host__ __device__ ~PassThrough() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final { y = x; }
__host__ __device__ inline void operator()(double& y, const double& x) const final { y = x; }
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final { y = x; }
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final { y = x; }
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final { y = x; }
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final { y = x; }
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
...@@ -334,6 +289,12 @@ struct PassThrough final : public UnaryOpBase ...@@ -334,6 +289,12 @@ struct PassThrough final : public UnaryOpBase
y = x; y = x;
} }
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const __host__ __device__ void operator()<float, double>(float& y, const double& x) const
{ {
...@@ -346,12 +307,36 @@ struct PassThrough final : public UnaryOpBase ...@@ -346,12 +307,36 @@ struct PassThrough final : public UnaryOpBase
y = type_convert<double>(x); y = type_convert<double>(x);
} }
template <>
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<half_t, half_t>(half_t& y, const half_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const __host__ __device__ void operator()<half_t, float>(half_t& y, const float& x) const
{ {
y = type_convert<half_t>(x); y = type_convert<half_t>(x);
} }
template <>
__host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const __host__ __device__ void operator()<bhalf_t, float>(bhalf_t& y, const float& x) const
{ {
...@@ -376,6 +361,12 @@ struct PassThrough final : public UnaryOpBase ...@@ -376,6 +361,12 @@ struct PassThrough final : public UnaryOpBase
y = type_convert<float>(x); y = type_convert<float>(x);
} }
template <>
__host__ __device__ void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const __host__ __device__ void operator()<half_t, int8_t>(half_t& y, const int8_t& x) const
{ {
...@@ -675,45 +666,20 @@ struct UnarySquare ...@@ -675,45 +666,20 @@ struct UnarySquare
}; };
}; };
struct UnaryAbs final : public UnaryOpBase struct UnaryAbs
{ {
__host__ __device__ constexpr UnaryAbs() = default; template <typename T>
__host__ __device__ constexpr UnaryAbs(const UnaryAbs&) = default; __host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ constexpr UnaryAbs(UnaryAbs&&) = default;
__host__ __device__ UnaryAbs& operator=(const UnaryAbs&) = default;
__host__ __device__ UnaryAbs& operator=(UnaryAbs&&) = default;
__host__ __device__ ~UnaryAbs() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = ck::math::abs(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{ {
y = ck::math::abs(x); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
} is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
y = ck::math::abs(x); y = ck::math::abs(x);
} };
template <>
__host__ __device__ void operator()(f8_t& y, const f8_t& x) const __host__ __device__ void operator()(f8_t& y, const f8_t& x) const
{ {
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x))); y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
...@@ -732,41 +698,20 @@ struct UnarySqrt ...@@ -732,41 +698,20 @@ struct UnarySqrt
}; };
}; };
struct Relu final : public UnaryOpBase struct Relu
{ {
__host__ __device__ constexpr Relu() = default; template <typename T>
__host__ __device__ constexpr Relu(const Relu&) = default; __host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ constexpr Relu(Relu&&) = default;
__host__ __device__ Relu& operator=(const Relu&) = default;
__host__ __device__ Relu& operator=(Relu&&) = default;
__host__ __device__ ~Relu() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = x > 0 ? x : 0;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0; y = x > 0 ? x : 0;
} }
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final template <>
__host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const
{ {
float x_f32 = ck::type_convert<float>(x); float x_f32 = ck::type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0; float y_f32 = x_f32 > 0 ? x_f32 : 0;
...@@ -913,52 +858,18 @@ struct Gelu ...@@ -913,52 +858,18 @@ struct Gelu
} }
}; };
struct Sigmoid final : public UnaryOpBase struct Sigmoid
{ {
__host__ __device__ constexpr Sigmoid() = default; template <typename T>
__host__ __device__ constexpr Sigmoid(const Sigmoid&) = default; __host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ constexpr Sigmoid(Sigmoid&&) = default;
__host__ __device__ Sigmoid& operator=(const Sigmoid&) = default;
__host__ __device__ Sigmoid& operator=(Sigmoid&&) = default;
__host__ __device__ ~Sigmoid() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
constexpr float one = type_convert<float>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
constexpr double one = type_convert<double>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
constexpr int32_t one = type_convert<int32_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
constexpr int8_t one = type_convert<int8_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
constexpr half_t one = type_convert<half_t>(1);
y = one / (one + ck::math::exp(-x));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{ {
constexpr float one = type_convert<float>(1); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
float x_f32 = ck::type_convert<float>(x); is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
float y_f32 = one / (one + ck::math::exp(x_f32)); is_same<T, int32_t>::value,
y = ck::type_convert<bhalf_t>(y_f32); "Data type is not supported by this operation!");
} constexpr T one = type_convert<T>(1);
y = one / (one + ck::math::exp(-x));
};
}; };
struct Silu struct Silu
...@@ -974,44 +885,18 @@ struct Silu ...@@ -974,44 +885,18 @@ struct Silu
}; };
}; };
struct TanH final : public UnaryOpBase struct TanH
{ {
__host__ __device__ constexpr TanH() = default; template <typename T>
__host__ __device__ constexpr TanH(const TanH&) = default; __host__ __device__ void operator()(T& y, const T& x) const
__host__ __device__ constexpr TanH(TanH&&) = default;
__host__ __device__ TanH& operator=(const TanH&) = default;
__host__ __device__ TanH& operator=(TanH&&) = default;
__host__ __device__ ~TanH() = default;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
y = ck::math::tanh(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{ {
y = ck::math::tanh(x); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
} is_same<T, ck::half_t>::value || is_same<T, int8_t>::value ||
is_same<T, int32_t>::value,
"Data type is not supported by this operation!");
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
y = ck::math::tanh(x); y = ck::math::tanh(x);
} };
}; };
struct ACos struct ACos
...@@ -1252,418 +1137,138 @@ struct Rcp ...@@ -1252,418 +1137,138 @@ struct Rcp
}; };
}; };
struct Swish final : public UnaryOpBase struct Swish
{ {
__host__ __device__ constexpr Swish(const Swish&) = default; Swish(float beta = 1.0f) : beta_(beta) {}
__host__ __device__ constexpr Swish(Swish&&) = default;
__host__ __device__ ~Swish() = default;
__host__ __device__ Swish(float beta = 1.0f) : beta_(beta) {}
__host__ __device__ float get_beta() const { return beta_; }
const float beta_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<float>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<double>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<int32_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<int8_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<half_t>(x / (1.f + ck::math::exp(bx)));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
float bx = -beta_ * type_convert<float>(x);
y = type_convert<bhalf_t>(x / (1.f + ck::math::exp(bx)));
}
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
static_assert(is_same<X, float>::value || is_same<X, double>::value || static_assert(is_same<X, float>::value || is_same<X, double>::value ||
is_same<X, half_t>::value, is_same<X, ck::half_t>::value || is_same<X, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
static_assert(is_same<Y, float>::value || is_same<Y, double>::value || static_assert(is_same<Y, float>::value || is_same<Y, double>::value ||
is_same<Y, half_t>::value, is_same<Y, ck::half_t>::value || is_same<Y, int8_t>::value,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x); float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck::math::exp(bx))); y = type_convert<Y>(x / (1.f + ck::math::exp(bx)));
} };
const float beta_;
}; };
struct SoftRelu final : public UnaryOpBase struct SoftRelu
{ {
__host__ __device__ constexpr SoftRelu(const SoftRelu&) = default; SoftRelu(float alpha = 1.f) : alpha_(alpha){};
__host__ __device__ constexpr SoftRelu(SoftRelu&&) = default;
__host__ __device__ ~SoftRelu() = default;
__host__ __device__ SoftRelu(float alpha = 1.0f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; } template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
constexpr float one = type_convert<float>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
constexpr double one = type_convert<double>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
constexpr int32_t one = type_convert<int32_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
constexpr int8_t one = type_convert<int8_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
constexpr half_t one = type_convert<half_t>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{ {
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
constexpr bhalf_t one = type_convert<bhalf_t>(1); is_same<T, half_t>::value || is_same<T, int32_t>::value ||
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha; is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = ck::math::log(one + ck::math::exp(x * casted_alpha)) / casted_alpha;
} }
const float alpha_;
}; };
struct Power final : public UnaryOpBase struct Power
{ {
__host__ __device__ constexpr Power(const Power&) = default; Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
__host__ __device__ constexpr Power(Power&&) = default; : alpha_(alpha), beta_(beta), gamma_(gamma){};
__host__ __device__ ~Power() = default;
__host__ __device__ Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) template <typename T>
: alpha_(alpha), beta_(beta), gamma_(gamma) __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
} }
__host__ __device__ float get_alpha() const { return alpha_; }
__host__ __device__ float get_beta() const { return beta_; }
__host__ __device__ float get_gamma() const { return gamma_; }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
const float gamma_; const float gamma_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
float casted_beta = type_convert<float>(beta_);
float casted_gamma = type_convert<float>(gamma_);
float shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
double casted_beta = type_convert<double>(beta_);
double casted_gamma = type_convert<double>(gamma_);
double shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
int32_t casted_beta = type_convert<int32_t>(beta_);
int32_t casted_gamma = type_convert<int32_t>(gamma_);
int32_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
int8_t casted_beta = type_convert<int8_t>(beta_);
int8_t casted_gamma = type_convert<int8_t>(gamma_);
int8_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
half_t casted_beta = type_convert<half_t>(beta_);
half_t casted_gamma = type_convert<half_t>(gamma_);
half_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
bhalf_t casted_beta = type_convert<bhalf_t>(beta_);
bhalf_t casted_gamma = type_convert<bhalf_t>(gamma_);
bhalf_t shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck::math::pow(shifted_scaled_x, casted_gamma);
}
}; };
struct ClippedRelu final : public UnaryOpBase struct ClippedRelu
{ {
__host__ __device__ constexpr ClippedRelu(const ClippedRelu&) = default; ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
__host__ __device__ constexpr ClippedRelu(ClippedRelu&&) = default;
__host__ __device__ ~ClippedRelu() = default;
__host__ __device__ ClippedRelu(float alpha = 0.f, float beta = 1.f) template <typename T>
: alpha_(alpha), beta_(beta) __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
} }
__host__ __device__ float get_alpha() const { return alpha_; }
__host__ __device__ float get_beta() const { return beta_; }
const float alpha_; const float alpha_;
const float beta_; const float beta_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
float casted_beta = type_convert<float>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
double casted_beta = type_convert<double>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
int32_t casted_beta = type_convert<int32_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
int8_t casted_beta = type_convert<int8_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
half_t casted_beta = type_convert<half_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_);
bhalf_t casted_beta = type_convert<bhalf_t>(beta_);
y = ck::math::min(casted_beta, ck::math::max(casted_alpha, x));
}
}; };
struct LeakyRelu final : public UnaryOpBase struct LeakyRelu
{ {
__host__ __device__ constexpr LeakyRelu(const LeakyRelu&) = default; LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
__host__ __device__ constexpr LeakyRelu(LeakyRelu&&) = default;
__host__ __device__ ~LeakyRelu() = default;
__host__ __device__ LeakyRelu(float alpha = 0.f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final template <typename T>
{ __host__ __device__ void operator()(T& y, const T& x) const
int8_t casted_alpha = type_convert<int8_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
__host__ __device__ inline void operator()([[maybe_unused]] bhalf_t& y,
[[maybe_unused]] const bhalf_t& x) const final
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
} }
const float alpha_;
}; };
struct Elu final : public UnaryOpBase struct Elu
{ {
__host__ __device__ constexpr Elu(const Elu&) = default; Elu(float alpha = 1.f) : alpha_(alpha){};
__host__ __device__ constexpr Elu(Elu&&) = default;
__host__ __device__ ~Elu() = default;
__host__ __device__ Elu(float alpha = 1.f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; }
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final template <typename T>
{ __host__ __device__ void operator()(T& y, const T& x) const
int8_t casted_alpha = type_convert<int8_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{ {
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
y = x > 0 ? x : casted_alpha * ck::math::expm1(x); is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck::math::expm1(x);
} }
const float alpha_;
}; };
struct Logistic final : public UnaryOpBase struct Logistic
{ {
__host__ __device__ constexpr Logistic(const Logistic&) = default; Logistic(float alpha = 1.f) : alpha_(alpha){};
__host__ __device__ constexpr Logistic(Logistic&&) = default;
__host__ __device__ ~Logistic() = default;
__host__ __device__ Logistic(float alpha = 1.0f) : alpha_(alpha) {}
__host__ __device__ float get_alpha() const { return alpha_; } template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
const float alpha_;
__host__ __device__ inline void operator()(float& y, const float& x) const final
{
float casted_alpha = type_convert<float>(alpha_);
constexpr float one = type_convert<float>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(double& y, const double& x) const final
{
double casted_alpha = type_convert<double>(alpha_);
constexpr double one = type_convert<double>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(int32_t& y, const int32_t& x) const final
{
int32_t casted_alpha = type_convert<int32_t>(alpha_);
constexpr int32_t one = type_convert<int32_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(int8_t& y, const int8_t& x) const final
{
int8_t casted_alpha = type_convert<int8_t>(alpha_);
constexpr int8_t one = type_convert<int8_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(half_t& y, const half_t& x) const final
{
half_t casted_alpha = type_convert<half_t>(alpha_);
constexpr half_t one = type_convert<half_t>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
}
__host__ __device__ inline void operator()(bhalf_t& y, const bhalf_t& x) const final
{ {
bhalf_t casted_alpha = type_convert<bhalf_t>(alpha_); static_assert(is_same<T, float>::value || is_same<T, double>::value ||
constexpr bhalf_t one = type_convert<bhalf_t>(1); is_same<T, half_t>::value || is_same<T, int32_t>::value ||
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); is_same<T, int8_t>::value,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha);
} }
const float alpha_;
}; };
struct ConvInvscale struct ConvInvscale
...@@ -1728,7 +1333,7 @@ struct ConvScaleRelu ...@@ -1728,7 +1333,7 @@ struct ConvScaleRelu
__host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const __host__ __device__ void operator()<f8_t, float>(f8_t& e, const float& c) const
{ {
float x; float x;
Relu{}(x, c * scale_in_ * scale_wei_); Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_);
e = type_convert<f8_t>(x * scale_out_); e = type_convert<f8_t>(x * scale_out_);
}; };
...@@ -1809,225 +1414,138 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N> ...@@ -1809,225 +1414,138 @@ struct FastNumericArrayConverter<uint8_t, ck::half_t, N>
struct DynamicUnaryOp struct DynamicUnaryOp
{ {
DynamicUnaryOp& operator=(const DynamicUnaryOp& other)
{
if(this != &other)
{
unary_op_ptr_ = other.unary_op_ptr_;
unary_op_type_ = other.unary_op_type_;
}
return *this;
}
__host__ __device__ DynamicUnaryOp() = delete; __host__ __device__ DynamicUnaryOp() = delete;
__host__ __device__ DynamicUnaryOp(const Swish& swish) __host__ __device__ DynamicUnaryOp(const Swish& swish)
: unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_}
{ {
unary_op_type_ = UnaryOpType::Swish;
beta = swish.get_beta();
} }
__host__ __device__ DynamicUnaryOp(const Swish&& swish) __host__ __device__ DynamicUnaryOp(const Swish&& swish)
: unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_}
{ {
unary_op_type_ = UnaryOpType::Swish;
beta = swish.get_beta();
} }
__host__ __device__ DynamicUnaryOp(const Sigmoid&) { unary_op_type_ = UnaryOpType::Sigmoid; } __host__ __device__ DynamicUnaryOp(const Sigmoid&) : unary_op_type_(UnaryOpType::Sigmoid) {}
__host__ __device__ DynamicUnaryOp(const Sigmoid&&) { unary_op_type_ = UnaryOpType::Sigmoid; } __host__ __device__ DynamicUnaryOp(const Sigmoid&&) : unary_op_type_(UnaryOpType::Sigmoid) {}
__host__ __device__ DynamicUnaryOp(const PassThrough&) __host__ __device__ DynamicUnaryOp(const PassThrough&)
: unary_op_type_(UnaryOpType::PassThrough)
{ {
unary_op_type_ = UnaryOpType::PassThrough;
} }
__host__ __device__ DynamicUnaryOp(const PassThrough&&) __host__ __device__ DynamicUnaryOp(const PassThrough&&)
: unary_op_type_(UnaryOpType::PassThrough)
{ {
unary_op_type_ = UnaryOpType::PassThrough;
} }
__host__ __device__ DynamicUnaryOp(const Logistic& logistic) __host__ __device__ DynamicUnaryOp(const Logistic& logistic)
: unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_}
{ {
unary_op_type_ = UnaryOpType::Logistic;
alpha = logistic.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const Logistic&& logistic) __host__ __device__ DynamicUnaryOp(const Logistic&& logistic)
: unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_}
{ {
unary_op_type_ = UnaryOpType::Logistic;
alpha = logistic.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const TanH&) { unary_op_type_ = UnaryOpType::TanH; } __host__ __device__ DynamicUnaryOp(const TanH&) : unary_op_type_(UnaryOpType::TanH) {}
__host__ __device__ DynamicUnaryOp(const TanH&&) { unary_op_type_ = UnaryOpType::TanH; } __host__ __device__ DynamicUnaryOp(const TanH&&) : unary_op_type_(UnaryOpType::TanH) {}
__host__ __device__ DynamicUnaryOp(const Relu&) { unary_op_type_ = UnaryOpType::Relu; } __host__ __device__ DynamicUnaryOp(const Relu&) : unary_op_type_(UnaryOpType::Relu) {}
__host__ __device__ DynamicUnaryOp(const Relu&&) { unary_op_type_ = UnaryOpType::Relu; } __host__ __device__ DynamicUnaryOp(const Relu&&) : unary_op_type_(UnaryOpType::Relu) {}
__host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu) __host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu)
: unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_}
{ {
unary_op_type_ = UnaryOpType::SoftRelu;
alpha = softrelu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu) __host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu)
: unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_}
{ {
unary_op_type_ = UnaryOpType::SoftRelu;
alpha = softrelu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const UnaryAbs&) { unary_op_type_ = UnaryOpType::UnaryAbs; } __host__ __device__ DynamicUnaryOp(const UnaryAbs&) : unary_op_type_(UnaryOpType::UnaryAbs) {}
__host__ __device__ DynamicUnaryOp(const UnaryAbs&&) { unary_op_type_ = UnaryOpType::UnaryAbs; } __host__ __device__ DynamicUnaryOp(const UnaryAbs&&) : unary_op_type_(UnaryOpType::UnaryAbs) {}
__host__ __device__ DynamicUnaryOp(const Power& pow) __host__ __device__ DynamicUnaryOp(const Power& pow)
: unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_)
{ {
unary_op_type_ = UnaryOpType::Power;
alpha = pow.get_alpha();
beta = pow.get_beta();
gamma = pow.get_gamma();
} }
__host__ __device__ DynamicUnaryOp(const Power&& pow) __host__ __device__ DynamicUnaryOp(const Power&& pow)
: unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_)
{ {
unary_op_type_ = UnaryOpType::Power;
alpha = pow.get_alpha();
beta = pow.get_beta();
gamma = pow.get_gamma();
} }
__host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu) __host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu)
: unary_op_type_(UnaryOpType::ClippedRelu),
clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_}
{ {
unary_op_type_ = UnaryOpType::ClippedRelu;
alpha = clippedrelu.get_alpha();
beta = clippedrelu.get_beta();
} }
__host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu) __host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu)
: unary_op_type_(UnaryOpType::ClippedRelu),
clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_}
{ {
unary_op_type_ = UnaryOpType::ClippedRelu;
alpha = clippedrelu.get_alpha();
beta = clippedrelu.get_beta();
} }
__host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu) __host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu)
: unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_}
{ {
unary_op_type_ = UnaryOpType::LeakyRelu;
alpha = leakyrelu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu) __host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu)
: unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_}
{ {
unary_op_type_ = UnaryOpType::LeakyRelu;
alpha = leakyrelu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const Elu& elu) __host__ __device__ DynamicUnaryOp(const Elu& elu)
: unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_}
{ {
unary_op_type_ = UnaryOpType::Elu;
alpha = elu.get_alpha();
} }
__host__ __device__ DynamicUnaryOp(const Elu&& elu) __host__ __device__ DynamicUnaryOp(const Elu&& elu)
: unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_}
{ {
unary_op_type_ = UnaryOpType::Elu;
alpha = elu.get_alpha();
}
__host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op)
: unary_op_type_(dynamic_op.unary_op_type_),
unary_op_ptr_(dynamic_op.unary_op_ptr_),
alpha(dynamic_op.alpha),
beta(dynamic_op.beta),
gamma(dynamic_op.gamma)
{
}
__host__ __device__ ~DynamicUnaryOp()
{
switch(unary_op_type_)
{
case(UnaryOpType::Swish): delete static_cast<Swish*>(unary_op_ptr_); break;
case(UnaryOpType::Sigmoid): delete static_cast<Sigmoid*>(unary_op_ptr_); break;
case(UnaryOpType::PassThrough): delete static_cast<PassThrough*>(unary_op_ptr_); break;
case(UnaryOpType::Logistic): delete static_cast<Logistic*>(unary_op_ptr_); break;
case(UnaryOpType::TanH): delete static_cast<TanH*>(unary_op_ptr_); break;
case(UnaryOpType::Relu): delete static_cast<Relu*>(unary_op_ptr_); break;
case(UnaryOpType::SoftRelu): delete static_cast<SoftRelu*>(unary_op_ptr_); break;
case(UnaryOpType::UnaryAbs): delete static_cast<UnaryAbs*>(unary_op_ptr_); break;
case(UnaryOpType::Power): delete static_cast<Power*>(unary_op_ptr_); break;
case(UnaryOpType::ClippedRelu): delete static_cast<ClippedRelu*>(unary_op_ptr_); break;
case(UnaryOpType::LeakyRelu): delete static_cast<LeakyRelu*>(unary_op_ptr_); break;
case(UnaryOpType::Elu): delete static_cast<Elu*>(unary_op_ptr_); break;
default: break;
}
} }
__device__ void InitUnaryOpPtrOnDevice() __host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) = default;
{
switch(unary_op_type_)
{
case(UnaryOpType::Swish): unary_op_ptr_ = new Swish(beta); break;
case(UnaryOpType::Sigmoid): unary_op_ptr_ = new Sigmoid; break;
case(UnaryOpType::PassThrough): unary_op_ptr_ = new PassThrough; break;
case(UnaryOpType::Logistic): unary_op_ptr_ = new Logistic(alpha); break;
case(UnaryOpType::TanH): unary_op_ptr_ = new TanH; break;
case(UnaryOpType::Relu): unary_op_ptr_ = new Relu; break;
case(UnaryOpType::SoftRelu): unary_op_ptr_ = new SoftRelu(alpha); break;
case(UnaryOpType::UnaryAbs): unary_op_ptr_ = new UnaryAbs; break;
case(UnaryOpType::Power): unary_op_ptr_ = new Power(alpha, beta, gamma); break;
case(UnaryOpType::ClippedRelu): unary_op_ptr_ = new ClippedRelu(alpha, beta); break;
case(UnaryOpType::LeakyRelu): unary_op_ptr_ = new LeakyRelu(alpha); break;
case(UnaryOpType::Elu): unary_op_ptr_ = new Elu(alpha); break;
default: unary_op_ptr_ = nullptr; break;
}
}
template <typename Y, typename X> __host__ __device__ ~DynamicUnaryOp() {}
__device__ void operator()(Y& y, const X& x) const
{
isSupported<X, Y>();
unary_op_ptr_->operator()(y, x);
}
template <typename Y, typename X> template <typename Y, typename X>
__host__ void operator()(Y& y, const X& x) const __host__ __device__ void operator()(Y& y, const X& x) const
{ {
isSupported<X, Y>();
switch(unary_op_type_) switch(unary_op_type_)
{ {
case(UnaryOpType::Swish): Swish{}.operator()(y, x); break; case(UnaryOpType::Swish): swish_(y, x); break;
case(UnaryOpType::Sigmoid): Sigmoid{}.operator()(y, x); break; case(UnaryOpType::Sigmoid): sigmoid_(y, x); break;
case(UnaryOpType::PassThrough): PassThrough{}.operator()(y, x); break; case(UnaryOpType::PassThrough): pass_through_(y, x); break;
case(UnaryOpType::Logistic): Logistic{}.operator()(y, x); break; case(UnaryOpType::Logistic): logistic_(y, x); break;
case(UnaryOpType::TanH): TanH{}.operator()(y, x); break; case(UnaryOpType::TanH): tanh_(y, x); break;
case(UnaryOpType::Relu): Relu{}.operator()(y, x); break; case(UnaryOpType::Relu): relu_(y, x); break;
case(UnaryOpType::SoftRelu): SoftRelu{}.operator()(y, x); break; case(UnaryOpType::SoftRelu): soft_relu_(y, x); break;
case(UnaryOpType::UnaryAbs): UnaryAbs{}.operator()(y, x); break; case(UnaryOpType::UnaryAbs): unary_abs_(y, x); break;
case(UnaryOpType::Power): Power{}.operator()(y, x); break; case(UnaryOpType::Power): power_(y, x); break;
case(UnaryOpType::ClippedRelu): ClippedRelu{}.operator()(y, x); break; case(UnaryOpType::ClippedRelu): clipped_relu_(y, x); break;
case(UnaryOpType::LeakyRelu): LeakyRelu{}.operator()(y, x); break; case(UnaryOpType::LeakyRelu): leaky_relu_(y, x); break;
case(UnaryOpType::Elu): Elu{}.operator()(y, x); break; case(UnaryOpType::Elu): elu_(y, x); break;
default: break; default: break;
} }
} }
template <typename X, typename Y> template <>
__device__ __host__ constexpr void isSupported() const __host__ __device__ void operator()<bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x) const
{ {
float y_float;
static_assert(std::is_same<X, Y>::value, "X and Y must be of the same type"); float x_float = type_convert<float>(x);
this->operator()(y_float, x_float);
static_assert(is_same<X, float>::value || is_same<X, double>::value || y = type_convert<bhalf_t>(y_float);
is_same<X, bhalf_t>::value || is_same<X, half_t>::value ||
is_same<X, int32_t>::value || is_same<X, int8_t>::value,
"Data type is not supported by this operation!");
} }
private: private:
...@@ -2049,12 +1567,20 @@ struct DynamicUnaryOp ...@@ -2049,12 +1567,20 @@ struct DynamicUnaryOp
public: public:
UnaryOpType unary_op_type_; UnaryOpType unary_op_type_;
UnaryOpBase* unary_op_ptr_ = nullptr;
float alpha; Swish swish_;
float beta; Sigmoid sigmoid_;
float gamma; PassThrough pass_through_;
Logistic logistic_;
TanH tanh_;
Relu relu_;
SoftRelu soft_relu_;
UnaryAbs unary_abs_;
Power power_;
ClippedRelu clipped_relu_;
LeakyRelu leaky_relu_;
Elu elu_;
}; };
#pragma clang diagnostic pop
} // namespace element_wise } // namespace element_wise
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -101,7 +101,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -101,7 +101,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
using GridwiseGemmPipe = remove_cvref_t< using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
#if CK_WORKAROUND_DENORM_FIX #if CK_GFX90A_DENORM_WORKAROUND
using AComputeDataType = using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>; conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
using BComputeDataType = using BComputeDataType =
......
...@@ -100,7 +100,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -100,7 +100,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using GridwiseGemmPipe = remove_cvref_t< using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
#if CK_WORKAROUND_DENORM_FIX #if CK_GFX90A_DENORM_WORKAROUND
using AComputeDataType = using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>; conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
using BComputeDataType = using BComputeDataType =
......
...@@ -164,7 +164,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -164,7 +164,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
using GridwiseGemmPipe = remove_cvref_t< using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>; decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
#if CK_WORKAROUND_DENORM_FIX #if CK_GFX90A_DENORM_WORKAROUND
using AComputeDataType = using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>; conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
#else #else
......
...@@ -271,7 +271,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -271,7 +271,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// when mfma if fixed, remove this section and update // when mfma if fixed, remove this section and update
// FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB, // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB,
// throughout this file // throughout this file
#if CK_WORKAROUND_DENORM_FIX #if CK_GFX90A_DENORM_WORKAROUND
using FloatAAdjusted = using FloatAAdjusted =
conditional_t<is_same_v<ComputeTypeA, ck::half_t>, ck::bhalf_t, ComputeTypeA>; conditional_t<is_same_v<ComputeTypeA, ck::half_t>, ck::bhalf_t, ComputeTypeA>;
using FloatBAdjusted = using FloatBAdjusted =
......
...@@ -254,7 +254,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -254,7 +254,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update // when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file // FloatABAdjusted -> FloatAB throughout this file
#if CK_WORKAROUND_DENORM_FIX #if CK_GFX90A_DENORM_WORKAROUND
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>; using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
#else #else
using FloatABAdjusted = FloatAB; using FloatABAdjusted = FloatAB;
......
...@@ -19,8 +19,6 @@ struct pk_i4_t ...@@ -19,8 +19,6 @@ struct pk_i4_t
type data; type data;
__host__ __device__ constexpr pk_i4_t() : data{type{}} {} __host__ __device__ constexpr pk_i4_t() : data{type{}} {}
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {} __host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
__host__ __device__ constexpr operator float() const { return static_cast<int8_t>(data); }
}; };
inline constexpr auto next_pow2(uint32_t x) inline constexpr auto next_pow2(uint32_t x)
......
...@@ -29,6 +29,13 @@ struct DynamicBuffer ...@@ -29,6 +29,13 @@ struct DynamicBuffer
ElementSpaceSize element_space_size_; ElementSpaceSize element_space_size_;
T invalid_element_value_ = T{0}; T invalid_element_value_ = T{0};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
return 2;
else
return 1;
}();
__host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
: p_data_{p_data}, element_space_size_{element_space_size} : p_data_{p_data}, element_space_size_{element_space_size}
{ {
...@@ -82,14 +89,18 @@ struct DynamicBuffer ...@@ -82,14 +89,18 @@ struct DynamicBuffer
return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
t_per_x, t_per_x,
coherence>( coherence>(
p_data_, i, is_valid_element, element_space_size_); p_data_, i, is_valid_element, element_space_size_ / PackedSize);
} }
else else
{ {
return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>, return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
t_per_x, t_per_x,
coherence>( coherence>(
p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); p_data_,
i,
is_valid_element,
element_space_size_ / PackedSize,
invalid_element_value_);
} }
} }
else else
...@@ -191,7 +202,7 @@ struct DynamicBuffer ...@@ -191,7 +202,7 @@ struct DynamicBuffer
dst_buf.p_data_, dst_buf.p_data_,
dst_offset, dst_offset,
is_valid_element, is_valid_element,
element_space_size_); element_space_size_ / PackedSize);
} }
template <typename X, template <typename X,
...@@ -226,7 +237,7 @@ struct DynamicBuffer ...@@ -226,7 +237,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>( amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
} }
else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value && is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
...@@ -378,7 +389,7 @@ struct DynamicBuffer ...@@ -378,7 +389,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
} }
else else
{ {
...@@ -417,7 +428,7 @@ struct DynamicBuffer ...@@ -417,7 +428,7 @@ struct DynamicBuffer
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>( amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
x, p_data_, i, is_valid_element, element_space_size_); x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
} }
else if(is_valid_element) else if(is_valid_element)
{ {
......
...@@ -14,6 +14,41 @@ namespace ck { ...@@ -14,6 +14,41 @@ namespace ck {
#define __gfx94__ #define __gfx94__
#endif #endif
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
// Nan check
if(x != x)
{
return uint16_t(0x7FC0);
}
union
{
float fp32;
uint32_t int32;
} u = {x};
const uint32_t first_bf16_mantisa_bit = ((u.int32 >> 16) & 1);
constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1);
return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
// Convert X to Y, both X and Y are non-const data types. // Convert X to Y, both X and Y are non-const data types.
template <typename Y, template <typename Y,
typename X, typename X,
...@@ -51,17 +86,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t ...@@ -51,17 +86,15 @@ inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t
return u.fp32; return u.fp32;
} }
// convert fp32 to bfp16 // convert fp32 to bfp16, round to nearest even
template <> template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x) inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{ {
union #if CK_USE_RNE_BF16_CONVERSION
{ return bf16_convert_rtn<bhalf_t>(x);
float fp32; #else
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16); return uint16_t(u.int32 >> 16);
#endif
} }
// convert bfp16 to fp16 via fp32 // convert bfp16 to fp16 via fp32
...@@ -615,60 +648,4 @@ inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array ...@@ -615,60 +648,4 @@ inline __host__ __device__ void array_convert(Array<Y, NumElems>& y, const Array
} }
} }
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
} // namespace ck } // namespace ck
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp" #include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/static_encoding_pattern.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
...@@ -53,8 +54,8 @@ ...@@ -53,8 +54,8 @@
#include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/transpose_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
namespace ck_tile {
/**
* @brief Enumeration describing static tile distribution patterns.
*
*/
enum struct tile_distribution_pattern
{
/**
* @brief Thread raked pattern.
*
*/
thread_raked,
/**
* @brief Warp raked pattern.
*
*/
warp_raked,
/**
* @brief Block raked pattern - aka linear.
*
*/
block_raked,
};
struct TileDistributionEncodingPattern
{
};
/**
* @brief Class creating 2D static tile distribution with different load/store patterns.
*
* @note We always assume that Tile is YPerTile x XPerTile where X dim (rightmost)
* is contiguous and we can do vector load on this dimension.
*
* @tparam BlockSize Number of threads in a workgroup.
* @tparam YPerTile The tile size of outer/leftmost dimension.
* @tparam XPerTile The tile size of inner/rightmost dimension (contiguous).
* @tparam VecSize The vector access size.
* @tparam DistributionPattern The enumeration describing used access pattern.
*/
template <index_t BlockSize,
index_t YPerTile,
index_t XPerTile,
index_t VecSize,
tile_distribution_pattern DistributionPattern>
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
{
};
// Thread raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::thread_raked>
: public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
// # of rows in Y dim accessed by single wavefront in one iteration
static constexpr index_t Y1 = warp_size / X0;
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
static constexpr index_t Y0 = num_warps;
// YPerWarp = YPerTile / Y0;
// Y2 = YPerWarp / Y1;
static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<1, 2>>{});
}
};
// Warp raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::warp_raked>
: public TileDistributionEncodingPattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
static constexpr index_t Y0 = num_warps;
static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
};
// Block raked
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
struct TileDistributionEncodingPattern2D<BlockSize,
YPerTile,
XPerTile,
VecSize,
tile_distribution_pattern::block_raked>
: public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
static constexpr index_t warp_size = get_warp_size();
static constexpr index_t num_warps = BlockSize / get_warp_size();
static constexpr index_t X1 = VecSize;
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
static constexpr index_t Y1 = num_warps;
static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 0>>{});
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -12,18 +12,37 @@ ...@@ -12,18 +12,37 @@
namespace ck_tile { namespace ck_tile {
enum struct address_space_enum template <typename, bool>
struct safe_underlying_type;
template <typename T>
struct safe_underlying_type<T, true>
{
using type = std::underlying_type_t<T>;
};
template <typename T>
struct safe_underlying_type<T, false>
{
using type = void;
};
template <typename T>
using safe_underlying_type_t = typename safe_underlying_type<T, std::is_enum<T>::value>::type;
enum struct address_space_enum : std::uint16_t
{ {
generic, generic = 0,
global, global,
lds, lds,
sgpr, sgpr,
vgpr, constant,
vgpr
}; };
enum struct memory_operation_enum enum struct memory_operation_enum : std::uint16_t
{ {
set, set = 0,
atomic_add, atomic_add,
atomic_max, atomic_max,
add add
...@@ -109,4 +128,30 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0) ...@@ -109,4 +128,30 @@ CK_TILE_DEVICE void s_nop(index_t cnt = 0)
#endif #endif
} }
#define CK_CONSTANT_ADDRESS_SPACE \
__attribute__((address_space( \
static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
template <typename T>
__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T*)(p); // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template <typename T>
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled;
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) defined(__gfx942__) || defined(__gfx950__)
#define __gfx9__ #define __gfx9__
#endif #endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
#define __gfx94__ #define __gfx94__
#endif #endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
...@@ -230,3 +230,15 @@ ...@@ -230,3 +230,15 @@
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1 #define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif #endif
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifdef CK_TILE_USE_OCP_FP8
#define CK_TILE_USE_OCP_FP8 1
#else
#define CK_TILE_USE_OCP_FP8 0
#endif
#elif defined(__gfx950__) || defined(__gfx12__) // for GPU code
#define CK_TILE_USE_OCP_FP8 1
#else // for GPU code
#define CK_TILE_USE_OCP_FP8 0
#endif
...@@ -546,7 +546,7 @@ CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t) ...@@ -546,7 +546,7 @@ CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple<Ts...>& t)
using Idx = number<tuple<Ts...>::size() - i - 1>; using Idx = number<tuple<Ts...>::size() - i - 1>;
return t.at(Idx{}); return t.at(Idx{});
}, },
number<tuple<Ts...>::size()()>{}); number<tuple<Ts...>::size()>{});
} }
// Reduce tuple values in specific range using Function // Reduce tuple values in specific range using Function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment