Commit a848e002 authored by carlushuang's avatar carlushuang
Browse files

support topk-softmax up to 64 experts

parent a24c5694
#include "topk_softmax_api.hpp" #include "topk_softmax_api.hpp"
#define TOPK_SOFTMAX_DISPATCH(experts_) \
constexpr ck_tile::index_t ts_experts = experts_; \
using ts_problem = ck_tile:: \
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>; \
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
\
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
constexpr dim3 blocks = kernel::BlockSize(); \
\
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs)); \
\
return ave_time;
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s) float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s)
{ {
if(t.input_type == "fp16" && t.weight_type == "fp32") if(t.input_type == "fp16" && t.weight_type == "fp32")
{ {
using ts_input_type = ck_tile::fp16_t; using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float; using ts_weight_type = float;
using ts_index_type = ck_tile::index_t; using ts_index_type = ck_tile::index_t;
constexpr ck_tile::index_t ts_experts = 8;
using ts_problem = ck_tile::
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>;
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>;
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
return ave_time; if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64)
}
} }
else if(t.input_type == "bf16" && t.weight_type == "fp32") else if(t.input_type == "bf16" && t.weight_type == "fp32")
{ {
using ts_input_type = ck_tile::bf16_t; using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float; using ts_weight_type = float;
using ts_index_type = ck_tile::index_t; using ts_index_type = ck_tile::index_t;
constexpr ck_tile::index_t ts_experts = 8; if(t.experts <= 8)
using ts_problem = ck_tile:: {
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>; TOPK_SOFTMAX_DISPATCH(8)
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; }
else if(t.experts <= 16)
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; {
TOPK_SOFTMAX_DISPATCH(16)
auto kargs = kernel::MakeKargs(a); }
else if(t.experts <= 32)
const dim3 grids = kernel::GridSize(a); {
constexpr dim3 blocks = kernel::BlockSize(); TOPK_SOFTMAX_DISPATCH(32)
}
float ave_time = ck_tile::launch_kernel( else if(t.experts <= 64)
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs)); {
TOPK_SOFTMAX_DISPATCH(64)
return ave_time; }
} }
return -1; return -1;
} }
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <random> #include <random>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <unordered_set>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
...@@ -41,6 +42,73 @@ struct FillUniformDistribution ...@@ -41,6 +42,73 @@ struct FillUniformDistribution
} }
}; };
namespace impl {
// clang-format off
template<index_t bytes> struct RawIntegerType_ {};
template<> struct RawIntegerType_<1> { using type = uint8_t;};
template<> struct RawIntegerType_<2> { using type = uint16_t;};
template<> struct RawIntegerType_<4> { using type = uint32_t;};
template<> struct RawIntegerType_<8> { using type = uint64_t;};
// clang-format on
template <typename T>
using RawIntegerType = typename RawIntegerType_<sizeof(T)>::type;
} // namespace impl
// Note: this struct will have no const-ness will generate random
template <typename T>
struct FillUniformDistribution_Unique
{
float a_{-5.f};
float b_{5.f};
std::optional<uint32_t> seed_{11939};
std::mt19937 gen_{}; // (seed_.has_value() ? *seed_ : std::random_device{}());
std::unordered_set<impl::RawIntegerType<T>> set_{};
FillUniformDistribution_Unique(float a = -5.f,
float b = 5.f,
std::optional<uint32_t> seed = {11939})
: a_(a),
b_(b),
seed_(seed),
gen_{seed_.has_value() ? *seed_ : std::random_device{}()},
set_{}
{
}
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last)
{
std::mt19937& gen = gen_;
std::uniform_real_distribution<float> dis(a_, b_);
auto& set = set_;
std::generate(first, last, [&dis, &gen, &set]() {
T v = static_cast<T>(0);
do
{
v = ck_tile::type_convert<T>(dis(gen));
} while(set.count(bit_cast<impl::RawIntegerType<T>>(v)) == 1);
set.insert(bit_cast<impl::RawIntegerType<T>>(v));
return v;
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range)
-> std::void_t<decltype(std::declval<FillUniformDistribution_Unique&>()(
std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
void clear() { set_.clear(); }
};
template <typename T> template <typename T>
struct FillNormalDistribution struct FillNormalDistribution
{ {
......
...@@ -71,7 +71,7 @@ struct TopkSoftmaxKernel ...@@ -71,7 +71,7 @@ struct TopkSoftmaxKernel
index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock); index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock);
const auto input_window = [&]() { const auto input_window = [&]() {
const InputType* p_input = reinterpret_cast<const InputType*>(kargs.p_input) + const InputType* p_input = reinterpret_cast<const InputType*>(kargs.p_input) +
blockIdx.x * Problem::RowsPerBlock * kargs.num_experts; block_row_id * kargs.num_experts;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>( auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_input, p_input,
...@@ -85,33 +85,33 @@ struct TopkSoftmaxKernel ...@@ -85,33 +85,33 @@ struct TopkSoftmaxKernel
return make_tile_window( return make_tile_window(
view, view,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}), make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
{block_row_id, 0}); {0, 0});
}(); }();
auto output_window = [&]() { auto output_window = [&]() {
WeightType* p_output = reinterpret_cast<WeightType*>(kargs.p_output) + WeightType* p_output =
blockIdx.x * Problem::RowsPerBlock * kargs.topk; reinterpret_cast<WeightType*>(kargs.p_output) + block_row_id * kargs.topk;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>( auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_output, make_tuple(kargs.num_rows, kargs.topk), number<Problem::VectorSize>{}); p_output, make_tuple(kargs.num_rows, kargs.topk), number<Problem::VectorSize>{});
auto view = pad_tensor_view( auto view = pad_tensor_view(
tmp, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), sequence<1, 0>{}); tmp, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), sequence<1, 0>{});
return make_tile_window( return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {block_row_id, 0}); view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
}(); }();
auto indices_window = [&]() { auto indices_window = [&]() {
IndexType* p_indices = reinterpret_cast<IndexType*>(kargs.p_indices) + IndexType* p_indices =
blockIdx.x * Problem::RowsPerBlock * kargs.topk; reinterpret_cast<IndexType*>(kargs.p_indices) + block_row_id * kargs.topk;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>( auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_indices, make_tuple(kargs.num_rows, kargs.topk), number<Problem::VectorSize>{}); p_indices, make_tuple(kargs.num_rows, kargs.topk), number<Problem::VectorSize>{});
auto view = pad_tensor_view( auto view = pad_tensor_view(
tmp, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), sequence<1, 0>{}); tmp, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), sequence<1, 0>{});
return make_tile_window( return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {block_row_id, 0}); view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
}(); }();
Pipeline{}(input_window, output_window, indices_window, kargs.topk); Pipeline{}(input_window, output_window, indices_window, kargs.topk, kargs.num_experts);
} }
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -14,14 +14,16 @@ template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy> ...@@ -14,14 +14,16 @@ template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
struct TopkSoftmaxWarpPerRowPipeline struct TopkSoftmaxWarpPerRowPipeline
{ {
// TODO: this kernel only support warp per row // TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
using WeightType = typename Problem::WeightType;
template <typename InputWindow, typename OutputWindow, typename IndexWindow> template <typename InputWindow, typename OutputWindow, typename IndexWindow>
CK_TILE_DEVICE auto operator()(const InputWindow& input_window, CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
OutputWindow& out_window, OutputWindow& out_window,
IndexWindow& idx_window, IndexWindow& idx_window,
index_t k) index_t k,
index_t experts)
{ {
auto input_win = make_tile_window(input_window.get_bottom_tensor_view(), auto input_win = make_tile_window(input_window.get_bottom_tensor_view(),
input_window.get_window_lengths(), input_window.get_window_lengths(),
...@@ -29,7 +31,25 @@ struct TopkSoftmaxWarpPerRowPipeline ...@@ -29,7 +31,25 @@ struct TopkSoftmaxWarpPerRowPipeline
Policy::template MakeInputDistribution<Problem>()); Policy::template MakeInputDistribution<Problem>());
auto x = load_tile(input_win); auto x = load_tile(input_win);
auto w = cast_tile<typename Problem::WeightType>(x);
// cast and pad input data
auto w = [&]() {
auto w_ = cast_tile<WeightType>(x);
constexpr auto span_2d = decltype(w_)::get_distributed_spans();
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices =
get_x_indices_from_distributed_indices(w_.get_tile_distribution(), i_j_idx);
const auto current_expert = x_indices.at(number<1>{});
// set to -INF if OOB so that later softmax can work properly
w_(i_j_idx) =
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(i_j_idx);
});
});
return w_;
}();
auto softmax = Policy::template GetSoftmax<Problem>(); auto softmax = Policy::template GetSoftmax<Problem>();
......
...@@ -14,11 +14,9 @@ ...@@ -14,11 +14,9 @@
#include "ck_tile/ops/reduce.hpp" #include "ck_tile/ops/reduce.hpp"
#include "topk_softmax_api.hpp" #include "topk_softmax_api.hpp"
// #ifndef TEST_TOPK_SOFTMAX_VERBOSE #ifndef TEST_TOPK_SOFTMAX_VERBOSE
// #define TEST_TOPK_SOFTMAX_VERBOSE 0 #define TEST_TOPK_SOFTMAX_VERBOSE 1
// #endif #endif
// #define BLOCK_SIZE 256
template <typename T> template <typename T>
void dump_host_tensor_2d(const ck_tile::HostTensor<T>& x) void dump_host_tensor_2d(const ck_tile::HostTensor<T>& x)
...@@ -28,7 +26,7 @@ void dump_host_tensor_2d(const ck_tile::HostTensor<T>& x) ...@@ -28,7 +26,7 @@ void dump_host_tensor_2d(const ck_tile::HostTensor<T>& x)
std::cout << "["; std::cout << "[";
for(size_t i = 0; i < len[0]; i++) for(size_t i = 0; i < len[0]; i++)
{ {
std::cout << "["; std::cout << i << ": [";
for(size_t j = 0; j < len[1]; j++) for(size_t j = 0; j < len[1]; j++)
{ {
if constexpr(std::is_same_v<T, ck_tile::fp16_t>) if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
...@@ -121,6 +119,7 @@ auto create_args(int argc, char* argv[]) ...@@ -121,6 +119,7 @@ auto create_args(int argc, char* argv[])
.insert("t", "32", "number of input tokens") .insert("t", "32", "number of input tokens")
.insert("e", "8", "number of experts") .insert("e", "8", "number of experts")
.insert("k", "2", "topk") .insert("k", "2", "topk")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "t to 1 will print kernel name"); .insert("kname", "0", "t to 1 will print kernel name");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
...@@ -136,17 +135,41 @@ bool test_topk_softmax(ck_tile::ArgParser args) ...@@ -136,17 +135,41 @@ bool test_topk_softmax(ck_tile::ArgParser args)
int tokens = args.get_int("t"); int tokens = args.get_int("t");
int experts = args.get_int("e"); int experts = args.get_int("e");
int topk = args.get_int("k"); int topk = args.get_int("k");
int seed = args.get_int("seed");
if(seed < 0)
{
seed = std::time(nullptr);
}
// int kname = args.get_int("kname"); // int kname = args.get_int("kname");
// int warmup = args.get_int("warmup"); // int warmup = args.get_int("warmup");
// int repeat = args.get_int("repeat"); // int repeat = args.get_int("repeat");
std::srand(std::time(nullptr));
if(topk > experts)
{
#if TEST_TOPK_SOFTMAX_VERBOSE
printf("topk:%d should smaller than (or equal to) experts:%d\n", topk, experts);
#endif
return false;
}
// tokens already considered batch size // tokens already considered batch size
ck_tile::HostTensor<InputType> x_host({tokens, experts}); ck_tile::HostTensor<InputType> x_host({tokens, experts});
ck_tile::HostTensor<WeightType> value_host({tokens, topk}); ck_tile::HostTensor<WeightType> value_host({tokens, topk});
ck_tile::HostTensor<IndexType> index_host({tokens, topk}); ck_tile::HostTensor<IndexType> index_host({tokens, topk});
ck_tile::FillUniformDistribution<InputType>{-5.f, 5.f}(x_host); {
// random require per-row unique
auto rand_gen = ck_tile::FillUniformDistribution_Unique<InputType>{
-5.f, 5.f, static_cast<uint32_t>(seed)};
for(int i_t = 0; i_t < tokens; i_t++)
{
ck_tile::HostTensor<InputType> x_row({experts});
rand_gen(x_row);
std::copy(x_row.begin(), x_row.end(), x_host.begin() + i_t * experts);
rand_gen.clear();
}
}
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem value_dev(value_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem value_dev(value_host.get_element_space_size_in_bytes());
...@@ -173,9 +196,21 @@ bool test_topk_softmax(ck_tile::ArgParser args) ...@@ -173,9 +196,21 @@ bool test_topk_softmax(ck_tile::ArgParser args)
return a_; return a_;
}(); }();
#if TEST_TOPK_SOFTMAX_VERBOSE
ck_tile::stream_config sc{nullptr, true};
auto ms = topk_softmax(trait, karg, sc);
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, ms:%f, ",
input_prec.c_str(),
weight_prec.c_str(),
tokens,
experts,
topk,
ms);
fflush(stdout);
#else
ck_tile::stream_config sc{nullptr}; ck_tile::stream_config sc{nullptr};
topk_softmax(trait, karg, sc); topk_softmax(trait, karg, sc);
#endif
value_dev.FromDevice(value_host.data()); value_dev.FromDevice(value_host.data());
index_dev.FromDevice(index_host.data()); index_dev.FromDevice(index_host.data());
...@@ -195,7 +230,10 @@ bool test_topk_softmax(ck_tile::ArgParser args) ...@@ -195,7 +230,10 @@ bool test_topk_softmax(ck_tile::ArgParser args)
rtn &= ck_tile::check_err( rtn &= ck_tile::check_err(
index_host, index_ref, std::string("Index Error: Incorrect results!"), rtol, atol); index_host, index_ref, std::string("Index Error: Incorrect results!"), rtol, atol);
} }
#if TEST_TOPK_SOFTMAX_VERBOSE
printf("valid:%s\n", rtn ? "y" : "n");
fflush(stdout);
#endif
return rtn; return rtn;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment