Unverified Commit bec6fbc6 authored by dummycoderfe's avatar dummycoderfe Committed by GitHub
Browse files

Ck tile/moe sorting (#1624)



* add moe_sorting & check ok

* fix comments & typo

* Run remod.py under include/ck_tile & example/ck_tile directories

* format codes

* fix output ci check bug

* fix moe sorting readme and error commit file

* use magiv div to accelerate compute

* add an loop unroll for moe lds ops

* add extblocksnel to set zeros for moebufs

* [Ck_tile] moe set zero run ok, add size check and fix ref check

* [Ck_tile]fix moe_sorting fuse set_zero remod

* [Ck_tile] change name style, fix zero buffer size err, change folder

* [Ck_tile] moe_sorting: fix name style

* [Ck_tile] moe_sorting, remove useless params in traits

* [Ck_tile] change outputtile cnt * unit_size; change output buf alloc

---------
Co-authored-by: default avatardummycoderfe <noplydummmycoder@163.com>
Co-authored-by: default avatarPo Yen, Chen <PoYen.Chen@amd.com>
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
parent af9546d9
add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp)
target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(tile_example_moe_sorting PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
# moe-sorting
This folder contains example for moe-sorting kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input&weight is a `token*topk` 2d matrix. The op rearange the input weight ids into different experts and feed into fuse moe gemm kernel.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_sorting -j
```
This will result in an executable `build/bin/tile_example_moe_sorting`
## example
```
args:
-v weather do CPU validation or not (default:1)
-pr_i index data type. (currently only fp32 supported now) (default:int32)
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
-t number of input tokens (default:32)
-e number of experts (default:8)
-k topk (default:2)
-st_i row stride of input, -1 means same as experts (default:-1)
-seed seed to be used, -1 means random every time (default:-1)
-kname when set to 1 it will print kernel name (default:0)
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <set>
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "moe_sorting_api.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("pr_i", "int32", "index data type. (currently only int32 supported now)")
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
.insert("t", "128", "number of input tokens")
.insert("e", "8", "number of num_experts")
.insert("k", "4", "topk")
.insert("unit", "32", "unit_size")
.insert("moe_buf_size", "0", "moe_buf_size")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool test_moe_sorting(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
int tokens = args.get_int("t");
int num_experts = args.get_int("e");
int topk = args.get_int("k");
int seed = args.get_int("seed");
int unit_size = args.get_int("unit");
int moe_buf_size = args.get_int("moe_buf_size");
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
int max_output_ids =
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
if(seed < 0)
{
seed = std::time(nullptr);
}
if(topk > num_experts)
{
printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n",
topk,
num_experts);
return false;
}
// tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_expert_ids_dev(
sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data());
if(moe_buf_size > 0)
{
moe_buf_dev.ToDevice(moe_buf_host.data());
}
moe_sorting_trait trait{index_prec, weight_prec};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(),
sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
tokens,
unit_size,
num_experts,
topk,
static_cast<ck_tile::index_t>(moe_buf_size * sizeof(float))};
ck_tile::stream_config sc{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
warmup,
repeat};
auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ",
index_prec.c_str(),
weight_prec.c_str(),
tokens,
num_experts,
topk,
ms);
if(ms < 0)
printf("not supported\n");
fflush(stdout);
if(ms < 0)
{
return false;
}
sorted_ids_dev.FromDevice(sorted_ids_host.data());
sorted_weights_dev.FromDevice(sorted_weights_host.data());
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
if(moe_buf_size > 0)
{
moe_buf_dev.FromDevice(moe_buf_host.data());
}
bool rtn = true;
if(validate)
{
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_ref({max_output_ids / unit_size}, {1});
int32_t ref_total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host,
sorted_ids_ref,
sorted_weights_ref,
sorted_expert_ids_ref,
ref_total_tokens_post_pad,
num_experts,
unit_size);
rtn &= ck_tile::check_err(
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
rtn &= ck_tile::check_err(sorted_weights_host,
sorted_weights_ref,
std::string("OUT Error: Incorrect w!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_expert_ids_host,
sorted_expert_ids_ref,
std::string("OUT Error: Incorrect eid!"),
1e-6,
1e-6);
if(moe_buf_size)
{
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
rtn &= ck_tile::check_err(
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
}
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
}
printf("valid:%s\n", rtn ? "y" : "n");
fflush(stdout);
return rtn;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
bool r = true;
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
{
r &= test_moe_sorting<float, ck_tile::index_t>(args);
}
return r ? 0 : -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{
if(t.weight_type == "fp32" && t.index_type == "int32")
{
if(a.num_experts > 127)
{
printf("lds size exceed, only support experts <127 \n");
return -1;
}
if(a.moe_buf_bytes % 16)
{
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes);
return -1;
}
using index_t = ck_tile::index_t;
using ms_weight_type = float;
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
switch(smem_io_unroll_num)
{
case(1): {
MOE_SORTING_DISPATCH(1);
}
case(2): {
MOE_SORTING_DISPATCH(2);
}
case(3): {
MOE_SORTING_DISPATCH(3);
}
case(5): {
MOE_SORTING_DISPATCH(5);
}
case(6): {
MOE_SORTING_DISPATCH(6);
}
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): {
MOE_SORTING_DISPATCH(8);
}
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): {
MOE_SORTING_DISPATCH(10);
}
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: {
MOE_SORTING_DISPATCH(4);
}
}
}
return -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/moe_sorting.hpp"
struct moe_sorting_trait
{
std::string index_type;
std::string weight_type; // currently always float
};
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
{
};
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
# #!/bin/sh
EXE=./build/bin/tile_example_moe_sorting
$EXE -t=80 -e=17 -moe_buf_size=16
$EXE -t=111 -e=117 -moe_buf_size=4
$EXE -t=1000 -e=55 -moe_buf_size=1024
$EXE -t=99 -e=120 -moe_buf_size=10244
$EXE -t=175 -e=64 -k=8
$EXE -t=65 -e=8 -k=2
$EXE -t=1 -e=25
$EXE -t=31 -e=19 -k=15
$EXE -t=81 -e=37 -k=7
$EXE -t=23 -e=1 -k=1
$EXE -t=127 -e=99 -k=19
$EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13
\ No newline at end of file
......@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax)
add_subdirectory(10_rmsnorm2d)
add_subdirectory(11_add_rmsnorm2d_rdquant)
add_subdirectory(12_smoothquant)
add_subdirectory(13_moe_sorting)
......@@ -23,6 +23,7 @@
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const HostTensor<WeightType>& weights,
HostTensor<IndexType>& p_sorted_token_ids,
HostTensor<WeightType>& sorted_weight,
HostTensor<IndexType>& sorted_expert_ids,
index_t& unit_cnt,
const index_t experts,
const index_t unit_size)
{
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1];
std::vector<std::vector<IndexType>> expert_tokens(experts,
std::vector<IndexType>(unit_size, num_token));
std::vector<std::vector<WeightType>> expert_token_weights(
experts, std::vector<WeightType>(unit_size, 0));
std::vector<IndexType> expert_slices(experts, 1);
std::vector<IndexType> expert_slice_idxs(experts, 0);
for(index_t t = 0; t < num_token; t++)
{
for(index_t k = 0; k < topk; k++)
{
IndexType e = topk_ids(t, k);
WeightType w = weights(t, k);
index_t idx = expert_slice_idxs[e];
if(idx > expert_slices[e] * unit_size - 1)
{
expert_slices[e]++;
index_t new_size = expert_slices[e] * unit_size;
expert_tokens[e].resize(new_size);
expert_token_weights[e].resize(new_size);
for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
{
expert_tokens[e][i] = num_token;
expert_token_weights[e][i] = 0;
}
}
expert_tokens[e][idx] = t;
expert_token_weights[e][idx] = w;
expert_slice_idxs[e]++;
}
}
IndexType* out_tokens = p_sorted_token_ids.data();
WeightType* out_weights = sorted_weight.data();
IndexType* out_expert_id = sorted_expert_ids.data();
for(index_t e = 0; e < experts; e++)
{
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
out_tokens += expert_slices[e] * unit_size;
memcpy(out_weights,
expert_token_weights[e].data(),
sizeof(WeightType) * expert_slices[e] * unit_size);
out_weights += expert_slices[e] * unit_size;
for(index_t s = 0; s < expert_slices[e]; s++)
{
out_expert_id[s] = e;
unit_cnt++;
}
out_expert_id += expert_slices[e];
}
unit_cnt *= unit_size;
return;
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
struct MoeSortingHostArgs
{
const void* p_topk_ids;
const void* p_weights;
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
void* p_total_tokens_post_pad;
void* p_moe_buf;
index_t tokens;
index_t unit_size;
index_t num_experts;
index_t topk;
index_t moe_buf_bytes;
};
template <typename Problem_>
struct MoeSortingKernel
{
using Problem = remove_cvref_t<Problem_>;
using IndexType = typename Problem::IndexType;
using WeightType = typename Problem::WeightType;
typedef MoeSortingHostArgs MoeSortingKargs;
using Hargs = MoeSortingHostArgs;
struct Kargs
{
const void* p_topk_ids;
const void* p_weights;
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
void* p_total_tokens_post_pad;
void* p_moe_buf;
index_t tokens;
index_t num_experts;
index_t moe_buf_bytes;
index_t tokens_per_thread;
mdiv unit_size_mdiv;
mdiv topk_mdiv;
};
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
// TODO: assume num-experts not too much
return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BlockSize(h).x * 16));
}
CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
{
return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size()));
}
// in byte
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
{
const auto blocks = BlockSize(h);
return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t);
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_topk_ids = h.p_topk_ids;
k.p_weights = h.p_weights;
k.p_sorted_token_ids = h.p_sorted_token_ids;
k.p_sorted_weights = h.p_sorted_weights;
k.p_sorted_expert_ids = h.p_sorted_expert_ids;
k.p_moe_buf = h.p_moe_buf;
k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
k.tokens = h.tokens;
k.num_experts = h.num_experts;
k.moe_buf_bytes = h.moe_buf_bytes;
const auto blocks = BlockSize(h);
k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x);
k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
return k;
}
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
{
return row * total_col + col;
}
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes) const
{
const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x;
if(offset < buf_bytes / 16)
{
buf[offset] = uint8x16_t{0};
}
}
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
const WeightType* __restrict__ weights,
index_t* p_sorted_token_ids,
WeightType* p_sorted_weights,
index_t* p_sorted_expert_ids,
index_t* p_total_tokens_post_pad,
const index_t num_experts,
const index_t tokens_per_thread,
const index_t numel,
const mdiv unit_size_mdiv,
const mdiv topk_mdiv,
void* smem) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
const index_t start_idx = tid * tokens_per_thread;
index_t* shared_mem = reinterpret_cast<index_t*>(smem);
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1)
for(int i = 0; i < num_experts; ++i)
{
tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0;
}
#pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])];
}
__syncthreads();
if(tid < num_experts)
{
tokens_cnts[calc_index(num_experts, 0, tid)] = 0;
for(int i = 1; i <= static_cast<index_t>(blockDim.x); ++i)
{
tokens_cnts[calc_index(num_experts, i, tid)] +=
tokens_cnts[calc_index(num_experts, i - 1, tid)];
}
}
// __syncthreads();
if(tid == 0)
{
cumsum[0] = 0;
for(int i = 1; i <= num_experts; ++i)
{
auto current_units = [&]() {
index_t x_ = tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)] +
unit_size_mdiv.divisor - 1;
index_t y_ = unit_size_mdiv.div(x_);
return max(y_, 1) * unit_size_mdiv.divisor;
}();
cumsum[i] = cumsum[i - 1] + current_units;
}
*p_total_tokens_post_pad = cumsum[num_experts];
}
__syncthreads();
if(tid < num_experts)
{
for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size_mdiv.divisor)
{
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
}
}
#pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
index_t expert_id = topk_id[i];
index_t rank_post_pad =
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id];
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
p_sorted_weights[rank_post_pad] = weights[i];
++tokens_cnts[calc_index(num_experts, tid, expert_id)];
}
const index_t prefill_token = topk_mdiv.div(numel);
if(tid < num_experts)
{
index_t expert_offset =
cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)];
while(expert_offset < cumsum[tid + 1])
{
p_sorted_token_ids[expert_offset] = prefill_token;
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset++;
}
}
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
if(blockIdx.x > 0)
{
if(kargs.p_moe_buf)
{
moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
kargs.moe_buf_bytes);
}
return;
}
const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
extern __shared__ char smem[];
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
static_cast<IndexType*>(kargs.p_sorted_token_ids),
static_cast<WeightType*>(kargs.p_sorted_weights),
static_cast<IndexType*>(kargs.p_sorted_expert_ids),
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
kargs.num_experts,
kargs.tokens_per_thread,
numel,
kargs.unit_size_mdiv,
kargs.topk_mdiv,
smem);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include <string>
#include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace ck_tile {
// template <typename Problem_, typename Policy_ = MoeSortingPolicy>
// struct MoeSortingPipeline
// {
// // TODO: this kernel only support warp per row
// using Problem = remove_cvref_t<Problem_>;
// using Policy = remove_cvref_t<Policy_>;
// using WeightType = typename Problem::WeightType;
// template <typename TopkIdWindow, typename WeightWindow>
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
// const WeightWindow& weight_window,
// index_t* p_sorted_token_ids,
// WeightType* p_sorted_weights,
// index_t* p_sorted_expert_ids,
// index_t* p_total_tokens_post_pad,
// const index_t num_experts,
// const index_t unit_size,
// const size_t numel,
// const index_t topk)
// {
// }
// };
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace ck_tile {
struct MoeSortingPolicy
{
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename IndexType_, typename WeightType_, index_t InternalLoadUnroll_>
struct MoeSortingProblem
{
// TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t InternalLoadUnroll = InternalLoadUnroll_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment