Commit e44e7a95 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Run remod.py under include/ck_tile & example/ck_tile directories

parent 9964919d
......@@ -23,12 +23,12 @@
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
......@@ -22,8 +22,8 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const index_t topk = topk_ids.mDesc.get_lengths()[1];
std::vector<std::vector<IndexType>> expert_tokens(experts,
std::vector<IndexType>(unit_size, num_token));
std::vector<std::vector<WeightType>> expert_token_weights(experts,
std::vector<WeightType>(unit_size, 0));
std::vector<std::vector<WeightType>> expert_token_weights(
experts, std::vector<WeightType>(unit_size, 0));
std::vector<IndexType> expert_slices(experts, 1);
std::vector<IndexType> expert_slice_idxs(experts, 0);
......@@ -60,8 +60,9 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
{
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
out_tokens += expert_slices[e] * unit_size;
memcpy(
out_weights, expert_token_weights[e].data(), sizeof(WeightType) * expert_slices[e] * unit_size);
memcpy(out_weights,
expert_token_weights[e].data(),
sizeof(WeightType) * expert_slices[e] * unit_size);
out_weights += expert_slices[e] * unit_size;
for(index_t s = 0; s < expert_slices[e]; s++)
......
......@@ -71,7 +71,7 @@ struct MoeSortingKernel
tokens_cnts[calc_index(num_experts, threadIdx.x + 1, i)] = 0;
}
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{
++tokens_cnts[calc_index(num_experts, threadIdx.x + 1, topk_id[i])];
}
......@@ -95,7 +95,8 @@ struct MoeSortingKernel
{
cumsum[i] =
cumsum[i - 1] +
max(integer_divide_ceil(tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)], unit_size),
max(integer_divide_ceil(tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)],
unit_size),
1) *
unit_size;
}
......@@ -137,12 +138,12 @@ struct MoeSortingKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const size_t numel = kargs.tokens * kargs.topk;
return moe_align_block_size_kernel(static_cast<const IndexType *>(kargs.p_topk_ids),
static_cast<const WeightType *>(kargs.p_weights),
static_cast<IndexType *>(kargs.sorted_token_ids),
static_cast<WeightType *>(kargs.sorted_weights),
static_cast<IndexType *>(kargs.expert_ids),
static_cast<IndexType *>(kargs.total_tokens_post_pad),
return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType*>(kargs.p_weights),
static_cast<IndexType*>(kargs.sorted_token_ids),
static_cast<WeightType*>(kargs.sorted_weights),
static_cast<IndexType*>(kargs.expert_ids),
static_cast<IndexType*>(kargs.total_tokens_post_pad),
kargs.num_experts,
kargs.unit_size,
numel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment