Unverified Commit b098b71b authored by carlushuang's avatar carlushuang Committed by GitHub
Browse files

topk_softmax (#1592)

* topk_softmax

* remove some file

* fix atomix linear_offset

* address various comment, and change sfc get_index api to static(tuple)
parent 31bf253a
......@@ -59,8 +59,16 @@ struct magic_division32_bit_range
CK_TILE_DEVICE static constexpr uint32_t
do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
if(__builtin_is_constant_evaluated())
{
uint32_t tmp = (static_cast<uint64_t>(dividend) * multiplier) >> 32;
return (tmp + dividend) >> shift;
}
else
{
uint32_t tmp = __umulhi(dividend, multiplier);
return (tmp + dividend) >> shift;
}
}
CK_TILE_HOST static constexpr uint32_t
......@@ -77,9 +85,18 @@ struct magic_division32_bit_range
CK_TILE_DEVICE static constexpr int32_t
do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
if(__builtin_is_constant_evaluated())
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = (static_cast<uint64_t>(dividend_u32) * multiplier) >> 32;
return (tmp + dividend_u32) >> shift;
}
else
{
uint32_t dividend_u32 = bit_cast<uint32_t>(dividend_i32);
uint32_t tmp = __umulhi(dividend_u32, multiplier);
return (tmp + dividend_u32) >> shift;
}
}
CK_TILE_HOST static constexpr int32_t
......
......@@ -24,5 +24,6 @@
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
......@@ -10,6 +10,7 @@
#include <random>
#include <type_traits>
#include <utility>
#include <unordered_set>
#include "ck_tile/core.hpp"
......@@ -41,6 +42,73 @@ struct FillUniformDistribution
}
};
namespace impl {
// clang-format off
template<index_t bytes> struct RawIntegerType_ {};
template<> struct RawIntegerType_<1> { using type = uint8_t;};
template<> struct RawIntegerType_<2> { using type = uint16_t;};
template<> struct RawIntegerType_<4> { using type = uint32_t;};
template<> struct RawIntegerType_<8> { using type = uint64_t;};
// clang-format on
template <typename T>
using RawIntegerType = typename RawIntegerType_<sizeof(T)>::type;
} // namespace impl
// Note: this struct will have no const-ness will generate random
template <typename T>
struct FillUniformDistribution_Unique
{
float a_{-5.f};
float b_{5.f};
std::optional<uint32_t> seed_{11939};
std::mt19937 gen_{};
std::unordered_set<impl::RawIntegerType<T>> set_{};
FillUniformDistribution_Unique(float a = -5.f,
float b = 5.f,
std::optional<uint32_t> seed = {11939})
: a_(a),
b_(b),
seed_(seed),
gen_{seed_.has_value() ? *seed_ : std::random_device{}()},
set_{}
{
}
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last)
{
std::mt19937& gen = gen_;
std::uniform_real_distribution<float> dis(a_, b_);
auto& set = set_;
std::generate(first, last, [&dis, &gen, &set]() {
T v = static_cast<T>(0);
do
{
v = ck_tile::type_convert<T>(dis(gen));
} while(set.count(bit_cast<impl::RawIntegerType<T>>(v)) == 1);
set.insert(bit_cast<impl::RawIntegerType<T>>(v));
return v;
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range)
-> std::void_t<decltype(std::declval<FillUniformDistribution_Unique&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
void clear() { set_.clear(); }
};
template <typename T>
struct FillNormalDistribution
{
......
......@@ -11,6 +11,7 @@
#include <thread>
#include <utility>
#include <vector>
#include <functional>
#include "ck_tile/core.hpp"
#include "ck_tile/host/ranges.hpp"
......@@ -545,6 +546,28 @@ struct HostTensor
typename Data::size_type size() const { return mData.size(); }
// return a slice of this tensor
// for simplicity we just copy the data and return a new tensor
auto slice(std::vector<size_t> s_begin, std::vector<size_t> s_end) const
{
assert(s_begin.size() == s_end.size());
assert(s_begin.size() == get_num_of_dimension());
std::vector<size_t> s_len(s_begin.size());
std::transform(
s_end.begin(), s_end.end(), s_begin.begin(), s_len.begin(), std::minus<size_t>{});
HostTensor<T> sliced_tensor(s_len);
sliced_tensor.ForEach([&](auto& self, auto idx) {
std::vector<size_t> src_idx(idx.size());
std::transform(
idx.begin(), idx.end(), s_begin.begin(), src_idx.begin(), std::plus<size_t>{});
self(idx) = operator()(src_idx);
});
return sliced_tensor;
}
template <typename U = T>
auto AsSpan() const
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -9,43 +9,81 @@
namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType>
CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n,
HostTensor<BDataType>& b_m_n)
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST void
reference_softmax(const HostTensor<InputType>& x, HostTensor<OutputType>& y, index_t dim = -1)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];
index_t rank = x.get_num_of_dimension();
assert(rank == y.get_num_of_dimension());
assert(dim == -1 || dim < rank);
AccDataType v_max = ck_tile::numeric<ADataType>::Lowest();
index_t target_dim = dim == -1 ? (rank - 1) : dim;
index_t softmax_len = x.get_length(target_dim);
index_t n_parallel = x.get_element_size() / softmax_len;
auto x_len = x.get_lengths();
// max
for(int n = 0; n < N; ++n)
{
const ADataType v_a = a_m_n(m, n);
auto f = [&](auto i_element) {
std::vector<size_t> coord = [&]() {
std::vector<size_t> t_(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--)
{
if(i == target_dim)
continue;
t_[i] = r % x_len[i];
r = r / x_len[i];
}
return t_;
}();
ComputeType v_max = -ck_tile::numeric<ComputeType>::infinity();
v_max = v_max < v_a ? v_a : v_max;
// compute max
for(auto idx = 0; idx < softmax_len; idx++)
{
auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_max = v_max < v_x ? v_x : v_max;
}
AccDataType v_exp_sum = 0;
ComputeType v_exp_sum = static_cast<ComputeType>(0);
// sum
for(int n = 0; n < N; ++n)
for(auto idx = 0; idx < softmax_len; idx++)
{
const ADataType v_a = a_m_n(m, n);
auto c_ = coord;
c_[target_dim] = idx;
v_exp_sum += ck_tile::exp(v_a - v_max);
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_exp_sum += ck_tile::exp(v_x - v_max);
}
// elementwise
for(int n = 0; n < N; ++n)
for(auto idx = 0; idx < softmax_len; idx++)
{
const ADataType v_a = a_m_n(m, n);
auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
auto out = ck_tile::exp(v_x - v_max) / v_exp_sum;
b_m_n(m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum;
y(c_) = ck_tile::type_convert<OutputType>(out);
}
};
make_ParallelTensorFunctor(f,
b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
}
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST auto reference_softmax(const HostTensor<InputType>& x, index_t dim = -1)
{
HostTensor<OutputType> y(x.get_lengths(), x.get_strides());
reference_softmax<InputType, ComputeType, OutputType>(x, y, dim);
return y;
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
#include <numeric>
#include <functional>
#include <utility>
#include <algorithm>
namespace ck_tile {
/*
similiar to torch.topk()
x (Tensor) – the input tensor.
k (int) – the k in “top-k”
dim (int, optional) – the dimension to sort along
largest (bool, optional) – largest or smallest elements
sorted (bool, optional) – elements in sorted order or not
output:
y_values
y_indices
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TopKImpl.h
*/
template <typename DataType, typename IndexType = index_t>
CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
HostTensor<DataType>& y_values,
HostTensor<IndexType>& y_indices,
index_t k,
index_t dim = -1,
bool largest = true,
bool sorted = true)
{
// rank must be the same
index_t rank = x.get_num_of_dimension();
assert(rank == y_values.get_num_of_dimension());
assert(rank == y_indices.get_num_of_dimension());
assert(dim == -1 || dim < rank);
index_t topk_dim = dim == -1 ? (rank - 1) : dim;
index_t topk_src_len = x.get_length(topk_dim);
auto x_len = x.get_lengths();
assert(k <= topk_src_len);
assert(k == y_values.get_length(topk_dim) && k == y_indices.get_length(topk_dim));
index_t n_parallel = x.get_element_size() / topk_src_len;
// clang-format off
auto f = [&](auto i_element) {
std::vector<size_t> topk_coord = [&](){
std::vector<size_t> t_(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--) {
if(i == topk_dim) continue; // topk dim should be zero
t_[i] = r % x_len[i]; r = r / x_len[i];
}
return t_;
}();
using elem_t = std::pair<DataType, IndexType>;
std::vector<elem_t> q = [&](){
std::vector<elem_t> t_(topk_src_len);
for(index_t i = 0; i < topk_src_len; i++) {
auto c_ = topk_coord; c_[topk_dim] = i;
t_[i].first = x(c_); t_[i].second = i;
}
return t_;
}();
// run topk
if(largest) {
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
if(sorted) {
std::sort(q.begin(), q.begin() + k - 1,
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
}
} else {
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
if(sorted) {
std::sort(q.begin(), q.begin() + k - 1,
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
}
}
// write out
for(index_t i = 0; i < k; i++) {
auto c_ = topk_coord; c_[topk_dim] = i;
y_values(c_) = q[i].first; y_indices(c_) = q[i].second;
}
};
// clang-format on
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
}
// TODO: if using this method, the return tensor would be dense(no stride)
template <typename DataType, typename IndexType = index_t>
CK_TILE_HOST auto reference_topk(const HostTensor<DataType>& x,
index_t k,
index_t dim = -1,
bool largest = true,
bool sorted = true)
{
auto lens = x.get_lengths();
index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim;
assert(target_dim < lens.size());
assert(k <= lens[target_dim]);
lens[target_dim] = k;
HostTensor<DataType> y_values(lens);
HostTensor<IndexType> y_indices(lens);
reference_topk<DataType, IndexType>(x, y_values, y_indices, k, dim, largest, sorted);
return ck_tile::make_tuple(y_values, y_indices);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <type_traits>
namespace ck_tile {
namespace element_wise {
#if 0
struct PassThroughPack2
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
CK_TILE_HOST_DEVICE constexpr void operator()(ck_tile::half2_t& y, const ck_tile::f8x2_t& x) const
{
auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t);
}
constexpr const static bool is_pack2_invocable = true;
};
#endif
struct PassThrough
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
const ck_tile::bf16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::fp16_t>(ck_tile::bf16_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
const ck_tile::fp16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
{
y = type_convert<int32_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
{
y = type_convert<float>(x);
}
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
{
y = type_convert<int4_t>(x);
}
#endif
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
const ck_tile::fp8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
const ck_tile::bf8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
}
};
#if 0
struct UnaryConvert
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = type_convert<Y>(x);
}
};
struct ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::bf16_t>, "Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = bf16_convert_rtn<Y>(x);
}
};
struct ConvertF8SR
{
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_sr<Y>(x);
}
};
struct ConvertF8RNE
{
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(std::is_same_v<Y, ck_tile::fp8_t> || std::is_same_v<Y, ck_tile::bf8_t>,
"Data type is not supported by this operation!");
// check X datatype
static_assert(std::is_same_v<X, float> || std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
}
};
#endif
struct Scale
{
CK_TILE_HOST_DEVICE Scale(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
y = ck_tile::type_convert<Y>(ck_tile::type_convert<float>(x) * scale_);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::fp16_t>(scale_) * x;
};
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
const float x_tmp = ck_tile::type_convert<float>(x);
const float y_tmp = scale_ * x_tmp;
y = ck_tile::type_convert<ck_tile::bf16_t>(y_tmp);
};
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = scale_ * x;
};
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
{
y = scale_ * x;
};
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = ck_tile::type_convert<int8_t>(scale_ * ck_tile::type_convert<float>(x));
};
float scale_;
};
struct ScaleAndResetNaNToMinusInfinity
{
CK_TILE_HOST_DEVICE ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = ck_tile::isnan(x) ? -numeric<float>::infinity() : scale_ * x;
};
float scale_;
};
struct UnaryDivide
{
CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = x / type_convert<T>(divider_);
};
int32_t divider_ = 1;
};
struct UnarySquare
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, ck_tile::fp16_t> ||
std::is_same_v<T, double> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| std::is_same_v<T, int4_t>
#endif
,
"Data type is not supported by this operation!");
y = x * x;
};
};
struct UnaryAbs
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
y = ck_tile::abs(x);
};
};
struct UnarySqrt
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double>,
"Data type is not supported by this operation!");
y = ck_tile::sqrt(x);
};
};
struct Relu
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
y = x > 0 ? x : 0;
}
template <>
CK_TILE_HOST_DEVICE void operator()(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
{
float x_f32 = ck_tile::type_convert<float>(x);
float y_f32 = x_f32 > 0 ? x_f32 : 0;
y = ck_tile::type_convert<ck_tile::bf16_t>(y_f32);
}
};
// Fast GeLU
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
// host code use higher accuracy "exp" and "div"
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct FastGelu
{
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
{
// const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = __ocml_exp_f32(u);
y = x * ck_tile::rcp(1.f + emu);
}
template <>
CK_TILE_HOST void operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y,
const ck_tile::fp16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y,
const ck_tile::fp16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::fp16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y, const float& x) const
{
float y_f;
this->operator()<float, float>(y_f, x);
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_DEVICE void operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y,
const ck_tile::bf16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::bf16_t>(y_f);
}
template <>
CK_TILE_HOST void operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y,
const ck_tile::bf16_t& x) const
{
float y_f;
this->operator()<float, float>(y_f, type_convert<float>(x));
y = type_convert<ck_tile::bf16_t>(y_f);
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
{
y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::fp16_t(0.5) * x *
(ck_tile::fp16_t(1) + ck_tile::fp16_t(erf(float(0.70710678118f * x))));
}
};
struct Sigmoid
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = one / (one + ck_tile::exp(-x));
};
};
struct Silu
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = x * (one / (one + ck_tile::exp(-x)));
};
};
struct TanH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::tanh(x);
};
};
struct ACos
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::acos(x);
};
};
struct Neg
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::neg(x);
};
};
struct ATan
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::atan(x);
};
};
struct Sin
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::sin(x);
};
};
struct ASinH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::asinh(x);
};
};
struct Cos
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::cos(x);
};
};
struct ACosH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::acosh(x);
};
};
struct Tan
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::tan(x);
};
};
struct ATanH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::atanh(x);
};
};
struct SinH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::sinh(x);
};
};
struct Ceil
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::ceil(x);
};
};
struct Exp
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::exp(x);
};
};
struct CosH
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::cosh(x);
};
};
struct Floor
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::floor(x);
};
};
struct Log
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::log(x);
};
};
struct ASin
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::asin(x);
};
};
struct Rcp
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int8_t> ||
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
y = ck_tile::rcp(x);
};
};
struct Swish
{
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
static_assert(std::is_same_v<X, float> || std::is_same_v<X, double> ||
std::is_same_v<X, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
static_assert(std::is_same_v<Y, float> || std::is_same_v<Y, double> ||
std::is_same_v<Y, ck_tile::fp16_t>,
"Data type is not supported by this operation!");
float bx = -beta_ * type_convert<float>(x);
y = type_convert<Y>(x / (1.f + ck_tile::exp(bx)));
};
const float beta_;
};
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = ck_tile::log(one + ck_tile::exp(x * casted_alpha)) / casted_alpha;
}
const float alpha_;
};
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
T casted_gamma = type_convert<T>(gamma_);
T shifted_scaled_x = casted_alpha + casted_beta * x;
y = ck_tile::pow(shifted_scaled_x, casted_gamma);
}
const float alpha_;
const float beta_;
const float gamma_;
};
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
T casted_beta = type_convert<T>(beta_);
y = ck_tile::min(casted_beta, ck_tile::max(casted_alpha, x));
}
const float alpha_;
const float beta_;
};
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x >= 0 ? x : x * casted_alpha;
}
const float alpha_;
};
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
y = x > 0 ? x : casted_alpha * ck_tile::expm1(x);
}
const float alpha_;
};
struct Logistic
{
Logistic(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
std::is_same_v<T, ck_tile::fp16_t> || std::is_same_v<T, int32_t> ||
std::is_same_v<T, int8_t>,
"Data type is not supported by this operation!");
T casted_alpha = type_convert<T>(alpha_);
constexpr T one = type_convert<T>(1);
y = casted_alpha / (one + ck_tile::exp(-x) * casted_alpha);
}
const float alpha_;
};
struct ConvInvscale
{
CK_TILE_HOST_DEVICE
ConvInvscale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
e = type_convert<ck_tile::fp8_t>(c / scale_in_ / scale_wei_ / scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
struct ConvScale
{
CK_TILE_HOST_DEVICE
ConvScale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
e = type_convert<ck_tile::fp8_t>(c * scale_in_ * scale_wei_ * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
struct ConvScaleRelu
{
CK_TILE_HOST_DEVICE
ConvScaleRelu(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
{
}
template <typename E, typename C>
CK_TILE_HOST_DEVICE void operator()(E& e, const C& c) const;
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& e,
const float& c) const
{
float x;
Relu{}.template operator()<float>(x, c * scale_in_ * scale_wei_);
e = type_convert<ck_tile::fp8_t>(x * scale_out_);
};
float scale_in_;
float scale_wei_;
float scale_out_;
};
template <typename DstType, typename SrcType>
struct Cast
{
template <typename T>
CK_TILE_HOST_DEVICE void operator()(DstType& y, const SrcType& x) const
{
y = ck_tile::type_convert<DstType>(x);
};
};
// support fastconvert of int8 to fp16
#if 0
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
struct FastNumericArrayConverter
{
};
template <>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4>
{
using InputArray = vector_type<uint8_t, 4>;
using OutputArray = vector_type<ck_tile::fp16_t, 4>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
OutputArray Output;
uint32_t* half_2 = reinterpret_cast<uint32_t*>(&Output);
uint32_t const uint8_4 = reinterpret_cast<uint32_t const&>(Input);
static constexpr uint32_t byte_selector_01 = 0x05010500;
static constexpr uint32_t byte_selector_23 = 0x05030502;
static constexpr uint32_t fp16_adder = 0x64646464;
half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01);
half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[0])
: "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM));
asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]"
: "=v"(half_2[1])
: "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM));
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
template <index_t N>
struct FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, N>
{
static constexpr int VEC_WIDTH = 4;
static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");
using InputArray = vector_type<uint8_t, N>;
using OutputArray = vector_type<ck_tile::fp16_t, N>;
CK_TILE_DEVICE static OutputArray convert(InputArray const& Input)
{
FastNumericArrayConverter<uint8_t, ck_tile::fp16_t, 4> converter;
OutputArray Output;
using Vec_InputArray = vector_type<uint8_t, 4>;
using Vec_OutputArray = vector_type<ck_tile::fp16_t, 4>;
Vec_OutputArray* half_4_ptr = reinterpret_cast<Vec_OutputArray*>(&Output);
Vec_InputArray const* uint8_4_ptr = reinterpret_cast<Vec_InputArray const*>(&Input);
static_for<0, N / VEC_WIDTH, 1>{}(
[&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); });
return Output;
}
CK_TILE_DEVICE OutputArray operator()(InputArray const& Input) { return convert(Input); }
};
#endif
} // namespace element_wise
} // namespace ck_tile
......@@ -334,7 +334,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer());
buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
......@@ -359,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access());
async_load_fence(k_dram_window.get_num_of_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
......
......@@ -4,9 +4,14 @@
#pragma once
#include "ck_tile/core.hpp"
#include <tuple>
namespace ck_tile {
/*
* TODO: block_tile_reduce_sync() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
......@@ -104,6 +109,65 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
});
}
/*
* this version is faster, using xor to do reduce, no need broadcast anymore
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
*/
template <typename AccDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func)
{
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
// loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local = acc_tensor.get_thread_buffer()[i];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
// xor
index_t src_lane =
__lane_id() ^ (number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
// reduce
v_local = reduce_func(v_local, v_remote);
});
}
});
acc_tensor.get_thread_buffer()(i) = v_local;
});
}
// FIXME: this is for 2D to 1D reduce only, need to support n-D
template <typename AccDistributedTensor_,
typename InDistributedTensor_,
......@@ -175,6 +239,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
#endif
}
/*
* TODO: block_tile_reduce() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
template <typename AccDataType_,
typename InDistributedTensor_,
index_t... InReduceDims,
......@@ -208,4 +276,106 @@ CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
return acc_tensor;
}
// this version only support 2D->1D reduce (reduce-dim=seq<0, 1>)
// this version only support in/acc/out datatypes are the same
// this version will call thread/warp+sync in one function call
//
template <typename InDistributedTensor_>
struct BlockReduce2D
{
using InDistributedTensor = remove_cvref_t<InDistributedTensor_>;
using InDataType = typename InDistributedTensor::DataType;
CK_TILE_HOST_DEVICE BlockReduce2D(const InDistributedTensor& t_, const InDataType& reduce_init_)
: t(t_), reduce_init(reduce_init_)
{
}
CK_TILE_HOST_DEVICE constexpr auto MakeDstBlockTile() const
{
using ReduceDim = sequence<1>; // hard coded
constexpr auto acc_dstr =
make_static_tile_distribution(ck_tile::detail::make_reduce_tile_distribution_encoding(
InDistributedTensor::get_tile_distribution()
.get_static_tile_distribution_encoding(),
ReduceDim{}));
return make_static_distributed_tensor<InDataType>(acc_dstr);
}
// return number of pixels each lane need to reduce
CK_TILE_HOST_DEVICE constexpr auto get_reduce_length_y() const
{
constexpr auto spans = InDistributedTensor::get_distributed_spans();
}
// Here ReducePacksPerXDim is not the same meaning as that in static_uford/sweep_tile_uspan
// this is number of packs along the X-dim. We need to compute the Unpacks along the Y dim
// internally
// For simplicity, we just support along the row dimension, ReducePacksPerXDim is always 2
// element , and the first element is always ignored For simplicity, will always try from
// right-to-left to find alone which Y dim to split
template <typename ReduceFunc,
typename ReduceSyncFunc,
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func,
const ReduceSyncFunc& reduce_sync_func,
ReducePacksPerXDim = {}) const
{
constexpr auto spans = InDistributedTensor::get_distributed_spans();
constexpr auto row_y_unpacks = [&]() {
constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
constexpr auto row_y_size =
reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
static_assert(row_y_size % row_y_packs == 0);
constexpr auto row_y_slice_size = row_y_size / row_y_packs;
constexpr auto slice_info = slice_sequence(row_y_lengths, number<row_y_slice_size>{});
constexpr auto unpacks = slice_info[number<1>{}];
return unpacks;
}();
auto acc_tensor = MakeDstBlockTile();
// in-thread reduction
// FIXME: hard coded to be 2D to 1D reduction
sweep_tile_span(spans[number<0>{}], [&](auto dstr_idx_i0) {
constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
auto acc = acc_tensor[acc_dstr_idx];
sweep_tile_uspan(
spans[number<1>{}],
[&](auto... dstr_idx_i1) {
acc = reduce_func(acc, t[make_tuple(dstr_idx_i0, dstr_idx_i1)]...);
},
row_y_unpacks);
acc_tensor(acc_dstr_idx) = acc;
});
// TODO: always use xor to do cross-lane reduce
block_tile_reduce_xor_sync(acc_tensor, reduce_sync_func);
return acc_tensor;
}
template <typename ReduceFunc>
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func) const
{
return operator()(reduce_func, reduce_func);
}
InDistributedTensor t;
InDataType reduce_init;
};
// deduction guide
template <typename T>
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D<T>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#define _BLOCK_SOFTMAX_USE_UNPACK2 0
namespace ck_tile {
/*
simple 2d softmax implementation, along row (dim=1)
requirement:
1). each row is within a warp
2). data type must be a dword
*/
template <typename Problem_, typename Policy_ = void>
struct BlockSoftmax2D
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using DataType = typename Problem::DataType;
template <typename DistributedTensor, index_t dim = 1>
CK_TILE_DEVICE void
operator()(const DistributedTensor& x, DistributedTensor& y, number<dim> = {})
{
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
#if _BLOCK_SOFTMAX_USE_UNPACK2
const auto f_max3 = [](auto e0, auto e1, auto e2) {
float rtn;
asm volatile("v_max3_f32 %0, %1, %2, %3" : "=v"(rtn) : "v"(e0), "v"(e1), "v"(e2));
return rtn;
};
const auto f_sum3 = [](auto e0, auto e1, auto e2) { return e0 + e1 + e2; };
#endif
// compute row max
auto reduce_row_max = BlockReduce2D{x, -numeric<DataType>::infinity()};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
#else
auto row_max = reduce_row_max(f_max);
#endif
sweep_tile<DistributedTensor>([&](auto idx) {
constexpr auto row_id = make_tuple(idx[number<0>{}]);
y(idx) = exp(x[idx] - row_max[row_id]);
});
// compute row sum
auto reduce_row_sum = BlockReduce2D<decltype(y)>{y, DataType{0}};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto row_sum = reduce_row_sum(f_sum3, f_sum, sequence<1, 2>{});
#else
auto row_sum = reduce_row_sum(f_sum);
#endif
// reciprocal
auto r = make_static_distributed_tensor<DataType>(row_sum.get_tile_distribution());
sweep_tile(row_sum, [&](auto idx) { r(idx) = DataType{1} / row_sum(idx); });
// scale
sweep_tile<DistributedTensor>([&](auto idx) {
constexpr auto row_id = make_tuple(idx[number<0>{}]);
y(idx) = y(idx) * r(row_id);
});
}
template <typename DistributedTensor, index_t dim = 1>
CK_TILE_DEVICE decltype(auto) operator()(const DistributedTensor& x, number<dim> = {})
{
auto y = DistributedTensor{}; // distributed tensor
operator()(x, y, number<dim>{});
return y;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename DataType_>
struct BlockSoftmax2DProblem
{
using DataType = remove_cvref_t<DataType_>;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template <typename Problem_, typename Policy_ = void>
struct BlockTopkStream2D
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using DataType = typename Problem::DataType;
using IndexType = typename Problem::IndexType;
// TODO: if DataType is subdword, need pack into single dword to use argmax
struct ArgmaxPacket
{
DataType arg;
index_t value;
};
template <typename DistributedTensor, typename OutWindow, typename IdxWindow, index_t dim = 1>
CK_TILE_DEVICE void operator()(const DistributedTensor& x,
const OutWindow& out_window,
const IdxWindow& idx_window,
index_t k,
number<dim> = {})
{
OutWindow out_window_tmp = out_window;
IdxWindow idx_window_tmp = idx_window;
static_assert(
std::is_same_v<typename DistributedTensor::DataType, typename OutWindow::DataType> &&
std::is_same_v<typename DistributedTensor::DataType, DataType>);
static_assert(std::is_same_v<typename IdxWindow::DataType, IndexType>);
DistributedTensor x_tmp = x;
constexpr auto dst_dist = typename IdxWindow::TileDstr{};
// argmax for topk
const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
return e0.arg > e1.arg ? e0 : e1;
};
for(index_t i_k = 0; i_k < k; i_k++)
{
constexpr auto span_2d = DistributedTensor::get_distributed_spans();
auto packet = [&]() {
auto tmp = make_static_distributed_tensor<ArgmaxPacket>(x.get_tile_distribution());
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
tmp.get_tile_distribution(), make_tuple(idx0, idx1));
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ArgmaxPacket t;
t.arg = x_tmp(i_j_idx); // !!! we reference x here
t.value = tile_idx.at(number<1>{});
tmp(i_j_idx) = t;
});
});
return tmp;
}();
auto argmax_init = ArgmaxPacket{-numeric<DataType>::infinity(), 0};
auto r = block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
block_tile_reduce_xor_sync(r, f_argmax);
auto o = make_static_distributed_tensor<DataType>(dst_dist);
auto i = make_static_distributed_tensor<IndexType>(dst_dist);
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ArgmaxPacket tmp = r(i_j_idx);
o(i_j_idx) = tmp.arg;
i(i_j_idx) = tmp.value;
});
});
// update value
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
x.get_tile_distribution(), make_tuple(idx0, idx1));
auto col_id = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
x_tmp(i_j_idx) = (col_id == r(i_j_idx).value) ? -numeric<DataType>::infinity()
: x_tmp(i_j_idx);
});
});
if(threadIdx.x % Problem::ColLanes == 0)
{
store_tile(out_window_tmp, o);
store_tile(idx_window_tmp, i);
}
move_tile_window(out_window_tmp, {number<0>{}, number<1>{}});
move_tile_window(idx_window_tmp, {number<0>{}, number<1>{}});
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template <typename DataType_, typename IndexType_, index_t ColLanes_>
struct BlockTopkStream2DProblem
{
using DataType = remove_cvref_t<DataType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t ColLanes = ColLanes_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
struct TopkSoftmaxHostArgs
{
const void* p_input;
void* p_output;
void* p_indices;
index_t num_rows;
index_t num_experts;
index_t topk;
index_t stride_input; // row stride for input, at least experts
index_t stride_output; // row stride for output/indices, at least tpok
};
template <typename Pipeline_>
struct TopkSoftmaxKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using InputType = typename Problem::InputType;
using WeightType = typename Problem::WeightType;
using IndexType = typename Problem::IndexType;
struct TopkSoftmaxKargs
{
const void* p_input;
void* p_output;
void* p_indices;
index_t num_rows;
index_t num_experts;
index_t topk;
index_t stride_input; // row stride for input, at least experts
index_t stride_output; // row stride for output/indices, at least tpok
};
using Kargs = TopkSoftmaxKargs;
using Hargs = TopkSoftmaxHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
if constexpr(Problem::LaunchType > 0)
{
int num_cu = [&]() {
hipDeviceProp_t dev_prop;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}();
return dim3(num_cu * Problem::LaunchType);
}
else
{
const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp;
const int num_blocks =
(num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock;
return dim3(num_blocks);
}
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_input = h.p_input;
k.p_output = h.p_output;
k.p_indices = h.p_indices;
k.num_rows = h.num_rows;
k.num_experts = h.num_experts;
k.topk = h.topk;
k.stride_input = h.stride_input;
k.stride_output = h.stride_output;
return k;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock);
if(block_row_id > kargs.num_rows)
return;
index_t block_os_inp = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_input);
index_t block_os_out = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_output);
index_t num_rows_rem = __builtin_amdgcn_readfirstlane(kargs.num_rows - block_row_id);
const auto input_window = [&]() {
const InputType* p_input =
reinterpret_cast<const InputType*>(kargs.p_input) + block_os_inp;
auto tmp = make_naive_tensor_view<address_space_enum::global>(
p_input,
make_tuple(num_rows_rem, kargs.num_experts),
make_tuple(kargs.stride_input, 1),
number<Problem::VectorSize>{},
number<1>{});
auto view = pad_tensor_view(
tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
sequence<0, 1>{}); // out-most dim no need pad(leverage oob)
return make_tile_window(
view,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
{0, 0});
}();
auto output_window = [&]() {
WeightType* p_output = reinterpret_cast<WeightType*>(kargs.p_output) + block_os_out;
auto tmp = make_naive_tensor_view<address_space_enum::global>(
p_output,
make_tuple(num_rows_rem, kargs.topk),
make_tuple(kargs.stride_output, 1),
number<Problem::VectorSize>{},
number<1>{});
auto view =
pad_tensor_view(tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
}();
auto indices_window = [&]() {
IndexType* p_indices = reinterpret_cast<IndexType*>(kargs.p_indices) + block_os_out;
auto tmp = make_naive_tensor_view<address_space_enum::global>(
p_indices,
make_tuple(num_rows_rem, kargs.topk),
make_tuple(kargs.stride_output, 1),
number<Problem::VectorSize>{},
number<1>{});
auto view =
pad_tensor_view(tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
}();
Pipeline{}(input_window,
output_window,
indices_window,
kargs.num_rows,
kargs.num_experts,
kargs.topk,
block_row_id);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include <string>
#include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace ck_tile {
template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
struct TopkSoftmaxWarpPerRowPipeline
{
// TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using WeightType = typename Problem::WeightType;
template <typename InputWindow, typename OutputWindow, typename IndexWindow>
CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
OutputWindow& out_window,
IndexWindow& idx_window,
index_t rows,
index_t experts,
index_t k,
index_t block_row_id)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
auto inp_win = make_tile_window_linear_raw(
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
#else
auto inp_win = make_tile_window_linear(
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
#endif
auto out_win = make_tile_window_linear(out_window.get_bottom_tensor_view(),
out_window.get_window_lengths(),
out_window.get_window_origin(),
Policy::template MakeOutputDistribution<Problem>());
auto idx_win = make_tile_window_linear(idx_window.get_bottom_tensor_view(),
idx_window.get_window_lengths(),
idx_window.get_window_origin(),
Policy::template MakeOutputDistribution<Problem>());
auto softmax = Policy::template GetSoftmax<Problem>();
auto topk = Policy::template GetTopk<Problem>();
const index_t grid_rows_per_loop = gridDim.x * Problem::RowsPerBlock;
while(1)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier(0);
auto x =
load_tile_raw(inp_win, number<-1>{}, bool_constant<true>{}, bool_constant<true>{});
buffer_load_fence(number<0>{});
__builtin_amdgcn_sched_barrier(0);
#else
auto x = load_tile(inp_win);
#endif
// cast and pad input data
auto w = [&]() {
#if 0
auto w_ = cast_tile<WeightType>(x);
constexpr auto span_2d = decltype(w_)::get_distributed_spans();
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
w_.get_tile_distribution(), i_j_idx);
const auto current_expert = x_indices.at(number<1>{});
// set to -INF if OOB so that later softmax can work properly
w_(i_j_idx) = current_expert >= experts ? -numeric<WeightType>::infinity()
: w_(i_j_idx);
});
});
return w_;
#else
auto w_ = make_static_distributed_tensor<WeightType>(x.get_tile_distribution());
auto w_f = [&](auto idx) {
w_(idx) = type_convert<WeightType>(x(idx));
const auto x_indices =
get_x_indices_from_distributed_indices(w_.get_tile_distribution(), idx);
const auto current_expert = x_indices.at(number<1>{});
w_(idx) =
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
};
tile_sweeper ts{w_, w_f};
ts();
return w_;
#endif
}();
// softmax
auto y = softmax(w);
topk(y, out_win, idx_win, k);
// check exit
if constexpr(Problem::LaunchType == 0)
{
break;
}
else
{
block_row_id += grid_rows_per_loop;
if(block_row_id >= rows)
break;
}
move_tile_window(inp_win, {grid_rows_per_loop, number<0>{}});
move_tile_window(out_win, {grid_rows_per_loop, number<0>{}});
move_tile_window(idx_win, {grid_rows_per_loop, number<0>{}});
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace ck_tile {
struct TopkSoftmaxWarpPerRowPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{
// TODO: Y dim must have one dim that is not reduced
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<Problem::IssuesPerCol,
Problem::WarpsPerBlock,
Problem::RowsPerWarpPerColIssue>,
sequence<Problem::IssuesPerRow, Problem::LanesPerRow, Problem::VectorSize>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<Problem::LanesPerRow>, // repeat this one
tuple<sequence<Problem::IssuesPerCol,
Problem::WarpsPerBlock,
Problem::RowsPerWarpPerColIssue>,
sequence<1>>, // each row write out single element
tuple<sequence<1>, sequence<1, 0>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSoftmax()
{
using softmax_problem = BlockSoftmax2DProblem<typename Problem::WeightType>;
return BlockSoftmax2D<softmax_problem>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTopk()
{
using topk_problem = BlockTopkStream2DProblem<typename Problem::WeightType,
typename Problem::IndexType,
Problem::LanesPerRow>;
// Note: replicate is LanesPerRow
return BlockTopkStream2D<topk_problem>{};
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment