Unverified Commit a5354b3e authored by Tyler Michael Smith's avatar Tyler Michael Smith Committed by GitHub
Browse files

[Bugfix][WideEP] Apply TP Attn + EP MoE fix to other models (#24982)


Signed-off-by: default avatarTyler Michael Smith <tlrmchlsmth@gmail.com>
parent f9df8b4a
......@@ -17,7 +17,8 @@ from vllm.config import (CacheConfig, ModelConfig, SpeculativeConfig,
VllmConfig, get_current_vllm_config)
from vllm.distributed import (divide, get_ep_group, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fla.ops import (
......@@ -47,6 +48,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
......@@ -69,14 +71,13 @@ KVCache = tuple[torch.Tensor, torch.Tensor]
class Qwen3NextSparseMoeBlock(nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
parallel_config = vllm_config.parallel_config
quant_config = vllm_config.quant_config
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_group = get_ep_group().device_group
......@@ -84,6 +85,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
self.ep_size = self.ep_group.size()
self.n_routed_experts = config.num_experts
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
......@@ -92,7 +95,7 @@ class Qwen3NextSparseMoeBlock(nn.Module):
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.enable_eplb = parallel_config.enable_eplb
self.n_logical_experts = self.n_routed_experts
self.n_redundant_experts = eplb_config.num_redundant_experts
......@@ -114,7 +117,8 @@ class Qwen3NextSparseMoeBlock(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.experts",
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
num_redundant_experts=self.n_redundant_experts,
is_sequence_parallel=self.is_sequence_parallel)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
......@@ -141,9 +145,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
hidden_dim = hidden_states.shape[-1]
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
shared_output = None
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
......@@ -158,7 +165,12 @@ class Qwen3NextSparseMoeBlock(nn.Module):
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0)
final_hidden_states = final_hidden_states[:num_tokens]
elif self.tp_size > 1:
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501
final_hidden_states)
......@@ -719,17 +731,17 @@ class Qwen3NextDecoderLayer(nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
vllm_config: VllmConfig,
layer_type: str,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
speculative_config: Optional[SpeculativeConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super().__init__()
self.config = config
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
speculative_config = vllm_config.speculative_config
self.layer_type = layer_type
self.layer_idx = extract_layer_index(prefix)
......@@ -759,10 +771,8 @@ class Qwen3NextDecoderLayer(nn.Module):
config.num_experts > 0 and
(self.layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen3NextSparseMoeBlock(
config=config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = Qwen3NextMLP(
......@@ -783,14 +793,14 @@ class Qwen3NextDecoderLayer(nn.Module):
torch.zeros(
1,
1,
self.config.hidden_size,
config.hidden_size,
dtype=config.torch_dtype,
), )
self.ffn_layer_scale = torch.nn.Parameter(
torch.zeros(
1,
1,
self.config.hidden_size,
config.hidden_size,
dtype=config.torch_dtype,
), )
......@@ -858,13 +868,8 @@ class Qwen3NextModel(nn.Module):
super().__init__()
config: Qwen3NextConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
lora_config = vllm_config.lora_config
speculative_config = vllm_config.speculative_config
enable_eplb = parallel_config.enable_eplb
eplb_config = parallel_config.eplb_config
self.num_redundant_experts = eplb_config.num_redundant_experts
......@@ -881,14 +886,9 @@ class Qwen3NextModel(nn.Module):
def get_layer(prefix: str):
return Qwen3NextDecoderLayer(
config,
vllm_config,
layer_type=config.layer_types[extract_layer_index(prefix)],
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
speculative_config=speculative_config,
prefix=prefix,
enable_eplb=enable_eplb,
)
self.start_layer, self.end_layer, self.layers = make_layers(
......
......@@ -38,7 +38,6 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
super().__init__()
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
config: Qwen3NextConfig = model_config.hf_config
......@@ -68,11 +67,8 @@ class Qwen3NextMultiTokenPredictor(nn.Module):
self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer(
config,
vllm_config,
layer_type="full_attention",
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
prefix=f'{prefix}.layers.{idx}',
) for idx in range(self.num_mtp_layers))
......
......@@ -13,11 +13,14 @@ from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
from vllm.utils import (cdiv, direct_register_custom_op,
get_cuda_view_from_cpu_tensor, is_pin_memory_available,
is_uva_available)
logger = init_logger(__name__)
......@@ -743,3 +746,46 @@ def get_model_hidden_size(hf_config: PretrainedConfig) -> int:
return hf_config.hidden_size
text_config = hf_config.get_text_config()
return text_config.hidden_size
# Chunk x along the num_tokens axis for sequence parallelism
# NOTE: This is wrapped in a torch custom op to work around the following issue:
# The output tensor can have a sequence length 0 at small input sequence lengths
# even though we explicitly pad to avoid this.
def sequence_parallel_chunk(x: torch.Tensor) -> torch.Tensor:
return torch.ops.vllm.sequence_parallel_chunk_impl(x)
def sequence_parallel_chunk_impl(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
# all_gather needs the sequence length to be divisible by tp_size
seq_len = x.size(0)
remainder = seq_len % tp_size
if remainder != 0:
pad_len = tp_size - remainder
y = nn.functional.pad(x, (0, 0, 0, pad_len))
else:
y = x
chunk = y.shape[0] // tp_size
start = tp_rank * chunk
return torch.narrow(y, 0, start, chunk)
def sequence_parallel_chunk_impl_fake(x: torch.Tensor) -> torch.Tensor:
tp_size = get_tensor_model_parallel_world_size()
seq_len = cdiv(x.size(0), tp_size)
shape = list(x.shape)
shape[0] = seq_len
out = torch.empty(shape, dtype=x.dtype, device=x.device)
return out
direct_register_custom_op(
op_name="sequence_parallel_chunk_impl",
op_func=sequence_parallel_chunk_impl,
fake_impl=sequence_parallel_chunk_impl_fake,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment