Commit 667047b9 authored by carlushuang's avatar carlushuang
Browse files

topk-softmax

parent 840cba8e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
struct fused_moe_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse;
bool has_dropout;
bool do_fp8_static_quant;
// TODO: padding check is inside this api
};
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
#include "topk_softmax_api.hpp"
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s)
{
if(t.input_type == "fp16" && t.weight_type == "fp32")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
constexpr ck_tile::index_t ts_experts = 8;
using ts_problem = ck_tile::
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>;
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>;
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
else if(t.input_type == "bf16" && t.weight_type == "fp32")
{
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
constexpr ck_tile::index_t ts_experts = 8;
using ts_problem = ck_tile::
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>;
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>;
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/topk_softmax.hpp"
#include <string>
struct topk_softmax_trait
{
std::string input_type;
std::string weight_type; // currently always float
int experts;
};
struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs
{
};
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s);
...@@ -120,4 +120,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta) ...@@ -120,4 +120,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif #endif
} }
template <typename T>
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
{
#if 0
return __shfl(v_local, src_lane);
#elif 1
if constexpr(sizeof(int32_t) > sizeof(T))
{
union packet
{
int32_t x;
T v;
};
packet p;
p.v = v_local;
packet p_remote;
p_remote.x = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(p));
return p_remote.v;
}
else if constexpr(sizeof(int32_t) == sizeof(T))
{
const int32_t v_remote_tmp =
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
return bit_cast<T>(v_remote_tmp);
}
else
{
static_assert(sizeof(T) % sizeof(int32_t) == 0, "wrong!");
constexpr index_t elm = sizeof(T) / sizeof(int32_t);
using vector_type = thread_buffer<int32_t, elm>;
auto vs = bit_cast<vector_type>(v_local);
auto vs_remote = vector_type{};
static_for<0, elm, 1>{}([&](auto i_e) {
int32_t tmp = __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(vs[i_e]));
vs_remote(i_e) = tmp;
});
return bit_cast<T>(vs_remote);
}
#endif
}
} // namespace ck_tile } // namespace ck_tile
...@@ -1406,7 +1406,8 @@ CK_TILE_DEVICE T rcp(T x) ...@@ -1406,7 +1406,8 @@ CK_TILE_DEVICE T rcp(T x)
#if !CK_TILE_WORKAROUND_SWDEV_383542 #if !CK_TILE_WORKAROUND_SWDEV_383542
return __frcp_rn(x); return __frcp_rn(x);
#else #else
return __ocml_native_recip_f32(x); // return __ocml_native_recip_f32(x);
return __builtin_amdgcn_rcpf(x);
#endif #endif
}; };
......
...@@ -9,43 +9,81 @@ ...@@ -9,43 +9,81 @@
namespace ck_tile { namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType> template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n, CK_TILE_HOST void
HostTensor<BDataType>& b_m_n) reference_softmax(const HostTensor<InputType>& x, HostTensor<OutputType>& y, index_t dim = -1)
{ {
auto f = [&](auto m) { index_t rank = x.get_num_of_dimension();
const int N = a_m_n.mDesc.get_lengths()[1]; assert(rank == y.get_num_of_dimension());
assert(dim == -1 || dim < rank);
AccDataType v_max = ck_tile::numeric<ADataType>::Lowest(); index_t target_dim = dim == -1 ? (rank - 1) : dim;
index_t softmax_len = x.get_length(target_dim);
index_t n_parallel = x.get_element_size() / softmax_len;
auto x_len = x.get_lengths();
// max auto f = [&](auto i_element) {
for(int n = 0; n < N; ++n) std::vector<size_t> coord = [&]() {
{ std::vector<size_t> t_(rank, 0);
const ADataType v_a = a_m_n(m, n); size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--)
{
if(i == target_dim)
continue;
t_[i] = r % x_len[i];
r = r / x_len[i];
}
return t_;
}();
ComputeType v_max = -ck_tile::numeric<ComputeType>::infinity();
v_max = v_max < v_a ? v_a : v_max; // compute max
for(auto idx = 0; idx < softmax_len; idx++)
{
auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_max = v_max < v_x ? v_x : v_max;
} }
AccDataType v_exp_sum = 0; ComputeType v_exp_sum = static_cast<ComputeType>(0);
// sum // sum
for(int n = 0; n < N; ++n) for(auto idx = 0; idx < softmax_len; idx++)
{ {
const ADataType v_a = a_m_n(m, n); auto c_ = coord;
c_[target_dim] = idx;
v_exp_sum += ck_tile::exp(v_a - v_max); const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_exp_sum += ck_tile::exp(v_x - v_max);
} }
// elementwise // elementwise
for(int n = 0; n < N; ++n) for(auto idx = 0; idx < softmax_len; idx++)
{ {
const ADataType v_a = a_m_n(m, n); auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
auto out = ck_tile::exp(v_x - v_max) / v_exp_sum;
b_m_n(m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum; y(c_) = ck_tile::type_convert<OutputType>(out);
} }
}; };
make_ParallelTensorFunctor(f, make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); }
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST auto reference_softmax(const HostTensor<InputType>& x, index_t dim = -1)
{
HostTensor<OutputType> y(x.get_lengths(), x.get_strides());
reference_softmax<InputType, ComputeType, OutputType>(x, y, dim);
return y;
} }
} // namespace ck_tile } // namespace ck_tile
...@@ -100,4 +100,25 @@ CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x, ...@@ -100,4 +100,25 @@ CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
} }
// TODO: if using this method, the return tensor would be dense(no stride)
template <typename DataType, typename IndexType = index_t>
CK_TILE_HOST auto reference_topk(const HostTensor<DataType>& x,
index_t k,
index_t dim = -1,
bool largest = true,
bool sorted = true)
{
auto lens = x.get_lengths();
index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim;
assert(target_dim < lens.size());
assert(k <= lens[target_dim]);
lens[target_dim] = k;
HostTensor<DataType> y_values(lens);
HostTensor<IndexType> y_indices(lens);
reference_topk<DataType, IndexType>(x, y_values, y_indices, k, dim, largest, sorted);
return ck_tile::make_tuple(y_values, y_indices);
}
} // namespace ck_tile } // namespace ck_tile
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include <type_traits>
namespace ck_tile { namespace ck_tile {
namespace element_wise { namespace element_wise {
...@@ -258,10 +259,10 @@ struct ConvertBF16RTN ...@@ -258,10 +259,10 @@ struct ConvertBF16RTN
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{ {
// check Y datatype // check Y datatype
static_assert(ck_tile::is_same<Y, ck_tile::bf16_t>::value, "Data type is not supported by this operation!"); static_assert(std::is_same_v<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
// check X datatype // check X datatype
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, ck_tile::fp16_t>::value, static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x); y = bf16_convert_rtn<Y>(x);
...@@ -275,11 +276,11 @@ struct ConvertF8SR ...@@ -275,11 +276,11 @@ struct ConvertF8SR
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{ {
// check Y datatype // check Y datatype
static_assert(ck_tile::is_same<Y, ck_tile::fp8_t>::value || ck_tile::is_same<Y, ck_tile::bf8_t>::value, static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
// check X datatype // check X datatype
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, ck_tile::fp16_t>::value, static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x); y = f8_convert_sr<Y>(x);
...@@ -293,11 +294,11 @@ struct ConvertF8RNE ...@@ -293,11 +294,11 @@ struct ConvertF8RNE
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{ {
// check Y datatype // check Y datatype
static_assert(ck_tile::is_same<Y, ck_tile::fp8_t>::value || ck_tile::is_same<Y, ck_tile::bf8_t>::value, static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
// check X datatype // check X datatype
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, ck_tile::fp16_t>::value, static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x); y = f8_convert_rne<Y>(x);
...@@ -362,7 +363,7 @@ struct ScaleAndResetNaNToMinusInfinity ...@@ -362,7 +363,7 @@ struct ScaleAndResetNaNToMinusInfinity
template <> template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{ {
y = ck_tile::isnan(x) ? -ck_tile::NumericLimits<float>::Infinity() : scale_ * x; y = ck_tile::isnan(x) ? -numeric<float>::infinity() : scale_ * x;
}; };
float scale_; float scale_;
...@@ -375,8 +376,8 @@ struct UnaryDivide ...@@ -375,8 +376,8 @@ struct UnaryDivide
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = x / type_convert<T>(divider_); y = x / type_convert<T>(divider_);
...@@ -390,10 +391,11 @@ struct UnarySquare ...@@ -390,10 +391,11 @@ struct UnarySquare
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(is_same_v<T, float> || is_same_v<T, ck_tile::fp16_t> || static_assert(std::is_same_v<T, float> || std::is_same_v<T, ck_tile::fp16_t> ||
is_same_v<T, double> || is_same_v<T, int32_t> || is_same_v<T, int8_t> std::is_same_v<T, double> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| is_same_v<T, int4_t> || std::is_same_v<T, int4_t>
#endif #endif
, ,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
...@@ -406,9 +408,9 @@ struct UnaryAbs ...@@ -406,9 +408,9 @@ struct UnaryAbs
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value, std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::abs(x); y = ck_tile::abs(x);
...@@ -420,7 +422,7 @@ struct UnarySqrt ...@@ -420,7 +422,7 @@ struct UnarySqrt
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value, static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::sqrt(x); y = ck_tile::sqrt(x);
...@@ -432,9 +434,9 @@ struct Relu ...@@ -432,9 +434,9 @@ struct Relu
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value, std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = x > 0 ? x : 0; y = x > 0 ? x : 0;
} }
...@@ -597,9 +599,9 @@ struct Sigmoid ...@@ -597,9 +599,9 @@ struct Sigmoid
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
y = one / (one + ck_tile::exp(-x)); y = one / (one + ck_tile::exp(-x));
...@@ -611,9 +613,9 @@ struct Silu ...@@ -611,9 +613,9 @@ struct Silu
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(is_same_v<T, float> || is_same_v<T, double> || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
is_same_v<T, ck_tile::fp16_t> || is_same_v<T, int8_t> || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
is_same_v<T, int32_t>, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x))); y = x * (one / (one + ck_tile::exp(-x)));
...@@ -625,9 +627,9 @@ struct TanH ...@@ -625,9 +627,9 @@ struct TanH
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::tanh(x); y = ck_tile::tanh(x);
...@@ -639,9 +641,9 @@ struct ACos ...@@ -639,9 +641,9 @@ struct ACos
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::acos(x); y = ck_tile::acos(x);
...@@ -653,9 +655,9 @@ struct Neg ...@@ -653,9 +655,9 @@ struct Neg
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::neg(x); y = ck_tile::neg(x);
...@@ -667,9 +669,9 @@ struct ATan ...@@ -667,9 +669,9 @@ struct ATan
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::atan(x); y = ck_tile::atan(x);
...@@ -681,9 +683,9 @@ struct Sin ...@@ -681,9 +683,9 @@ struct Sin
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::sin(x); y = ck_tile::sin(x);
...@@ -695,9 +697,9 @@ struct ASinH ...@@ -695,9 +697,9 @@ struct ASinH
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::asinh(x); y = ck_tile::asinh(x);
...@@ -709,9 +711,9 @@ struct Cos ...@@ -709,9 +711,9 @@ struct Cos
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::cos(x); y = ck_tile::cos(x);
...@@ -723,9 +725,9 @@ struct ACosH ...@@ -723,9 +725,9 @@ struct ACosH
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::acosh(x); y = ck_tile::acosh(x);
...@@ -737,9 +739,9 @@ struct Tan ...@@ -737,9 +739,9 @@ struct Tan
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::tan(x); y = ck_tile::tan(x);
...@@ -751,9 +753,9 @@ struct ATanH ...@@ -751,9 +753,9 @@ struct ATanH
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::atanh(x); y = ck_tile::atanh(x);
...@@ -765,9 +767,9 @@ struct SinH ...@@ -765,9 +767,9 @@ struct SinH
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::sinh(x); y = ck_tile::sinh(x);
...@@ -779,9 +781,9 @@ struct Ceil ...@@ -779,9 +781,9 @@ struct Ceil
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::ceil(x); y = ck_tile::ceil(x);
...@@ -793,9 +795,9 @@ struct Exp ...@@ -793,9 +795,9 @@ struct Exp
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::exp(x); y = ck_tile::exp(x);
...@@ -807,9 +809,9 @@ struct CosH ...@@ -807,9 +809,9 @@ struct CosH
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::cosh(x); y = ck_tile::cosh(x);
...@@ -821,9 +823,9 @@ struct Floor ...@@ -821,9 +823,9 @@ struct Floor
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::floor(x); y = ck_tile::floor(x);
...@@ -835,9 +837,9 @@ struct Log ...@@ -835,9 +837,9 @@ struct Log
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::log(x); y = ck_tile::log(x);
...@@ -849,9 +851,9 @@ struct ASin ...@@ -849,9 +851,9 @@ struct ASin
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::asin(x); y = ck_tile::asin(x);
...@@ -863,9 +865,9 @@ struct Rcp ...@@ -863,9 +865,9 @@ struct Rcp
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
ck_tile::is_same<T, int8_t>::value || ck_tile::is_same<T, int32_t>::value, std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = ck_tile::rcp(x); y = ck_tile::rcp(x);
...@@ -879,12 +881,12 @@ struct Swish ...@@ -879,12 +881,12 @@ struct Swish
template <typename Y, typename X> template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{ {
static_assert(ck_tile::is_same<X, float>::value || ck_tile::is_same<X, double>::value || static_assert(std::is_same_v<X, float> || std::is_same_v<X, double> ||
ck_tile::is_same<X, ck_tile::fp16_t>::value, std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
static_assert(ck_tile::is_same<Y, float>::value || ck_tile::is_same<Y, double>::value || static_assert(std::is_same_v<Y, float> || std::is_same_v<Y, double> ||
ck_tile::is_same<Y, ck_tile::fp16_t>::value, std::is_same_v<Y, ck_tile::fp16_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x); float bx = -beta_ * type_convert<float>(x);
...@@ -901,9 +903,9 @@ struct SoftRelu ...@@ -901,9 +903,9 @@ struct SoftRelu
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value, std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
...@@ -920,9 +922,9 @@ struct Power ...@@ -920,9 +922,9 @@ struct Power
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value, std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_); T casted_beta = type_convert<T>(beta_);
...@@ -942,9 +944,9 @@ struct ClippedRelu ...@@ -942,9 +944,9 @@ struct ClippedRelu
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value, std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_); T casted_beta = type_convert<T>(beta_);
...@@ -961,9 +963,9 @@ struct LeakyRelu ...@@ -961,9 +963,9 @@ struct LeakyRelu
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value, std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha; y = x >= 0 ? x : x * casted_alpha;
...@@ -978,9 +980,9 @@ struct Elu ...@@ -978,9 +980,9 @@ struct Elu
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value, std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck_tile::expm1(x); y = x > 0 ? x : casted_alpha * ck_tile::expm1(x);
...@@ -995,9 +997,9 @@ struct Logistic ...@@ -995,9 +997,9 @@ struct Logistic
template <typename T> template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{ {
static_assert(ck_tile::is_same<T, float>::value || ck_tile::is_same<T, double>::value || static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
ck_tile::is_same<T, ck_tile::fp16_t>::value || std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
ck_tile::is_same<T, int32_t>::value || ck_tile::is_same<T, int8_t>::value, std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_); T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1); constexpr T one = type_convert<T>(1);
...@@ -1078,7 +1080,7 @@ struct ConvScaleRelu ...@@ -1078,7 +1080,7 @@ struct ConvScaleRelu
}; };
// support fastconvert of int8 to fp16 // support fastconvert of int8 to fp16
#if 0
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber> template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
struct FastNumericArrayConverter struct FastNumericArrayConverter
{ {
...@@ -1146,6 +1148,6 @@ struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N> ...@@ -1146,6 +1148,6 @@ struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); } CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
}; };
#endif
} // namespace element_wise } // namespace element_wise
} // namespace ck_tile } // namespace ck_tile
...@@ -7,6 +7,10 @@ ...@@ -7,6 +7,10 @@
namespace ck_tile { namespace ck_tile {
/*
* TODO: block_tile_reduce_sync() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension) // synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true> template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
...@@ -55,7 +59,17 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, ...@@ -55,7 +59,17 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
// pull data from remote lane // pull data from remote lane
const auto v_remote = warp_shuffle_down(v_local, lid_delta); const auto v_remote = warp_shuffle_down(v_local, lid_delta);
#if 0
if constexpr(Verbose_)
{
printf("warp_shuffle_down : %d - %d, %d (%.3f, %.3f)\n",
static_cast<int>(threadIdx.x),
static_cast<int>(lid_over_rid_derivative),
static_cast<int>(lid_delta),
v_local,
v_remote);
}
#endif
// reduce // reduce
v_local = reduce_func(v_local, v_remote); v_local = reduce_func(v_local, v_remote);
}); });
...@@ -104,6 +118,76 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, ...@@ -104,6 +118,76 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
}); });
} }
/*
* this version is faster, using xor to do reduce, no need broadcast anymore
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
*/
template <typename AccDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func)
{
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
// loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local = acc_tensor.get_thread_buffer()[i];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
// TODO: lid_over_rid_derivative not ok in xor? maybe need limit the usage of
// xor
index_t src_lane = (__lane_id() * lid_over_rid_derivative) ^
(number<1 << istage.value>{}.value);
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
#if 0
if constexpr(Verbose_)
{
printf("block_tile_reduce_xor_sync : %d - %d, %d (%.3f, %.3f)\n",
static_cast<int>(threadIdx.x),
static_cast<int>(istage),
static_cast<int>(src_lane),
v_local,
v_remote);
}
#endif
// reduce
v_local = reduce_func(v_local, v_remote);
});
}
});
acc_tensor.get_thread_buffer()(i) = v_local;
});
}
// FIXME: this is for 2D to 1D reduce only, need to support n-D // FIXME: this is for 2D to 1D reduce only, need to support n-D
template <typename AccDistributedTensor_, template <typename AccDistributedTensor_,
typename InDistributedTensor_, typename InDistributedTensor_,
...@@ -175,6 +259,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor, ...@@ -175,6 +259,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
#endif #endif
} }
/*
* TODO: block_tile_reduce() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
template <typename AccDataType_, template <typename AccDataType_,
typename InDistributedTensor_, typename InDistributedTensor_,
index_t... InReduceDims, index_t... InReduceDims,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
namespace ck_tile {
/*
simple 2d softmax implementation, along row (dim=1)
requirement:
1). each row is within a warp
2). data type must be a dword
*/
template <typename Problem_, typename Policy_ = void>
struct BlockSoftmax2D
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using DataType = typename Problem::DataType;
template <typename DistributedTensor, index_t dim = 1>
CK_TILE_DEVICE void
operator()(const DistributedTensor& x, DistributedTensor& y, number<dim> = {})
{
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// compute row max
auto row_max =
block_tile_reduce<DataType>(x, sequence<dim>{}, f_max, -numeric<DataType>::infinity());
block_tile_reduce_xor_sync(row_max, f_max);
// compute elementwise softmax
constexpr auto span_2d = DistributedTensor::get_distributed_spans();
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
y(i_j_idx) = exp(x[i_j_idx] - row_max(i_idx));
});
});
// compute row sum
auto row_sum = block_tile_reduce<DataType>(y, sequence<dim>{}, f_sum, DataType{0});
block_tile_reduce_xor_sync(row_sum, f_sum);
// reciprocal
auto r = make_static_distributed_tensor<DataType>(row_sum.get_tile_distribution());
constexpr auto span_1d = decltype(r)::get_distributed_spans();
sweep_tile_span(span_1d[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
r(i_idx) = DataType{1} / row_sum(i_idx);
});
// scale
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
y(i_j_idx) = y(i_j_idx) * r(i_idx);
});
});
}
template <typename DistributedTensor, index_t dim = 1>
CK_TILE_DEVICE decltype(auto) operator()(const DistributedTensor& x, number<dim> = {})
{
auto y = DistributedTensor{}; // distributed tensor
operator()(x, y, number<dim>{});
return y;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename DataType_>
struct BlockSoftmax2DProblem
{
using DataType = remove_cvref_t<DataType_>;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template <typename Problem_, typename Policy_ = void>
struct BlockTopkStream2D
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using DataType = typename Problem::DataType;
using IndexType = typename Problem::IndexType;
// TODO: if DataType is subdword, need pack into single dword to use argmax
struct ArgmaxPacket
{
DataType arg;
index_t value;
};
template <typename DistributedTensor, typename OutWindow, typename IdxWindow, index_t dim = 1>
CK_TILE_DEVICE void operator()(const DistributedTensor& x,
OutWindow& out_window,
IdxWindow& idx_window,
index_t k,
number<dim> = {})
{
// static_assert(OutWindow::get_window_lengths()[number<1>] == 1);
static_assert(
std::is_same_v<typename DistributedTensor::DataType, typename OutWindow::DataType> &&
std::is_same_v<typename DistributedTensor::DataType, DataType>);
static_assert(std::is_same_v<typename IdxWindow::DataType, IndexType>);
DistributedTensor x_tmp = x;
constexpr auto dst_dist = typename IdxWindow::TileDstr{};
// argmax for topk
const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
return e0.arg > e1.arg ? e0 : e1;
};
for(index_t i_k = 0; i_k < k; i_k++)
{
constexpr auto span_2d = DistributedTensor::get_distributed_spans();
auto packet = [&]() {
auto tmp = make_static_distributed_tensor<ArgmaxPacket>(x.get_tile_distribution());
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
tmp.get_tile_distribution(), make_tuple(idx0, idx1));
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ArgmaxPacket t;
t.arg = x_tmp(i_j_idx); // !!! we reference x here
t.value = tile_idx.at(number<1>{});
tmp(i_j_idx) = t;
});
});
return tmp;
}();
auto argmax_init = ArgmaxPacket{-numeric<DataType>::infinity(), 0};
auto r = block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
block_tile_reduce_xor_sync(r, f_argmax);
auto o = make_static_distributed_tensor<DataType>(dst_dist);
auto i = make_static_distributed_tensor<IndexType>(dst_dist);
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ArgmaxPacket tmp = r(i_j_idx);
o(i_j_idx) = tmp.arg;
i(i_j_idx) = tmp.value;
});
});
// update value
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
x.get_tile_distribution(), make_tuple(idx0, idx1));
auto col_id = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
x_tmp(i_j_idx) = (col_id == r(i_j_idx).value) ? -numeric<DataType>::infinity()
: x_tmp(i_j_idx);
});
});
if(threadIdx.x % Problem::ColLanes == 0)
{
store_tile(out_window, o);
store_tile(idx_window, i);
}
move_tile_window(out_window, {number<0>{}, number<1>{}});
move_tile_window(idx_window, {number<0>{}, number<1>{}});
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template <typename DataType_, typename IndexType_, index_t ColLanes_>
struct BlockTopkStream2DProblem
{
using DataType = remove_cvref_t<DataType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t ColLanes = ColLanes_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
struct TopkSoftmaxHostArgs
{
const void* p_input;
void* p_output;
void* p_indices;
index_t num_rows;
index_t num_experts;
index_t topk;
};
template <typename Pipeline_>
struct TopkSoftmaxKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using InputType = typename Problem::InputType;
using WeightType = typename Problem::WeightType;
using IndexType = typename Problem::IndexType;
struct TopkSoftmaxKargs
{
const void* p_input;
void* p_output;
void* p_indices;
index_t num_rows;
index_t num_experts;
index_t topk;
};
using Kargs = TopkSoftmaxKargs;
using Hargs = TopkSoftmaxHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp;
const int num_blocks = (num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock;
return dim3(num_blocks);
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_input = h.p_input;
k.p_output = h.p_output;
k.p_indices = h.p_indices;
k.num_rows = h.num_rows;
k.num_experts = h.num_experts;
k.topk = h.topk;
return k;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock);
const auto input_window = [&]() {
const InputType* p_input = reinterpret_cast<const InputType*>(kargs.p_input) +
blockIdx.x * Problem::RowsPerBlock * kargs.num_experts;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_input,
make_tuple(kargs.num_rows, kargs.num_experts),
number<Problem::VectorSize>{});
auto view = pad_tensor_view(
tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
sequence<1, 1>{});
return make_tile_window(
view,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
{block_row_id, 0});
}();
auto output_window = [&]() {
WeightType* p_output = reinterpret_cast<WeightType*>(kargs.p_output) +
blockIdx.x * Problem::RowsPerBlock * kargs.topk;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_output, make_tuple(kargs.num_rows, kargs.topk), number<Problem::VectorSize>{});
auto view = pad_tensor_view(
tmp, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), sequence<1, 0>{});
return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {block_row_id, 0});
}();
auto indices_window = [&]() {
IndexType* p_indices = reinterpret_cast<IndexType*>(kargs.p_indices) +
blockIdx.x * Problem::RowsPerBlock * kargs.topk;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_indices, make_tuple(kargs.num_rows, kargs.topk), number<Problem::VectorSize>{});
auto view = pad_tensor_view(
tmp, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), sequence<1, 0>{});
return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {block_row_id, 0});
}();
Pipeline{}(input_window, output_window, indices_window, kargs.topk);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
struct TopkSoftmaxWarpPerRowPipeline
{
// TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
template <typename InputWindow, typename OutputWindow, typename IndexWindow>
CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
OutputWindow& out_window,
IndexWindow& idx_window,
index_t k)
{
auto input_win = make_tile_window(input_window.get_bottom_tensor_view(),
input_window.get_window_lengths(),
input_window.get_window_origin(),
Policy::template MakeInputDistribution<Problem>());
auto x = load_tile(input_win);
auto w = cast_tile<typename Problem::WeightType>(x);
auto softmax = Policy::template GetSoftmax<Problem>();
// softmax
auto y = softmax(w);
auto topk = Policy::template GetTopk<Problem>();
auto out_win = make_tile_window(out_window.get_bottom_tensor_view(),
out_window.get_window_lengths(),
out_window.get_window_origin(),
Policy::template MakeOutputDistribution<Problem>());
auto idx_win = make_tile_window(idx_window.get_bottom_tensor_view(),
idx_window.get_window_lengths(),
idx_window.get_window_origin(),
Policy::template MakeOutputDistribution<Problem>());
topk(y, out_win, idx_win, k);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace ck_tile {
struct TopkSoftmaxWarpPerRowPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{
// TODO: Y dim must have one dim that is not reduced
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<Problem::IssuesPerCol, Problem::WarpsPerBlock, Problem::RowsPerWarp>,
sequence<Problem::IssuesPerRow, Problem::LanesPerRow, Problem::VectorSize>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<Problem::LanesPerRow>, // repeat this one
tuple<sequence<Problem::WarpsPerBlock, Problem::RowsPerWarp>,
sequence<1>>, // each row write out single element
tuple<sequence<1>, sequence<1, 0>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2>,
sequence<0>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSoftmax()
{
using softmax_problem = BlockSoftmax2DProblem<typename Problem::WeightType>;
return BlockSoftmax2D<softmax_problem>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTopk()
{
using topk_problem = BlockTopkStream2DProblem<typename Problem::WeightType,
typename Problem::IndexType,
Problem::LanesPerRow>;
// Note: replicate is LanesPerRow
return BlockTopkStream2D<topk_problem>{};
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment