Unverified Commit 56a724eb authored by Qubitium-ModelCloud's avatar Qubitium-ModelCloud Committed by GitHub
Browse files

[QUANT] Add GPTQModel Dynamic Quantization + `lm_head` Quantization (#3790)


Signed-off-by: default avatarZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: default avatarZX-ModelCloud <zx@modelcloud.ai>
parent 583d6af7
...@@ -24,7 +24,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -24,7 +24,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers from sglang.srt.utils import add_prefix, make_layers
@torch.jit.script @torch.jit.script
...@@ -70,13 +70,14 @@ class Phi3SmallMLP(nn.Module): ...@@ -70,13 +70,14 @@ class Phi3SmallMLP(nn.Module):
2 * [self.intermediate_size], 2 * [self.intermediate_size],
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.up_proj", prefix=add_prefix("up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
self.intermediate_size, self.intermediate_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
def forward(self, x): def forward(self, x):
...@@ -140,7 +141,7 @@ class Phi3SmallSelfAttention(nn.Module): ...@@ -140,7 +141,7 @@ class Phi3SmallSelfAttention(nn.Module):
self.num_key_value_heads, self.num_key_value_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=add_prefix("qkv_proj", prefix),
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
...@@ -148,7 +149,7 @@ class Phi3SmallSelfAttention(nn.Module): ...@@ -148,7 +149,7 @@ class Phi3SmallSelfAttention(nn.Module):
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=add_prefix("o_proj", prefix),
) )
if getattr(self.config, "rope_scaling", None) is not None: if getattr(self.config, "rope_scaling", None) is not None:
...@@ -201,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module): ...@@ -201,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module):
self.scale, self.scale,
num_kv_heads=self.num_kv_heads_per_partion, num_kv_heads=self.num_kv_heads_per_partion,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -234,13 +236,21 @@ class Phi3SmallDecoderLayer(nn.Module): ...@@ -234,13 +236,21 @@ class Phi3SmallDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Phi3SmallSelfAttention( self.self_attn = Phi3SmallSelfAttention(
config, layer_id, quant_config=quant_config config,
layer_id,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
self.mlp = Phi3SmallMLP(
config,
quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.mlp = Phi3SmallMLP(config, quant_config)
self.input_layernorm = nn.LayerNorm( self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_epsilon config.hidden_size, eps=config.layer_norm_epsilon
...@@ -284,15 +294,20 @@ class Phi3SmallModel(nn.Module): ...@@ -284,15 +294,20 @@ class Phi3SmallModel(nn.Module):
self.config = config self.config = config
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.mup_embedding_multiplier = config.mup_embedding_multiplier self.mup_embedding_multiplier = config.mup_embedding_multiplier
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Phi3SmallDecoderLayer( lambda prefix: Phi3SmallDecoderLayer(
config, int(prefix.split(".")[-1]), quant_config config,
int(prefix.split(".")[-1]),
quant_config,
prefix=prefix,
), ),
prefix=f"{prefix}.layers", prefix=add_prefix("layers", prefix),
) )
self.final_layernorm = nn.LayerNorm( self.final_layernorm = nn.LayerNorm(
...@@ -335,6 +350,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -335,6 +350,7 @@ class Phi3SmallForCausalLM(nn.Module):
self, self,
config: Phi3Config, config: Phi3Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -344,7 +360,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -344,7 +360,7 @@ class Phi3SmallForCausalLM(nn.Module):
self.model = Phi3SmallModel( self.model = Phi3SmallModel(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix="model", prefix=add_prefix("model", prefix),
) )
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.mup_width_multiplier = config.mup_width_multiplier self.mup_width_multiplier = config.mup_width_multiplier
...@@ -354,6 +370,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -354,6 +370,7 @@ class Phi3SmallForCausalLM(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
......
...@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
...@@ -48,6 +49,7 @@ class QWenMLP(nn.Module): ...@@ -48,6 +49,7 @@ class QWenMLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_act: str = "silu", hidden_act: str = "silu",
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -56,6 +58,7 @@ class QWenMLP(nn.Module): ...@@ -56,6 +58,7 @@ class QWenMLP(nn.Module):
bias=False, bias=False,
gather_output=False, gather_output=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
...@@ -63,6 +66,7 @@ class QWenMLP(nn.Module): ...@@ -63,6 +66,7 @@ class QWenMLP(nn.Module):
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("c_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -88,6 +92,7 @@ class QWenAttention(nn.Module): ...@@ -88,6 +92,7 @@ class QWenAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -104,6 +109,7 @@ class QWenAttention(nn.Module): ...@@ -104,6 +109,7 @@ class QWenAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("c_attn", prefix),
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
...@@ -111,6 +117,7 @@ class QWenAttention(nn.Module): ...@@ -111,6 +117,7 @@ class QWenAttention(nn.Module):
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("c_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -126,6 +133,7 @@ class QWenAttention(nn.Module): ...@@ -126,6 +133,7 @@ class QWenAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_heads, num_kv_heads=self.num_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -148,6 +156,7 @@ class QWenBlock(nn.Module): ...@@ -148,6 +156,7 @@ class QWenBlock(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id, layer_id,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -162,6 +171,7 @@ class QWenBlock(nn.Module): ...@@ -162,6 +171,7 @@ class QWenBlock(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("attn", prefix),
) )
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -170,6 +180,7 @@ class QWenBlock(nn.Module): ...@@ -170,6 +180,7 @@ class QWenBlock(nn.Module):
config.hidden_size, config.hidden_size,
config.intermediate_size // 2, config.intermediate_size // 2,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
def forward( def forward(
...@@ -201,6 +212,7 @@ class QWenModel(nn.Module): ...@@ -201,6 +212,7 @@ class QWenModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -210,10 +222,16 @@ class QWenModel(nn.Module): ...@@ -210,10 +222,16 @@ class QWenModel(nn.Module):
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
vocab_size, vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("wte", prefix),
) )
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
QWenBlock(config, i, quant_config=quant_config) QWenBlock(
config,
i,
quant_config=quant_config,
prefix=add_prefix(f"h.{i}", prefix),
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -242,12 +260,17 @@ class QWenLMHeadModel(nn.Module): ...@@ -242,12 +260,17 @@ class QWenLMHeadModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = QWenModel(config, quant_config=quant_config) self.transformer = QWenModel(
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
)
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(
vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# Adapted from llama2.py # Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model. # Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from readline import add_history
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
...@@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
kv_cache_scales_loader, kv_cache_scales_loader,
) )
from sglang.srt.utils import make_layers from sglang.srt.utils import add_prefix, make_layers
Qwen2Config = None Qwen2Config = None
...@@ -58,6 +58,7 @@ class Qwen2MLP(nn.Module): ...@@ -58,6 +58,7 @@ class Qwen2MLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -65,12 +66,14 @@ class Qwen2MLP(nn.Module): ...@@ -65,12 +66,14 @@ class Qwen2MLP(nn.Module):
[intermediate_size] * 2, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -97,6 +100,7 @@ class Qwen2Attention(nn.Module): ...@@ -97,6 +100,7 @@ class Qwen2Attention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 32768, max_position_embeddings: int = 32768,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -128,12 +132,14 @@ class Qwen2Attention(nn.Module): ...@@ -128,12 +132,14 @@ class Qwen2Attention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -149,6 +155,7 @@ class Qwen2Attention(nn.Module): ...@@ -149,6 +155,7 @@ class Qwen2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -171,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -171,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module):
config: Qwen2Config, config: Qwen2Config,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -186,12 +194,14 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -186,12 +194,14 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
) )
self.mlp = Qwen2MLP( self.mlp = Qwen2MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -228,6 +238,7 @@ class Qwen2Model(nn.Module): ...@@ -228,6 +238,7 @@ class Qwen2Model(nn.Module):
self, self,
config: Qwen2Config, config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -237,6 +248,7 @@ class Qwen2Model(nn.Module): ...@@ -237,6 +248,7 @@ class Qwen2Model(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = make_layers( self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
...@@ -244,7 +256,9 @@ class Qwen2Model(nn.Module): ...@@ -244,7 +256,9 @@ class Qwen2Model(nn.Module):
layer_id=idx, layer_id=idx,
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix,
), ),
prefix=add_prefix("layers", prefix),
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -325,16 +339,22 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -325,16 +339,22 @@ class Qwen2ForCausalLM(nn.Module):
self, self,
config: Qwen2Config, config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2Model(config, quant_config=quant_config) self.model = Qwen2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......
...@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch ...@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -65,16 +66,29 @@ class Qwen2_5_VLMLP(nn.Module): ...@@ -65,16 +66,29 @@ class Qwen2_5_VLMLP(nn.Module):
bias: bool = True, bias: bool = True,
hidden_act="silu", hidden_act="silu",
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.gate_proj = ColumnParallelLinear( self.gate_proj = ColumnParallelLinear(
in_features, hidden_features, bias=bias, quant_config=quant_config in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_proj", prefix),
) )
self.up_proj = ColumnParallelLinear( self.up_proj = ColumnParallelLinear(
in_features, hidden_features, bias=bias, quant_config=quant_config in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
hidden_features, in_features, bias=bias, quant_config=quant_config hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
self.act = ACT2FN[hidden_act] self.act = ACT2FN[hidden_act]
...@@ -98,6 +112,7 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -98,6 +112,7 @@ class Qwen2_5_VisionBlock(nn.Module):
norm_layer: Type[nn.Module] = None, norm_layer: Type[nn.Module] = None,
attn_implementation: Optional[str] = "sdpa", attn_implementation: Optional[str] = "sdpa",
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
...@@ -123,9 +138,14 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -123,9 +138,14 @@ class Qwen2_5_VisionBlock(nn.Module):
use_full_precision_softmax=use_full_precision_softmax, use_full_precision_softmax=use_full_precision_softmax,
flatten_batch=True, flatten_batch=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("attn", prefix),
) )
self.mlp = Qwen2_5_VLMLP( self.mlp = Qwen2_5_VLMLP(
dim, intermediate_dim, hidden_act=hidden_act, quant_config=quant_config dim,
intermediate_dim,
hidden_act=hidden_act,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
def forward( def forward(
...@@ -178,6 +198,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module): ...@@ -178,6 +198,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
context_dim: int, context_dim: int,
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size**2)
...@@ -189,10 +210,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module): ...@@ -189,10 +210,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp.0", prefix),
), ),
nn.GELU(), nn.GELU(),
RowParallelLinear( RowParallelLinear(
self.hidden_size, dim, bias=True, quant_config=quant_config self.hidden_size,
dim,
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.2", prefix),
), ),
] ]
) )
...@@ -250,6 +276,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -250,6 +276,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
vision_config: Qwen2_5_VLVisionConfig, vision_config: Qwen2_5_VLVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -286,8 +313,9 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -286,8 +313,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
attn_implementation="sdpa", attn_implementation="sdpa",
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix(f"blocks.{i}", prefix),
) )
for _ in range(depth) for i in range(depth)
] ]
) )
self.merger = Qwen2_5_VisionPatchMerger( self.merger = Qwen2_5_VisionPatchMerger(
...@@ -295,6 +323,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -295,6 +323,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
context_dim=hidden_size, context_dim=hidden_size,
spatial_merge_size=spatial_merge_size, spatial_merge_size=spatial_merge_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("merger", prefix),
) )
def get_window_index(self, grid_thw): def get_window_index(self, grid_thw):
...@@ -447,6 +476,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -447,6 +476,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self, self,
config: Qwen2VLConfig, config: Qwen2VLConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -457,15 +487,23 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -457,15 +487,23 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
# NOTE: Qwen2-VL vision encoder does not support any # NOTE: Qwen2-VL vision encoder does not support any
# quantization method now. # quantization method now.
quant_config=None, quant_config=None,
prefix=add_prefix("visual", prefix),
) )
self.model = Qwen2Model(config, quant_config) self.model = Qwen2Model(
config,
quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from sglang.srt.utils import add_prefix
# Adapted from # Adapted from
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights.""" """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
...@@ -42,7 +44,7 @@ class Qwen2DecoderLayer(Qwen2DecoderLayer): ...@@ -42,7 +44,7 @@ class Qwen2DecoderLayer(Qwen2DecoderLayer):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__(config, layer_id, quant_config) super().__init__(config, layer_id, quant_config, prefix=prefix)
# Skip the input_layernorm # Skip the input_layernorm
# https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
...@@ -56,6 +58,7 @@ class Qwen2Model(nn.Module): ...@@ -56,6 +58,7 @@ class Qwen2Model(nn.Module):
self, self,
config: Qwen2Config, config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -63,11 +66,15 @@ class Qwen2Model(nn.Module): ...@@ -63,11 +66,15 @@ class Qwen2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Qwen2DecoderLayer( Qwen2DecoderLayer(
config, i, quant_config=quant_config, prefix=f"model.layers.{i}" config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
) )
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
...@@ -107,16 +114,22 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM): ...@@ -107,16 +114,22 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
self, self,
config: Qwen2Config, config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2Model(config, quant_config=quant_config) self.model = Qwen2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):
...@@ -56,10 +57,15 @@ class Qwen2MoeMLP(nn.Module): ...@@ -56,10 +57,15 @@ class Qwen2MoeMLP(nn.Module):
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
...@@ -67,6 +73,7 @@ class Qwen2MoeMLP(nn.Module): ...@@ -67,6 +73,7 @@ class Qwen2MoeMLP(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
prefix=add_prefix("down_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -87,6 +94,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -87,6 +94,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
...@@ -105,10 +113,15 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -105,10 +113,15 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
reduce_results=False, reduce_results=False,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("experts", prefix),
) )
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
config.hidden_size, config.num_experts, bias=False, quant_config=None config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=add_prefix("gate", prefix),
) )
if config.shared_expert_intermediate_size > 0: if config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen2MoeMLP( self.shared_expert = Qwen2MoeMLP(
...@@ -117,6 +130,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -117,6 +130,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=False, reduce_results=False,
prefix=add_prefix("shared_expert", prefix),
) )
else: else:
self.shared_expert = None self.shared_expert = None
...@@ -157,6 +171,7 @@ class Qwen2MoeAttention(nn.Module): ...@@ -157,6 +171,7 @@ class Qwen2MoeAttention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -188,6 +203,7 @@ class Qwen2MoeAttention(nn.Module): ...@@ -188,6 +203,7 @@ class Qwen2MoeAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -195,6 +211,7 @@ class Qwen2MoeAttention(nn.Module): ...@@ -195,6 +211,7 @@ class Qwen2MoeAttention(nn.Module):
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -210,6 +227,7 @@ class Qwen2MoeAttention(nn.Module): ...@@ -210,6 +227,7 @@ class Qwen2MoeAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -232,6 +250,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -232,6 +250,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -247,6 +266,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -247,6 +266,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
) )
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
...@@ -257,13 +277,18 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -257,13 +277,18 @@ class Qwen2MoeDecoderLayer(nn.Module):
if (layer_id not in mlp_only_layers) and ( if (layer_id not in mlp_only_layers) and (
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0 config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
): ):
self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config) self.mlp = Qwen2MoeSparseMoeBlock(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
else: else:
self.mlp = Qwen2MoeMLP( self.mlp = Qwen2MoeMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -300,6 +325,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -300,6 +325,7 @@ class Qwen2MoeModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -308,10 +334,16 @@ class Qwen2MoeModel(nn.Module): ...@@ -308,10 +334,16 @@ class Qwen2MoeModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
Qwen2MoeDecoderLayer(config, layer_id, quant_config=quant_config) Qwen2MoeDecoderLayer(
config,
layer_id,
quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix),
)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
...@@ -346,13 +378,19 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -346,13 +378,19 @@ class Qwen2MoeForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2MoeModel(config, quant_config) self.model = Qwen2MoeModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType ...@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
from sglang.srt.utils import add_prefix
class Qwen2ForRewardModel(nn.Module): class Qwen2ForRewardModel(nn.Module):
...@@ -29,12 +30,15 @@ class Qwen2ForRewardModel(nn.Module): ...@@ -29,12 +30,15 @@ class Qwen2ForRewardModel(nn.Module):
self, self,
config: Qwen2Config, config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.num_labels = 1 self.num_labels = 1
self.model = Qwen2Model(config, quant_config=quant_config) self.model = Qwen2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.score = nn.Sequential( self.score = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size), nn.Linear(config.hidden_size, config.hidden_size),
nn.ReLU(), nn.ReLU(),
......
...@@ -46,6 +46,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs ...@@ -46,6 +46,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -91,14 +92,21 @@ class Qwen2VisionMLP(nn.Module): ...@@ -91,14 +92,21 @@ class Qwen2VisionMLP(nn.Module):
hidden_features: int = None, hidden_features: int = None,
act_layer: Type[nn.Module] = QuickGELU, act_layer: Type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.fc1 = ColumnParallelLinear( self.fc1 = ColumnParallelLinear(
in_features, hidden_features, quant_config=quant_config in_features,
hidden_features,
quant_config=quant_config,
prefix=add_prefix("fc1", prefix),
) )
self.act = act_layer() self.act = act_layer()
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
hidden_features, in_features, quant_config=quant_config hidden_features,
in_features,
quant_config=quant_config,
prefix=add_prefix("fc2", prefix),
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -119,6 +127,7 @@ class Qwen2VisionBlock(nn.Module): ...@@ -119,6 +127,7 @@ class Qwen2VisionBlock(nn.Module):
norm_layer: Type[nn.Module] = None, norm_layer: Type[nn.Module] = None,
attn_implementation: Optional[str] = "sdpa", attn_implementation: Optional[str] = "sdpa",
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
...@@ -145,9 +154,14 @@ class Qwen2VisionBlock(nn.Module): ...@@ -145,9 +154,14 @@ class Qwen2VisionBlock(nn.Module):
use_full_precision_softmax=use_full_precision_softmax, use_full_precision_softmax=use_full_precision_softmax,
flatten_batch=True, flatten_batch=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("attn", prefix),
) )
self.mlp = Qwen2VisionMLP( self.mlp = Qwen2VisionMLP(
dim, mlp_hidden_dim, act_layer=act_layer, quant_config=quant_config dim,
mlp_hidden_dim,
act_layer=act_layer,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
def forward( def forward(
...@@ -199,6 +213,7 @@ class Qwen2VisionPatchMerger(nn.Module): ...@@ -199,6 +213,7 @@ class Qwen2VisionPatchMerger(nn.Module):
norm_layer: Type[nn.Module] = None, norm_layer: Type[nn.Module] = None,
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size**2)
...@@ -212,10 +227,15 @@ class Qwen2VisionPatchMerger(nn.Module): ...@@ -212,10 +227,15 @@ class Qwen2VisionPatchMerger(nn.Module):
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp.0", prefix),
), ),
nn.GELU(), nn.GELU(),
RowParallelLinear( RowParallelLinear(
self.hidden_size, d_model, bias=True, quant_config=quant_config self.hidden_size,
d_model,
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.2", prefix),
), ),
] ]
) )
...@@ -273,6 +293,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -273,6 +293,7 @@ class Qwen2VisionTransformer(nn.Module):
vision_config: Qwen2VLVisionConfig, vision_config: Qwen2VLVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -307,8 +328,9 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -307,8 +328,9 @@ class Qwen2VisionTransformer(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
attn_implementation="sdpa", attn_implementation="sdpa",
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix(f"blocks.{i}", prefix),
) )
for _ in range(depth) for i in range(depth)
] ]
) )
self.merger = Qwen2VisionPatchMerger( self.merger = Qwen2VisionPatchMerger(
...@@ -316,6 +338,7 @@ class Qwen2VisionTransformer(nn.Module): ...@@ -316,6 +338,7 @@ class Qwen2VisionTransformer(nn.Module):
context_dim=embed_dim, context_dim=embed_dim,
norm_layer=norm_layer, norm_layer=norm_layer,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("merger", prefix),
) )
@property @property
...@@ -440,6 +463,7 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -440,6 +463,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
self, self,
config: Qwen2VLConfig, config: Qwen2VLConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -450,15 +474,21 @@ class Qwen2VLForConditionalGeneration(nn.Module): ...@@ -450,15 +474,21 @@ class Qwen2VLForConditionalGeneration(nn.Module):
# NOTE: Qwen2-VL vision encoder does not support any # NOTE: Qwen2-VL vision encoder does not support any
# quantization method now. # quantization method now.
quant_config=None, quant_config=None,
prefix=add_prefix("visual", prefix),
) )
self.model = Qwen2Model(config, quant_config) self.model = Qwen2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -42,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -42,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):
...@@ -49,6 +50,7 @@ class StablelmMLP(nn.Module): ...@@ -49,6 +50,7 @@ class StablelmMLP(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -59,12 +61,14 @@ class StablelmMLP(nn.Module): ...@@ -59,12 +61,14 @@ class StablelmMLP(nn.Module):
[config.intermediate_size] * 2, [config.intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
...@@ -81,6 +85,7 @@ class StablelmAttention(nn.Module): ...@@ -81,6 +85,7 @@ class StablelmAttention(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -122,11 +127,15 @@ class StablelmAttention(nn.Module): ...@@ -122,11 +127,15 @@ class StablelmAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_key_value_heads, self.total_num_key_value_heads,
self.qkv_bias, self.qkv_bias,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -140,6 +149,7 @@ class StablelmAttention(nn.Module): ...@@ -140,6 +149,7 @@ class StablelmAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_key_value_heads, num_kv_heads=self.num_key_value_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -162,10 +172,15 @@ class StablelmDecoderLayer(nn.Module): ...@@ -162,10 +172,15 @@ class StablelmDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.self_attn = StablelmAttention(config, layer_id=layer_id) self.self_attn = StablelmAttention(
self.mlp = StablelmMLP(config, quant_config=quant_config) config, layer_id=layer_id, prefix=add_prefix("self_attn", prefix)
)
self.mlp = StablelmMLP(
config, quant_config=quant_config, prefix=add_prefix("mlp", prefix)
)
norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
...@@ -200,15 +215,22 @@ class StableLMEpochModel(nn.Module): ...@@ -200,15 +215,22 @@ class StableLMEpochModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
StablelmDecoderLayer(config, i, quant_config=quant_config) StablelmDecoderLayer(
config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -242,12 +264,17 @@ class StableLmForCausalLM(nn.Module): ...@@ -242,12 +264,17 @@ class StableLmForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = StableLMEpochModel(config, quant_config=quant_config) self.model = StableLMEpochModel(
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
......
...@@ -64,6 +64,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -64,6 +64,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -294,14 +295,14 @@ class LlamaDecoderLayer(nn.Module): ...@@ -294,14 +295,14 @@ class LlamaDecoderLayer(nn.Module):
rope_is_neox_style=rope_is_neox_style, rope_is_neox_style=rope_is_neox_style,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=add_prefix("self_attn", prefix),
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
......
...@@ -40,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -40,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.model_runner import ForwardBatch from sglang.srt.model_executor.model_runner import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class XverseMLP(nn.Module): class XverseMLP(nn.Module):
...@@ -57,14 +58,14 @@ class XverseMLP(nn.Module): ...@@ -57,14 +58,14 @@ class XverseMLP(nn.Module):
[intermediate_size] * 2, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj", prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.down_proj", prefix=add_prefix("down_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -128,14 +129,14 @@ class XverseAttention(nn.Module): ...@@ -128,14 +129,14 @@ class XverseAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -152,6 +153,7 @@ class XverseAttention(nn.Module): ...@@ -152,6 +153,7 @@ class XverseAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -202,14 +204,14 @@ class XverseDecoderLayer(nn.Module): ...@@ -202,14 +204,14 @@ class XverseDecoderLayer(nn.Module):
rope_is_neox_style=rope_is_neox_style, rope_is_neox_style=rope_is_neox_style,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=add_prefix("self_attn", prefix),
) )
self.mlp = XverseMLP( self.mlp = XverseMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -246,6 +248,7 @@ class XverseModel(nn.Module): ...@@ -246,6 +248,7 @@ class XverseModel(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -254,11 +257,15 @@ class XverseModel(nn.Module): ...@@ -254,11 +257,15 @@ class XverseModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
XverseDecoderLayer( XverseDecoderLayer(
config, i, quant_config=quant_config, prefix=f"model.layers.{i}" config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
) )
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
...@@ -295,12 +302,17 @@ class XverseForCausalLM(nn.Module): ...@@ -295,12 +302,17 @@ class XverseForCausalLM(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = XverseModel(config, quant_config=quant_config) self.model = XverseModel(
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
......
...@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -43,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class XverseMLP(nn.Module): class XverseMLP(nn.Module):
...@@ -54,10 +55,15 @@ class XverseMLP(nn.Module): ...@@ -54,10 +55,15 @@ class XverseMLP(nn.Module):
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
...@@ -65,6 +71,7 @@ class XverseMLP(nn.Module): ...@@ -65,6 +71,7 @@ class XverseMLP(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
prefix=add_prefix("down_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -86,6 +93,7 @@ class XverseMoE(nn.Module): ...@@ -86,6 +93,7 @@ class XverseMoE(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -107,14 +115,19 @@ class XverseMoE(nn.Module): ...@@ -107,14 +115,19 @@ class XverseMoE(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=False, reduce_results=False,
prefix=add_prefix(f"experts.{i}", prefix),
) )
for _ in range(self.n_routed_experts) for i in range(self.n_routed_experts)
] ]
) )
self.pack_params() self.pack_params()
self.router = ReplicatedLinear( self.router = ReplicatedLinear(
config.hidden_size, self.n_routed_experts, bias=False, quant_config=None config.hidden_size,
self.n_routed_experts,
bias=False,
quant_config=None,
prefix=add_prefix("router", prefix),
) )
if config.num_shared_experts is not None: if config.num_shared_experts is not None:
...@@ -125,6 +138,7 @@ class XverseMoE(nn.Module): ...@@ -125,6 +138,7 @@ class XverseMoE(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=False, reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
) )
def pack_params(self): def pack_params(self):
...@@ -182,6 +196,7 @@ class XverseAttention(nn.Module): ...@@ -182,6 +196,7 @@ class XverseAttention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -213,6 +228,7 @@ class XverseAttention(nn.Module): ...@@ -213,6 +228,7 @@ class XverseAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -220,6 +236,7 @@ class XverseAttention(nn.Module): ...@@ -220,6 +236,7 @@ class XverseAttention(nn.Module):
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -235,6 +252,7 @@ class XverseAttention(nn.Module): ...@@ -235,6 +252,7 @@ class XverseAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -258,6 +276,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -258,6 +276,7 @@ class XverseDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -276,15 +295,21 @@ class XverseDecoderLayer(nn.Module): ...@@ -276,15 +295,21 @@ class XverseDecoderLayer(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
) )
if config.num_experts is not None: if config.num_experts is not None:
self.mlp = XverseMoE(config=config, quant_config=quant_config) self.mlp = XverseMoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
else: else:
self.mlp = XverseMLP( self.mlp = XverseMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -324,6 +349,7 @@ class XverseModel(nn.Module): ...@@ -324,6 +349,7 @@ class XverseModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -332,10 +358,16 @@ class XverseModel(nn.Module): ...@@ -332,10 +358,16 @@ class XverseModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
XverseDecoderLayer(config, layer_id, quant_config=quant_config) XverseDecoderLayer(
config,
layer_id,
quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix),
)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
...@@ -364,13 +396,19 @@ class XverseMoeForCausalLM(nn.Module): ...@@ -364,13 +396,19 @@ class XverseMoeForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = XverseModel(config, quant_config) self.model = XverseModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -29,8 +29,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM): ...@@ -29,8 +29,9 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
self, self,
config: LlavaConfig, config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__(config, quant_config) super().__init__(config, quant_config, prefix=prefix)
self.multi_modal_projector = YiVLMultiModalProjector(self.config) self.multi_modal_projector = YiVLMultiModalProjector(self.config)
self.vision_tower_subfolder = self.config.mm_vision_tower.replace( self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
......
...@@ -313,7 +313,7 @@ def make_layers( ...@@ -313,7 +313,7 @@ def make_layers(
"""Make a list of layers with the given layer function""" """Make a list of layers with the given layer function"""
modules = torch.nn.ModuleList( modules = torch.nn.ModuleList(
[ [
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}")) maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
for idx in range(num_hidden_layers) for idx in range(num_hidden_layers)
] ]
) )
...@@ -1464,3 +1464,16 @@ def set_cuda_arch(): ...@@ -1464,3 +1464,16 @@ def set_cuda_arch():
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
arch = f"{capability[0]}.{capability[1]}" arch = f"{capability[0]}.{capability[1]}"
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
def add_prefix(name: str, prefix: str) -> str:
"""Add a weight path prefix to a module name.
Args:
name: base module name.
prefix: weight prefix str to added to the front of `name` concatenated with `.`.
Returns:
The string `prefix.name` if prefix is non-empty, otherwise just `name`.
"""
return name if not prefix else f"{prefix}.{name}"
...@@ -12,6 +12,7 @@ suites = { ...@@ -12,6 +12,7 @@ suites = {
"models/test_generation_models.py", "models/test_generation_models.py",
"models/test_qwen_models.py", "models/test_qwen_models.py",
"models/test_reward_models.py", "models/test_reward_models.py",
"test_gptqmodel_dynamic.py",
"test_abort.py", "test_abort.py",
"test_chunked_prefill.py", "test_chunked_prefill.py",
"test_custom_allreduce.py", "test_custom_allreduce.py",
......
import time
import unittest
import requests
import torch
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
def check_quant_method(model_path: str, use_marlin_kernel: bool):
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
get_tp_group,
init_distributed_environment,
initialize_model_parallel,
set_custom_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.quantization import get_dynamic_override
from sglang.srt.model_loader import get_model
from sglang.srt.server_args import PortArgs, ServerArgs
try:
init_distributed_environment(
backend="nccl",
world_size=1,
rank=0,
local_rank=0,
distributed_init_method="tcp://127.0.0.1:2646",
)
initialize_model_parallel(tensor_model_parallel_size=1)
monkey_patch_vllm_parallel_state()
except AssertionError:
# ignore this error: tensor model parallel group is already initialized
pass
server_args = ServerArgs(model_path=model_path, dtype=torch.float16)
model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
load_config = LoadConfig()
device_config = DeviceConfig("cuda")
model = get_model(
model_config=model_config, load_config=load_config, device_config=device_config
)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
)
from sglang.srt.layers.linear import UnquantizedLinearMethod
linear_method_cls = (
GPTQMarlinLinearMethod if use_marlin_kernel else (GPTQLinearMethod)
)
for name, submodule in model.named_modules():
if name == "lm_head":
assert isinstance(submodule.quant_method, linear_method_cls)
elif name == "model.layers.0.self_attn.qkv_proj":
# The first layer is quantized using bits=4, group_size=128
# desc_act=True
assert isinstance(submodule.quant_method, linear_method_cls)
config = submodule.quant_method.quant_config
assert config.weight_bits == 4
assert config.group_size == 128
assert config.desc_act
elif name == "model.layers.1.self_attn.qkv_proj":
# The second layer is quantized using bits=8, group_size=32
# desc_act=False
assert isinstance(submodule.quant_method, linear_method_cls)
config = submodule.quant_method.quant_config
assert get_dynamic_override(config, layer_name=name, key="bits") == 8
assert get_dynamic_override(config, layer_name=name, key="group_size") == 32
assert not get_dynamic_override(config, layer_name=name, key="desc_act")
elif (
name == "model.layers.2.self_attn.qkv_proj"
or name == "model.layers.2.mlp.gate_up_proj"
):
# All other layers (layer index >= 2) are not quantized
assert isinstance(submodule.quant_method, UnquantizedLinearMethod)
del model
# GPTQ with Dynamic Per/Module Quantization Control
# Leverages GPTQModel (pypi) to produce the `dynamic` models
# Test GPTQ fallback kernel that is not Marlin
class TestGPTQModelDynamic(unittest.TestCase):
MODEL_PATH = (
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse"
)
@classmethod
def setUpClass(cls):
cls.model = cls.MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--dtype", "float16"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, max_new_tokens):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"max_new_tokens": max_new_tokens,
},
},
)
return response.json()
def test_throughput(self):
max_tokens = 256
tic = time.time()
result = self.run_decode(max_tokens)
tok = time.time()
print(f"result = `{result}`")
assert "paris" in result["text"].lower()
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
assert throughput >= 140
def test_gptq_module(self):
check_quant_method(self.MODEL_PATH, use_marlin_kernel=False)
# GPTQ with Dynamic Per/Module Quantization Control
# Leverages GPTQModel (pypi) to produce the `dynamic` models
# Test Marlin kernel
class TestGPTQModelDynamicWithMarlin(unittest.TestCase):
MODEL_PATH = (
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue"
)
@classmethod
def setUpClass(cls):
cls.model = cls.MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--dtype", "float16"],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, max_new_tokens):
response = requests.post(
self.base_url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"max_new_tokens": max_new_tokens,
},
},
)
return response.json()
def test_throughput(self):
max_tokens = 256
tic = time.time()
result = self.run_decode(max_tokens)
tok = time.time()
print(f"result = `{result}`")
assert "paris" in result["text"].lower()
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
assert throughput >= 140
def test_gptq_marlin_module(self):
check_quant_method(self.MODEL_PATH, use_marlin_kernel=True)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment