Unverified Commit 1fb632fd authored by Lain's avatar Lain Committed by GitHub
Browse files

[Perf] Improve fp8 quant in mla; replace ReduceSum with ReduceScatterSum (#29795)


Signed-off-by: default avatarSiyuan Fu <siyuanf@nvidia.com>
parent 6af70e11
...@@ -225,7 +225,7 @@ class CudaCommunicator(DeviceCommunicatorBase): ...@@ -225,7 +225,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
output_shape, dtype=input_tensor.dtype, device=input_tensor.device output_shape, dtype=input_tensor.dtype, device=input_tensor.device
) )
if sizes is not None: if sizes is not None and sizes.count(sizes[0]) != len(sizes):
pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes) pynccl_comm.reduce_scatterv(output, input_tensor, sizes=sizes)
else: else:
pynccl_comm.reduce_scatter(output, input_tensor) pynccl_comm.reduce_scatter(output, input_tensor)
......
...@@ -2037,21 +2037,30 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ...@@ -2037,21 +2037,30 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
if fp8_attention: if fp8_attention:
ql_nope_shape = decode_ql_nope.shape ql_nope_shape = decode_ql_nope.shape
decode_ql_nope, _ = ops.scaled_fp8_quant(
decode_ql_nope.reshape(
[ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]]
),
layer._q_scale,
)
decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
q_pe_shape = decode_q_pe.shape q_pe_shape = decode_q_pe.shape
decode_q_pe, _ = ops.scaled_fp8_quant( assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]), assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
layer._q_scale, decode_q_shape = (
ql_nope_shape[0],
ql_nope_shape[1],
ql_nope_shape[2] + q_pe_shape[2],
)
# Using empty and copy since torch.cat introduces significant overhead.
decode_q0 = torch.empty(
decode_q_shape,
device=decode_ql_nope.device,
dtype=decode_ql_nope.dtype,
) )
decode_q_pe = decode_q_pe.reshape(q_pe_shape) decode_q0[..., : ql_nope_shape[2]].copy_(decode_ql_nope)
decode_q0[..., ql_nope_shape[2] :].copy_(decode_q_pe)
decode_q = (decode_ql_nope, decode_q_pe) decode_q, _ = ops.scaled_fp8_quant(
decode_q0.view(decode_q_shape[0], -1),
layer._q_scale,
)
decode_q = decode_q.view(decode_q_shape)
else:
decode_q = (decode_ql_nope, decode_q_pe)
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
assert not fp8_attention, "DCP not support fp8 kvcache now." assert not fp8_attention, "DCP not support fp8 kvcache now."
# concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment