Commit 0385cded authored by dummycoderfe's avatar dummycoderfe
Browse files

[Ck_tile] moe set zero run ok, add size check and fix ref check

parent bfe0120a
......@@ -96,7 +96,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> expert_ids_host({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
ck_tile::HostTensor<IndexType> moe_buf_host({moe_buf_size}, {1});
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size}, {1});
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
......@@ -127,7 +127,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
unit_size,
experts,
topk,
moe_buf_size};
static_cast<ck_tile::index_t>(moe_buf_host.get_element_space_size_in_bytes())};
ck_tile::stream_config sc{nullptr,
true,
......@@ -162,7 +162,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> expert_ids_ref({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> moe_buf_ref({moe_buf_size}, {1});
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size}, {1});
moe_buf_ref.SetZero();
int32_t total_tokens_post_pad = 0;
......
......@@ -24,6 +24,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
printf("lds size exceed, only support experts <127 \n");
return -1;
}
if(a.moe_buf_set_bytes % 16)
{
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_set_bytes);
return -1;
}
using index_t = ck_tile::index_t;
using ms_weight_type = float;
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
......
......@@ -107,7 +107,7 @@ struct MoeSortingKernel
const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x;
if(offset < buf_bytes / 16)
{
buf[offset] = uint8x16_t(0);
buf[offset] = uint8x16_t{0};
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment