Unverified Commit 34f7564d authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

[NVIDIA] Fix wrong symmetric sizes for fp4 cases (#12640)

parent 1cfbbc42
......@@ -1602,8 +1602,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
x_sf = nvfp4_block_scale_interleave(x_sf)
with use_symmetric_memory(get_tp_group()) as sm:
# The x might be packed in the case of fp4. So, use the output dim of the
# weight of the second GEMM.
symm_output = torch.empty(
x.shape[0], x.shape[1], dtype=output_dtype, device=x.device
x.shape[0],
layer.w2_weight.shape[1],
dtype=output_dtype,
device=x.device,
)
sm.tag(symm_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment