Unverified Commit 400d3b97 authored by kk's avatar kk Committed by GitHub
Browse files

Fix run time error in dsv3-fp8 model on mi35x (#10104)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
Co-authored-by: default avatarHaiShaw <hixiao@gmail.com>
Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent 37d83c6e
......@@ -249,7 +249,11 @@ class DeepseekV2MLP(nn.Module):
if (self.tp_size == 1) and x.shape[0] == 0:
return x
if gemm_output_zero_allocator != None and x.shape[0] <= 256:
if (
gemm_output_zero_allocator is not None
and x.shape[0] <= 256
and self.gate_up_proj.weight.dtype == torch.uint8
):
y = gemm_output_zero_allocator.allocate(
x.shape[0] * self.gate_up_proj.output_size_per_partition
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment