Commit 26a7a33b authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by simon-mo
Browse files

[Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)


Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: default avatarsimon-mo <simon.mo@hey.com>
parent 5aa5811a
...@@ -17,7 +17,8 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig, ...@@ -17,7 +17,8 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
VllmConfig, get_current_vllm_config) VllmConfig, get_current_vllm_config)
from vllm.distributed import (divide, get_ep_group, get_pp_group, from vllm.distributed import (divide, get_ep_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fla.ops import ( from vllm.model_executor.layers.fla.ops import (
...@@ -47,6 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -47,6 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader) default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -69,14 +71,13 @@ KVCache = tuple[torch.Tensor, torch.Tensor] ...@@ -69,14 +71,13 @@ KVCache = tuple[torch.Tensor, torch.Tensor]
class Qwen3NextSparseMoeBlock(nn.Module): class Qwen3NextSparseMoeBlock(nn.Module):
def __init__( def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
self,
config: Qwen3NextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group self.ep_group = get_ep_group().device_group
...@@ -84,6 +85,8 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -84,6 +85,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
self.ep_size = self.ep_group.size() self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts: if self.tp_size > config.num_experts:
raise ValueError( raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than " f"Tensor parallel size {self.tp_size} is greater than "
...@@ -92,7 +95,7 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -92,7 +95,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
# Load balancing settings. # Load balancing settings.
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts self.n_redundant_experts = eplb_config.num_redundant_experts
...@@ -114,7 +117,8 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -114,7 +117,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts) num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts, config.num_experts,
...@@ -141,9 +145,12 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -141,9 +145,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape. # NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1] num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
shared_output = None shared_output = None
if self.shared_expert is not None: if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states) shared_output = self.shared_expert(hidden_states)
...@@ -158,7 +165,12 @@ class Qwen3NextSparseMoeBlock(nn.Module): ...@@ -158,7 +165,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
if shared_output is not None: if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states) final_hidden_states)
...@@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module): ...@@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen3NextConfig, vllm_config: VllmConfig,
layer_type: str, layer_type: str,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prefix: str = "", prefix: str = "",
enable_eplb: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config
self.layer_type = layer_type self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix) self.layer_idx = extract_layer_index(prefix)
...@@ -759,10 +771,8 @@ class Qwen3NextDecoderLayer(nn.Module): ...@@ -759,10 +771,8 @@ class Qwen3NextDecoderLayer(nn.Module):
config.num_experts > 0 and config.num_experts > 0 and
(self.layer_idx + 1) % config.decoder_sparse_step == 0): (self.layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3NextSparseMoeBlock( self.mlp = Qwen3NextSparseMoeBlock(
config=config, vllm_config=vllm_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
) )
else: else:
self.mlp = Qwen3NextMLP( self.mlp = Qwen3NextMLP(
...@@ -783,14 +793,14 @@ class Qwen3NextDecoderLayer(nn.Module): ...@@ -783,14 +793,14 @@ class Qwen3NextDecoderLayer(nn.Module):
torch.zeros( torch.zeros(
1, 1,
1, 1,
self.config.hidden_size, config.hidden_size,
dtype=config.torch_dtype, dtype=config.torch_dtype,
), ) ), )
self.ffn_layer_scale = torch.nn.Parameter( self.ffn_layer_scale = torch.nn.Parameter(
torch.zeros( torch.zeros(
1, 1,
1, 1,
self.config.hidden_size, config.hidden_size,
dtype=config.torch_dtype, dtype=config.torch_dtype,
), ) ), )
...@@ -858,13 +868,8 @@ class Qwen3NextModel(nn.Module): ...@@ -858,13 +868,8 @@ class Qwen3NextModel(nn.Module):
super().__init__() super().__init__()
config: Qwen3NextConfig = vllm_config.model_config.hf_config config: Qwen3NextConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
speculative_config = vllm_config.speculative_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts self.num_redundant_experts = eplb_config.num_redundant_experts
...@@ -881,14 +886,9 @@ class Qwen3NextModel(nn.Module): ...@@ -881,14 +886,9 @@ class Qwen3NextModel(nn.Module):
def get_layer(prefix: str): def get_layer(prefix: str):
return Qwen3NextDecoderLayer( return Qwen3NextDecoderLayer(
config, vllm_config,
layer_type=config.layer_types[extract_layer_index(prefix)], layer_type=config.layer_types[extract_layer_index(prefix)],
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
prefix=prefix, prefix=prefix,
enable_eplb=enable_eplb,
) )
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
......
...@@ -38,7 +38,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module): ...@@ -38,7 +38,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
super().__init__() super().__init__()
model_config = vllm_config.model_config model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config config: Qwen3NextConfig = model_config.hf_config
...@@ -68,11 +67,8 @@ class Qwen3NextMultiTokenPredictor(nn.Module): ...@@ -68,11 +67,8 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer( Qwen3NextDecoderLayer(
config, vllm_config,
layer_type="full_attention", layer_type="full_attention",
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.layers.{idx}', prefix=f'{prefix}.layers.{idx}',
) for idx in range(self.num_mtp_layers)) ) for idx in range(self.num_mtp_layers))
......
...@@ -13,11 +13,14 @@ from transformers import PretrainedConfig ...@@ -13,11 +13,14 @@ from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import NestedTensors from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available, from vllm.utils import (cdiv, direct_register_custom_op,
get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available) is_uva_available)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -760,3 +763,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int: ...@@ -760,3 +763,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
return hf_config.hidden_size return hf_config.hidden_size
text_config = hf_config.get_text_config() text_config = hf_config.get_text_config()
return text_config.hidden_size return text_config.hidden_size
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.sequence_parallel_chunk_impl(x)
def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
y = nn.functional.pad(x, (0, 0, 0, pad_len))
else:
y = x
chunk = y.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(y, 0, start, chunk)
def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk_impl",
op_func=sequence_parallel_chunk_impl,
fake_impl=sequence_parallel_chunk_impl_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment