"git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "61f50c45849ee7315da5f7e2df193d4c9fe9713c"
Commit 5002a39c authored by Mateusz Ozga's avatar Mateusz Ozga
Browse files

Merge remote-tracking branch 'origin/develop' into mozga-amd/universal_gemm_weight

parents 2c546b0c 3d15f364
...@@ -3,18 +3,42 @@ ...@@ -3,18 +3,42 @@
#include "moe_sorting_api.hpp" #include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using ms_problem = \
auto kargs = kernel::MakeKargs(a); \ ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
const dim3 grids = kernel::GridSize(a); \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
const dim3 blocks = kernel::BlockSize(a); \ auto kargs = kernel::MakeKargs(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \ const dim3 grids = kernel::GridSize(a); \
float ave_time = ck_tile::launch_kernel( \ const dim3 blocks = kernel::BlockSize(a); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
...@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi ...@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
case(6): { case(6): {
MOE_SORTING_DISPATCH(6); MOE_SORTING_DISPATCH(6);
} }
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): { case(8): {
MOE_SORTING_DISPATCH(8); MOE_SORTING_DISPATCH(8);
} }
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): { case(10): {
MOE_SORTING_DISPATCH(10); MOE_SORTING_DISPATCH(10);
} }
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: { default: {
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
......
...@@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19 ...@@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19
$EXE -t=71 -e=11 -k=11 $EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1 $EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1 $EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13 $EXE -t=333 -e=99 -k=13
\ No newline at end of file $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
...@@ -3,18 +3,42 @@ ...@@ -3,18 +3,42 @@
#include "fused_moesorting.hpp" #include "fused_moesorting.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using ms_problem = \
auto kargs = kernel::MakeKargs(a); \ ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
const dim3 grids = kernel::GridSize(a); \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
const dim3 blocks = kernel::BlockSize(a); \ auto kargs = kernel::MakeKargs(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \ const dim3 grids = kernel::GridSize(a); \
float ave_time = ck_tile::launch_kernel( \ const dim3 blocks = kernel::BlockSize(a); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
...@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til ...@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
case(6): { case(6): {
MOE_SORTING_DISPATCH(6); MOE_SORTING_DISPATCH(6);
} }
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): { case(8): {
MOE_SORTING_DISPATCH(8); MOE_SORTING_DISPATCH(8);
} }
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): { case(10): {
MOE_SORTING_DISPATCH(10); MOE_SORTING_DISPATCH(10);
} }
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: { default: {
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
......
...@@ -115,8 +115,8 @@ ...@@ -115,8 +115,8 @@
#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@ #cmakedefine CK_USE_GFX94 @CK_USE_GFX94@
#endif #endif
#ifndef DCK_USE_OCP_FP8 #ifndef CK_USE_OCP_FP8
#cmakedefine DCK_USE_OCP_FP8 @DCK_USE_OCP_FP8@ #cmakedefine CK_USE_OCP_FP8 @CK_USE_OCP_FP8@
#endif #endif
#ifndef CK_USE_FNUZ_FP8 #ifndef CK_USE_FNUZ_FP8
......
...@@ -130,7 +130,8 @@ struct MoeSortingKernel ...@@ -130,7 +130,8 @@ struct MoeSortingKernel
CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
{ {
const auto blocks = BlockSize(h); const auto blocks = BlockSize(h);
return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t); // usually num_experts is power of 2, we pad 1 dword here for the row-size
return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t);
} }
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
...@@ -154,6 +155,75 @@ struct MoeSortingKernel ...@@ -154,6 +155,75 @@ struct MoeSortingKernel
return k; return k;
} }
// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
template <typename data_t, int wave_size>
__device__ inline void wave_cumsum(data_t& thread_data) const
{
// wave_size must be power of 2
constexpr int row_mask = 0xf;
constexpr int bank_mask = 0xf;
constexpr bool bound_ctrl = true; // ! out-of-bound is zero !
auto reduce_op = [&](auto x_, auto y_) { return x_ + y_; };
if constexpr(wave_size > 1)
{
thread_data = reduce_op(
thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x111,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:1
}
if constexpr(wave_size > 2)
{
thread_data = reduce_op(
thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x112,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:2
}
if constexpr(wave_size > 4)
{
thread_data =
reduce_op(thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x114,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:4
}
if constexpr(wave_size > 8)
{
thread_data =
reduce_op(thread_data,
__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
0x118,
row_mask,
bank_mask,
bound_ctrl))); // row_shr:8
}
if constexpr(wave_size > 16)
{
// now row-0, row-0+row-1, row-1+row-2, row-2+row-3
int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 1) << 2, __builtin_bit_cast(int, thread_data));
v_remote_tmp = __lane_id() >= 16 ? v_remote_tmp : 0;
thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
}
if constexpr(wave_size > 32)
{
// lane-id 48...63->31
int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 17) << 2, __builtin_bit_cast(int, thread_data));
v_remote_tmp = __lane_id() >= 32 ? v_remote_tmp : 0;
thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
}
}
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
{ {
return row * total_col + col; return row * total_col + col;
...@@ -187,48 +257,124 @@ struct MoeSortingKernel ...@@ -187,48 +257,124 @@ struct MoeSortingKernel
index_t* shared_mem = reinterpret_cast<index_t*>(smem); index_t* shared_mem = reinterpret_cast<index_t*>(smem);
index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts) index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1) index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts+1); // 1: (num_experts + 1)
for(int i = 0; i < num_experts; ++i) for(int i = 0; i < num_experts; ++i)
{ {
tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0; tokens_cnts[calc_index(num_experts+1, tid + 1, i)] = 0;
} }
#pragma unroll Problem_::InternalLoadUnroll #pragma unroll Problem_::InternalLoadUnroll
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{ {
++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])]; ++tokens_cnts[calc_index(num_experts+1, tid + 1, topk_id[i])];
} }
__syncthreads(); __syncthreads();
#if 1
if(tid < num_experts) if(tid < num_experts)
{ {
tokens_cnts[calc_index(num_experts, 0, tid)] = 0; tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0;
for(int i = 1; i <= static_cast<index_t>(blockDim.x); ++i) index_t local_c[8];
index_t prev_c = 0;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
for(int i = 1; i <= static_cast<index_t>(blockDim.x); i+= 8)
{ {
tokens_cnts[calc_index(num_experts, i, tid)] += local_c[0] = tokens_cnts[calc_index(num_experts+1, i + 0, tid)];
tokens_cnts[calc_index(num_experts, i - 1, tid)]; local_c[1] = tokens_cnts[calc_index(num_experts+1, i + 1, tid)];
local_c[2] = tokens_cnts[calc_index(num_experts+1, i + 2, tid)];
local_c[3] = tokens_cnts[calc_index(num_experts+1, i + 3, tid)];
local_c[4] = tokens_cnts[calc_index(num_experts+1, i + 4, tid)];
local_c[5] = tokens_cnts[calc_index(num_experts+1, i + 5, tid)];
local_c[6] = tokens_cnts[calc_index(num_experts+1, i + 6, tid)];
local_c[7] = tokens_cnts[calc_index(num_experts+1, i + 7, tid)];
local_c[0] += prev_c;
local_c[1] += local_c[0];
local_c[2] += local_c[1];
local_c[3] += local_c[2];
local_c[4] += local_c[3];
local_c[5] += local_c[4];
local_c[6] += local_c[5];
local_c[7] += local_c[6];
prev_c = local_c[7];
tokens_cnts[calc_index(num_experts+1, i + 0, tid)] = local_c[0];
tokens_cnts[calc_index(num_experts+1, i + 1, tid)] = local_c[1];
tokens_cnts[calc_index(num_experts+1, i + 2, tid)] = local_c[2];
tokens_cnts[calc_index(num_experts+1, i + 3, tid)] = local_c[3];
tokens_cnts[calc_index(num_experts+1, i + 4, tid)] = local_c[4];
tokens_cnts[calc_index(num_experts+1, i + 5, tid)] = local_c[5];
tokens_cnts[calc_index(num_experts+1, i + 6, tid)] = local_c[6];
tokens_cnts[calc_index(num_experts+1, i + 7, tid)] = local_c[7];
} }
} }
#else
// __syncthreads(); // TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future heuristic
if(tid == 0)
{ {
cumsum[0] = 0; if(tid < num_experts)
for(int i = 1; i <= num_experts; ++i) tokens_cnts[calc_index(num_experts+1, 0, tid)] = 0;
for(int i = 0; i < num_experts; i+=8) {
index_t local_c[8];
#pragma unroll
for(int j = 0; j < 8; j++) {
local_c[j] = tokens_cnts[calc_index(num_experts+1, tid+1, i+j)];
}
#pragma unroll
for(int j = 0; j < 8; j++) {
wave_cumsum<int, 64>(local_c[j]);
}
#pragma unroll
for(int j = 0; j < 8; j++) {
tokens_cnts[calc_index(num_experts+1, tid+1, i+j)] = local_c[j];
}
}
}
#endif
__syncthreads();
if constexpr (Problem::ExpertTile == 0) {
if(tid == 0)
{ {
auto current_units = [&]() { cumsum[0] = 0;
index_t x_ = tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)] + for(int i = 1; i <= num_experts; ++i)
unit_size_mdiv.divisor - 1; {
index_t y_ = unit_size_mdiv.div(x_); auto current_units = [&]() {
return max(y_, 1) * unit_size_mdiv.divisor; index_t x_ = tokens_cnts[calc_index(num_experts+1, blockDim.x, i - 1)] +
}(); unit_size_mdiv.divisor - 1;
cumsum[i] = cumsum[i - 1] + current_units; index_t y_ = unit_size_mdiv.div(x_);
return max(y_, 1) * unit_size_mdiv.divisor;
}();
cumsum[i] = cumsum[i - 1] + current_units;
}
*p_total_tokens_post_pad = cumsum[num_experts];
}
} else {
// TODO: we have out-of-bound read here. But result is still OK (will ignore tid >= expert)
// for simplicity, not check experts here.
int local_cnt = tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)];
int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
int local_cumsum = padded_tokens_per_expert;
wave_cumsum<int, 64>(local_cumsum);
if(tid == (num_experts - 1)) {
cumsum[0] = 0;
*p_total_tokens_post_pad = local_cumsum;
}
if(tid < num_experts) {
cumsum[tid + 1] = local_cumsum;
} }
*p_total_tokens_post_pad = cumsum[num_experts];
} }
__syncthreads(); __syncthreads();
if(tid < num_experts) if(tid < num_experts)
{ {
for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size_mdiv.divisor) int e_start = cumsum[tid];
int e_end = cumsum[tid + 1];
for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
{ {
p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid; p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
} }
...@@ -238,8 +384,8 @@ struct MoeSortingKernel ...@@ -238,8 +384,8 @@ struct MoeSortingKernel
for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
{ {
index_t expert_id = topk_id[i]; index_t expert_id = topk_id[i];
index_t rank_post_pad = index_t local_cnt = tokens_cnts[calc_index(num_experts+1, tid, expert_id)];
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id]; index_t rank_post_pad = local_cnt + cumsum[expert_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t curr_token_id, curr_topk_id; uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id); topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
...@@ -247,27 +393,54 @@ struct MoeSortingKernel ...@@ -247,27 +393,54 @@ struct MoeSortingKernel
#else #else
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i); p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
#endif #endif
p_sorted_weights[rank_post_pad] = weights[i]; p_sorted_weights[rank_post_pad] = weights[i];
++tokens_cnts[calc_index(num_experts, tid, expert_id)]; tokens_cnts[calc_index(num_experts+1, tid, expert_id)] = local_cnt+1;
} }
const index_t prefill_token = topk_mdiv.div(numel); if constexpr (Problem::ExpertTile == 0) {
if(tid < num_experts) const index_t prefill_token = topk_mdiv.div(numel);
{ if(tid < num_experts)
index_t expert_offset =
cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)];
while(expert_offset < cumsum[tid + 1])
{ {
index_t expert_offset =
cumsum[tid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, tid)];
index_t expert_end = cumsum[tid + 1];
while(expert_offset < expert_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID #if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[expert_offset] = p_sorted_token_ids[expert_offset] =
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor); MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
#else #else
p_sorted_token_ids[expert_offset] = prefill_token; p_sorted_token_ids[expert_offset] = prefill_token;
#endif #endif
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0); p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset++; expert_offset++;
}
} }
} }
else {
const index_t prefill_token = topk_mdiv.div(numel);
// TODO: only support expert-tile like 8, 16, 32
static constexpr index_t experts_per_wave = warpSize / Problem::ExpertTile;
{
index_t eid = tid / experts_per_wave;
index_t expert_offset =
cumsum[eid] + tokens_cnts[calc_index(num_experts+1, blockDim.x, eid)] + tid % experts_per_wave;
index_t expert_end = cumsum[eid + 1];
if(eid < num_experts) {
while(expert_offset < expert_end)
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[expert_offset] =
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
#else
p_sorted_token_ids[expert_offset] = prefill_token;
#endif
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset+=experts_per_wave;
}
}
}
}
} }
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
......
...@@ -9,15 +9,20 @@ ...@@ -9,15 +9,20 @@
namespace ck_tile { namespace ck_tile {
template <typename IndexType_, typename WeightType_, index_t InternalLoadUnroll_> template <typename IndexType_,
typename WeightType_,
index_t InternalLoadUnroll_,
index_t ExpertTile_ = 0>
struct MoeSortingProblem struct MoeSortingProblem
{ {
// TODO: this kernel only support warp per row // TODO: this kernel only support warp per row
using WeightType = remove_cvref_t<WeightType_>; using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>; using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t WarpSize = get_warp_size(); static constexpr index_t WarpSize = get_warp_size();
static constexpr index_t WarpsPerBlock = 1; static constexpr index_t WarpsPerBlock = 1;
static constexpr index_t InternalLoadUnroll = InternalLoadUnroll_; static constexpr index_t InternalLoadUnroll =
InternalLoadUnroll_; // TODO: need better design(like tile size)
static constexpr index_t ExpertTile = ExpertTile_; // TODO: only used in store out
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment