Commit e44e7a95 authored by Po Yen, Chen's avatar Po Yen, Chen
Browse files

Run remod.py under include/ck_tile & example/ck_tile directories

parent 9964919d
...@@ -7,9 +7,9 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_conf ...@@ -7,9 +7,9 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_kargs a, ck_tile::stream_conf
{ {
if(t.weight_type == "fp32") if(t.weight_type == "fp32")
{ {
using index_t = ck_tile::index_t; using index_t = ck_tile::index_t;
using ms_weight_type = float; using ms_weight_type = float;
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type>; using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type>;
// using ms_pipeline = ck_tile::MoeSortingPipeline<ms_problem>; // using ms_pipeline = ck_tile::MoeSortingPipeline<ms_problem>;
using kernel = ck_tile::MoeSortingKernel<ms_problem>; using kernel = ck_tile::MoeSortingKernel<ms_problem>;
auto kargs = kernel::MakeKargs(a); auto kargs = kernel::MakeKargs(a);
......
...@@ -23,12 +23,12 @@ ...@@ -23,12 +23,12 @@
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp" #include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp" #include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp" #include "ck_tile/host/timer.hpp"
...@@ -19,11 +19,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -19,11 +19,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const index_t unit_size) const index_t unit_size)
{ {
const index_t num_token = topk_ids.mDesc.get_lengths()[0]; const index_t num_token = topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1]; const index_t topk = topk_ids.mDesc.get_lengths()[1];
std::vector<std::vector<IndexType>> expert_tokens(experts, std::vector<std::vector<IndexType>> expert_tokens(experts,
std::vector<IndexType>(unit_size, num_token)); std::vector<IndexType>(unit_size, num_token));
std::vector<std::vector<WeightType>> expert_token_weights(experts, std::vector<std::vector<WeightType>> expert_token_weights(
std::vector<WeightType>(unit_size, 0)); experts, std::vector<WeightType>(unit_size, 0));
std::vector<IndexType> expert_slices(experts, 1); std::vector<IndexType> expert_slices(experts, 1);
std::vector<IndexType> expert_slice_idxs(experts, 0); std::vector<IndexType> expert_slice_idxs(experts, 0);
...@@ -31,7 +31,7 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -31,7 +31,7 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
{ {
for(index_t k = 0; k < topk; k++) for(index_t k = 0; k < topk; k++)
{ {
IndexType e = topk_ids(t, k); IndexType e = topk_ids(t, k);
WeightType w = weights(t, k); WeightType w = weights(t, k);
index_t idx = expert_slice_idxs[e]; index_t idx = expert_slice_idxs[e];
if(idx > expert_slices[e] * unit_size - 1) if(idx > expert_slices[e] * unit_size - 1)
...@@ -53,15 +53,16 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids, ...@@ -53,15 +53,16 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
} }
} }
IndexType* out_tokens = sorted_token_ids.data(); IndexType* out_tokens = sorted_token_ids.data();
WeightType* out_weights = sorted_weight.data(); WeightType* out_weights = sorted_weight.data();
IndexType* out_expert_id = sorted_expert_ids.data(); IndexType* out_expert_id = sorted_expert_ids.data();
for(index_t e = 0; e < experts; e++) for(index_t e = 0; e < experts; e++)
{ {
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size); memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
out_tokens += expert_slices[e] * unit_size; out_tokens += expert_slices[e] * unit_size;
memcpy( memcpy(out_weights,
out_weights, expert_token_weights[e].data(), sizeof(WeightType) * expert_slices[e] * unit_size); expert_token_weights[e].data(),
sizeof(WeightType) * expert_slices[e] * unit_size);
out_weights += expert_slices[e] * unit_size; out_weights += expert_slices[e] * unit_size;
for(index_t s = 0; s < expert_slices[e]; s++) for(index_t s = 0; s < expert_slices[e]; s++)
......
...@@ -30,7 +30,7 @@ template <typename Problem_> ...@@ -30,7 +30,7 @@ template <typename Problem_>
struct MoeSortingKernel struct MoeSortingKernel
{ {
// using Pipeline = remove_cvref_t<Pipeline_>; // using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using IndexType = typename Problem::IndexType; using IndexType = typename Problem::IndexType;
using WeightType = typename Problem::WeightType; using WeightType = typename Problem::WeightType;
...@@ -71,7 +71,7 @@ struct MoeSortingKernel ...@@ -71,7 +71,7 @@ struct MoeSortingKernel
tokens_cnts[calc_index(num_experts, threadIdx.x + 1, i)] = 0; tokens_cnts[calc_index(num_experts, threadIdx.x + 1, i)] = 0;
} }
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{ {
++tokens_cnts[calc_index(num_experts, threadIdx.x + 1, topk_id[i])]; ++tokens_cnts[calc_index(num_experts, threadIdx.x + 1, topk_id[i])];
} }
...@@ -95,7 +95,8 @@ struct MoeSortingKernel ...@@ -95,7 +95,8 @@ struct MoeSortingKernel
{ {
cumsum[i] = cumsum[i] =
cumsum[i - 1] + cumsum[i - 1] +
max(integer_divide_ceil(tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)], unit_size), max(integer_divide_ceil(tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)],
unit_size),
1) * 1) *
unit_size; unit_size;
} }
...@@ -137,12 +138,12 @@ struct MoeSortingKernel ...@@ -137,12 +138,12 @@ struct MoeSortingKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
const size_t numel = kargs.tokens * kargs.topk; const size_t numel = kargs.tokens * kargs.topk;
return moe_align_block_size_kernel(static_cast<const IndexType *>(kargs.p_topk_ids), return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
static_cast<const WeightType *>(kargs.p_weights), static_cast<const WeightType*>(kargs.p_weights),
static_cast<IndexType *>(kargs.sorted_token_ids), static_cast<IndexType*>(kargs.sorted_token_ids),
static_cast<WeightType *>(kargs.sorted_weights), static_cast<WeightType*>(kargs.sorted_weights),
static_cast<IndexType *>(kargs.expert_ids), static_cast<IndexType*>(kargs.expert_ids),
static_cast<IndexType *>(kargs.total_tokens_post_pad), static_cast<IndexType*>(kargs.total_tokens_post_pad),
kargs.num_experts, kargs.num_experts,
kargs.unit_size, kargs.unit_size,
numel, numel,
......
...@@ -16,8 +16,8 @@ namespace ck_tile { ...@@ -16,8 +16,8 @@ namespace ck_tile {
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension) // synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true> template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func, const ReduceFunc& reduce_func,
bool_constant<WithBroadcast> = {}) bool_constant<WithBroadcast> = {})
{ {
using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode; using DstrEncode = typename Dstr::DstrEncode;
...@@ -116,7 +116,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, ...@@ -116,7 +116,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
*/ */
template <typename AccDistributedTensor_, typename ReduceFunc> template <typename AccDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor, CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func) const ReduceFunc& reduce_func)
{ {
using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode; using DstrEncode = typename Dstr::DstrEncode;
...@@ -175,9 +175,9 @@ template <typename AccDistributedTensor_, ...@@ -175,9 +175,9 @@ template <typename AccDistributedTensor_,
index_t... InReduceDims, index_t... InReduceDims,
typename ReduceFunc> typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor, CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
const InDistributedTensor_& in_tensor, const InDistributedTensor_& in_tensor,
sequence<InReduceDims...>, sequence<InReduceDims...>,
const ReduceFunc& reduce_func) const ReduceFunc& reduce_func)
{ {
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{}; constexpr auto I1 = number<1>{};
...@@ -250,9 +250,9 @@ template <typename AccDataType_, ...@@ -250,9 +250,9 @@ template <typename AccDataType_,
typename ReduceFunc, typename ReduceFunc,
typename InDataType_> typename InDataType_>
CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor, CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
sequence<InReduceDims...> in_reduce_dims, sequence<InReduceDims...> in_reduce_dims,
const ReduceFunc& reduce_func, const ReduceFunc& reduce_func,
const InDataType_& reduce_init) const InDataType_& reduce_init)
{ {
using InDataType = typename InDistributedTensor_::DataType; using InDataType = typename InDistributedTensor_::DataType;
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment