Commit eed60199 authored by carlushuang's avatar carlushuang
Browse files

more robust api

parent cae751d1
...@@ -30,12 +30,13 @@ struct BlockTopkStream2D ...@@ -30,12 +30,13 @@ struct BlockTopkStream2D
template <typename DistributedTensor, typename OutWindow, typename IdxWindow, index_t dim = 1> template <typename DistributedTensor, typename OutWindow, typename IdxWindow, index_t dim = 1>
CK_TILE_DEVICE void operator()(const DistributedTensor& x, CK_TILE_DEVICE void operator()(const DistributedTensor& x,
OutWindow& out_window, const OutWindow& out_window,
IdxWindow& idx_window, const IdxWindow& idx_window,
index_t k, index_t k,
number<dim> = {}) number<dim> = {})
{ {
// static_assert(OutWindow::get_window_lengths()[number<1>] == 1); OutWindow out_window_tmp = out_window;
IdxWindow idx_window_tmp = idx_window;
static_assert( static_assert(
std::is_same_v<typename DistributedTensor::DataType, typename OutWindow::DataType> && std::is_same_v<typename DistributedTensor::DataType, typename OutWindow::DataType> &&
std::is_same_v<typename DistributedTensor::DataType, DataType>); std::is_same_v<typename DistributedTensor::DataType, DataType>);
...@@ -100,11 +101,11 @@ struct BlockTopkStream2D ...@@ -100,11 +101,11 @@ struct BlockTopkStream2D
if(threadIdx.x % Problem::ColLanes == 0) if(threadIdx.x % Problem::ColLanes == 0)
{ {
store_tile(out_window, o); store_tile(out_window_tmp, o);
store_tile(idx_window, i); store_tile(idx_window_tmp, i);
} }
move_tile_window(out_window, {number<0>{}, number<1>{}}); move_tile_window(out_window_tmp, {number<0>{}, number<1>{}});
move_tile_window(idx_window, {number<0>{}, number<1>{}}); move_tile_window(idx_window_tmp, {number<0>{}, number<1>{}});
} }
} }
}; };
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp" #include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string> #include <string>
#include <type_traits> #include <type_traits>
...@@ -19,6 +20,8 @@ struct TopkSoftmaxHostArgs ...@@ -19,6 +20,8 @@ struct TopkSoftmaxHostArgs
index_t num_rows; index_t num_rows;
index_t num_experts; index_t num_experts;
index_t topk; index_t topk;
index_t stride_input; // row stride for input, at least experts
index_t stride_output; // row stride for output/indices, at least tpok
}; };
template <typename Pipeline_> template <typename Pipeline_>
...@@ -39,18 +42,34 @@ struct TopkSoftmaxKernel ...@@ -39,18 +42,34 @@ struct TopkSoftmaxKernel
index_t num_rows; index_t num_rows;
index_t num_experts; index_t num_experts;
index_t topk; index_t topk;
index_t stride_input; // row stride for input, at least experts
index_t stride_output; // row stride for output/indices, at least tpok
}; };
using Kargs = TopkSoftmaxKargs; using Kargs = TopkSoftmaxKargs;
using Hargs = TopkSoftmaxHostArgs; using Hargs = TopkSoftmaxHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
if constexpr(Problem::LaunchType > 0)
{
int num_cu = [&]() {
hipDeviceProp_t dev_prop;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}();
return dim3(num_cu * Problem::LaunchType);
}
else
{ {
const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp; const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp;
const int num_blocks = (num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock; const int num_blocks =
(num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock;
return dim3(num_blocks); return dim3(num_blocks);
} }
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{ {
...@@ -61,6 +80,8 @@ struct TopkSoftmaxKernel ...@@ -61,6 +80,8 @@ struct TopkSoftmaxKernel
k.num_rows = h.num_rows; k.num_rows = h.num_rows;
k.num_experts = h.num_experts; k.num_experts = h.num_experts;
k.topk = h.topk; k.topk = h.topk;
k.stride_input = h.stride_input;
k.stride_output = h.stride_output;
return k; return k;
} }
...@@ -69,18 +90,29 @@ struct TopkSoftmaxKernel ...@@ -69,18 +90,29 @@ struct TopkSoftmaxKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock); index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock);
if(block_row_id > kargs.num_rows)
return;
index_t block_os_inp = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_input);
index_t block_os_out = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_output);
index_t num_rows_rem = __builtin_amdgcn_readfirstlane(kargs.num_rows - block_row_id);
const auto input_window = [&]() { const auto input_window = [&]() {
const InputType* p_input = reinterpret_cast<const InputType*>(kargs.p_input) + const InputType* p_input =
block_row_id * kargs.num_experts; reinterpret_cast<const InputType*>(kargs.p_input) + block_os_inp;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>( auto tmp = make_naive_tensor_view<address_space_enum::global>(
p_input, p_input,
make_tuple(kargs.num_rows, kargs.num_experts), make_tuple(num_rows_rem, kargs.num_experts),
number<Problem::VectorSize>{}); make_tuple(kargs.stride_input, 1),
number<Problem::VectorSize>{},
number<1>{});
auto view = pad_tensor_view( auto view = pad_tensor_view(
tmp, tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}), make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
sequence<1, 1>{}); sequence<0, 1>{}); // out-most dim no need pad(leverage oob)
return make_tile_window( return make_tile_window(
view, view,
...@@ -89,29 +121,46 @@ struct TopkSoftmaxKernel ...@@ -89,29 +121,46 @@ struct TopkSoftmaxKernel
}(); }();
auto output_window = [&]() { auto output_window = [&]() {
WeightType* p_output = WeightType* p_output = reinterpret_cast<WeightType*>(kargs.p_output) + block_os_out;
reinterpret_cast<WeightType*>(kargs.p_output) + block_row_id * kargs.topk; auto tmp = make_naive_tensor_view<address_space_enum::global>(
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>( p_output,
p_output, make_tuple(kargs.num_rows, kargs.topk), number<Problem::VectorSize>{}); make_tuple(num_rows_rem, kargs.topk),
auto view = pad_tensor_view( make_tuple(kargs.stride_output, 1),
tmp, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), sequence<1, 0>{}); number<Problem::VectorSize>{},
number<1>{});
auto view =
pad_tensor_view(tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return make_tile_window( return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0}); view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
}(); }();
auto indices_window = [&]() { auto indices_window = [&]() {
IndexType* p_indices = IndexType* p_indices = reinterpret_cast<IndexType*>(kargs.p_indices) + block_os_out;
reinterpret_cast<IndexType*>(kargs.p_indices) + block_row_id * kargs.topk; auto tmp = make_naive_tensor_view<address_space_enum::global>(
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>( p_indices,
p_indices, make_tuple(kargs.num_rows, kargs.topk), number<Problem::VectorSize>{}); make_tuple(num_rows_rem, kargs.topk),
auto view = pad_tensor_view( make_tuple(kargs.stride_output, 1),
tmp, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), sequence<1, 0>{}); number<Problem::VectorSize>{},
number<1>{});
auto view =
pad_tensor_view(tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return make_tile_window( return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0}); view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
}(); }();
Pipeline{}(input_window, output_window, indices_window, kargs.topk, kargs.num_experts); Pipeline{}(input_window,
output_window,
indices_window,
kargs.num_rows,
kargs.num_experts,
kargs.topk,
block_row_id);
} }
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -8,6 +8,10 @@ ...@@ -8,6 +8,10 @@
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace ck_tile { namespace ck_tile {
template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy> template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
...@@ -22,16 +26,42 @@ struct TopkSoftmaxWarpPerRowPipeline ...@@ -22,16 +26,42 @@ struct TopkSoftmaxWarpPerRowPipeline
CK_TILE_DEVICE auto operator()(const InputWindow& input_window, CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
OutputWindow& out_window, OutputWindow& out_window,
IndexWindow& idx_window, IndexWindow& idx_window,
index_t rows,
index_t experts,
index_t k, index_t k,
index_t experts) index_t block_row_id)
{ {
auto input_win = make_tile_window(input_window.get_bottom_tensor_view(), #if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
input_window.get_window_lengths(), auto inp_win = make_tile_window_linear_raw(
input_window.get_window_origin(), input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
Policy::template MakeInputDistribution<Problem>()); #else
auto inp_win = make_tile_window_linear(
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
#endif
auto out_win = make_tile_window(out_window.get_bottom_tensor_view(),
out_window.get_window_lengths(),
out_window.get_window_origin(),
Policy::template MakeOutputDistribution<Problem>());
auto idx_win = make_tile_window(idx_window.get_bottom_tensor_view(),
idx_window.get_window_lengths(),
idx_window.get_window_origin(),
Policy::template MakeOutputDistribution<Problem>());
auto softmax = Policy::template GetSoftmax<Problem>();
auto topk = Policy::template GetTopk<Problem>();
auto x = load_tile(input_win); const index_t grid_rows_per_loop = gridDim.x * Problem::RowsPerBlock;
while(1)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier(0);
auto x = load_tile_raw(inp_win, bool_constant<true>{}, bool_constant<true>{});
buffer_load_fence(number<0>{});
__builtin_amdgcn_sched_barrier(0);
#else
auto x = load_tile(inp_win);
#endif
// cast and pad input data // cast and pad input data
auto w = [&]() { auto w = [&]() {
auto w_ = cast_tile<WeightType>(x); auto w_ = cast_tile<WeightType>(x);
...@@ -40,34 +70,38 @@ struct TopkSoftmaxWarpPerRowPipeline ...@@ -40,34 +70,38 @@ struct TopkSoftmaxWarpPerRowPipeline
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = const auto x_indices = get_x_indices_from_distributed_indices(
get_x_indices_from_distributed_indices(w_.get_tile_distribution(), i_j_idx); w_.get_tile_distribution(), i_j_idx);
const auto current_expert = x_indices.at(number<1>{}); const auto current_expert = x_indices.at(number<1>{});
// set to -INF if OOB so that later softmax can work properly // set to -INF if OOB so that later softmax can work properly
w_(i_j_idx) = w_(i_j_idx) = current_expert >= experts ? -numeric<WeightType>::infinity()
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(i_j_idx); : w_(i_j_idx);
}); });
}); });
return w_; return w_;
}(); }();
auto softmax = Policy::template GetSoftmax<Problem>();
// softmax // softmax
auto y = softmax(w); auto y = softmax(w);
auto topk = Policy::template GetTopk<Problem>(); topk(y, out_win, idx_win, k);
auto out_win = make_tile_window(out_window.get_bottom_tensor_view(), // check exit
out_window.get_window_lengths(), if constexpr(Problem::LaunchType == 0)
out_window.get_window_origin(), {
Policy::template MakeOutputDistribution<Problem>()); break;
auto idx_win = make_tile_window(idx_window.get_bottom_tensor_view(), }
idx_window.get_window_lengths(), else
idx_window.get_window_origin(), {
Policy::template MakeOutputDistribution<Problem>()); block_row_id += grid_rows_per_loop;
if(block_row_id >= rows)
break;
}
topk(y, out_win, idx_win, k); move_tile_window(inp_win, {grid_rows_per_loop, number<0>{}});
move_tile_window(out_win, {grid_rows_per_loop, number<0>{}});
move_tile_window(idx_win, {grid_rows_per_loop, number<0>{}});
}
} }
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -18,7 +18,9 @@ struct TopkSoftmaxWarpPerRowPolicy ...@@ -18,7 +18,9 @@ struct TopkSoftmaxWarpPerRowPolicy
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<1>, sequence<1>,
tuple<sequence<Problem::IssuesPerCol, Problem::WarpsPerBlock, Problem::RowsPerWarp>, tuple<sequence<Problem::IssuesPerCol,
Problem::WarpsPerBlock,
Problem::RowsPerWarpPerColIssue>,
sequence<Problem::IssuesPerRow, Problem::LanesPerRow, Problem::VectorSize>>, sequence<Problem::IssuesPerRow, Problem::LanesPerRow, Problem::VectorSize>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>, tuple<sequence<1>, sequence<2, 1>>,
...@@ -31,12 +33,14 @@ struct TopkSoftmaxWarpPerRowPolicy ...@@ -31,12 +33,14 @@ struct TopkSoftmaxWarpPerRowPolicy
{ {
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<Problem::LanesPerRow>, // repeat this one tile_distribution_encoding<sequence<Problem::LanesPerRow>, // repeat this one
tuple<sequence<Problem::WarpsPerBlock, Problem::RowsPerWarp>, tuple<sequence<Problem::IssuesPerCol,
Problem::WarpsPerBlock,
Problem::RowsPerWarpPerColIssue>,
sequence<1>>, // each row write out single element sequence<1>>, // each row write out single element
tuple<sequence<1>, sequence<1, 0>>, tuple<sequence<1>, sequence<1, 0>>,
tuple<sequence<0>, sequence<1, 0>>, tuple<sequence<1>, sequence<2, 0>>,
sequence<2>, sequence<1, 2>,
sequence<0>>{}); sequence<0, 0>>{});
} }
template <typename Problem> template <typename Problem>
......
...@@ -13,8 +13,9 @@ template <typename InputType_, ...@@ -13,8 +13,9 @@ template <typename InputType_,
typename WeightType_, typename WeightType_,
typename IndexType_, typename IndexType_,
index_t Experts_, index_t Experts_,
index_t IssuesPerCol_ = 1, // issue along col, to make sure block_reduce() OK index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
index_t BytesPerIssue_ = sizeof(InputType_), index_t BytesPerIssue_ = sizeof(InputType_),
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
index_t BlockSize_ = 256> index_t BlockSize_ = 256>
struct TopkSoftmaxWarpPerRowProblem struct TopkSoftmaxWarpPerRowProblem
{ {
...@@ -23,8 +24,10 @@ struct TopkSoftmaxWarpPerRowProblem ...@@ -23,8 +24,10 @@ struct TopkSoftmaxWarpPerRowProblem
using WeightType = remove_cvref_t<WeightType_>; using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>; using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t LaunchType = LaunchType_;
static constexpr index_t Experts = Experts_; static constexpr index_t Experts = Experts_;
static constexpr index_t BytesPerIssue = BytesPerIssue_; static constexpr index_t BytesPerIssue = BytesPerIssue_;
static constexpr index_t IssuesPerCol = IssuesPerCol_;
static constexpr index_t BlockSize = BlockSize_; static constexpr index_t BlockSize = BlockSize_;
static constexpr index_t WarpSize = get_warp_size(); static constexpr index_t WarpSize = get_warp_size();
...@@ -33,11 +36,10 @@ struct TopkSoftmaxWarpPerRowProblem ...@@ -33,11 +36,10 @@ struct TopkSoftmaxWarpPerRowProblem
static_assert(Experts % VectorSize == 0); static_assert(Experts % VectorSize == 0);
static constexpr index_t LanesPerRow = min(Experts / VectorSize, WarpSize); static constexpr index_t LanesPerRow = min(Experts / VectorSize, WarpSize);
static_assert(WarpSize % LanesPerRow == 0); static_assert(WarpSize % LanesPerRow == 0);
static constexpr index_t RowsPerWarp = WarpSize / LanesPerRow; static constexpr index_t RowsPerWarpPerColIssue = WarpSize / LanesPerRow;
static constexpr index_t RowsPerWarp = IssuesPerCol * RowsPerWarpPerColIssue;
static constexpr index_t IssuesPerRow = Experts / (LanesPerRow * VectorSize); static constexpr index_t IssuesPerRow = Experts / (LanesPerRow * VectorSize);
static constexpr index_t IssuesPerCol = IssuesPerCol_;
static constexpr index_t WarpsPerBlock = BlockSize / WarpSize; static constexpr index_t WarpsPerBlock = BlockSize / WarpSize;
static constexpr index_t RowsPerBlock = RowsPerWarp * WarpsPerBlock; static constexpr index_t RowsPerBlock = RowsPerWarp * WarpsPerBlock;
}; };
......
#!/bin/sh
EXE=./build/bin/test_topk_softmax
for pr_i in "fp16" "bf16" ; do
$EXE -pr_i=$pr_i -t=80 -e=17
$EXE -pr_i=$pr_i -t=111 -e=117
$EXE -pr_i=$pr_i -t=1000 -e=55
$EXE -pr_i=$pr_i -t=99 -e=180
$EXE -pr_i=$pr_i -t=175 -e=64 -k=8
$EXE -pr_i=$pr_i -t=65 -e=8 -k=2
$EXE -pr_i=$pr_i -t=1 -e=25
$EXE -pr_i=$pr_i -t=31 -e=19 -k=15
$EXE -pr_i=$pr_i -t=81 -e=37 -k=7
$EXE -pr_i=$pr_i -t=199 -e=128 -k=13
$EXE -pr_i=$pr_i -t=23 -e=1 -k=1
$EXE -pr_i=$pr_i -t=127 -e=99 -k=19 -st_i=233 -st_o=31
$EXE -pr_i=$pr_i -t=71 -e=11 -k=11 -st_i=30 -st_o=12
$EXE -pr_i=$pr_i -t=1 -e=1 -k=1
$EXE -pr_i=$pr_i -t=99 -e=2 -k=1 -st_i=11 -st_o=5
$EXE -pr_i=$pr_i -t=333 -e=99 -k=13 -st_i=191 -st_o=17
done
...@@ -18,6 +18,11 @@ ...@@ -18,6 +18,11 @@
#define TEST_TOPK_SOFTMAX_VERBOSE 1 #define TEST_TOPK_SOFTMAX_VERBOSE 1
#endif #endif
// set this to 1 if input/output have stride
#ifndef TEST_TOPK_VERIFY_PER_TOKEN
#define TEST_TOPK_VERIFY_PER_TOKEN 1
#endif
template <typename T> template <typename T>
void dump_host_tensor_2d(const ck_tile::HostTensor<T>& x) void dump_host_tensor_2d(const ck_tile::HostTensor<T>& x)
{ {
...@@ -62,19 +67,32 @@ auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x, ...@@ -62,19 +67,32 @@ auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
{ {
using namespace ck_tile; using namespace ck_tile;
// dump_host_tensor_2d(x);
auto y = reference_softmax<InputType, WeightType, WeightType>(x, dim); auto y = reference_softmax<InputType, WeightType, WeightType>(x, dim);
// dump_host_tensor_2d(y);
auto [y_values, y_indices] = reference_topk(y, k, dim, largest, sorted); auto [y_values, y_indices] = reference_topk(y, k, dim, largest, sorted);
// dump_host_tensor_2d(y_values);
// dump_host_tensor_2d(y_indices);
return ck_tile::make_tuple(y_values, y_indices); return ck_tile::make_tuple(y_values, y_indices);
} }
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
ck_tile::HostTensor<WeightType>& y_values,
ck_tile::HostTensor<IndexType>& y_indices,
ck_tile::index_t k,
ck_tile::index_t dim = -1,
bool largest = true,
bool sorted = true)
{
using namespace ck_tile;
// dump_host_tensor_2d(x);
auto y = reference_softmax<InputType, WeightType, WeightType>(x, dim);
// dump_host_tensor_2d(y);
reference_topk(y, y_values, y_indices, k, dim, largest, sorted);
}
// different threshold for different dtype // different threshold for different dtype
template <typename DataType> template <typename DataType>
auto get_elimit(std::string /*init_method*/) auto get_elimit(std::string /*init_method*/)
...@@ -113,12 +131,13 @@ auto create_args(int argc, char* argv[]) ...@@ -113,12 +131,13 @@ auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not") arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert( .insert("pr_i", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
"input_prec", "fp16", "input data type. fp8/fp16/fp32 (representing 8/16/32 bit data)") .insert("pr_w", "fp32", "weight data type(currently only fp32 supported now)")
.insert("weight_prec", "fp32", "weight data type")
.insert("t", "32", "number of input tokens") .insert("t", "32", "number of input tokens")
.insert("e", "8", "number of experts") .insert("e", "8", "number of experts")
.insert("k", "2", "topk") .insert("k", "2", "topk")
.insert("st_i", "-1", "row stride of input, -1 means same as experts")
.insert("st_o", "-1", "row stride of output/indices, -1 means same as topk")
.insert("seed", "-1", "seed to be used, -1 means random every time") .insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "t to 1 will print kernel name"); .insert("kname", "0", "t to 1 will print kernel name");
...@@ -130,12 +149,25 @@ template <typename InputType, typename WeightType, typename IndexType = ck_tile: ...@@ -130,12 +149,25 @@ template <typename InputType, typename WeightType, typename IndexType = ck_tile:
bool test_topk_softmax(ck_tile::ArgParser args) bool test_topk_softmax(ck_tile::ArgParser args)
{ {
int validate = args.get_int("v"); int validate = args.get_int("v");
std::string input_prec = args.get_str("input_prec"); std::string input_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("weight_prec"); std::string weight_prec = args.get_str("pr_w");
int tokens = args.get_int("t"); int tokens = args.get_int("t");
int experts = args.get_int("e"); int experts = args.get_int("e");
int topk = args.get_int("k"); int topk = args.get_int("k");
int seed = args.get_int("seed"); int seed = args.get_int("seed");
int stride_input = args.get_int("st_i");
int stride_output = args.get_int("st_o");
if(stride_input < 0)
{
stride_input = experts;
}
if(stride_output < 0)
{
stride_output = topk;
}
assert(stride_input >= experts);
assert(stride_output >= topk);
if(seed < 0) if(seed < 0)
{ {
seed = std::time(nullptr); seed = std::time(nullptr);
...@@ -153,9 +185,9 @@ bool test_topk_softmax(ck_tile::ArgParser args) ...@@ -153,9 +185,9 @@ bool test_topk_softmax(ck_tile::ArgParser args)
} }
// tokens already considered batch size // tokens already considered batch size
ck_tile::HostTensor<InputType> x_host({tokens, experts}); ck_tile::HostTensor<InputType> x_host({tokens, experts}, {stride_input, 1});
ck_tile::HostTensor<WeightType> value_host({tokens, topk}); ck_tile::HostTensor<WeightType> value_host({tokens, topk}, {stride_output, 1});
ck_tile::HostTensor<IndexType> index_host({tokens, topk}); ck_tile::HostTensor<IndexType> index_host({tokens, topk}, {stride_output, 1});
{ {
// random require per-row unique // random require per-row unique
...@@ -166,7 +198,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) ...@@ -166,7 +198,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
{ {
ck_tile::HostTensor<InputType> x_row({experts}); ck_tile::HostTensor<InputType> x_row({experts});
rand_gen(x_row); rand_gen(x_row);
std::copy(x_row.begin(), x_row.end(), x_host.begin() + i_t * experts); std::copy(x_row.begin(), x_row.end(), x_host.begin() + i_t * stride_input);
rand_gen.clear(); rand_gen.clear();
} }
} }
...@@ -193,24 +225,35 @@ bool test_topk_softmax(ck_tile::ArgParser args) ...@@ -193,24 +225,35 @@ bool test_topk_softmax(ck_tile::ArgParser args)
a_.num_rows = tokens; a_.num_rows = tokens;
a_.num_experts = experts; a_.num_experts = experts;
a_.topk = topk; a_.topk = topk;
a_.stride_input = stride_input;
a_.stride_output = stride_output;
return a_; return a_;
}(); }();
#if TEST_TOPK_SOFTMAX_VERBOSE #if TEST_TOPK_SOFTMAX_VERBOSE
ck_tile::stream_config sc{nullptr, true}; ck_tile::stream_config sc{nullptr, true};
// ck_tile::stream_config sc{nullptr};
auto ms = topk_softmax(trait, karg, sc); auto ms = topk_softmax(trait, karg, sc);
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, ms:%f, ", printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ",
input_prec.c_str(), input_prec.c_str(),
weight_prec.c_str(), weight_prec.c_str(),
tokens, tokens,
experts, experts,
topk, topk,
stride_input,
stride_output,
ms); ms);
if(ms < 0)
printf("not supported\n");
fflush(stdout); fflush(stdout);
#else #else
ck_tile::stream_config sc{nullptr}; ck_tile::stream_config sc{nullptr};
topk_softmax(trait, karg, sc); auto ms = topk_softmax(trait, karg, sc);
#endif #endif
if(ms < 0)
{
return false;
}
value_dev.FromDevice(value_host.data()); value_dev.FromDevice(value_host.data());
index_dev.FromDevice(index_host.data()); index_dev.FromDevice(index_host.data());
...@@ -218,17 +261,44 @@ bool test_topk_softmax(ck_tile::ArgParser args) ...@@ -218,17 +261,44 @@ bool test_topk_softmax(ck_tile::ArgParser args)
bool rtn = true; bool rtn = true;
if(validate) if(validate)
{ {
ck_tile::HostTensor<WeightType> value_host_ref({tokens, topk}); // this host buffer will not copy to GPU, so no need use stride
ck_tile::HostTensor<IndexType> index_host_ref({tokens, topk}); ck_tile::HostTensor<WeightType> value_ref({tokens, topk}, {stride_output, 1});
ck_tile::HostTensor<IndexType> index_ref({tokens, topk}, {stride_output, 1});
auto [value_ref, index_ref] = // auto [value_ref, index_ref] =
reference_topk_softmax<InputType, WeightType, IndexType>(x_host, topk); reference_topk_softmax<InputType, WeightType, IndexType>(
x_host, value_ref, index_ref, topk);
auto [rtol, atol] = get_elimit<InputType>(""); auto [rtol, atol] = get_elimit<InputType>("");
#if TEST_TOPK_VERIFY_PER_TOKEN
for(int i_t = 0; i_t < tokens; i_t++)
{
auto s_begin = std::vector<size_t>{static_cast<size_t>(i_t), static_cast<size_t>(0)};
auto s_end =
std::vector<size_t>{static_cast<size_t>(i_t + 1), static_cast<size_t>(topk)};
auto s_value_host = value_host.slice(s_begin, s_end);
auto s_value_ref = value_ref.slice(s_begin, s_end);
rtn &= ck_tile::check_err(s_value_host,
s_value_ref,
std::string("[") + std::to_string(i_t) +
std::string("] Value Error:"),
rtol,
atol);
auto s_index_host = index_host.slice(s_begin, s_end);
auto s_index_ref = index_ref.slice(s_begin, s_end);
rtn &= ck_tile::check_err(s_index_host,
s_index_ref,
std::string("[") + std::to_string(i_t) +
std::string("] Index Error:"),
rtol,
atol);
}
#else
rtn &= ck_tile::check_err( rtn &= ck_tile::check_err(
value_host, value_ref, std::string("Value Error: Incorrect results!"), rtol, atol); value_host, value_ref, std::string("Value Error: Incorrect results!"), rtol, atol);
rtn &= ck_tile::check_err( rtn &= ck_tile::check_err(
index_host, index_ref, std::string("Index Error: Incorrect results!"), rtol, atol); index_host, index_ref, std::string("Index Error: Incorrect results!"), rtol, atol);
#endif
} }
#if TEST_TOPK_SOFTMAX_VERBOSE #if TEST_TOPK_SOFTMAX_VERBOSE
printf("valid:%s\n", rtn ? "y" : "n"); printf("valid:%s\n", rtn ? "y" : "n");
...@@ -242,8 +312,8 @@ int main(int argc, char** argv) ...@@ -242,8 +312,8 @@ int main(int argc, char** argv)
auto [result, args] = create_args(argc, argv); auto [result, args] = create_args(argc, argv);
if(!result) if(!result)
return -1; return -1;
std::string input_prec = args.get_str("input_prec"); std::string input_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("weight_prec"); std::string weight_prec = args.get_str("pr_w");
bool r = true; bool r = true;
if(input_prec.compare("fp16") == 0 && weight_prec.compare("fp32") == 0) if(input_prec.compare("fp16") == 0 && weight_prec.compare("fp32") == 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment