Commit 9964919d authored by dummycoderfe's avatar dummycoderfe
Browse files

fix comments & typo

parent 2bf0057a
...@@ -158,34 +158,27 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -158,34 +158,27 @@ bool test_moe_sorting(ck_tile::ArgParser args)
bool rtn = true; bool rtn = true;
if(validate) if(validate)
{ {
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1}); ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1}); ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> expert_ids_ref({max_output_ids / unit_size}, {1}); ck_tile::HostTensor<IndexType> expert_ids_ref({max_output_ids / unit_size}, {1});
int32_t total_tokens_post_pad = 0; int32_t total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(sorted_ids_ref.data(), ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
sorted_weights_ref.data(), weights_host,
expert_ids_ref.data(), sorted_ids_ref,
sorted_weights_ref,
expert_ids_ref,
total_tokens_post_pad, total_tokens_post_pad,
weights_host.data(),
topk_ids_host.data(),
topk_ids_host.size() / topk,
experts, experts,
topk,
unit_size); unit_size);
float atol = 1e-6;
float rtol = 1e-6;
rtn &= ck_tile::check_err( rtn &= ck_tile::check_err(
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), rtol, atol); sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
rtn &= ck_tile::check_err(sorted_weights_host, rtn &= ck_tile::check_err(sorted_weights_host,
sorted_weights_ref, sorted_weights_ref,
std::string("OUT Error: Incorrect w!"), std::string("OUT Error: Incorrect w!"),
rtol, 1e-6,
atol); 1e-6);
rtn &= ck_tile::check_err( rtn &= ck_tile::check_err(
expert_ids_host, expert_ids_ref, std::string("OUT Error: Incorrect eid!"), rtol, atol); expert_ids_host, expert_ids_ref, std::string("OUT Error: Incorrect eid!"), 1e-6, 1e-6);
rtn &= total_tokens_post_pad == sorted_id_cnt_host.mData[0]; rtn &= total_tokens_post_pad == sorted_id_cnt_host.mData[0];
} }
......
...@@ -10,8 +10,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_conf ...@@ -10,8 +10,8 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_conf
using index_t = ck_tile::index_t; using index_t = ck_tile::index_t;
using ms_weight_type = float; using ms_weight_type = float;
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type>; using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type>;
using ms_pipeline = ck_tile::MoeSortingPipeline<ms_problem>; // using ms_pipeline = ck_tile::MoeSortingPipeline<ms_problem>;
using kernel = ck_tile::MoeSortingKernel<ms_pipeline>; using kernel = ck_tile::MoeSortingKernel<ms_problem>;
auto kargs = kernel::MakeKargs(a); auto kargs = kernel::MakeKargs(a);
const dim3 grids = 1; const dim3 grids = 1;
const dim3 blocks = ck_tile::max(t.experts, ck_tile::get_warp_size()); const dim3 blocks = ck_tile::max(t.experts, ck_tile::get_warp_size());
......
...@@ -9,21 +9,21 @@ ...@@ -9,21 +9,21 @@
namespace ck_tile { namespace ck_tile {
template <typename WeightType, typename IndexType = index_t> template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr, CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
WeightType* sorted_weight_buf, const HostTensor<WeightType>& weights,
IndexType* sorted_expert_ids_ptr, HostTensor<IndexType>& sorted_token_ids,
index_t& sub_x_cnt, HostTensor<WeightType>& sorted_weight,
const WeightType* weights_ptr, HostTensor<IndexType>& sorted_expert_ids,
const IndexType* topk_ids_ptr, index_t& unit_cnt,
const index_t num_token,
const index_t experts, const index_t experts,
const index_t topk, const index_t unit_size)
const index_t sub_x)
{ {
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1];
std::vector<std::vector<IndexType>> expert_tokens(experts, std::vector<std::vector<IndexType>> expert_tokens(experts,
std::vector<IndexType>(sub_x, num_token)); std::vector<IndexType>(unit_size, num_token));
std::vector<std::vector<WeightType>> expert_token_weights(experts, std::vector<std::vector<WeightType>> expert_token_weights(experts,
std::vector<WeightType>(sub_x, 0)); std::vector<WeightType>(unit_size, 0));
std::vector<IndexType> expert_slices(experts, 1); std::vector<IndexType> expert_slices(experts, 1);
std::vector<IndexType> expert_slice_idxs(experts, 0); std::vector<IndexType> expert_slice_idxs(experts, 0);
...@@ -31,16 +31,16 @@ CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr, ...@@ -31,16 +31,16 @@ CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr,
{ {
for(index_t k = 0; k < topk; k++) for(index_t k = 0; k < topk; k++)
{ {
index_t e = *(topk_ids_ptr + t * topk + k); IndexType e = topk_ids(t, k);
WeightType w = *(weights_ptr + t * topk + k); WeightType w = weights(t, k);
index_t idx = expert_slice_idxs[e]; index_t idx = expert_slice_idxs[e];
if(idx > expert_slices[e] * sub_x - 1) if(idx > expert_slices[e] * unit_size - 1)
{ {
expert_slices[e]++; expert_slices[e]++;
index_t new_size = expert_slices[e] * sub_x; index_t new_size = expert_slices[e] * unit_size;
expert_tokens[e].resize(new_size); expert_tokens[e].resize(new_size);
expert_token_weights[e].resize(new_size); expert_token_weights[e].resize(new_size);
for(index_t idx = (expert_slices[e] - 1) * sub_x; idx < new_size; idx++) for(index_t idx = (expert_slices[e] - 1) * unit_size; idx < new_size; idx++)
{ {
expert_tokens[e][idx] = num_token; expert_tokens[e][idx] = num_token;
expert_token_weights[e][idx] = 0; expert_token_weights[e][idx] = 0;
...@@ -53,23 +53,23 @@ CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr, ...@@ -53,23 +53,23 @@ CK_TILE_HOST void reference_moe_sorting(IndexType* sorted_token_ids_ptr,
} }
} }
IndexType* tokens = sorted_token_ids_ptr; IndexType* out_tokens = sorted_token_ids.data();
WeightType* weights = sorted_weight_buf; WeightType* out_weights = sorted_weight.data();
IndexType* erp_ids = sorted_expert_ids_ptr; IndexType* out_expert_id = sorted_expert_ids.data();
for(index_t e = 0; e < experts; e++) for(index_t e = 0; e < experts; e++)
{ {
memcpy(tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * sub_x); memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
tokens += expert_slices[e] * sub_x; out_tokens += expert_slices[e] * unit_size;
memcpy( memcpy(
weights, expert_token_weights[e].data(), sizeof(WeightType) * expert_slices[e] * sub_x); out_weights, expert_token_weights[e].data(), sizeof(WeightType) * expert_slices[e] * unit_size);
weights += expert_slices[e] * sub_x; out_weights += expert_slices[e] * unit_size;
for(index_t s = 0; s < expert_slices[e]; s++) for(index_t s = 0; s < expert_slices[e]; s++)
{ {
erp_ids[s] = e; out_expert_id[s] = e;
sub_x_cnt++; unit_cnt++;
} }
erp_ids += expert_slices[e]; out_expert_id += expert_slices[e];
} }
return; return;
......
...@@ -26,11 +26,11 @@ struct MoeSortingHostArgs ...@@ -26,11 +26,11 @@ struct MoeSortingHostArgs
index_t topk; index_t topk;
}; };
template <typename Pipeline_> template <typename Problem_>
struct MoeSortingKernel struct MoeSortingKernel
{ {
using Pipeline = remove_cvref_t<Pipeline_>; // using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>; using Problem = remove_cvref_t<Problem_>;
using IndexType = typename Problem::IndexType; using IndexType = typename Problem::IndexType;
using WeightType = typename Problem::WeightType; using WeightType = typename Problem::WeightType;
...@@ -47,8 +47,6 @@ struct MoeSortingKernel ...@@ -47,8 +47,6 @@ struct MoeSortingKernel
return row * total_col + col; return row * total_col + col;
} }
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#define MAX(x, y) (((x) > (y)) ? (x) : (y))
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id, CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
const WeightType* __restrict__ weights, const WeightType* __restrict__ weights,
index_t* sorted_token_ids, index_t* sorted_token_ids,
...@@ -60,7 +58,7 @@ struct MoeSortingKernel ...@@ -60,7 +58,7 @@ struct MoeSortingKernel
const size_t numel, const size_t numel,
const index_t topk) const const index_t topk) const
{ {
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t tokens_per_thread = integer_divide_ceil(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread; const size_t start_idx = threadIdx.x * tokens_per_thread;
extern __shared__ index_t shared_mem[]; extern __shared__ index_t shared_mem[];
...@@ -73,6 +71,10 @@ struct MoeSortingKernel ...@@ -73,6 +71,10 @@ struct MoeSortingKernel
tokens_cnts[calc_index(num_experts, threadIdx.x + 1, i)] = 0; tokens_cnts[calc_index(num_experts, threadIdx.x + 1, i)] = 0;
} }
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
++tokens_cnts[calc_index(num_experts, threadIdx.x + 1, topk_id[i])];
}
__syncthreads(); __syncthreads();
if(threadIdx.x < num_experts) if(threadIdx.x < num_experts)
...@@ -93,7 +95,7 @@ struct MoeSortingKernel ...@@ -93,7 +95,7 @@ struct MoeSortingKernel
{ {
cumsum[i] = cumsum[i] =
cumsum[i - 1] + cumsum[i - 1] +
MAX(CEILDIV(tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)], unit_size), max(integer_divide_ceil(tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)], unit_size),
1) * 1) *
unit_size; unit_size;
} }
......
...@@ -14,26 +14,26 @@ ...@@ -14,26 +14,26 @@
namespace ck_tile { namespace ck_tile {
template <typename Problem_, typename Policy_ = MoeSortingPolicy> // template <typename Problem_, typename Policy_ = MoeSortingPolicy>
struct MoeSortingPipeline // struct MoeSortingPipeline
{ // {
// TODO: this kernel only support warp per row // // TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>; // using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; // using Policy = remove_cvref_t<Policy_>;
using WeightType = typename Problem::WeightType; // using WeightType = typename Problem::WeightType;
// template <typename TopkIdWindow, typename WeightWindow> // template <typename TopkIdWindow, typename WeightWindow>
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window, // CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
// const WeightWindow& weight_window, // const WeightWindow& weight_window,
// index_t* sorted_token_ids, // index_t* sorted_token_ids,
// WeightType* sorted_weights, // WeightType* sorted_weights,
// index_t* expert_ids, // index_t* expert_ids,
// index_t* total_tokens_post_pad, // index_t* total_tokens_post_pad,
// const index_t num_experts, // const index_t num_experts,
// const index_t unit_size, // const index_t unit_size,
// const size_t numel, // const size_t numel,
// const index_t topk) // const index_t topk)
// { // {
// } // }
}; // };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment