Unverified Commit 4c22897a authored by wxzhoucs's avatar wxzhoucs Committed by GitHub
Browse files

Feature: support qwen and llama4 reducescatter for dp attention padding (#9101)

parent 1bc183c6
...@@ -253,7 +253,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -253,7 +253,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
) )
return lora_output return lora_output
def forward(self, input_: torch.Tensor): def forward(self, input_: torch.Tensor, skip_all_reduce=False):
# duplicate the logic in RowParallelLinear # duplicate the logic in RowParallelLinear
if self.base_layer.input_is_parallel: if self.base_layer.input_is_parallel:
input_parallel = input_ input_parallel = input_
...@@ -270,7 +270,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -270,7 +270,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
if self.set_lora: if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_parallel) output_parallel = self.apply_lora(output_parallel, input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1: if (
self.base_layer.reduce_results
and self.base_layer.tp_size > 1
and not skip_all_reduce
):
output_ = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
else: else:
output_ = output_parallel output_ = output_parallel
......
...@@ -91,10 +91,18 @@ class LlamaMLP(nn.Module): ...@@ -91,10 +91,18 @@ class LlamaMLP(nn.Module):
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x, forward_batch=None): def forward(
self,
x,
forward_batch=None,
use_reduce_scatter: bool = False,
):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(
x,
skip_all_reduce=use_reduce_scatter,
)
return x return x
......
...@@ -131,14 +131,19 @@ class Llama4MoE(nn.Module): ...@@ -131,14 +131,19 @@ class Llama4MoE(nn.Module):
reduce_results=False, # We need to do scatter before reduce reduce_results=False, # We need to do scatter before reduce
) )
def forward(self, hidden_states, forward_batch: ForwardBatch): def forward(
self,
hidden_states,
forward_batch: ForwardBatch,
use_reduce_scatter: bool = False,
):
shared_out, routed_out = self._forward_core( shared_out, routed_out = self._forward_core(
hidden_states, forward_batch.forward_mode hidden_states, forward_batch.forward_mode
) )
out_aD = routed_out + shared_out out_aD = routed_out + shared_out
if self.tp_size > 1: if self.tp_size > 1 and not use_reduce_scatter:
out_aD = tensor_model_parallel_all_reduce(out_aD) out_aD = tensor_model_parallel_all_reduce(out_aD)
return out_aD return out_aD
...@@ -412,6 +417,7 @@ class Llama4DecoderLayer(nn.Module): ...@@ -412,6 +417,7 @@ class Llama4DecoderLayer(nn.Module):
layer_scatter_modes=self.layer_scatter_modes, layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm, input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm, post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
) )
def _is_moe_layer(self, layer_id: int) -> bool: def _is_moe_layer(self, layer_id: int) -> bool:
...@@ -441,8 +447,15 @@ class Llama4DecoderLayer(nn.Module): ...@@ -441,8 +447,15 @@ class Llama4DecoderLayer(nn.Module):
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
# For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
# Fully Connected # Fully Connected
hidden_states = self.feed_forward(hidden_states, forward_batch) hidden_states = self.feed_forward(
hidden_states, forward_batch, use_reduce_scatter
)
hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
......
...@@ -108,10 +108,14 @@ class Qwen2MoeMLP(nn.Module): ...@@ -108,10 +108,14 @@ class Qwen2MoeMLP(nn.Module):
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x): def forward(
self,
x,
use_reduce_scatter: bool = False,
):
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter)
return x return x
...@@ -176,7 +180,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -176,7 +180,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
def forward( def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -194,6 +201,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -194,6 +201,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
final_hidden_states = self.experts(hidden_states, topk_output) final_hidden_states = self.experts(hidden_states, topk_output)
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1 and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
...@@ -368,6 +376,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -368,6 +376,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
layer_scatter_modes=self.layer_scatter_modes, layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm, input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm, post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
) )
def forward( def forward(
...@@ -393,7 +402,12 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -393,7 +402,12 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
hidden_states = self.mlp(hidden_states, forward_batch) # For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
......
...@@ -144,11 +144,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -144,11 +144,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
def forward( def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep(): if not global_server_args_dict["moe_a2a_backend"].is_deepep():
return self.forward_normal(hidden_states) return self.forward_normal(hidden_states, use_reduce_scatter)
else: else:
return self.forward_deepep(hidden_states, forward_batch) return self.forward_deepep(hidden_states, forward_batch)
...@@ -159,7 +162,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -159,7 +162,11 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
if name not in ["correction_bias"] if name not in ["correction_bias"]
] ]
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward_normal(
self,
hidden_states: torch.Tensor,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -167,7 +174,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): ...@@ -167,7 +174,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits) topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output) final_hidden_states = self.experts(hidden_states, topk_output)
if self.tp_size > 1: if self.tp_size > 1 and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
...@@ -521,6 +528,7 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -521,6 +528,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
layer_scatter_modes=self.layer_scatter_modes, layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm, input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm, post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
) )
def forward( def forward(
...@@ -546,7 +554,12 @@ class Qwen3MoeDecoderLayer(nn.Module): ...@@ -546,7 +554,12 @@ class Qwen3MoeDecoderLayer(nn.Module):
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
hidden_states = self.mlp(hidden_states, forward_batch) # For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
hidden_states, residual = self.layer_communicator.postprocess_layer( hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment