Commit 667047b9 authored by carlushuang's avatar carlushuang
Browse files

topk-softmax

parent 840cba8e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename InputType_,
typename WeightType_,
typename IndexType_,
index_t Experts_,
index_t IssuesPerCol_ = 1, // issue along col, to make sure block_reduce() OK
index_t BytesPerIssue_ = sizeof(InputType_),
index_t BlockSize_ = 256>
struct TopkSoftmaxWarpPerRowProblem
{
// TODO: this kernel only support warp per row
using InputType = remove_cvref_t<InputType_>;
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t Experts = Experts_;
static constexpr index_t BytesPerIssue = BytesPerIssue_;
static constexpr index_t BlockSize = BlockSize_;
static constexpr index_t WarpSize = get_warp_size();
static_assert(BytesPerIssue % sizeof(InputType) == 0);
static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType);
static_assert(Experts % VectorSize == 0);
static constexpr index_t LanesPerRow = min(Experts / VectorSize, WarpSize);
static_assert(WarpSize % LanesPerRow == 0);
static constexpr index_t RowsPerWarp = WarpSize / LanesPerRow;
static constexpr index_t IssuesPerRow = Experts / (LanesPerRow * VectorSize);
static constexpr index_t IssuesPerCol = IssuesPerCol_;
static constexpr index_t WarpsPerBlock = BlockSize / WarpSize;
static constexpr index_t RowsPerBlock = RowsPerWarp * WarpsPerBlock;
};
} // namespace ck_tile
...@@ -219,5 +219,6 @@ endif() ...@@ -219,5 +219,6 @@ endif()
add_subdirectory(position_embedding) add_subdirectory(position_embedding)
add_subdirectory(scatter_gather) add_subdirectory(scatter_gather)
add_subdirectory(topk) add_subdirectory(topk)
add_subdirectory(topk_softmax)
add_subdirectory(tile_reduce) add_subdirectory(tile_reduce)
...@@ -50,11 +50,11 @@ __global__ void reduce_row(DataType* p_src, DataType* p_dst) ...@@ -50,11 +50,11 @@ __global__ void reduce_row(DataType* p_src, DataType* p_dst)
constexpr auto src_dist = make_static_tile_distribution( constexpr auto src_dist = make_static_tile_distribution(
tile_distribution_encoding< tile_distribution_encoding<
sequence<1>, sequence<1>,
tuple<sequence<row_repeat, num_warps, row_lanes>, sequence<col_lanes, vec>>, tuple<sequence<row_repeat, num_warps, row_lanes>, sequence<1, col_lanes, vec>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2>, sequence<1, 2, 2>,
sequence<0, 1>>{}); sequence<0, 0, 2>>{});
auto src_view = auto src_view =
make_naive_tensor_view<address_space_enum::global>(p_src, make_naive_tensor_view<address_space_enum::global>(p_src,
...@@ -98,7 +98,7 @@ __global__ void reduce_row(DataType* p_src, DataType* p_dst) ...@@ -98,7 +98,7 @@ __global__ void reduce_row(DataType* p_src, DataType* p_dst)
block_tile_reduce<DataType>(data, sequence<1>{}, f_max, -numeric<DataType>::infinity()); block_tile_reduce<DataType>(data, sequence<1>{}, f_max, -numeric<DataType>::infinity());
// further reduce cross thread, Note Now the HLength of r is 1D // further reduce cross thread, Note Now the HLength of r is 1D
block_tile_reduce_sync(r, f_max, bool_constant<false>{}); block_tile_reduce_xor_sync(r, f_max);
if(threadIdx.x % col_lanes == 0) if(threadIdx.x % col_lanes == 0)
{ {
...@@ -205,7 +205,7 @@ __global__ void reduce_row_argmax(DataType* p_src, DataType* p_dst, int* p_idx) ...@@ -205,7 +205,7 @@ __global__ void reduce_row_argmax(DataType* p_src, DataType* p_dst, int* p_idx)
auto r = block_tile_reduce<kv>(kv_data, sequence<1>{}, f_arg_max, arg_max_init); auto r = block_tile_reduce<kv>(kv_data, sequence<1>{}, f_arg_max, arg_max_init);
// further reduce cross thread, Note Now the HLength of r is 1D // further reduce cross thread, Note Now the HLength of r is 1D
block_tile_reduce_sync(r, f_arg_max, bool_constant<false>{}); block_tile_reduce_xor_sync(r, f_arg_max);
auto o = make_static_distributed_tensor<DataType>(dst_dist); auto o = make_static_distributed_tensor<DataType>(dst_dist);
auto i = make_static_distributed_tensor<int>(dst_dist); auto i = make_static_distributed_tensor<int>(dst_dist);
...@@ -368,7 +368,7 @@ int main() ...@@ -368,7 +368,7 @@ int main()
{ {
bool r = true; bool r = true;
r &= test_tile_reduce<32, 64, float>(); r &= test_tile_reduce<32, 64, float>();
r &= test_tile_reduce<32, 16, float, 4>(); r &= test_tile_reduce<32, 8, float, 4>();
r &= test_tile_reduce<32, 16, ck_tile::fp16_t, 4>(); r &= test_tile_reduce<32, 16, ck_tile::fp16_t, 4>();
r &= test_tile_reduce_argmax<32, 16, float, 4>(); r &= test_tile_reduce_argmax<32, 16, float, 4>();
......
add_test_executable(test_topk_softmax topk_softmax.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../example/ck_tile/05_moe/topk_softmax_api.cpp)
target_include_directories(test_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../example/ck_tile/05_moe)
target_compile_options(test_topk_softmax PRIVATE -v --save-temps -Wno-gnu-line-marker)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "topk_softmax_api.hpp"
// #ifndef TEST_TOPK_SOFTMAX_VERBOSE
// #define TEST_TOPK_SOFTMAX_VERBOSE 0
// #endif
// #define BLOCK_SIZE 256
template <typename T>
void dump_host_tensor_2d(const ck_tile::HostTensor<T>& x)
{
auto len = x.get_lengths();
assert(len.size() == 2);
std::cout << "[";
for(size_t i = 0; i < len[0]; i++)
{
std::cout << "[";
for(size_t j = 0; j < len[1]; j++)
{
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
{
auto v = ck_tile::type_convert<float>(x(std::vector<std::size_t>{i, j}));
std::cout << v;
if(j != len[1] - 1)
std::cout << ",";
}
else
{
std::cout << x(std::vector<std::size_t>{i, j}) << " ";
}
}
std::cout << "]";
if(i != len[0] - 1)
std::cout << ",";
else
std::cout << "]";
std::cout << std::endl;
}
std::cout << "--------------------" << std::endl;
}
// CPU reference
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
ck_tile::index_t k,
ck_tile::index_t dim = -1,
bool largest = true,
bool sorted = true)
{
using namespace ck_tile;
// dump_host_tensor_2d(x);
auto y = reference_softmax<InputType, WeightType, WeightType>(x, dim);
// dump_host_tensor_2d(y);
auto [y_values, y_indices] = reference_topk(y, k, dim, largest, sorted);
// dump_host_tensor_2d(y_values);
// dump_host_tensor_2d(y_indices);
return ck_tile::make_tuple(y_values, y_indices);
}
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert(
"input_prec", "fp16", "input data type. fp8/fp16/fp32 (representing 8/16/32 bit data)")
.insert("weight_prec", "fp32", "weight data type")
.insert("t", "32", "number of input tokens")
.insert("e", "8", "number of experts")
.insert("k", "2", "topk")
.insert("kname", "0", "t to 1 will print kernel name");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
bool test_topk_softmax(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string input_prec = args.get_str("input_prec");
std::string weight_prec = args.get_str("weight_prec");
int tokens = args.get_int("t");
int experts = args.get_int("e");
int topk = args.get_int("k");
// int kname = args.get_int("kname");
// int warmup = args.get_int("warmup");
// int repeat = args.get_int("repeat");
std::srand(std::time(nullptr));
// tokens already considered batch size
ck_tile::HostTensor<InputType> x_host({tokens, experts});
ck_tile::HostTensor<WeightType> value_host({tokens, topk});
ck_tile::HostTensor<IndexType> index_host({tokens, topk});
ck_tile::FillUniformDistribution<InputType>{-5.f, 5.f}(x_host);
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem value_dev(value_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem index_dev(index_host.get_element_space_size_in_bytes());
x_dev.ToDevice(x_host.data());
topk_softmax_trait trait = [&]() {
topk_softmax_trait t_;
t_.input_type = input_prec;
t_.weight_type = weight_prec;
t_.experts = experts;
return t_;
}();
topk_softmax_kargs karg = [&]() {
topk_softmax_kargs a_;
a_.p_input = x_dev.GetDeviceBuffer();
a_.p_output = value_dev.GetDeviceBuffer();
a_.p_indices = index_dev.GetDeviceBuffer();
a_.num_rows = tokens;
a_.num_experts = experts;
a_.topk = topk;
return a_;
}();
ck_tile::stream_config sc{nullptr};
topk_softmax(trait, karg, sc);
value_dev.FromDevice(value_host.data());
index_dev.FromDevice(index_host.data());
bool rtn = true;
if(validate)
{
ck_tile::HostTensor<WeightType> value_host_ref({tokens, topk});
ck_tile::HostTensor<IndexType> index_host_ref({tokens, topk});
auto [value_ref, index_ref] =
reference_topk_softmax<InputType, WeightType, IndexType>(x_host, topk);
auto [rtol, atol] = get_elimit<InputType>("");
rtn &= ck_tile::check_err(
value_host, value_ref, std::string("Value Error: Incorrect results!"), rtol, atol);
rtn &= ck_tile::check_err(
index_host, index_ref, std::string("Index Error: Incorrect results!"), rtol, atol);
}
return rtn;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string input_prec = args.get_str("input_prec");
std::string weight_prec = args.get_str("weight_prec");
bool r = true;
if(input_prec.compare("fp16") == 0 && weight_prec.compare("fp32") == 0)
{
r &= test_topk_softmax<ck_tile::fp16_t, float, ck_tile::index_t>(args);
}
else if(input_prec.compare("bf16") == 0 && weight_prec.compare("fp32") == 0)
{
r &= test_topk_softmax<ck_tile::bf16_t, float, ck_tile::index_t>(args);
}
return r ? 0 : -1;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment