Commit e44e7a95 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Run remod.py under include/ck_tile & example/ck_tile directories

parent 9964919d
...@@ -23,12 +23,12 @@ ...@@ -23,12 +23,12 @@
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp" #include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp" #include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp" #include "ck_tile/host/timer.hpp"
...@@ -22,8 +22,8 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -22,8 +22,8 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const index_t topk = topk_ids.mDesc.get_lengths()[1]; const index_t topk = topk_ids.mDesc.get_lengths()[1];
std::vector<std::vector<IndexType>> expert_tokens(experts, std::vector<std::vector<IndexType>> expert_tokens(experts,
std::vector<IndexType>(unit_size, num_token)); std::vector<IndexType>(unit_size, num_token));
std::vector<std::vector<WeightType>> expert_token_weights(experts, std::vector<std::vector<WeightType>> expert_token_weights(
std::vector<WeightType>(unit_size, 0)); experts, std::vector<WeightType>(unit_size, 0));
std::vector<IndexType> expert_slices(experts, 1); std::vector<IndexType> expert_slices(experts, 1);
std::vector<IndexType> expert_slice_idxs(experts, 0); std::vector<IndexType> expert_slice_idxs(experts, 0);
...@@ -60,8 +60,9 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -60,8 +60,9 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
{ {
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size); memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
out_tokens += expert_slices[e] * unit_size; out_tokens += expert_slices[e] * unit_size;
memcpy( memcpy(out_weights,
out_weights, expert_token_weights[e].data(), sizeof(WeightType) * expert_slices[e] * unit_size); expert_token_weights[e].data(),
sizeof(WeightType) * expert_slices[e] * unit_size);
out_weights += expert_slices[e] * unit_size; out_weights += expert_slices[e] * unit_size;
for(index_t s = 0; s < expert_slices[e]; s++) for(index_t s = 0; s < expert_slices[e]; s++)
......
...@@ -71,7 +71,7 @@ struct MoeSortingKernel ...@@ -71,7 +71,7 @@ struct MoeSortingKernel
tokens_cnts[calc_index(num_experts, threadIdx.x + 1, i)] = 0; tokens_cnts[calc_index(num_experts, threadIdx.x + 1, i)] = 0;
} }
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{ {
++tokens_cnts[calc_index(num_experts, threadIdx.x + 1, topk_id[i])]; ++tokens_cnts[calc_index(num_experts, threadIdx.x + 1, topk_id[i])];
} }
...@@ -95,7 +95,8 @@ struct MoeSortingKernel ...@@ -95,7 +95,8 @@ struct MoeSortingKernel
{ {
cumsum[i] = cumsum[i] =
cumsum[i - 1] + cumsum[i - 1] +
max(integer_divide_ceil(tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)], unit_size), max(integer_divide_ceil(tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)],
unit_size),
1) * 1) *
unit_size; unit_size;
} }
...@@ -137,12 +138,12 @@ struct MoeSortingKernel ...@@ -137,12 +138,12 @@ struct MoeSortingKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
const size_t numel = kargs.tokens * kargs.topk; const size_t numel = kargs.tokens * kargs.topk;
return moe_align_block_size_kernel(static_cast<const IndexType *>(kargs.p_topk_ids), return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType *>(kargs.p_weights), static_cast<const WeightType*>(kargs.p_weights),
static_cast<IndexType *>(kargs.sorted_token_ids), static_cast<IndexType*>(kargs.sorted_token_ids),
static_cast<WeightType *>(kargs.sorted_weights), static_cast<WeightType*>(kargs.sorted_weights),
static_cast<IndexType *>(kargs.expert_ids), static_cast<IndexType*>(kargs.expert_ids),
static_cast<IndexType *>(kargs.total_tokens_post_pad), static_cast<IndexType*>(kargs.total_tokens_post_pad),
kargs.num_experts, kargs.num_experts,
kargs.unit_size, kargs.unit_size,
numel, numel,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment