Commit 667047b9 authored by carlushuang's avatar carlushuang
Browse files

topk-softmax

parent 840cba8e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
struct fused_moe_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool has_dropout;
bool do_fp8_static_quant;
// TODO: padding check is inside this api
};
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
#include "topk_softmax_api.hpp"
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s)
{
if(t.input_type == "fp16" && t.weight_type == "fp32")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
constexpr ck_tile::index_t ts_experts = 8;
using ts_problem = ck_tile::
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>;
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>;
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
else if(t.input_type == "bf16" && t.weight_type == "fp32")
{
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
constexpr ck_tile::index_t ts_experts = 8;
using ts_problem = ck_tile::
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>;
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>;
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/topk_softmax.hpp"
#include <string>
struct topk_softmax_trait
{
std::string input_type;
std::string weight_type; // currently always float
int experts;
};
struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs
{
};
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s);
......@@ -120,4 +120,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
}
template <typename T>
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
{
#if 0
return __shfl(v_local, src_lane);
#elif 1
if constexpr(sizeof(int32_t) > sizeof(T))
{
union packet
{
int32_t x;
T v;
};
packet p;
p.v = v_local;
packet p_remote;
p_remote.x = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(p));
return p_remote.v;
}
else if constexpr(sizeof(int32_t) == sizeof(T))
{
const int32_t v_remote_tmp =
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
}
else
{
static_assert(sizeof(T) % sizeof(int32_t) == 0, "wrong!");
constexpr index_t elm = sizeof(T) / sizeof(int32_t);
using vector_type = thread_buffer<int32_t, elm>;
auto vs = bit_cast<vector_type>(v_local);
auto vs_remote = vector_type{};
static_for<0, elm, 1>{}([&](auto i_e) {
int32_t tmp = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(vs[i_e]));
vs_remote(i_e) = tmp;
});
return bit_cast<T>(vs_remote);
}
#endif
}
} // namespace ck_tile
......@@ -1406,7 +1406,8 @@ CK_TILE_DEVICE T rcp(T x)
#if !CK_TILE_WORKAROUND_SWDEV_383542
return __frcp_rn(x);
#else
return __ocml_native_recip_f32(x);
// return __ocml_native_recip_f32(x);
return __builtin_amdgcn_rcpf(x);
#endif
};
......
......@@ -9,43 +9,81 @@
namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType>
CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n,
HostTensor<BDataType>& b_m_n)
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST void
reference_softmax(const HostTensor<InputType>& x, HostTensor<OutputType>& y, index_t dim = -1)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];
index_t rank = x.get_num_of_dimension();
assert(rank == y.get_num_of_dimension());
assert(dim == -1 || dim < rank);
AccDataType v_max = ck_tile::numeric<ADataType>::Lowest();
index_t target_dim = dim == -1 ? (rank - 1) : dim;
index_t softmax_len = x.get_length(target_dim);
index_t n_parallel = x.get_element_size() / softmax_len;
auto x_len = x.get_lengths();
// max
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_m_n(m, n);
auto f = [&](auto i_element) {
std::vector<size_t> coord = [&]() {
std::vector<size_t> t_(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--)
{
if(i == target_dim)
continue;
t_[i] = r % x_len[i];
r = r / x_len[i];
}
return t_;
}();
ComputeType v_max = -ck_tile::numeric<ComputeType>::infinity();
v_max = v_max < v_a ? v_a : v_max;
// compute max
for(auto idx = 0; idx < softmax_len; idx++)
{
auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_max = v_max < v_x ? v_x : v_max;
}
AccDataType v_exp_sum = 0;
ComputeType v_exp_sum = static_cast<ComputeType>(0);
// sum
for(int n = 0; n < N; ++n)
for(auto idx = 0; idx < softmax_len; idx++)
{
const ADataType v_a = a_m_n(m, n);
auto c_ = coord;
c_[target_dim] = idx;
v_exp_sum += ck_tile::exp(v_a - v_max);
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_exp_sum += ck_tile::exp(v_x - v_max);
}
// elementwise
for(int n = 0; n < N; ++n)
for(auto idx = 0; idx < softmax_len; idx++)
{
const ADataType v_a = a_m_n(m, n);
auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
auto out = ck_tile::exp(v_x - v_max) / v_exp_sum;
b_m_n(m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum;
y(c_) = ck_tile::type_convert<OutputType>(out);
}
};
make_ParallelTensorFunctor(f,
b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
}
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST auto reference_softmax(const HostTensor<InputType>& x, index_t dim = -1)
{
HostTensor<OutputType> y(x.get_lengths(), x.get_strides());
reference_softmax<InputType, ComputeType, OutputType>(x, y, dim);
return y;
}
} // namespace ck_tile
......@@ -100,4 +100,25 @@ CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
}
// TODO: if using this method, the return tensor would be dense(no stride)
template <typename DataType, typename IndexType = index_t>
CK_TILE_HOST auto reference_topk(const HostTensor<DataType>& x,
index_t k,
index_t dim = -1,
bool largest = true,
bool sorted = true)
{
auto lens = x.get_lengths();
index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim;
assert(target_dim < lens.size());
assert(k <= lens[target_dim]);
lens[target_dim] = k;
HostTensor<DataType> y_values(lens);
HostTensor<IndexType> y_indices(lens);
reference_topk<DataType, IndexType>(x, y_values, y_indices, k, dim, largest, sorted);
return ck_tile::make_tuple(y_values, y_indices);
}
} // namespace ck_tile
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include <type_traits>
namespace ck_tile {
namespace element_wise {
......@@ -258,10 +259,10 @@ struct ConvertBF16RTN
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(ck_tile::is_same<Y, ck_tile::bf16_t>::value, "Data type is not supported by this operation!");
static_assert(std::is_same_v<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
// check X datatype
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, ck_tile::fp16_t>::value,
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
......@@ -275,11 +276,11 @@ struct ConvertF8SR
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(ck_tile::is_same<Y, ck_tile::fp8_t>::value || ck_tile::is_same<Y, ck_tile::bf8_t>::value,
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, ck_tile::fp16_t>::value,
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
......@@ -293,11 +294,11 @@ struct ConvertF8RNE
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(ck_tile::is_same<Y, ck_tile::fp8_t>::value || ck_tile::is_same<Y, ck_tile::bf8_t>::value,
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, ck_tile::fp16_t>::value,
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
......@@ -362,7 +363,7 @@ struct ScaleAndResetNaNToMinusInfinity
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = ck_tile::isnan(x) ? -ck_tile::NumericLimits<float>::Infinity() : scale_ * x;
y = ck_tile::isnan(x) ? -numeric<float>::infinity() : scale_ * x;
};
float scale_;
......@@ -375,8 +376,8 @@ struct UnaryDivide
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = x / type_convert<T>(divider_);
......@@ -390,10 +391,11 @@ struct UnarySquare
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(is_same_v<T, float> || is_same_v<T, ck_tile::fp16_t> ||
is_same_v<T, double> || is_same_v<T, int32_t> || is_same_v<T, int8_t>
static_assert(std::is_same_v<T, float> || std::is_same_v<T, ck_tile::fp16_t> ||
std::is_same_v<T, double> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| is_same_v<T, int4_t>
|| std::is_same_v<T, int4_t>
#endif
,
"Data type is not supported by this operation!");
......@@ -406,9 +408,9 @@ struct UnaryAbs
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
y = ck_tile::abs(x);
......@@ -420,7 +422,7 @@ struct UnarySqrt
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"Data type is not supported by this operation!");
y = ck_tile::sqrt(x);
......@@ -432,9 +434,9 @@ struct Relu
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0;
}
......@@ -597,9 +599,9 @@ struct Sigmoid
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = one / (one + ck_tile::exp(-x));
......@@ -611,9 +613,9 @@ struct Silu
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(is_same_v<T, float> || is_same_v<T, double> ||
is_same_v<T, ck_tile::fp16_t> || is_same_v<T, int8_t> ||
is_same_v<T, int32_t>,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x)));
......@@ -625,9 +627,9 @@ struct TanH
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::tanh(x);
......@@ -639,9 +641,9 @@ struct ACos
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::acos(x);
......@@ -653,9 +655,9 @@ struct Neg
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::neg(x);
......@@ -667,9 +669,9 @@ struct ATan
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::atan(x);
......@@ -681,9 +683,9 @@ struct Sin
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::sin(x);
......@@ -695,9 +697,9 @@ struct ASinH
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::asinh(x);
......@@ -709,9 +711,9 @@ struct Cos
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::cos(x);
......@@ -723,9 +725,9 @@ struct ACosH
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::acosh(x);
......@@ -737,9 +739,9 @@ struct Tan
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::tan(x);
......@@ -751,9 +753,9 @@ struct ATanH
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::atanh(x);
......@@ -765,9 +767,9 @@ struct SinH
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::sinh(x);
......@@ -779,9 +781,9 @@ struct Ceil
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::ceil(x);
......@@ -793,9 +795,9 @@ struct Exp
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::exp(x);
......@@ -807,9 +809,9 @@ struct CosH
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::cosh(x);
......@@ -821,9 +823,9 @@ struct Floor
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::floor(x);
......@@ -835,9 +837,9 @@ struct Log
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::log(x);
......@@ -849,9 +851,9 @@ struct ASin
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::asin(x);
......@@ -863,9 +865,9 @@ struct Rcp
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::rcp(x);
......@@ -879,12 +881,12 @@ struct Swish
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, double>::value ||
ck_tile::is_same<X, ck_tile::fp16_t>::value,
static_assert(std::is_same_v<X, float> || std::is_same_v<X, double> ||
std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
static_assert(ck_tile::is_same<Y, float>::value || ck_tile::is_same<Y, double>::value ||
ck_tile::is_same<Y, ck_tile::fp16_t>::value,
static_assert(std::is_same_v<Y, float> || std::is_same_v<Y, double> ||
std::is_same_v<Y, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x);
......@@ -901,9 +903,9 @@ struct SoftRelu
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
......@@ -920,9 +922,9 @@ struct Power
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
......@@ -942,9 +944,9 @@ struct ClippedRelu
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
......@@ -961,9 +963,9 @@ struct LeakyRelu
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
......@@ -978,9 +980,9 @@ struct Elu
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck_tile::expm1(x);
......@@ -995,9 +997,9 @@ struct Logistic
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value ||
ck_tile::is_same<T, ck_tile::fp16_t>::value ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value,
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
......@@ -1078,7 +1080,7 @@ struct ConvScaleRelu
};
// support fastconvert of int8 to fp16
#if 0
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
struct FastNumericArrayConverter
{
......@@ -1146,6 +1148,6 @@ struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
#endif
} // namespace element_wise
} // namespace ck_tile
......@@ -7,6 +7,10 @@
namespace ck_tile {
/*
* TODO: block_tile_reduce_sync() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
......@@ -55,7 +59,17 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
// pull data from remote lane
const auto v_remote = warp_shuffle_down(v_local, lid_delta);
#if 0
if constexpr(Verbose_)
{
printf("warp_shuffle_down : %d - %d, %d (%.3f, %.3f)\n",
static_cast<int>(threadIdx.x),
static_cast<int>(lid_over_rid_derivative),
static_cast<int>(lid_delta),
v_local,
v_remote);
}
#endif
// reduce
v_local = reduce_func(v_local, v_remote);
});
......@@ -104,6 +118,76 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
});
}
/*
* this version is faster, using xor to do reduce, no need broadcast anymore
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
*/
template <typename AccDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func)
{
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
// loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local = acc_tensor.get_thread_buffer()[i];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
// TODO: lid_over_rid_derivative not ok in xor? maybe need limit the usage of
// xor
index_t src_lane = (__lane_id() * lid_over_rid_derivative) ^
(number<1 << istage.value>{}.value);
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
#if 0
if constexpr(Verbose_)
{
printf("block_tile_reduce_xor_sync : %d - %d, %d (%.3f, %.3f)\n",
static_cast<int>(threadIdx.x),
static_cast<int>(istage),
static_cast<int>(src_lane),
v_local,
v_remote);
}
#endif
// reduce
v_local = reduce_func(v_local, v_remote);
});
}
});
acc_tensor.get_thread_buffer()(i) = v_local;
});
}
// FIXME: this is for 2D to 1D reduce only, need to support n-D
template <typename AccDistributedTensor_,
typename InDistributedTensor_,
......@@ -175,6 +259,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
#endif
}
/*
* TODO: block_tile_reduce() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
template <typename AccDataType_,
typename InDistributedTensor_,
index_t... InReduceDims,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
namespace ck_tile {
/*
simple 2d softmax implementation, along row (dim=1)
requirement:
1). each row is within a warp
2). data type must be a dword
*/
template <typename Problem_, typename Policy_ = void>
struct BlockSoftmax2D
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using DataType = typename Problem::DataType;
template <typename DistributedTensor, index_t dim = 1>
CK_TILE_DEVICE void
operator()(const DistributedTensor& x, DistributedTensor& y, number<dim> = {})
{
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// compute row max
auto row_max =
block_tile_reduce<DataType>(x, sequence<dim>{}, f_max, -numeric<DataType>::infinity());
block_tile_reduce_xor_sync(row_max, f_max);
// compute elementwise softmax
constexpr auto span_2d = DistributedTensor::get_distributed_spans();
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
y(i_j_idx) = exp(x[i_j_idx] - row_max(i_idx));
});
});
// compute row sum
auto row_sum = block_tile_reduce<DataType>(y, sequence<dim>{}, f_sum, DataType{0});
block_tile_reduce_xor_sync(row_sum, f_sum);
// reciprocal
auto r = make_static_distributed_tensor<DataType>(row_sum.get_tile_distribution());
constexpr auto span_1d = decltype(r)::get_distributed_spans();
sweep_tile_span(span_1d[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
r(i_idx) = DataType{1} / row_sum(i_idx);
});
// scale
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
y(i_j_idx) = y(i_j_idx) * r(i_idx);
});
});
}
template <typename DistributedTensor, index_t dim = 1>
CK_TILE_DEVICE decltype(auto) operator()(const DistributedTensor& x, number<dim> = {})
{
auto y = DistributedTensor{}; // distributed tensor
operator()(x, y, number<dim>{});
return y;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename DataType_>
struct BlockSoftmax2DProblem
{
using DataType = remove_cvref_t<DataType_>;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template <typename Problem_, typename Policy_ = void>
struct BlockTopkStream2D
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using DataType = typename Problem::DataType;
using IndexType = typename Problem::IndexType;
// TODO: if DataType is subdword, need pack into single dword to use argmax
struct ArgmaxPacket
{
DataType arg;
index_t value;
};
template <typename DistributedTensor, typename OutWindow, typename IdxWindow, index_t dim = 1>
CK_TILE_DEVICE void operator()(const DistributedTensor& x,
OutWindow& out_window,
IdxWindow& idx_window,
index_t k,
number<dim> = {})
{
// static_assert(OutWindow::get_window_lengths()[number<1>] == 1);
static_assert(
std::is_same_v<typename DistributedTensor::DataType, typename OutWindow::DataType> &&
std::is_same_v<typename DistributedTensor::DataType, DataType>);
static_assert(std::is_same_v<typename IdxWindow::DataType, IndexType>);
DistributedTensor x_tmp = x;
constexpr auto dst_dist = typename IdxWindow::TileDstr{};
// argmax for topk
const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
return e0.arg > e1.arg ? e0 : e1;
};
for(index_t i_k = 0; i_k < k; i_k++)
{
constexpr auto span_2d = DistributedTensor::get_distributed_spans();
auto packet = [&]() {
auto tmp = make_static_distributed_tensor<ArgmaxPacket>(x.get_tile_distribution());
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
tmp.get_tile_distribution(), make_tuple(idx0, idx1));
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ArgmaxPacket t;
t.arg = x_tmp(i_j_idx); // !!! we reference x here
t.value = tile_idx.at(number<1>{});
tmp(i_j_idx) = t;
});
});
return tmp;
}();
auto argmax_init = ArgmaxPacket{-numeric<DataType>::infinity(), 0};
auto r = block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
block_tile_reduce_xor_sync(r, f_argmax);
auto o = make_static_distributed_tensor<DataType>(dst_dist);
auto i = make_static_distributed_tensor<IndexType>(dst_dist);
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ArgmaxPacket tmp = r(i_j_idx);
o(i_j_idx) = tmp.arg;
i(i_j_idx) = tmp.value;
});
});
// update value
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
x.get_tile_distribution(), make_tuple(idx0, idx1));
auto col_id = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
x_tmp(i_j_idx) = (col_id == r(i_j_idx).value) ? -numeric<DataType>::infinity()
: x_tmp(i_j_idx);
});
});
if(threadIdx.x % Problem::ColLanes == 0)
{
store_tile(out_window, o);
store_tile(idx_window, i);
}
move_tile_window(out_window, {number<0>{}, number<1>{}});
move_tile_window(idx_window, {number<0>{}, number<1>{}});
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template <typename DataType_, typename IndexType_, index_t ColLanes_>
struct BlockTopkStream2DProblem
{
using DataType = remove_cvref_t<DataType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t ColLanes = ColLanes_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
struct TopkSoftmaxHostArgs
{
const void* p_input;
void* p_output;
void* p_indices;
index_t num_rows;
index_t num_experts;
index_t topk;
};
template <typename Pipeline_>
struct TopkSoftmaxKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using InputType = typename Problem::InputType;
using WeightType = typename Problem::WeightType;
using IndexType = typename Problem::IndexType;
struct TopkSoftmaxKargs
{
const void* p_input;
void* p_output;
void* p_indices;
index_t num_rows;
index_t num_experts;
index_t topk;
};
using Kargs = TopkSoftmaxKargs;
using Hargs = TopkSoftmaxHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp;
const int num_blocks = (num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock;
return dim3(num_blocks);
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_input = h.p_input;
k.p_output = h.p_output;
k.p_indices = h.p_indices;
k.num_rows = h.num_rows;
k.num_experts = h.num_experts;
k.topk = h.topk;
return k;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock);
const auto input_window = [&]() {
const InputType* p_input = reinterpret_cast<const InputType*>(kargs.p_input) +
blockIdx.x * Problem::RowsPerBlock * kargs.num_experts;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_input,
make_tuple(kargs.num_rows, kargs.num_experts),
number<Problem::VectorSize>{});
auto view = pad_tensor_view(
tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
sequence<1, 1>{});
return make_tile_window(
view,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
{block_row_id, 0});
}();
auto output_window = [&]() {
WeightType* p_output = reinterpret_cast<WeightType*>(kargs.p_output) +
blockIdx.x * Problem::RowsPerBlock * kargs.topk;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_output, make_tuple(kargs.num_rows, kargs.topk), number<Problem::VectorSize>{});
auto view = pad_tensor_view(
tmp, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), sequence<1, 0>{});
return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {block_row_id, 0});
}();
auto indices_window = [&]() {
IndexType* p_indices = reinterpret_cast<IndexType*>(kargs.p_indices) +
blockIdx.x * Problem::RowsPerBlock * kargs.topk;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_indices, make_tuple(kargs.num_rows, kargs.topk), number<Problem::VectorSize>{});
auto view = pad_tensor_view(
tmp, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), sequence<1, 0>{});
return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {block_row_id, 0});
}();
Pipeline{}(input_window, output_window, indices_window, kargs.topk);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
struct TopkSoftmaxWarpPerRowPipeline
{
// TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
template <typename InputWindow, typename OutputWindow, typename IndexWindow>
CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
OutputWindow& out_window,
IndexWindow& idx_window,
index_t k)
{
auto input_win = make_tile_window(input_window.get_bottom_tensor_view(),
input_window.get_window_lengths(),
input_window.get_window_origin(),
Policy::template MakeInputDistribution<Problem>());
auto x = load_tile(input_win);
auto w = cast_tile<typename Problem::WeightType>(x);
auto softmax = Policy::template GetSoftmax<Problem>();
// softmax
auto y = softmax(w);
auto topk = Policy::template GetTopk<Problem>();
auto out_win = make_tile_window(out_window.get_bottom_tensor_view(),
out_window.get_window_lengths(),
out_window.get_window_origin(),
Policy::template MakeOutputDistribution<Problem>());
auto idx_win = make_tile_window(idx_window.get_bottom_tensor_view(),
idx_window.get_window_lengths(),
idx_window.get_window_origin(),
Policy::template MakeOutputDistribution<Problem>());
topk(y, out_win, idx_win, k);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace ck_tile {
struct TopkSoftmaxWarpPerRowPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{
// TODO: Y dim must have one dim that is not reduced
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<Problem::IssuesPerCol, Problem::WarpsPerBlock, Problem::RowsPerWarp>,
sequence<Problem::IssuesPerRow, Problem::LanesPerRow, Problem::VectorSize>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<Problem::LanesPerRow>, // repeat this one
tuple<sequence<Problem::WarpsPerBlock, Problem::RowsPerWarp>,
sequence<1>>, // each row write out single element
tuple<sequence<1>, sequence<1, 0>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2>,
sequence<0>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSoftmax()
{
using softmax_problem = BlockSoftmax2DProblem<typename Problem::WeightType>;
return BlockSoftmax2D<softmax_problem>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTopk()
{
using topk_problem = BlockTopkStream2DProblem<typename Problem::WeightType,
typename Problem::IndexType,
Problem::LanesPerRow>;
// Note: replicate is LanesPerRow
return BlockTopkStream2D<topk_problem>{};
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment