"vscode:/vscode.git/clone" did not exist on "44902202583f8a13cd0c6bf58c9bdc526d5a1ca2"
Unverified Commit 11553c1a authored by libra's avatar libra Committed by GitHub
Browse files

Add pipeline parallelism for Qwen2 and Qwen3 Model (#6250)

parent 01dd39ba
...@@ -15,12 +15,14 @@ ...@@ -15,12 +15,14 @@
# Adapted from llama2.py # Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model. # Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Tuple import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
...@@ -36,11 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType ...@@ -36,11 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import ( from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
kv_cache_scales_loader, kv_cache_scales_loader,
...@@ -50,6 +53,9 @@ from sglang.srt.utils import add_prefix, make_layers ...@@ -50,6 +53,9 @@ from sglang.srt.utils import add_prefix, make_layers
Qwen2Config = None Qwen2Config = None
logger = logging.getLogger(__name__)
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
def __init__( def __init__(
self, self,
...@@ -245,15 +251,21 @@ class Qwen2Model(nn.Module): ...@@ -245,15 +251,21 @@ class Qwen2Model(nn.Module):
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix), prefix=add_prefix("embed_tokens", prefix),
) )
else:
self.embed_tokens = PPMissingLayer()
# Use the provided decoder layer type or default to Qwen2DecoderLayer # Use the provided decoder layer type or default to Qwen2DecoderLayer
decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
self.layers = make_layers( self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda idx, prefix: decoder_layer_type( lambda idx, prefix: decoder_layer_type(
layer_id=idx, layer_id=idx,
...@@ -261,9 +273,14 @@ class Qwen2Model(nn.Module): ...@@ -261,9 +273,14 @@ class Qwen2Model(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
), ),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix), prefix=add_prefix("layers", prefix),
) )
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
if hasattr(self.config, "scale_emb"): if hasattr(self.config, "scale_emb"):
...@@ -280,13 +297,20 @@ class Qwen2Model(nn.Module): ...@@ -280,13 +297,20 @@ class Qwen2Model(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
else: else:
hidden_states = input_embeds hidden_states = input_embeds
residual = None residual = None
for i in range(len(self.layers)): else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
...@@ -294,6 +318,14 @@ class Qwen2Model(nn.Module): ...@@ -294,6 +318,14 @@ class Qwen2Model(nn.Module):
forward_batch, forward_batch,
residual, residual,
) )
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -348,6 +380,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -348,6 +380,7 @@ class Qwen2ForCausalLM(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.pp_group = get_pp_group()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2Model( self.model = Qwen2Model(
...@@ -379,14 +412,33 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -379,14 +412,33 @@ class Qwen2ForCausalLM(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
get_embedding: bool = False, get_embedding: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
if self.pp_group.is_last_rank:
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
else:
return hidden_states
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -400,6 +452,17 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -400,6 +452,17 @@ class Qwen2ForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
continue continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
...@@ -426,9 +489,15 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -426,9 +489,15 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name in params_dict.keys():
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
def get_embed_and_head(self): def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight return self.model.embed_tokens.weight, self.lm_head.weight
......
...@@ -16,9 +16,10 @@ ...@@ -16,9 +16,10 @@
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import logging
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -26,6 +27,7 @@ from torch import nn ...@@ -26,6 +27,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
...@@ -52,18 +54,21 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE ...@@ -52,18 +54,21 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
expert_distribution_recorder = ExpertDistributionRecorder() expert_distribution_recorder = ExpertDistributionRecorder()
logger = logging.getLogger(__name__)
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):
def __init__( def __init__(
...@@ -535,16 +540,21 @@ class Qwen2MoeModel(nn.Module): ...@@ -535,16 +540,21 @@ class Qwen2MoeModel(nn.Module):
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not global_server_args_dict["enable_dp_attention"],
prefix=add_prefix("embed_tokens", prefix), prefix=add_prefix("embed_tokens", prefix),
) )
else:
self.embed_tokens = PPMissingLayer()
# Use the provided decoder layer type or default to Qwen2MoeDecoderLayer # Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer decoder_layer_type = decoder_layer_type or Qwen2MoeDecoderLayer
self.layers = make_layers( self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda idx, prefix: decoder_layer_type( lambda idx, prefix: decoder_layer_type(
layer_id=idx, layer_id=idx,
...@@ -552,9 +562,14 @@ class Qwen2MoeModel(nn.Module): ...@@ -552,9 +562,14 @@ class Qwen2MoeModel(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix, prefix=prefix,
), ),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix), prefix=add_prefix("layers", prefix),
) )
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
def forward( def forward(
self, self,
...@@ -562,18 +577,33 @@ class Qwen2MoeModel(nn.Module): ...@@ -562,18 +577,33 @@ class Qwen2MoeModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
else: else:
hidden_states = input_embeds hidden_states = input_embeds
residual = None residual = None
for i in range(len(self.layers)): else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
expert_distribution_recorder.set_current_layer(i) expert_distribution_recorder.set_current_layer(i)
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual
) )
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if hidden_states.shape[0] != 0: if hidden_states.shape[0] != 0:
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -589,6 +619,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -589,6 +619,7 @@ class Qwen2MoeForCausalLM(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.pp_group = get_pp_group()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2MoeModel( self.model = Qwen2MoeModel(
...@@ -609,11 +640,29 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -609,11 +640,29 @@ class Qwen2MoeForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput: pp_proxy_tensors: Optional[PPProxyTensors] = None,
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) ) -> torch.Tensor:
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
if self.pp_group.is_last_rank:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
else:
return hidden_states
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -636,6 +685,16 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -636,6 +685,16 @@ class Qwen2MoeForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
...@@ -684,11 +743,14 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -684,11 +743,14 @@ class Qwen2MoeForCausalLM(nn.Module):
if name not in params_dict: if name not in params_dict:
continue continue
if name in params_dict.keys():
param = params_dict[name] param = params_dict[name]
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
EntryClass = Qwen2MoeForCausalLM EntryClass = Qwen2MoeForCausalLM
# Adapted from qwen2.py # Adapted from qwen2.py
import logging
from functools import partial from functools import partial
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
...@@ -7,6 +8,7 @@ import torch ...@@ -7,6 +8,7 @@ import torch
from torch import nn from torch import nn
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
...@@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType ...@@ -19,8 +21,9 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
...@@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix ...@@ -28,6 +31,8 @@ from sglang.srt.utils import add_prefix
Qwen3Config = None Qwen3Config = None
logger = logging.getLogger(__name__)
class Qwen3Attention(nn.Module): class Qwen3Attention(nn.Module):
def __init__( def __init__(
...@@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -238,6 +243,7 @@ class Qwen3ForCausalLM(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.pp_group = get_pp_group()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen3Model( self.model = Qwen3Model(
...@@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -266,14 +272,33 @@ class Qwen3ForCausalLM(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
get_embedding: bool = False, get_embedding: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
if self.pp_group.is_last_rank:
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
else: else:
return self.pooler(hidden_states, forward_batch) return self.pooler(hidden_states, forward_batch)
else:
return hidden_states
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -287,6 +312,17 @@ class Qwen3ForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if "rotary_emb.inv_freq" in name or "projector" in name: if "rotary_emb.inv_freq" in name or "projector" in name:
continue continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
...@@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module): ...@@ -313,9 +349,15 @@ class Qwen3ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name in params_dict.keys():
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
def get_embed_and_head(self): def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight return self.model.embed_tokens.weight, self.lm_head.weight
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" """Inference-only Qwen3MoE model compatible with HuggingFace weights."""
import logging
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from functools import partial from functools import partial
...@@ -28,6 +29,7 @@ from torch import nn ...@@ -28,6 +29,7 @@ from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
...@@ -57,12 +59,13 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE ...@@ -57,12 +59,13 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.models.qwen2_moe import Qwen2MoeModel
...@@ -70,6 +73,8 @@ from sglang.srt.utils import add_prefix ...@@ -70,6 +73,8 @@ from sglang.srt.utils import add_prefix
Qwen3MoeConfig = None Qwen3MoeConfig = None
logger = logging.getLogger(__name__)
class Qwen3MoeSparseMoeBlock(nn.Module): class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__( def __init__(
...@@ -516,6 +521,7 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -516,6 +521,7 @@ class Qwen3MoeForCausalLM(nn.Module):
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.pp_group = get_pp_group()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen3MoeModel( self.model = Qwen3MoeModel(
...@@ -536,11 +542,30 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -536,11 +542,30 @@ class Qwen3MoeForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput: pp_proxy_tensors: Optional[PPProxyTensors] = None,
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) ) -> torch.Tensor:
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
if self.pp_group.is_last_rank:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
else:
return hidden_states
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
...@@ -563,6 +588,17 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -563,6 +588,17 @@ class Qwen3MoeForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in weights: for name, loaded_weight in weights:
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
...@@ -611,11 +647,14 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -611,11 +647,14 @@ class Qwen3MoeForCausalLM(nn.Module):
if name not in params_dict: if name not in params_dict:
continue continue
if name in params_dict.keys():
param = params_dict[name] param = params_dict[name]
weight_loader = getattr( weight_loader = getattr(
param, "weight_loader", default_weight_loader param, "weight_loader", default_weight_loader
) )
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
EntryClass = Qwen3MoeForCausalLM EntryClass = Qwen3MoeForCausalLM
""" """
Usage: Usage:
python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k
python3 -m unittest test_pp_single_node.TestQwenPPAccuracy.test_pp_consistency
python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs
""" """
...@@ -61,6 +62,60 @@ class TestPPAccuracy(unittest.TestCase): ...@@ -61,6 +62,60 @@ class TestPPAccuracy(unittest.TestCase):
time.sleep(5) time.sleep(5)
class TestQwenPPAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.base_url = "http://127.0.0.1:23334" # different ports to avoid conflicts
cls.model_name = "Qwen/Qwen3-8B" # replace with your Qwen Model if needed
def run_gsm8k_test(self, pp_size):
process = popen_launch_server(
self.model_name,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--pp-size",
pp_size,
"--chunked-prefill-size",
256,
],
)
try:
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
time.sleep(5)
return metrics
finally:
kill_process_tree(process.pid)
def test_baseline_accuracy(self):
metrics = self.run_gsm8k_test(pp_size=1)
print(f"[Qwen Baseline] {metrics=}")
self.assertGreater(metrics["accuracy"], 0.74)
def test_pp_consistency(self):
baseline = self.run_gsm8k_test(pp_size=1)
pp_metrics = self.run_gsm8k_test(pp_size=2)
print(f"[Qwen PP Comparison] Baseline: {baseline} | PP: {pp_metrics}")
self.assertAlmostEqual(
pp_metrics["accuracy"],
baseline["accuracy"],
delta=0.01,
msg=f"PP accuracy exceeds 1% (baseline: {baseline['accuracy']}, pp: {pp_metrics['accuracy']})",
)
class TestFixedBugs(unittest.TestCase): class TestFixedBugs(unittest.TestCase):
def test_chunked_prefill_with_small_bs(self): def test_chunked_prefill_with_small_bs(self):
model = DEFAULT_MODEL_NAME_FOR_TEST model = DEFAULT_MODEL_NAME_FOR_TEST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment