Commit 3c0e74be authored by wujl5's avatar wujl5 Committed by wangmin6
Browse files

[fix] MLP传入量化参数,MLP allGather通讯优化

parent c637d1aa
...@@ -183,6 +183,44 @@ class DeepseekAttention(nn.Module): ...@@ -183,6 +183,44 @@ class DeepseekAttention(nn.Module):
return output return output
def eff_2d_iqis_all_gather(
iqis: tuple[torch.Tensor, torch.Tensor],
tp_size: int | None = None,
tp_rank: int | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert iqis is not None
iq_tensor, is_tensor = iqis
assert isinstance(iq_tensor, torch.Tensor)
assert isinstance(is_tensor, torch.Tensor)
assert iq_tensor.dtype == torch.int8, f"iq_tensor dtype is {iq_tensor.dtype}"
assert is_tensor.dtype == torch.float32, f"is_tensor dtype is {is_tensor.dtype}"
assert iq_tensor.dim() == 2
assert is_tensor.dim() == 2
m_local, n = iq_tensor.shape
assert is_tensor.shape[0] == m_local, f"{is_tensor.shape[0]} != {iq_tensor.shape[0]}"
assert is_tensor.shape[1] == 1, f"is_tensor dim 1 ={is_tensor.shape[1]}"
iq_int8_2d = iq_tensor.view(torch.int8)
is_int8_2d = is_tensor.view(torch.int8)
combined_2d = torch.cat([iq_int8_2d, is_int8_2d], dim=1) # [m_local, n + 4]
if not combined_2d.is_contiguous():
combined_2d = combined_2d.contiguous()
combined_gathered = tensor_model_parallel_all_gather(combined_2d, dim=0)
split_idx = n
iq_gathered_int8 = combined_gathered[:, :split_idx].contiguous()
is_gathered_int8 = combined_gathered[:, split_idx:].contiguous()
iq_gathered = iq_gathered_int8.view(torch.int8)
assert iq_gathered.shape[0] == m_local * tp_size, f"iq_gathered dim0= {iq_gathered.shape[0]}, expected {m_local * tp_size}"
# is_gathered_int8 should be [m_local*tp_size, 4]
assert is_gathered_int8.shape[0] == m_local * tp_size, f"is_gathered_int8 dim0= {is_gathered_int8.shape[0]}, expected {m_local * tp_size}"
assert is_gathered_int8.shape[1] == 4, f"is_gathered_int8 dim1= {is_gathered_int8.shape[1]}"
is_gathered = is_gathered_int8.view(torch.float32)
return (iq_gathered, is_gathered)
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
self, self,
...@@ -232,13 +270,14 @@ class DeepseekV2MLP(nn.Module): ...@@ -232,13 +270,14 @@ class DeepseekV2MLP(nn.Module):
enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP# and not get_forward_context().draft_model enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP# and not get_forward_context().draft_model
if enable_mla_cp: if enable_mla_cp:
if iqis is not None and iqis[0] is not None and iqis[1] is not None: if iqis is not None and iqis[0] is not None and iqis[1] is not None:
if False:
i_q_gahter = tensor_model_parallel_all_gather(iqis[0].contiguous(), 0) i_q_gahter = tensor_model_parallel_all_gather(iqis[0].contiguous(), 0)
i_s_gather = tensor_model_parallel_all_gather(iqis[1].contiguous(), 0) i_s_gather = tensor_model_parallel_all_gather(iqis[1].contiguous(), 0)
iqis = (i_q_gahter, i_s_gather) iqis = (i_q_gahter, i_s_gather)
else: else:
x = tensor_model_parallel_all_gather( iqis = eff_2d_iqis_all_gather(iqis, tp_size=self.tp_size, tp_rank=get_tensor_model_parallel_rank())
x.contiguous(), 0 else:
) x = tensor_model_parallel_all_gather(x.contiguous(), 0)
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis) gate_up, _ = self.gate_up_proj(x, iqis=iqis)
...@@ -1233,7 +1272,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1233,7 +1272,7 @@ class DeepseekV2DecoderLayer(nn.Module):
update_input=update_hs update_input=update_hs
) )
new_resi = residual new_resi = residual
if skip_moe_large_batch_size: if skip_moe_large_batch_size and isinstance(self.mlp, DeepseekV2MoE):
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
else: else:
hidden_states = self.mlp(hidden_states, iqis=(_i_q, _i_s)) hidden_states = self.mlp(hidden_states, iqis=(_i_q, _i_s))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment