"...git@developer.sourcefind.cn:modelzoo/solov2-pytorch.git" did not exist on "ddfb38efa733f52ced8d02b03c9fd913e5d7e044"
Commit 2baf9422 authored by letaoqin's avatar letaoqin
Browse files

add moe general

parent c918cd4f
......@@ -19,13 +19,13 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>;
r = fused_moegemm_<t_>(s, a);
}
// clang-format on
......
......@@ -19,8 +19,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0,
typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0>;
typename Ts_::WarpPerBlock_1,
typename Ts_::WarpTile_1>;
using f_problem =
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::GDataType,
......@@ -38,9 +38,9 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_General<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
using f_kernel = ck_tile::FusedMoeGemmGlKernel<f_partitioner, f_pipeline, void>;
const dim3 grids = f_kernel::GridSize(a);
constexpr dim3 blocks = f_kernel::BlockSize();
......
......@@ -44,8 +44,8 @@ struct fmoe_ // traits, ugly name, only used for internal
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
using WarpPerBlock_1 = ck_tile::sequence<1, 1, 4>;//ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t GateOnly = GateOnly_;
......
......@@ -8,7 +8,7 @@
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -8,7 +8,7 @@
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
......@@ -261,8 +261,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
// ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// do moe sorting
if(balance)
......@@ -345,8 +345,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
// done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host);
ck_tile::DeviceMem g_perm_buf(g_perm_host);
ck_tile::DeviceMem d_perm_buf(d_perm_host);
ck_tile::DeviceMem g_perm_buf(g_host);
ck_tile::DeviceMem d_perm_buf(d_host);
ck_tile::DeviceMem sa_buf(sa_host);
ck_tile::DeviceMem sg_buf(sg_host);
ck_tile::DeviceMem sd_buf(sd_host);
......@@ -390,7 +390,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
tokens,
experts,
topk,
stride};
stride,
max_num_tokens_padded};
float ave_time = fused_moegemm(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
......
......@@ -57,4 +57,76 @@ struct indexing_adaptor_onshot_cached
return ck_tile::is_known_at_compile_time<IndexingType>::value;
}
};
#define Using_Gather 1
template <typename IndexingType>
struct indexing_adaptor
{
CK_TILE_HOST_DEVICE constexpr indexing_adaptor() = default;
CK_TILE_HOST_DEVICE constexpr indexing_adaptor(const IndexingType* idx) : cached_idx_(idx) {}
const IndexingType* cached_idx_;
#if Using_Gather
mutable index_t pre_up_index_ = 0;
mutable index_t pre_low_index_ = 0;
#endif
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = *(cached_idx_ + idx_up[number<0>{}]);
#if Using_Gather
pre_up_index_ = idx_up[number<0>{}];
pre_low_index_ = idx_low(number<0>{});
#if 0
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
}
#endif
#endif
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& /*idx_low*/,
const UpIdx& /*idx_up*/) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
#if !Using_Gather
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
#else
int up_index = idx_diff_up[number<0>{}] + pre_up_index_;
int low_index = *(cached_idx_ + up_index);
idx_diff_low(number<0>{}) = low_index - pre_low_index_;
pre_up_index_ = up_index;
pre_low_index_ = low_index;
#if 0
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
printf("\n index form %d to %d, diff from %d to %d \n",
up_index,
low_index,
idx_diff_up[number<0>{}],
idx_diff_low(number<0>{}));
}
#endif
#endif
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<IndexingType>::value;
}
};
} // namespace ck_tile
......@@ -213,9 +213,9 @@ struct FusedMoeGemmGlKernel
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
constexpr index_t block_m = BlockShape::Block_M0;
int max_num_tokens_padded =
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
//constexpr index_t block_m = BlockShape::Block_M0;
int max_num_tokens_padded = hargs.max_num_tokens_padded;
//hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
}
......
......@@ -117,6 +117,7 @@ struct FusedMoeGemmHostArgs
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
};
// This is scatter/gather b2b group-gemm
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment