"vscode:/vscode.git/clone" did not exist on "43c5792592d9beb02eea57730ce5a4647dc0c838"
Unverified Commit 6fbec8ed authored by Jakub Zakrzewski's avatar Jakub Zakrzewski Committed by GitHub
Browse files

[Bugfix][Kernel] nvfp4 cutlass MoE: fix nvfp4 experts quant out-of-bounds read...


[Bugfix][Kernel] nvfp4 cutlass MoE: fix nvfp4 experts quant out-of-bounds read for expert counts not divisible by 4 or 16 (#40351)
Signed-off-by: default avatarJakub Zakrzewski <jzakrzewski@nvidia.com>
parent 5544f8c1
...@@ -277,7 +277,9 @@ void quant_impl(void* output, void* output_scale, void* input, ...@@ -277,7 +277,9 @@ void quant_impl(void* output, void* output_scale, void* input,
(totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x); (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
if (blockRepeat > 1) { if (blockRepeat > 1) {
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t); size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
if (n_experts >= 4) { // The shared-memory vectorized offset load only handles full 4-expert
// chunks. Use the scalar specialization for the remainder cases.
if (n_experts >= 4 && n_experts % 4 == 0) {
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false> cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
m_topk, k, reinterpret_cast<T*>(input), m_topk, k, reinterpret_cast<T*>(input),
...@@ -299,7 +301,9 @@ void quant_impl(void* output, void* output_scale, void* input, ...@@ -299,7 +301,9 @@ void quant_impl(void* output, void* output_scale, void* input,
n_experts); n_experts);
} }
} else { } else {
if (n_experts >= 16) { // The low-latency vectorized expert lookup only handles full 16-expert
// chunks. Fall back to the scalar lookup path for the remainder cases.
if (n_experts >= 16 && n_experts % 16 == 0) {
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false> cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input), m_topk, k, reinterpret_cast<T*>(input),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment