Commit 9964919d authored by dummycoderfe's avatar dummycoderfe
Browse files

fix comments & typo

parent 2bf0057a
......@@ -158,34 +158,27 @@ bool test_moe_sorting(ck_tile::ArgParser args)
bool rtn = true;
if(validate)
{
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> expert_ids_ref({max_output_ids / unit_size}, {1});
int32_t total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(sorted_ids_ref.data(),
sorted_weights_ref.data(),
expert_ids_ref.data(),
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host,
sorted_ids_ref,
sorted_weights_ref,
expert_ids_ref,
total_tokens_post_pad,
weights_host.data(),
topk_ids_host.data(),
topk_ids_host.size() / topk,
experts,
topk,
unit_size);
float atol = 1e-6;
float rtol = 1e-6;
rtn &= ck_tile::check_err(
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), rtol, atol);
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
rtn &= ck_tile::check_err(sorted_weights_host,
sorted_weights_ref,
std::string("OUT Error: Incorrect w!"),
rtol,
atol);
1e-6,
1e-6);
rtn &= ck_tile::check_err(
expert_ids_host, expert_ids_ref, std::string("OUT Error: Incorrect eid!"), rtol, atol);
expert_ids_host, expert_ids_ref, std::string("OUT Error: Incorrect eid!"), 1e-6, 1e-6);
rtn &= total_tokens_post_pad == sorted_id_cnt_host.mData[0];
}
......
......@@ -10,8 +10,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_conf
using index_t = ck_tile::index_t;
using ms_weight_type = float;
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type>;
using ms_pipeline = ck_tile::MoeSortingPipeline<ms_problem>;
using kernel = ck_tile::MoeSortingKernel<ms_pipeline>;
// using ms_pipeline = ck_tile::MoeSortingPipeline<ms_problem>;
using kernel = ck_tile::MoeSortingKernel<ms_problem>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = 1;
const dim3 blocks = ck_tile::max(t.experts, ck_tile::get_warp_size());
......
......@@ -9,21 +9,21 @@
namespace ck_tile {
template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr,
WeightType* sorted_weight_buf,
IndexType* sorted_expert_ids_ptr,
index_t& sub_x_cnt,
const WeightType* weights_ptr,
const IndexType* topk_ids_ptr,
const index_t num_token,
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const HostTensor<WeightType>& weights,
HostTensor<IndexType>& sorted_token_ids,
HostTensor<WeightType>& sorted_weight,
HostTensor<IndexType>& sorted_expert_ids,
index_t& unit_cnt,
const index_t experts,
const index_t topk,
const index_t sub_x)
const index_t unit_size)
{
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1];
std::vector<std::vector<IndexType>> expert_tokens(experts,
std::vector<IndexType>(sub_x, num_token));
std::vector<IndexType>(unit_size, num_token));
std::vector<std::vector<WeightType>> expert_token_weights(experts,
std::vector<WeightType>(sub_x, 0));
std::vector<WeightType>(unit_size, 0));
std::vector<IndexType> expert_slices(experts, 1);
std::vector<IndexType> expert_slice_idxs(experts, 0);
......@@ -31,16 +31,16 @@ CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr,
{
for(index_t k = 0; k < topk; k++)
{
index_t e = *(topk_ids_ptr + t * topk + k);
WeightType w = *(weights_ptr + t * topk + k);
IndexType e = topk_ids(t, k);
WeightType w = weights(t, k);
index_t idx = expert_slice_idxs[e];
if(idx > expert_slices[e] * sub_x - 1)
if(idx > expert_slices[e] * unit_size - 1)
{
expert_slices[e]++;
index_t new_size = expert_slices[e] * sub_x;
index_t new_size = expert_slices[e] * unit_size;
expert_tokens[e].resize(new_size);
expert_token_weights[e].resize(new_size);
for(index_t idx = (expert_slices[e] - 1) * sub_x; idx < new_size; idx++)
for(index_t idx = (expert_slices[e] - 1) * unit_size; idx < new_size; idx++)
{
expert_tokens[e][idx] = num_token;
expert_token_weights[e][idx] = 0;
......@@ -53,23 +53,23 @@ CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr,
}
}
IndexType* tokens = sorted_token_ids_ptr;
WeightType* weights = sorted_weight_buf;
IndexType* erp_ids = sorted_expert_ids_ptr;
IndexType* out_tokens = sorted_token_ids.data();
WeightType* out_weights = sorted_weight.data();
IndexType* out_expert_id = sorted_expert_ids.data();
for(index_t e = 0; e < experts; e++)
{
memcpy(tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * sub_x);
tokens += expert_slices[e] * sub_x;
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
out_tokens += expert_slices[e] * unit_size;
memcpy(
weights, expert_token_weights[e].data(), sizeof(WeightType) * expert_slices[e] * sub_x);
weights += expert_slices[e] * sub_x;
out_weights, expert_token_weights[e].data(), sizeof(WeightType) * expert_slices[e] * unit_size);
out_weights += expert_slices[e] * unit_size;
for(index_t s = 0; s < expert_slices[e]; s++)
{
erp_ids[s] = e;
sub_x_cnt++;
out_expert_id[s] = e;
unit_cnt++;
}
erp_ids += expert_slices[e];
out_expert_id += expert_slices[e];
}
return;
......
......@@ -26,11 +26,11 @@ struct MoeSortingHostArgs
index_t topk;
};
template <typename Pipeline_>
template <typename Problem_>
struct MoeSortingKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
// using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<Problem_>;
using IndexType = typename Problem::IndexType;
using WeightType = typename Problem::WeightType;
......@@ -47,8 +47,6 @@ struct MoeSortingKernel
return row * total_col + col;
}
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define MAX(x, y) (((x) > (y)) ? (x) : (y))
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
const WeightType* __restrict__ weights,
index_t* sorted_token_ids,
......@@ -60,7 +58,7 @@ struct MoeSortingKernel
const size_t numel,
const index_t topk) const
{
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t tokens_per_thread = integer_divide_ceil(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ index_t shared_mem[];
......@@ -73,6 +71,10 @@ struct MoeSortingKernel
tokens_cnts[calc_index(num_experts, threadIdx.x + 1, i)] = 0;
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
++tokens_cnts[calc_index(num_experts, threadIdx.x + 1, topk_id[i])];
}
__syncthreads();
if(threadIdx.x < num_experts)
......@@ -93,7 +95,7 @@ struct MoeSortingKernel
{
cumsum[i] =
cumsum[i - 1] +
MAX(CEILDIV(tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)], unit_size),
max(integer_divide_ceil(tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)], unit_size),
1) *
unit_size;
}
......
......@@ -14,26 +14,26 @@
namespace ck_tile {
template <typename Problem_, typename Policy_ = MoeSortingPolicy>
struct MoeSortingPipeline
{
// TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using WeightType = typename Problem::WeightType;
// template <typename Problem_, typename Policy_ = MoeSortingPolicy>
// struct MoeSortingPipeline
// {
// // TODO: this kernel only support warp per row
// using Problem = remove_cvref_t<Problem_>;
// using Policy = remove_cvref_t<Policy_>;
// using WeightType = typename Problem::WeightType;
// template <typename TopkIdWindow, typename WeightWindow>
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
// const WeightWindow& weight_window,
// index_t* sorted_token_ids,
// WeightType* sorted_weights,
// index_t* expert_ids,
// index_t* total_tokens_post_pad,
// const index_t num_experts,
// const index_t unit_size,
// const size_t numel,
// const index_t topk)
// {
// }
};
// template <typename TopkIdWindow, typename WeightWindow>
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
// const WeightWindow& weight_window,
// index_t* sorted_token_ids,
// WeightType* sorted_weights,
// index_t* expert_ids,
// index_t* total_tokens_post_pad,
// const index_t num_experts,
// const index_t unit_size,
// const size_t numel,
// const index_t topk)
// {
// }
// };
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment