Unverified Commit 56a724eb authored by Qubitium-ModelCloud's avatar Qubitium-ModelCloud Committed by GitHub
Browse files

[QUANT] Add GPTQModel Dynamic Quantization + `lm_head` Quantization (#3790)


Signed-off-by: default avatarZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: default avatarZX-ModelCloud <zx@modelcloud.ai>
parent 583d6af7
...@@ -24,7 +24,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -24,7 +24,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers from sglang.srt.utils import add_prefix, make_layers
@torch.jit.script @torch.jit.script
...@@ -70,13 +70,14 @@ class Phi3SmallMLP(nn.Module): ...@@ -70,13 +70,14 @@ class Phi3SmallMLP(nn.Module):
2 * [self.intermediate_size], 2 * [self.intermediate_size],
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.up_proj", prefix=add_prefix("up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
self.intermediate_size, self.intermediate_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
def forward(self, x): def forward(self, x):
...@@ -140,7 +141,7 @@ class Phi3SmallSelfAttention(nn.Module): ...@@ -140,7 +141,7 @@ class Phi3SmallSelfAttention(nn.Module):
self.num_key_value_heads, self.num_key_value_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=add_prefix("qkv_proj", prefix),
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
...@@ -148,7 +149,7 @@ class Phi3SmallSelfAttention(nn.Module): ...@@ -148,7 +149,7 @@ class Phi3SmallSelfAttention(nn.Module):
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=add_prefix("o_proj", prefix),
) )
if getattr(self.config, "rope_scaling", None) is not None: if getattr(self.config, "rope_scaling", None) is not None:
...@@ -201,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module): ...@@ -201,6 +202,7 @@ class Phi3SmallSelfAttention(nn.Module):
self.scale, self.scale,
num_kv_heads=self.num_kv_heads_per_partion, num_kv_heads=self.num_kv_heads_per_partion,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -234,13 +236,21 @@ class Phi3SmallDecoderLayer(nn.Module): ...@@ -234,13 +236,21 @@ class Phi3SmallDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = Phi3SmallSelfAttention( self.self_attn = Phi3SmallSelfAttention(
config, layer_id, quant_config=quant_config config,
layer_id,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
self.mlp = Phi3SmallMLP(
config,
quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.mlp = Phi3SmallMLP(config, quant_config)
self.input_layernorm = nn.LayerNorm( self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_epsilon config.hidden_size, eps=config.layer_norm_epsilon
...@@ -284,15 +294,20 @@ class Phi3SmallModel(nn.Module): ...@@ -284,15 +294,20 @@ class Phi3SmallModel(nn.Module):
self.config = config self.config = config
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.mup_embedding_multiplier = config.mup_embedding_multiplier self.mup_embedding_multiplier = config.mup_embedding_multiplier
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Phi3SmallDecoderLayer( lambda prefix: Phi3SmallDecoderLayer(
config, int(prefix.split(".")[-1]), quant_config config,
int(prefix.split(".")[-1]),
quant_config,
prefix=prefix,
), ),
prefix=f"{prefix}.layers", prefix=add_prefix("layers", prefix),
) )
self.final_layernorm = nn.LayerNorm( self.final_layernorm = nn.LayerNorm(
...@@ -335,6 +350,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -335,6 +350,7 @@ class Phi3SmallForCausalLM(nn.Module):
self, self,
config: Phi3Config, config: Phi3Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -344,7 +360,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -344,7 +360,7 @@ class Phi3SmallForCausalLM(nn.Module):
self.model = Phi3SmallModel( self.model = Phi3SmallModel(
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix="model", prefix=add_prefix("model", prefix),
) )
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.mup_width_multiplier = config.mup_width_multiplier self.mup_width_multiplier = config.mup_width_multiplier
...@@ -354,6 +370,7 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -354,6 +370,7 @@ class Phi3SmallForCausalLM(nn.Module):
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
......
...@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -39,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
...@@ -48,6 +49,7 @@ class QWenMLP(nn.Module): ...@@ -48,6 +49,7 @@ class QWenMLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_act: str = "silu", hidden_act: str = "silu",
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -56,6 +58,7 @@ class QWenMLP(nn.Module): ...@@ -56,6 +58,7 @@ class QWenMLP(nn.Module):
bias=False, bias=False,
gather_output=False, gather_output=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
...@@ -63,6 +66,7 @@ class QWenMLP(nn.Module): ...@@ -63,6 +66,7 @@ class QWenMLP(nn.Module):
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("c_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -88,6 +92,7 @@ class QWenAttention(nn.Module): ...@@ -88,6 +92,7 @@ class QWenAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -104,6 +109,7 @@ class QWenAttention(nn.Module): ...@@ -104,6 +109,7 @@ class QWenAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("c_attn", prefix),
) )
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
...@@ -111,6 +117,7 @@ class QWenAttention(nn.Module): ...@@ -111,6 +117,7 @@ class QWenAttention(nn.Module):
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("c_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -126,6 +133,7 @@ class QWenAttention(nn.Module): ...@@ -126,6 +133,7 @@ class QWenAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_heads, num_kv_heads=self.num_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -148,6 +156,7 @@ class QWenBlock(nn.Module): ...@@ -148,6 +156,7 @@ class QWenBlock(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id, layer_id,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -162,6 +171,7 @@ class QWenBlock(nn.Module): ...@@ -162,6 +171,7 @@ class QWenBlock(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("attn", prefix),
) )
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
...@@ -170,6 +180,7 @@ class QWenBlock(nn.Module): ...@@ -170,6 +180,7 @@ class QWenBlock(nn.Module):
config.hidden_size, config.hidden_size,
config.intermediate_size // 2, config.intermediate_size // 2,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
def forward( def forward(
...@@ -201,6 +212,7 @@ class QWenModel(nn.Module): ...@@ -201,6 +212,7 @@ class QWenModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -210,10 +222,16 @@ class QWenModel(nn.Module): ...@@ -210,10 +222,16 @@ class QWenModel(nn.Module):
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
vocab_size, vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("wte", prefix),
) )
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
QWenBlock(config, i, quant_config=quant_config) QWenBlock(
config,
i,
quant_config=quant_config,
prefix=add_prefix(f"h.{i}", prefix),
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -242,12 +260,17 @@ class QWenLMHeadModel(nn.Module): ...@@ -242,12 +260,17 @@ class QWenLMHeadModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = QWenModel(config, quant_config=quant_config) self.transformer = QWenModel(
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
)
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(
vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# Adapted from llama2.py # Adapted from llama2.py
# Modify details for the adaptation of Qwen2 model. # Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from readline import add_history
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
...@@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
kv_cache_scales_loader, kv_cache_scales_loader,
) )
from sglang.srt.utils import make_layers from sglang.srt.utils import add_prefix, make_layers
Qwen2Config = None Qwen2Config = None
...@@ -58,6 +58,7 @@ class Qwen2MLP(nn.Module): ...@@ -58,6 +58,7 @@ class Qwen2MLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -65,12 +66,14 @@ class Qwen2MLP(nn.Module): ...@@ -65,12 +66,14 @@ class Qwen2MLP(nn.Module):
[intermediate_size] * 2, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -97,6 +100,7 @@ class Qwen2Attention(nn.Module): ...@@ -97,6 +100,7 @@ class Qwen2Attention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 32768, max_position_embeddings: int = 32768,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -128,12 +132,14 @@ class Qwen2Attention(nn.Module): ...@@ -128,12 +132,14 @@ class Qwen2Attention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -149,6 +155,7 @@ class Qwen2Attention(nn.Module): ...@@ -149,6 +155,7 @@ class Qwen2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -171,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -171,6 +178,7 @@ class Qwen2DecoderLayer(nn.Module):
config: Qwen2Config, config: Qwen2Config,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -186,12 +194,14 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -186,12 +194,14 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
) )
self.mlp = Qwen2MLP( self.mlp = Qwen2MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -228,6 +238,7 @@ class Qwen2Model(nn.Module): ...@@ -228,6 +238,7 @@ class Qwen2Model(nn.Module):
self, self,
config: Qwen2Config, config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -237,6 +248,7 @@ class Qwen2Model(nn.Module): ...@@ -237,6 +248,7 @@ class Qwen2Model(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = make_layers( self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
...@@ -244,7 +256,9 @@ class Qwen2Model(nn.Module): ...@@ -244,7 +256,9 @@ class Qwen2Model(nn.Module):
layer_id=idx, layer_id=idx,
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix,
), ),
prefix=add_prefix("layers", prefix),
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -325,16 +339,22 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -325,16 +339,22 @@ class Qwen2ForCausalLM(nn.Module):
self, self,
config: Qwen2Config, config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2Model(config, quant_config=quant_config) self.model = Qwen2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......
...@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch ...@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -65,16 +66,29 @@ class Qwen2_5_VLMLP(nn.Module): ...@@ -65,16 +66,29 @@ class Qwen2_5_VLMLP(nn.Module):
bias: bool = True, bias: bool = True,
hidden_act="silu", hidden_act="silu",
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.gate_proj = ColumnParallelLinear( self.gate_proj = ColumnParallelLinear(
in_features, hidden_features, bias=bias, quant_config=quant_config in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_proj", prefix),
) )
self.up_proj = ColumnParallelLinear( self.up_proj = ColumnParallelLinear(
in_features, hidden_features, bias=bias, quant_config=quant_config in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
hidden_features, in_features, bias=bias, quant_config=quant_config hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
self.act = ACT2FN[hidden_act] self.act = ACT2FN[hidden_act]
...@@ -98,6 +112,7 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -98,6 +112,7 @@ class Qwen2_5_VisionBlock(nn.Module):
norm_layer: Type[nn.Module] = None, norm_layer: Type[nn.Module] = None,
attn_implementation: Optional[str] = "sdpa", attn_implementation: Optional[str] = "sdpa",
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
if norm_layer is None: if norm_layer is None:
...@@ -123,9 +138,14 @@ class Qwen2_5_VisionBlock(nn.Module): ...@@ -123,9 +138,14 @@ class Qwen2_5_VisionBlock(nn.Module):
use_full_precision_softmax=use_full_precision_softmax, use_full_precision_softmax=use_full_precision_softmax,
flatten_batch=True, flatten_batch=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("attn", prefix),
) )
self.mlp = Qwen2_5_VLMLP( self.mlp = Qwen2_5_VLMLP(
dim, intermediate_dim, hidden_act=hidden_act, quant_config=quant_config dim,
intermediate_dim,
hidden_act=hidden_act,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
def forward( def forward(
...@@ -178,6 +198,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module): ...@@ -178,6 +198,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
context_dim: int, context_dim: int,
spatial_merge_size: int = 2, spatial_merge_size: int = 2,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = context_dim * (spatial_merge_size**2) self.hidden_size = context_dim * (spatial_merge_size**2)
...@@ -189,10 +210,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module): ...@@ -189,10 +210,15 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
self.hidden_size, self.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp.0", prefix),
), ),
nn.GELU(), nn.GELU(),
RowParallelLinear( RowParallelLinear(
self.hidden_size, dim, bias=True, quant_config=quant_config self.hidden_size,
dim,
bias=True,
quant_config=quant_config,
prefix=add_prefix("mlp.2", prefix),
), ),
] ]
) )
...@@ -250,6 +276,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -250,6 +276,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
vision_config: Qwen2_5_VLVisionConfig, vision_config: Qwen2_5_VLVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -286,8 +313,9 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -286,8 +313,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
norm_layer=norm_layer, norm_layer=norm_layer,
attn_implementation="sdpa", attn_implementation="sdpa",
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix(f"blocks.{i}", prefix),
) )
for _ in range(depth) for i in range(depth)
] ]
) )
self.merger = Qwen2_5_VisionPatchMerger( self.merger = Qwen2_5_VisionPatchMerger(
...@@ -295,6 +323,7 @@ class Qwen2_5_VisionTransformer(nn.Module): ...@@ -295,6 +323,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
context_dim=hidden_size, context_dim=hidden_size,
spatial_merge_size=spatial_merge_size, spatial_merge_size=spatial_merge_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("merger", prefix),
) )
def get_window_index(self, grid_thw): def get_window_index(self, grid_thw):
...@@ -447,6 +476,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -447,6 +476,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self, self,
config: Qwen2VLConfig, config: Qwen2VLConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -457,15 +487,23 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): ...@@ -457,15 +487,23 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
# NOTE: Qwen2-VL vision encoder does not support any # NOTE: Qwen2-VL vision encoder does not support any
# quantization method now. # quantization method now.
quant_config=None, quant_config=None,
prefix=add_prefix("visual", prefix),
) )
self.model = Qwen2Model(config, quant_config) self.model = Qwen2Model(
config,
quant_config,
prefix=add_prefix("model", prefix),
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
This diff is collapsed.
This diff is collapsed.
...@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType ...@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model from sglang.srt.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
from sglang.srt.utils import add_prefix
class Qwen2ForRewardModel(nn.Module): class Qwen2ForRewardModel(nn.Module):
...@@ -29,12 +30,15 @@ class Qwen2ForRewardModel(nn.Module): ...@@ -29,12 +30,15 @@ class Qwen2ForRewardModel(nn.Module):
self, self,
config: Qwen2Config, config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.num_labels = 1 self.num_labels = 1
self.model = Qwen2Model(config, quant_config=quant_config) self.model = Qwen2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.score = nn.Sequential( self.score = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size), nn.Linear(config.hidden_size, config.hidden_size),
nn.ReLU(), nn.ReLU(),
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment