Unverified Commit 56a724eb authored by Qubitium-ModelCloud's avatar Qubitium-ModelCloud Committed by GitHub
Browse files

[QUANT] Add GPTQModel Dynamic Quantization + `lm_head` Quantization (#3790)


Signed-off-by: default avatarZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: default avatarZX-ModelCloud <zx@modelcloud.ai>
parent 583d6af7
...@@ -42,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -42,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -62,14 +63,14 @@ class GraniteMLP(nn.Module): ...@@ -62,14 +63,14 @@ class GraniteMLP(nn.Module):
[intermediate_size] * 2, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj", prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.down_proj", prefix=add_prefix("down_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -133,14 +134,14 @@ class GraniteAttention(nn.Module): ...@@ -133,14 +134,14 @@ class GraniteAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -157,6 +158,7 @@ class GraniteAttention(nn.Module): ...@@ -157,6 +158,7 @@ class GraniteAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -205,14 +207,14 @@ class GraniteDecoderLayer(nn.Module): ...@@ -205,14 +207,14 @@ class GraniteDecoderLayer(nn.Module):
rope_is_neox_style=rope_is_neox_style, rope_is_neox_style=rope_is_neox_style,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=add_prefix("self_attn", prefix),
) )
self.mlp = GraniteMLP( self.mlp = GraniteMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -252,6 +254,7 @@ class GraniteModel(nn.Module): ...@@ -252,6 +254,7 @@ class GraniteModel(nn.Module):
self, self,
config: GraniteConfig, config: GraniteConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -263,7 +266,10 @@ class GraniteModel(nn.Module): ...@@ -263,7 +266,10 @@ class GraniteModel(nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
GraniteDecoderLayer( GraniteDecoderLayer(
config, i, quant_config=quant_config, prefix=f"model.layers.{i}" config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
) )
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
...@@ -300,17 +306,23 @@ class GraniteForCausalLM(nn.Module): ...@@ -300,17 +306,23 @@ class GraniteForCausalLM(nn.Module):
self, self,
config: GraniteConfig, config: GraniteConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = GraniteModel(config, quant_config=quant_config) self.model = GraniteModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
# If tie_word_embeddings == True, then input and output embeddings are # If tie_word_embeddings == True, then input and output embeddings are
# the same tensor. Enforce during object creation so that weights will # the same tensor. Enforce during object creation so that weights will
# load correctly even if the LM head weights don't have a separate entry # load correctly even if the LM head weights don't have a separate entry
# in the state dict. # in the state dict.
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.tie_weights(self.model.embed_tokens) self.lm_head.tie_weights(self.model.embed_tokens)
......
...@@ -47,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -47,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class Grok1MLP(nn.Module): class Grok1MLP(nn.Module):
...@@ -65,7 +66,7 @@ class Grok1MLP(nn.Module): ...@@ -65,7 +66,7 @@ class Grok1MLP(nn.Module):
[intermediate_size] * 2, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj", prefix=add_prefix("gate_up_proj", prefix),
use_presharded_weights=use_presharded_weights, use_presharded_weights=use_presharded_weights,
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
...@@ -73,7 +74,7 @@ class Grok1MLP(nn.Module): ...@@ -73,7 +74,7 @@ class Grok1MLP(nn.Module):
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.down_proj", prefix=add_prefix("down_proj", prefix),
reduce_results=reduce_results, reduce_results=reduce_results,
use_presharded_weights=use_presharded_weights, use_presharded_weights=use_presharded_weights,
) )
...@@ -107,6 +108,7 @@ class Grok1MoE(nn.Module): ...@@ -107,6 +108,7 @@ class Grok1MoE(nn.Module):
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
reduce_results=True, reduce_results=True,
use_presharded_weights: bool = False, use_presharded_weights: bool = False,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -118,6 +120,7 @@ class Grok1MoE(nn.Module): ...@@ -118,6 +120,7 @@ class Grok1MoE(nn.Module):
bias=False, bias=False,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=None, quant_config=None,
prefix=add_prefix("gate", prefix),
) )
self.router_logit_softcapping = getattr( self.router_logit_softcapping = getattr(
...@@ -135,6 +138,7 @@ class Grok1MoE(nn.Module): ...@@ -135,6 +138,7 @@ class Grok1MoE(nn.Module):
tp_size=tp_size, tp_size=tp_size,
activation="gelu", activation="gelu",
use_presharded_weights=use_presharded_weights, use_presharded_weights=use_presharded_weights,
prefix=add_prefix("experts", prefix),
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -163,6 +167,7 @@ class Grok1Attention(nn.Module): ...@@ -163,6 +167,7 @@ class Grok1Attention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -195,6 +200,7 @@ class Grok1Attention(nn.Module): ...@@ -195,6 +200,7 @@ class Grok1Attention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
...@@ -202,6 +208,7 @@ class Grok1Attention(nn.Module): ...@@ -202,6 +208,7 @@ class Grok1Attention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, reduce_results=reduce_results,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -220,6 +227,7 @@ class Grok1Attention(nn.Module): ...@@ -220,6 +227,7 @@ class Grok1Attention(nn.Module):
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
logit_cap=logit_cap, logit_cap=logit_cap,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -243,6 +251,7 @@ class Grok1DecoderLayer(nn.Module): ...@@ -243,6 +251,7 @@ class Grok1DecoderLayer(nn.Module):
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
use_presharded_weights: bool = False, use_presharded_weights: bool = False,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.num_experts = config.num_local_experts self.num_experts = config.num_local_experts
...@@ -259,6 +268,7 @@ class Grok1DecoderLayer(nn.Module): ...@@ -259,6 +268,7 @@ class Grok1DecoderLayer(nn.Module):
layer_id=layer_id, layer_id=layer_id,
rope_theta=rope_theta, rope_theta=rope_theta,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("attn", prefix),
) )
self.block_sparse_moe = Grok1MoE( self.block_sparse_moe = Grok1MoE(
config=config, config=config,
...@@ -273,6 +283,7 @@ class Grok1DecoderLayer(nn.Module): ...@@ -273,6 +283,7 @@ class Grok1DecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
reduce_results=True, reduce_results=True,
use_presharded_weights=use_presharded_weights, use_presharded_weights=use_presharded_weights,
prefix=add_prefix("block_sparse_moe", prefix),
) )
self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -311,6 +322,7 @@ class Grok1Model(nn.Module): ...@@ -311,6 +322,7 @@ class Grok1Model(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
use_presharded_weights: bool = False, use_presharded_weights: bool = False,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -320,6 +332,7 @@ class Grok1Model(nn.Module): ...@@ -320,6 +332,7 @@ class Grok1Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
...@@ -328,6 +341,7 @@ class Grok1Model(nn.Module): ...@@ -328,6 +341,7 @@ class Grok1Model(nn.Module):
i, i,
quant_config=quant_config, quant_config=quant_config,
use_presharded_weights=use_presharded_weights, use_presharded_weights=use_presharded_weights,
prefix=add_prefix(f"layers.{i}", prefix),
) )
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
...@@ -359,6 +373,7 @@ class Grok1ForCausalLM(nn.Module): ...@@ -359,6 +373,7 @@ class Grok1ForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -377,8 +392,11 @@ class Grok1ForCausalLM(nn.Module): ...@@ -377,8 +392,11 @@ class Grok1ForCausalLM(nn.Module):
config, config,
quant_config=quant_config, quant_config=quant_config,
use_presharded_weights=self.use_presharded_weights, use_presharded_weights=self.use_presharded_weights,
prefix=add_prefix("model", prefix),
)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
) )
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
def forward( def forward(
......
...@@ -38,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -38,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class InternLM2MLP(nn.Module): class InternLM2MLP(nn.Module):
...@@ -47,13 +48,22 @@ class InternLM2MLP(nn.Module): ...@@ -47,13 +48,22 @@ class InternLM2MLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.w2 = RowParallelLinear( self.w2 = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("w2", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -80,6 +90,7 @@ class InternLM2Attention(nn.Module): ...@@ -80,6 +90,7 @@ class InternLM2Attention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -111,12 +122,14 @@ class InternLM2Attention(nn.Module): ...@@ -111,12 +122,14 @@ class InternLM2Attention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("wqkv", prefix),
) )
self.wo = RowParallelLinear( self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("wo", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -127,7 +140,12 @@ class InternLM2Attention(nn.Module): ...@@ -127,7 +140,12 @@ class InternLM2Attention(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
) )
self.attn = RadixAttention( self.attn = RadixAttention(
self.num_heads, self.head_dim, self.scaling, self.num_kv_heads, layer_id self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -150,6 +168,7 @@ class InternLMDecoderLayer(nn.Module): ...@@ -150,6 +168,7 @@ class InternLMDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -165,12 +184,14 @@ class InternLMDecoderLayer(nn.Module): ...@@ -165,12 +184,14 @@ class InternLMDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("attention", prefix),
) )
self.feed_forward = InternLM2MLP( self.feed_forward = InternLM2MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("feed_forward", prefix),
) )
self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -205,6 +226,7 @@ class InternLM2Model(nn.Module): ...@@ -205,6 +226,7 @@ class InternLM2Model(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -213,10 +235,13 @@ class InternLM2Model(nn.Module): ...@@ -213,10 +235,13 @@ class InternLM2Model(nn.Module):
self.tok_embeddings = VocabParallelEmbedding( self.tok_embeddings = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("tok_embeddings", prefix),
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
InternLMDecoderLayer(config, i, quant_config) InternLMDecoderLayer(
config, i, quant_config, prefix=add_prefix(f"layers.{i}", prefix)
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -251,12 +276,17 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -251,12 +276,17 @@ class InternLM2ForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = InternLM2Model(config, quant_config) self.model = InternLM2Model(
self.output = ParallelLMHead(config.vocab_size, config.hidden_size) config, quant_config, prefix=add_prefix("model", prefix)
)
self.output = ParallelLMHead(
config.vocab_size, config.hidden_size, prefix=add_prefix("output", prefix)
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
......
...@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType ...@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.internlm2 import InternLM2ForCausalLM, InternLM2Model from sglang.srt.models.internlm2 import InternLM2ForCausalLM, InternLM2Model
from sglang.srt.utils import add_prefix
class InternLM2ForRewardModel(nn.Module): class InternLM2ForRewardModel(nn.Module):
...@@ -29,12 +30,15 @@ class InternLM2ForRewardModel(nn.Module): ...@@ -29,12 +30,15 @@ class InternLM2ForRewardModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.model = InternLM2Model(config, quant_config) self.model = InternLM2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.v_head = nn.Linear(config.hidden_size, 1, bias=False) self.v_head = nn.Linear(config.hidden_size, 1, bias=False)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
......
...@@ -49,7 +49,7 @@ from sglang.srt.model_loader.weight_utils import ( ...@@ -49,7 +49,7 @@ from sglang.srt.model_loader.weight_utils import (
kv_cache_scales_loader, kv_cache_scales_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
) )
from sglang.srt.utils import make_layers from sglang.srt.utils import add_prefix, make_layers
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -70,14 +70,14 @@ class LlamaMLP(nn.Module): ...@@ -70,14 +70,14 @@ class LlamaMLP(nn.Module):
[intermediate_size] * 2, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj", prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.down_proj", prefix=add_prefix("down_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -142,14 +142,14 @@ class LlamaAttention(nn.Module): ...@@ -142,14 +142,14 @@ class LlamaAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -166,6 +166,7 @@ class LlamaAttention(nn.Module): ...@@ -166,6 +166,7 @@ class LlamaAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -218,7 +219,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -218,7 +219,7 @@ class LlamaDecoderLayer(nn.Module):
rope_is_neox_style=rope_is_neox_style, rope_is_neox_style=rope_is_neox_style,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=add_prefix("self_attn", prefix),
bias=attention_bias, bias=attention_bias,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
...@@ -226,7 +227,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -226,7 +227,7 @@ class LlamaDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp", prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -263,6 +264,7 @@ class LlamaModel(nn.Module): ...@@ -263,6 +264,7 @@ class LlamaModel(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -272,6 +274,7 @@ class LlamaModel(nn.Module): ...@@ -272,6 +274,7 @@ class LlamaModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = make_layers( self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
...@@ -358,18 +361,24 @@ class LlamaForCausalLM(nn.Module): ...@@ -358,18 +361,24 @@ class LlamaForCausalLM(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
# Llama 3.2 1B Instruct set tie_word_embeddings to True # Llama 3.2 1B Instruct set tie_word_embeddings to True
# Llama 3.1 8B Instruct set tie_word_embeddings to False # Llama 3.1 8B Instruct set tie_word_embeddings to False
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
......
...@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
from sglang.srt.utils import add_prefix
class LlamaForClassification(nn.Module): class LlamaForClassification(nn.Module):
...@@ -30,11 +31,14 @@ class LlamaForClassification(nn.Module): ...@@ -30,11 +31,14 @@ class LlamaForClassification(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.classification_head = nn.Linear( self.classification_head = nn.Linear(
config.hidden_size, config.classification_out_size, bias=False config.hidden_size, config.classification_out_size, bias=False
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from sglang.srt.utils import add_prefix
# Adapted from # Adapted from
# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights.""" """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
...@@ -55,6 +57,7 @@ class LlamaModel(nn.Module): ...@@ -55,6 +57,7 @@ class LlamaModel(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -62,11 +65,15 @@ class LlamaModel(nn.Module): ...@@ -62,11 +65,15 @@ class LlamaModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
LlamaDecoderLayer( LlamaDecoderLayer(
config, i, quant_config=quant_config, prefix=f"model.layers.{i}" config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
) )
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
...@@ -106,24 +113,26 @@ class LlamaForCausalLMEagle(LlamaForCausalLM): ...@@ -106,24 +113,26 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
# Llama 3.2 1B Instruct set tie_word_embeddings to True # Llama 3.2 1B Instruct set tie_word_embeddings to True
# Llama 3.1 8B Instruct set tie_word_embeddings to False # Llama 3.1 8B Instruct set tie_word_embeddings to False
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
if hasattr(config, "hot_vocab_size"):
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.hot_vocab_size, config.hidden_size, quant_config=quant_config getattr(config, "hot_vocab_size", config.vocab_size),
) config.hidden_size,
else: quant_config=quant_config,
self.lm_head = ParallelLMHead( prefix=add_prefix("lm_head", prefix),
config.vocab_size, config.hidden_size, quant_config=quant_config
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -8,6 +8,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType ...@@ -8,6 +8,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.model_executor.model_runner import ForwardBatch from sglang.srt.model_executor.model_runner import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaModel from sglang.srt.models.llama import LlamaModel
from sglang.srt.utils import add_prefix
class LlamaEmbeddingModel(nn.Module): class LlamaEmbeddingModel(nn.Module):
...@@ -15,9 +16,12 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -15,9 +16,12 @@ class LlamaEmbeddingModel(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config=None, quant_config=None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@torch.no_grad() @torch.no_grad()
......
...@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType ...@@ -22,6 +22,7 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
from sglang.srt.utils import add_prefix
class LlamaForSequenceClassification(nn.Module): class LlamaForSequenceClassification(nn.Module):
...@@ -29,12 +30,15 @@ class LlamaForSequenceClassification(nn.Module): ...@@ -29,12 +30,15 @@ class LlamaForSequenceClassification(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
...@@ -82,8 +86,9 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific ...@@ -82,8 +86,9 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__(config, quant_config) super().__init__(config, quant_config, prefix=prefix)
self.weights = self.Weights(config.hidden_size, self.num_labels) self.weights = self.Weights(config.hidden_size, self.num_labels)
@torch.no_grad() @torch.no_grad()
......
...@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader ...@@ -42,6 +42,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.utils import add_prefix
class LlavaBaseForCausalLM(nn.Module): class LlavaBaseForCausalLM(nn.Module):
...@@ -475,6 +476,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM): ...@@ -475,6 +476,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
self, self,
config: LlavaConfig, config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -484,7 +486,11 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM): ...@@ -484,7 +486,11 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
self.config.text_config.hidden_size = config.hidden_size self.config.text_config.hidden_size = config.hidden_size
self.multi_modal_projector = LlavaMultiModalProjector(config) self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = LlamaForCausalLM(config, quant_config=quant_config) self.language_model = LlamaForCausalLM(
config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
if "unpad" in getattr(config, "mm_patch_merge_type", ""): if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter( self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16) torch.empty(config.text_config.hidden_size, dtype=torch.float16)
...@@ -496,6 +502,7 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM): ...@@ -496,6 +502,7 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
self, self,
config: LlavaConfig, config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -516,7 +523,11 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM): ...@@ -516,7 +523,11 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
self.config.image_token_index = 151646 self.config.image_token_index = 151646
self.multi_modal_projector = LlavaMultiModalProjector(config) self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config) self.language_model = Qwen2ForCausalLM(
config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
if "unpad" in getattr(config, "mm_patch_merge_type", ""): if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter( self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16) torch.empty(config.text_config.hidden_size, dtype=torch.float16)
...@@ -528,6 +539,7 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM): ...@@ -528,6 +539,7 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
self, self,
config: LlavaConfig, config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -548,7 +560,11 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM): ...@@ -548,7 +560,11 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
self.config.image_token_index = 32000 self.config.image_token_index = 32000
self.multi_modal_projector = LlavaMultiModalProjector(config) self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = MistralForCausalLM(config, quant_config=quant_config) self.language_model = MistralForCausalLM(
config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
if "unpad" in getattr(config, "mm_patch_merge_type", ""): if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter( self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16) torch.empty(config.text_config.hidden_size, dtype=torch.float16)
......
...@@ -26,6 +26,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs ...@@ -26,6 +26,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.utils import add_prefix
class LlavaVidForCausalLM(nn.Module): class LlavaVidForCausalLM(nn.Module):
...@@ -33,6 +34,7 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -33,6 +34,7 @@ class LlavaVidForCausalLM(nn.Module):
self, self,
config: LlavaConfig, config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -44,7 +46,11 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -44,7 +46,11 @@ class LlavaVidForCausalLM(nn.Module):
self.resampler = nn.AvgPool2d( self.resampler = nn.AvgPool2d(
kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
) )
self.language_model = LlamaForCausalLM(config, quant_config=quant_config) self.language_model = LlamaForCausalLM(
config,
quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
)
self.num_frames = getattr(self.config, "num_frames", 16) self.num_frames = getattr(self.config, "num_frames", 16)
if "unpad" in getattr(config, "mm_patch_merge_type", ""): if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter( self.language_model.model.image_newline = nn.Parameter(
......
...@@ -37,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -37,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class MiniCPMMLP(nn.Module): class MiniCPMMLP(nn.Module):
...@@ -46,6 +47,7 @@ class MiniCPMMLP(nn.Module): ...@@ -46,6 +47,7 @@ class MiniCPMMLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -53,12 +55,14 @@ class MiniCPMMLP(nn.Module): ...@@ -53,12 +55,14 @@ class MiniCPMMLP(nn.Module):
[intermediate_size] * 2, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -85,6 +89,7 @@ class MiniCPMAttention(nn.Module): ...@@ -85,6 +89,7 @@ class MiniCPMAttention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -116,12 +121,14 @@ class MiniCPMAttention(nn.Module): ...@@ -116,12 +121,14 @@ class MiniCPMAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -139,6 +146,7 @@ class MiniCPMAttention(nn.Module): ...@@ -139,6 +146,7 @@ class MiniCPMAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -164,6 +172,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -164,6 +172,7 @@ class MiniCPMDecoderLayer(nn.Module):
config, config,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -180,12 +189,14 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -180,12 +189,14 @@ class MiniCPMDecoderLayer(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
) )
self.mlp = MiniCPMMLP( self.mlp = MiniCPMMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -227,6 +238,7 @@ class MiniCPMModel(nn.Module): ...@@ -227,6 +238,7 @@ class MiniCPMModel(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -236,10 +248,16 @@ class MiniCPMModel(nn.Module): ...@@ -236,10 +248,16 @@ class MiniCPMModel(nn.Module):
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
MiniCPMDecoderLayer(config, i, quant_config=quant_config) MiniCPMDecoderLayer(
config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -275,19 +293,23 @@ class MiniCPMForCausalLM(nn.Module): ...@@ -275,19 +293,23 @@ class MiniCPMForCausalLM(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config self.quant_config = quant_config
self.model = MiniCPMModel(config, quant_config=quant_config) self.model = MiniCPMModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if not self.config.tie_word_embeddings: if not self.config.tie_word_embeddings:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
prefix=add_prefix("lm_head", prefix),
) )
self.scale_width = self.config.hidden_size / self.config.dim_model_base self.scale_width = self.config.hidden_size / self.config.dim_model_base
......
...@@ -40,7 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -40,7 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import is_cuda_available from sglang.srt.utils import add_prefix, is_cuda_available
if is_cuda_available(): if is_cuda_available():
from sgl_kernel import bmm_fp8 from sgl_kernel import bmm_fp8
...@@ -53,6 +53,7 @@ class MiniCPM3MLP(nn.Module): ...@@ -53,6 +53,7 @@ class MiniCPM3MLP(nn.Module):
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
...@@ -60,12 +61,14 @@ class MiniCPM3MLP(nn.Module): ...@@ -60,12 +61,14 @@ class MiniCPM3MLP(nn.Module):
[intermediate_size] * 2, [intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
...@@ -107,6 +110,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -107,6 +110,7 @@ class MiniCPM3Attention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id=None, layer_id=None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
...@@ -131,6 +135,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -131,6 +135,7 @@ class MiniCPM3Attention(nn.Module):
self.q_lora_rank, self.q_lora_rank,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
) )
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear( self.q_b_proj = ColumnParallelLinear(
...@@ -138,6 +143,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -138,6 +143,7 @@ class MiniCPM3Attention(nn.Module):
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
) )
else: else:
self.q_proj = ColumnParallelLinear( self.q_proj = ColumnParallelLinear(
...@@ -145,6 +151,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -145,6 +151,7 @@ class MiniCPM3Attention(nn.Module):
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
) )
self.kv_a_proj_with_mqa = ReplicatedLinear( self.kv_a_proj_with_mqa = ReplicatedLinear(
...@@ -152,6 +159,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -152,6 +159,7 @@ class MiniCPM3Attention(nn.Module):
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
...@@ -159,6 +167,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -159,6 +167,7 @@ class MiniCPM3Attention(nn.Module):
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
) )
# O projection. # O projection.
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -166,6 +175,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -166,6 +175,7 @@ class MiniCPM3Attention(nn.Module):
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
...@@ -182,6 +192,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -182,6 +192,7 @@ class MiniCPM3Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_local_heads, num_kv_heads=self.num_local_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -250,6 +261,7 @@ class MiniCPM3AttentionMLA(nn.Module): ...@@ -250,6 +261,7 @@ class MiniCPM3AttentionMLA(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id=None, layer_id=None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
...@@ -274,6 +286,7 @@ class MiniCPM3AttentionMLA(nn.Module): ...@@ -274,6 +286,7 @@ class MiniCPM3AttentionMLA(nn.Module):
self.q_lora_rank, self.q_lora_rank,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
) )
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear( self.q_b_proj = ColumnParallelLinear(
...@@ -281,6 +294,7 @@ class MiniCPM3AttentionMLA(nn.Module): ...@@ -281,6 +294,7 @@ class MiniCPM3AttentionMLA(nn.Module):
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
) )
else: else:
self.q_proj = ColumnParallelLinear( self.q_proj = ColumnParallelLinear(
...@@ -288,6 +302,7 @@ class MiniCPM3AttentionMLA(nn.Module): ...@@ -288,6 +302,7 @@ class MiniCPM3AttentionMLA(nn.Module):
self.num_heads * self.qk_head_dim, self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
) )
self.kv_a_proj_with_mqa = ReplicatedLinear( self.kv_a_proj_with_mqa = ReplicatedLinear(
...@@ -295,6 +310,7 @@ class MiniCPM3AttentionMLA(nn.Module): ...@@ -295,6 +310,7 @@ class MiniCPM3AttentionMLA(nn.Module):
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
...@@ -302,6 +318,7 @@ class MiniCPM3AttentionMLA(nn.Module): ...@@ -302,6 +318,7 @@ class MiniCPM3AttentionMLA(nn.Module):
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
) )
# O projection. # O projection.
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -309,6 +326,7 @@ class MiniCPM3AttentionMLA(nn.Module): ...@@ -309,6 +326,7 @@ class MiniCPM3AttentionMLA(nn.Module):
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
...@@ -325,6 +343,7 @@ class MiniCPM3AttentionMLA(nn.Module): ...@@ -325,6 +343,7 @@ class MiniCPM3AttentionMLA(nn.Module):
num_kv_heads=1, num_kv_heads=1,
layer_id=layer_id, layer_id=layer_id,
v_head_dim=self.kv_lora_rank, v_head_dim=self.kv_lora_rank,
prefix=add_prefix("attn", prefix),
) )
self.w_kc = None self.w_kc = None
...@@ -405,6 +424,7 @@ class MiniCPM3DecoderLayer(nn.Module): ...@@ -405,6 +424,7 @@ class MiniCPM3DecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -429,6 +449,7 @@ class MiniCPM3DecoderLayer(nn.Module): ...@@ -429,6 +449,7 @@ class MiniCPM3DecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("self_attn", prefix),
) )
else: else:
self.self_attn = MiniCPM3Attention( self.self_attn = MiniCPM3Attention(
...@@ -447,12 +468,14 @@ class MiniCPM3DecoderLayer(nn.Module): ...@@ -447,12 +468,14 @@ class MiniCPM3DecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("self_attn", prefix),
) )
self.mlp = MiniCPM3MLP( self.mlp = MiniCPM3MLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -494,6 +517,7 @@ class MiniCPM3Model(nn.Module): ...@@ -494,6 +517,7 @@ class MiniCPM3Model(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -503,10 +527,16 @@ class MiniCPM3Model(nn.Module): ...@@ -503,10 +527,16 @@ class MiniCPM3Model(nn.Module):
self.vocab_size, self.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
MiniCPM3DecoderLayer(config, i, quant_config=quant_config) MiniCPM3DecoderLayer(
config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -542,19 +572,23 @@ class MiniCPM3ForCausalLM(nn.Module): ...@@ -542,19 +572,23 @@ class MiniCPM3ForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config self.quant_config = quant_config
self.model = MiniCPM3Model(config, quant_config=quant_config) self.model = MiniCPM3Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if not self.config.tie_word_embeddings: if not self.config.tie_word_embeddings:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
prefix=add_prefix("lm_head", prefix),
) )
self.scale_width = self.config.hidden_size / self.config.dim_model_base self.scale_width = self.config.hidden_size / self.config.dim_model_base
......
...@@ -56,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch ...@@ -56,6 +56,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
from sglang.srt.utils import add_prefix
RawImageType = Union[Image.Image, torch.Tensor] RawImageType = Union[Image.Image, torch.Tensor]
...@@ -158,14 +159,14 @@ class Idefics2VisionMLP(nn.Module): ...@@ -158,14 +159,14 @@ class Idefics2VisionMLP(nn.Module):
config.intermediate_size, config.intermediate_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1", prefix=add_prefix("fc1", prefix),
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc2", prefix=add_prefix("fc2", prefix),
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -199,10 +200,14 @@ class Idefics2EncoderLayer(nn.Module): ...@@ -199,10 +200,14 @@ class Idefics2EncoderLayer(nn.Module):
use_context_forward=False, use_context_forward=False,
use_full_precision_softmax=True, use_full_precision_softmax=True,
flatten_batch=False, flatten_batch=False,
prefix=f"{prefix}.self_attn", prefix=add_prefix("self_attn", prefix),
) )
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Idefics2VisionMLP(config, quant_config=quant_config) self.mlp = Idefics2VisionMLP(
config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward( def forward(
...@@ -242,6 +247,7 @@ class Idefics2Encoder(nn.Module): ...@@ -242,6 +247,7 @@ class Idefics2Encoder(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -251,8 +257,9 @@ class Idefics2Encoder(nn.Module): ...@@ -251,8 +257,9 @@ class Idefics2Encoder(nn.Module):
Idefics2EncoderLayer( Idefics2EncoderLayer(
config, config,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
) )
for _ in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -379,13 +386,18 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -379,13 +386,18 @@ class Idefics2VisionTransformer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
embed_dim = config.hidden_size embed_dim = config.hidden_size
self.config = config self.config = config
self.embeddings = Idefics2VisionEmbeddings(config) self.embeddings = Idefics2VisionEmbeddings(config)
self.encoder = Idefics2Encoder(config=config, quant_config=quant_config) self.encoder = Idefics2Encoder(
config=config,
quant_config=quant_config,
prefix=add_prefix("encoder", prefix),
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def get_input_embeddings(self): def get_input_embeddings(self):
...@@ -503,7 +515,7 @@ class BaseResampler(nn.Module): ...@@ -503,7 +515,7 @@ class BaseResampler(nn.Module):
embed_dim, embed_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.kv_proj", prefix=add_prefix("kv_proj", prefix),
) )
else: else:
# Maintain the same return value with ReplicatedLinear.forward # Maintain the same return value with ReplicatedLinear.forward
...@@ -660,6 +672,7 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -660,6 +672,7 @@ class MiniCPMVBaseModel(nn.Module):
*, *,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but # All MiniCPM-V models disable `tie_word_embeddings` but
...@@ -669,8 +682,12 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -669,8 +682,12 @@ class MiniCPMVBaseModel(nn.Module):
self.config = config self.config = config
self.version = get_version_by_config(self.config) self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config=config, quant_config=quant_config) self.llm = self.init_llm(
self.vpm = self.init_vision_module(config, quant_config) config=config, quant_config=quant_config, prefix=add_prefix("llm", prefix)
)
self.vpm = self.init_vision_module(
config, quant_config, add_prefix("vpm", prefix)
)
self.vision_dim = ( self.vision_dim = (
self.vpm.embed_dim self.vpm.embed_dim
if self.version == (2, 0) if self.version == (2, 0)
...@@ -679,7 +696,10 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -679,7 +696,10 @@ class MiniCPMVBaseModel(nn.Module):
self.embed_dim = self.config.hidden_size self.embed_dim = self.config.hidden_size
self.resampler = self.init_resampler( self.resampler = self.init_resampler(
self.embed_dim, self.vision_dim, quant_config=quant_config self.embed_dim,
self.vision_dim,
quant_config=quant_config,
prefix=add_prefix("resampler", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
...@@ -937,6 +957,7 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -937,6 +957,7 @@ class MiniCPMVBaseModel(nn.Module):
self, self,
config: Qwen2Config, config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module: ) -> nn.Module:
raise NotImplementedError raise NotImplementedError
...@@ -944,6 +965,7 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -944,6 +965,7 @@ class MiniCPMVBaseModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig], quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module: ) -> nn.Module:
raise NotImplementedError raise NotImplementedError
...@@ -952,6 +974,7 @@ class MiniCPMVBaseModel(nn.Module): ...@@ -952,6 +974,7 @@ class MiniCPMVBaseModel(nn.Module):
embed_dim: int, embed_dim: int,
vision_dim: int, vision_dim: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module: ) -> nn.Module:
raise NotImplementedError raise NotImplementedError
...@@ -1011,24 +1034,27 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -1011,24 +1034,27 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__(config=config, quant_config=quant_config) super().__init__(config=config, quant_config=quant_config, prefix=prefix)
assert self.version == (2, 6) assert self.version == (2, 6)
def init_llm( def init_llm(
self, self,
config: Qwen2Config, config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module: ) -> nn.Module:
return Qwen2ForCausalLM(config=config, quant_config=quant_config) return Qwen2ForCausalLM(config=config, quant_config=quant_config, prefix=prefix)
def init_vision_module( def init_vision_module(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig], quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> nn.Module: ) -> nn.Module:
model = Idefics2VisionTransformer( model = Idefics2VisionTransformer(
config=config.vision_config, quant_config=quant_config config=config.vision_config, quant_config=quant_config, prefix=prefix
) )
if self.config.drop_vision_last_layer: if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1] model.encoder.layers = model.encoder.layers[:-1]
...@@ -1042,6 +1068,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -1042,6 +1068,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
embed_dim: int, embed_dim: int,
vision_dim: int, vision_dim: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> nn.Module: ) -> nn.Module:
with set_default_torch_dtype(torch.float16): with set_default_torch_dtype(torch.float16):
# The resampler in 2.6 remains consistent with the one in 2.5. # The resampler in 2.6 remains consistent with the one in 2.5.
...@@ -1051,6 +1078,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel): ...@@ -1051,6 +1078,7 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
num_heads=embed_dim // 128, num_heads=embed_dim // 128,
kv_dim=vision_dim, kv_dim=vision_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix,
) )
return resampler.to(device="cuda", dtype=torch.get_default_dtype()) return resampler.to(device="cuda", dtype=torch.get_default_dtype())
...@@ -1207,6 +1235,7 @@ class MiniCPMV: ...@@ -1207,6 +1235,7 @@ class MiniCPMV:
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1221,7 +1250,9 @@ class MiniCPMV: ...@@ -1221,7 +1250,9 @@ class MiniCPMV:
raise ValueError("Currently, MiniCPMV only supports versions 2.6") raise ValueError("Currently, MiniCPMV only supports versions 2.6")
try: try:
minicpmv = instance_class(config=config, quant_config=quant_config) minicpmv = instance_class(
config=config, quant_config=quant_config, prefix=prefix
)
self.minicpmv = minicpmv self.minicpmv = minicpmv
except Exception as e: except Exception as e:
print(f"Failed to instantiate MiniCPMV: {e}") print(f"Failed to instantiate MiniCPMV: {e}")
......
...@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
...@@ -78,7 +79,7 @@ class MixtralMoE(nn.Module): ...@@ -78,7 +79,7 @@ class MixtralMoE(nn.Module):
bias=False, bias=False,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=None, quant_config=None,
prefix=f"{prefix}.gate", prefix=add_prefix("gate", prefix),
) )
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl( self.experts = MoEImpl(
...@@ -90,7 +91,7 @@ class MixtralMoE(nn.Module): ...@@ -90,7 +91,7 @@ class MixtralMoE(nn.Module):
renormalize=True, renormalize=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
prefix=f"{prefix}.experts", prefix=add_prefix("experts", prefix),
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -146,14 +147,14 @@ class MixtralAttention(nn.Module): ...@@ -146,14 +147,14 @@ class MixtralAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv_proj", prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -168,6 +169,7 @@ class MixtralAttention(nn.Module): ...@@ -168,6 +169,7 @@ class MixtralAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -204,7 +206,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -204,7 +206,7 @@ class MixtralDecoderLayer(nn.Module):
layer_id=layer_id, layer_id=layer_id,
rope_theta=rope_theta, rope_theta=rope_theta,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=add_prefix("self_attn", prefix),
) )
self.block_sparse_moe = MixtralMoE( self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts, num_experts=config.num_local_experts,
...@@ -212,7 +214,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -212,7 +214,7 @@ class MixtralDecoderLayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.block_sparse_moe", prefix=add_prefix("block_sparse_moe", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -258,11 +260,15 @@ class MixtralModel(nn.Module): ...@@ -258,11 +260,15 @@ class MixtralModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
MixtralDecoderLayer( MixtralDecoderLayer(
config, i, quant_config=quant_config, prefix=f"{prefix}.layers" config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
) )
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
...@@ -296,12 +302,17 @@ class MixtralForCausalLM(nn.Module): ...@@ -296,12 +302,17 @@ class MixtralForCausalLM(nn.Module):
self, self,
config: MixtralConfig, config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.model = MixtralModel(
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
def forward( def forward(
......
...@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -45,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
class MixtralMLP(nn.Module): class MixtralMLP(nn.Module):
...@@ -54,6 +55,7 @@ class MixtralMLP(nn.Module): ...@@ -54,6 +55,7 @@ class MixtralMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.num_experts = num_experts self.num_experts = num_experts
...@@ -61,13 +63,25 @@ class MixtralMLP(nn.Module): ...@@ -61,13 +63,25 @@ class MixtralMLP(nn.Module):
self.hidden_dim = hidden_size self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear( self.w1 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config self.hidden_dim,
self.ffn_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("w1", prefix),
) )
self.w2 = ReplicatedLinear( self.w2 = ReplicatedLinear(
self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config self.ffn_dim,
self.hidden_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("w2", prefix),
) )
self.w3 = ReplicatedLinear( self.w3 = ReplicatedLinear(
self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config self.hidden_dim,
self.ffn_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("w3", prefix),
) )
# TODO: Use vllm's SiluAndMul # TODO: Use vllm's SiluAndMul
...@@ -87,6 +101,7 @@ class MixtralMoE(nn.Module): ...@@ -87,6 +101,7 @@ class MixtralMoE(nn.Module):
self, self,
config: MixtralConfig, config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -114,6 +129,7 @@ class MixtralMoE(nn.Module): ...@@ -114,6 +129,7 @@ class MixtralMoE(nn.Module):
config.hidden_size, config.hidden_size,
config.intermediate_size, config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix(f"experts.{idx}", prefix),
) )
if idx in self.expert_indicies if idx in self.expert_indicies
else None else None
...@@ -122,7 +138,11 @@ class MixtralMoE(nn.Module): ...@@ -122,7 +138,11 @@ class MixtralMoE(nn.Module):
] ]
) )
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
config.hidden_size, self.num_total_experts, bias=False, quant_config=None config.hidden_size,
self.num_total_experts,
bias=False,
quant_config=None,
prefix=add_prefix("gate", prefix),
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -159,6 +179,7 @@ class MixtralAttention(nn.Module): ...@@ -159,6 +179,7 @@ class MixtralAttention(nn.Module):
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -189,12 +210,14 @@ class MixtralAttention(nn.Module): ...@@ -189,12 +210,14 @@ class MixtralAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -209,6 +232,7 @@ class MixtralAttention(nn.Module): ...@@ -209,6 +232,7 @@ class MixtralAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -231,6 +255,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -231,6 +255,7 @@ class MixtralDecoderLayer(nn.Module):
config: MixtralConfig, config: MixtralConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -244,8 +269,13 @@ class MixtralDecoderLayer(nn.Module): ...@@ -244,8 +269,13 @@ class MixtralDecoderLayer(nn.Module):
layer_id=layer_id, layer_id=layer_id,
rope_theta=rope_theta, rope_theta=rope_theta,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
)
self.block_sparse_moe = MixtralMoE(
config=config,
quant_config=quant_config,
prefix=add_prefix("block_sparse_moe", prefix),
) )
self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
...@@ -281,6 +311,7 @@ class MixtralModel(nn.Module): ...@@ -281,6 +311,7 @@ class MixtralModel(nn.Module):
self, self,
config: MixtralConfig, config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -289,10 +320,16 @@ class MixtralModel(nn.Module): ...@@ -289,10 +320,16 @@ class MixtralModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
MixtralDecoderLayer(config, i, quant_config=quant_config) MixtralDecoderLayer(
config,
i,
quant_config=quant_config,
prefix=add_prefix(f"layers.{i}", prefix),
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -324,12 +361,17 @@ class QuantMixtralForCausalLM(nn.Module): ...@@ -324,12 +361,17 @@ class QuantMixtralForCausalLM(nn.Module):
self, self,
config: MixtralConfig, config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = MixtralModel(config, quant_config=quant_config) self.model = MixtralModel(
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, prefix=add_prefix("lm_head", prefix)
)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
......
...@@ -36,6 +36,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs ...@@ -36,6 +36,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
from sglang.srt.utils import add_prefix
class ColumnParallelConv2dPatch(torch.nn.Module): class ColumnParallelConv2dPatch(torch.nn.Module):
...@@ -147,7 +148,12 @@ class MllamaPrecomputedPositionEmbedding(nn.Module): ...@@ -147,7 +148,12 @@ class MllamaPrecomputedPositionEmbedding(nn.Module):
class MllamaVisionMLP(nn.Module): class MllamaVisionMLP(nn.Module):
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
self.config = config self.config = config
self.activation_fn = get_act_fn(config.hidden_act) self.activation_fn = get_act_fn(config.hidden_act)
...@@ -156,12 +162,14 @@ class MllamaVisionMLP(nn.Module): ...@@ -156,12 +162,14 @@ class MllamaVisionMLP(nn.Module):
config.intermediate_size, config.intermediate_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("fc1", prefix),
) )
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.intermediate_size, config.intermediate_size,
config.hidden_size, config.hidden_size,
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("fc2", prefix),
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -174,7 +182,10 @@ class MllamaVisionMLP(nn.Module): ...@@ -174,7 +182,10 @@ class MllamaVisionMLP(nn.Module):
class MllamaVisionEncoderLayer(nn.Module): class MllamaVisionEncoderLayer(nn.Module):
def __init__( def __init__(
self, config: config_mllama.MllamaVisionConfig, is_gated: bool = False self,
config: config_mllama.MllamaVisionConfig,
is_gated: bool = False,
prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -193,8 +204,9 @@ class MllamaVisionEncoderLayer(nn.Module): ...@@ -193,8 +204,9 @@ class MllamaVisionEncoderLayer(nn.Module):
use_context_forward=False, use_context_forward=False,
use_full_precision_softmax=False, use_full_precision_softmax=False,
flatten_batch=False, flatten_batch=False,
prefix=add_prefix("self_attn", prefix),
) )
self.mlp = MllamaVisionMLP(config) self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix))
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
self.post_attention_layernorm = nn.LayerNorm( self.post_attention_layernorm = nn.LayerNorm(
...@@ -235,11 +247,17 @@ class MllamaVisionEncoder(nn.Module): ...@@ -235,11 +247,17 @@ class MllamaVisionEncoder(nn.Module):
num_layers=32, num_layers=32,
is_gated=False, is_gated=False,
output_hidden_states=None, output_hidden_states=None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)] [
MllamaVisionEncoderLayer(
config, is_gated, prefix=add_prefix(f"layers.{i}", prefix)
)
for i in range(num_layers)
]
) )
self.output_hidden_states = output_hidden_states or [] self.output_hidden_states = output_hidden_states or []
...@@ -265,7 +283,7 @@ class MllamaVisionEncoder(nn.Module): ...@@ -265,7 +283,7 @@ class MllamaVisionEncoder(nn.Module):
class MllamaVisionModel(nn.Module): class MllamaVisionModel(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig): def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
super().__init__() super().__init__()
self.image_size = config.image_size self.image_size = config.image_size
self.patch_size = config.patch_size self.patch_size = config.patch_size
...@@ -305,9 +323,13 @@ class MllamaVisionModel(nn.Module): ...@@ -305,9 +323,13 @@ class MllamaVisionModel(nn.Module):
config.num_hidden_layers, config.num_hidden_layers,
is_gated=False, is_gated=False,
output_hidden_states=config.intermediate_layers_indices, output_hidden_states=config.intermediate_layers_indices,
prefix=add_prefix("transformer", prefix),
) )
self.global_transformer = MllamaVisionEncoder( self.global_transformer = MllamaVisionEncoder(
config, config.num_global_layers, is_gated=True config,
config.num_global_layers,
is_gated=True,
prefix=add_prefix("global_transformer", prefix),
) )
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
...@@ -464,6 +486,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -464,6 +486,7 @@ class MllamaTextCrossAttention(nn.Module):
config: Optional[config_mllama.MllamaTextConfig] = None, config: Optional[config_mllama.MllamaTextConfig] = None,
layer_id: Optional[int] = None, layer_id: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -489,6 +512,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -489,6 +512,7 @@ class MllamaTextCrossAttention(nn.Module):
self.num_key_value_heads, self.num_key_value_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim, self.num_heads * self.head_dim,
...@@ -496,6 +520,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -496,6 +520,7 @@ class MllamaTextCrossAttention(nn.Module):
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue, # vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
# use huggingface's instead # use huggingface's instead
...@@ -510,6 +535,7 @@ class MllamaTextCrossAttention(nn.Module): ...@@ -510,6 +535,7 @@ class MllamaTextCrossAttention(nn.Module):
self.num_local_key_value_heads, self.num_local_key_value_heads,
layer_id=layer_id, layer_id=layer_id,
is_cross_attention=True, is_cross_attention=True,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -551,6 +577,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): ...@@ -551,6 +577,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
config: config_mllama.MllamaTextConfig, config: config_mllama.MllamaTextConfig,
layer_id: int, layer_id: int,
quant_config: Optional[QuantizationConfig], quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
...@@ -558,6 +585,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): ...@@ -558,6 +585,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
config=config, config=config,
layer_id=layer_id, layer_id=layer_id,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("cross_attn", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -568,6 +596,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): ...@@ -568,6 +596,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps config.hidden_size, eps=config.rms_norm_eps
...@@ -610,12 +639,15 @@ class MllamaTextModel(nn.Module): ...@@ -610,12 +639,15 @@ class MllamaTextModel(nn.Module):
self, self,
config: config_mllama.MllamaTextConfig, config: config_mllama.MllamaTextConfig,
quant_config: Optional[QuantizationConfig], quant_config: Optional[QuantizationConfig],
prefix: str = "",
): ):
super().__init__() super().__init__()
self.padding_id = config.pad_token_id self.padding_id = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size + 8, config.hidden_size config.vocab_size + 8,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.cross_attention_layers = config.cross_attention_layers self.cross_attention_layers = config.cross_attention_layers
...@@ -624,14 +656,20 @@ class MllamaTextModel(nn.Module): ...@@ -624,14 +656,20 @@ class MllamaTextModel(nn.Module):
if layer_id in self.cross_attention_layers: if layer_id in self.cross_attention_layers:
layers.append( layers.append(
MllamaCrossAttentionDecoderLayer( MllamaCrossAttentionDecoderLayer(
config, layer_id, quant_config=quant_config config,
layer_id,
quant_config=quant_config,
prefix=add_prefix(f"layers.{layer_id}", prefix),
) )
) )
else: else:
# TODO: force LlamaDecoderLayer to config.attention_bias=False # TODO: force LlamaDecoderLayer to config.attention_bias=False
layers.append( layers.append(
LlamaDecoderLayer( LlamaDecoderLayer(
config, quant_config=quant_config, layer_id=layer_id config,
quant_config=quant_config,
layer_id=layer_id,
prefix=add_prefix(f"layers.{layer_id}", prefix),
) )
) )
...@@ -687,16 +725,20 @@ class MllamaForCausalLM(nn.Module): ...@@ -687,16 +725,20 @@ class MllamaForCausalLM(nn.Module):
self, self,
config: config_mllama.MllamaTextConfig, config: config_mllama.MllamaTextConfig,
quant_config: Optional[QuantizationConfig], quant_config: Optional[QuantizationConfig],
prefix: str = "",
): ):
super().__init__() super().__init__()
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.model = MllamaTextModel(config, quant_config) self.model = MllamaTextModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
def forward( def forward(
...@@ -726,6 +768,7 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -726,6 +768,7 @@ class MllamaForConditionalGeneration(nn.Module):
self, self,
config: config_mllama.MllamaConfig, config: config_mllama.MllamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
...@@ -737,10 +780,13 @@ class MllamaForConditionalGeneration(nn.Module): ...@@ -737,10 +780,13 @@ class MllamaForConditionalGeneration(nn.Module):
) )
self.image_size = config.vision_config.image_size self.image_size = config.vision_config.image_size
self.vision_model = MllamaVisionModel(config.vision_config) self.vision_model = MllamaVisionModel(
config.vision_config, prefix=add_prefix("vision_model", prefix)
)
self.language_model = MllamaForCausalLM( self.language_model = MllamaForCausalLM(
config.text_config, config.text_config,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("language_model", prefix),
) )
self.multi_modal_projector = nn.Linear( self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim, config.vision_config.vision_output_dim,
......
...@@ -38,7 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -38,7 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers from sglang.srt.utils import add_prefix, make_layers
class OlmoAttention(nn.Module): class OlmoAttention(nn.Module):
...@@ -53,6 +53,7 @@ class OlmoAttention(nn.Module): ...@@ -53,6 +53,7 @@ class OlmoAttention(nn.Module):
config: OlmoConfig, config: OlmoConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -75,6 +76,7 @@ class OlmoAttention(nn.Module): ...@@ -75,6 +76,7 @@ class OlmoAttention(nn.Module):
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
bias=config.attention_bias, bias=config.attention_bias,
prefix=add_prefix("qkv_proj", prefix),
) )
# Rotary embeddings. # Rotary embeddings.
...@@ -91,6 +93,7 @@ class OlmoAttention(nn.Module): ...@@ -91,6 +93,7 @@ class OlmoAttention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_heads, num_kv_heads=self.num_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
# Attention output projection. # Attention output projection.
...@@ -98,6 +101,7 @@ class OlmoAttention(nn.Module): ...@@ -98,6 +101,7 @@ class OlmoAttention(nn.Module):
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=config.attention_bias, bias=config.attention_bias,
prefix=add_prefix("o_proj", prefix),
) )
def forward( def forward(
...@@ -127,6 +131,7 @@ class OlmoMLP(nn.Module): ...@@ -127,6 +131,7 @@ class OlmoMLP(nn.Module):
self, self,
config: OlmoConfig, config: OlmoConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -139,6 +144,7 @@ class OlmoMLP(nn.Module): ...@@ -139,6 +144,7 @@ class OlmoMLP(nn.Module):
[self.intermediate_size] * 2, [self.intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
# Activation function. # Activation function.
...@@ -150,6 +156,7 @@ class OlmoMLP(nn.Module): ...@@ -150,6 +156,7 @@ class OlmoMLP(nn.Module):
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
def forward( def forward(
...@@ -174,13 +181,23 @@ class OlmoDecoderLayer(nn.Module): ...@@ -174,13 +181,23 @@ class OlmoDecoderLayer(nn.Module):
config: OlmoConfig, config: OlmoConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
# Attention block. # Attention block.
self.self_attn = OlmoAttention(config, layer_id, quant_config) self.self_attn = OlmoAttention(
config,
layer_id,
quant_config,
prefix=add_prefix("self_attn", prefix),
)
# MLP block. # MLP block.
self.mlp = OlmoMLP(config, quant_config) self.mlp = OlmoMLP(
config,
quant_config,
prefix=add_prefix("mlp", prefix),
)
# LayerNorm # LayerNorm
self.input_layernorm = nn.LayerNorm( self.input_layernorm = nn.LayerNorm(
...@@ -213,13 +230,18 @@ class OlmoDecoderLayer(nn.Module): ...@@ -213,13 +230,18 @@ class OlmoDecoderLayer(nn.Module):
class OlmoModel(nn.Module): class OlmoModel(nn.Module):
def __init__( def __init__(
self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None self,
config: OlmoConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = make_layers( self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
...@@ -227,7 +249,9 @@ class OlmoModel(nn.Module): ...@@ -227,7 +249,9 @@ class OlmoModel(nn.Module):
layer_id=idx, layer_id=idx,
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix,
), ),
prefix=add_prefix("layers", prefix),
) )
self.norm = nn.LayerNorm( self.norm = nn.LayerNorm(
config.hidden_size, elementwise_affine=False, bias=False config.hidden_size, elementwise_affine=False, bias=False
...@@ -275,10 +299,11 @@ class OlmoForCausalLM(nn.Module): ...@@ -275,10 +299,11 @@ class OlmoForCausalLM(nn.Module):
self, self,
config: OlmoConfig, config: OlmoConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = OlmoModel(config, quant_config) self.model = OlmoModel(config, quant_config, prefix=add_prefix("model", prefix))
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
...@@ -288,6 +313,7 @@ class OlmoForCausalLM(nn.Module): ...@@ -288,6 +313,7 @@ class OlmoForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -45,7 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -45,7 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers from sglang.srt.utils import add_prefix, make_layers
class Olmo2Attention(nn.Module): class Olmo2Attention(nn.Module):
...@@ -60,6 +60,7 @@ class Olmo2Attention(nn.Module): ...@@ -60,6 +60,7 @@ class Olmo2Attention(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -93,6 +94,8 @@ class Olmo2Attention(nn.Module): ...@@ -93,6 +94,8 @@ class Olmo2Attention(nn.Module):
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
bias=config.attention_bias, bias=config.attention_bias,
quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
...@@ -115,6 +118,7 @@ class Olmo2Attention(nn.Module): ...@@ -115,6 +118,7 @@ class Olmo2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("attn", prefix),
) )
# Attention output projection. # Attention output projection.
...@@ -122,6 +126,8 @@ class Olmo2Attention(nn.Module): ...@@ -122,6 +126,8 @@ class Olmo2Attention(nn.Module):
self.head_dim * self.total_num_heads, self.head_dim * self.total_num_heads,
self.hidden_size, self.hidden_size,
bias=config.attention_bias, bias=config.attention_bias,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
def _apply_qk_norm( def _apply_qk_norm(
...@@ -164,6 +170,7 @@ class Olmo2MLP(nn.Module): ...@@ -164,6 +170,7 @@ class Olmo2MLP(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -176,6 +183,7 @@ class Olmo2MLP(nn.Module): ...@@ -176,6 +183,7 @@ class Olmo2MLP(nn.Module):
[self.intermediate_size] * 2, [self.intermediate_size] * 2,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
) )
# Activation function. # Activation function.
...@@ -187,6 +195,7 @@ class Olmo2MLP(nn.Module): ...@@ -187,6 +195,7 @@ class Olmo2MLP(nn.Module):
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
) )
def forward( def forward(
...@@ -211,13 +220,16 @@ class Olmo2DecoderLayer(nn.Module): ...@@ -211,13 +220,16 @@ class Olmo2DecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
# Attention block. # Attention block.
self.self_attn = Olmo2Attention(config, layer_id, quant_config) self.self_attn = Olmo2Attention(
config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
)
# MLP block. # MLP block.
self.mlp = Olmo2MLP(config, quant_config) self.mlp = Olmo2MLP(config, quant_config, prefix=add_prefix("mlp", prefix))
# RMSNorm # RMSNorm
self.post_attention_layernorm = RMSNorm( self.post_attention_layernorm = RMSNorm(
...@@ -254,12 +266,15 @@ class Olmo2Model(nn.Module): ...@@ -254,12 +266,15 @@ class Olmo2Model(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.hidden_size config.vocab_size,
config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = make_layers( self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
...@@ -267,7 +282,9 @@ class Olmo2Model(nn.Module): ...@@ -267,7 +282,9 @@ class Olmo2Model(nn.Module):
layer_id=idx, layer_id=idx,
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix,
), ),
prefix=add_prefix("layers", prefix),
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -313,10 +330,13 @@ class Olmo2ForCausalLM(nn.Module): ...@@ -313,10 +330,13 @@ class Olmo2ForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = Olmo2Model(config, quant_config) self.model = Olmo2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
...@@ -326,6 +346,7 @@ class Olmo2ForCausalLM(nn.Module): ...@@ -326,6 +346,7 @@ class Olmo2ForCausalLM(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -41,7 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -41,7 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers, print_warning_once from sglang.srt.utils import add_prefix, make_layers, print_warning_once
class OlmoeMoE(nn.Module): class OlmoeMoE(nn.Module):
...@@ -69,7 +69,11 @@ class OlmoeMoE(nn.Module): ...@@ -69,7 +69,11 @@ class OlmoeMoE(nn.Module):
# Gate always runs at half / full precision for now. # Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
hidden_size, num_experts, bias=False, quant_config=None hidden_size,
num_experts,
bias=False,
quant_config=None,
prefix=add_prefix("gate", prefix),
) )
self.experts = FusedMoE( self.experts = FusedMoE(
...@@ -81,6 +85,7 @@ class OlmoeMoE(nn.Module): ...@@ -81,6 +85,7 @@ class OlmoeMoE(nn.Module):
renormalize=False, renormalize=False,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size, tp_size=tp_size,
prefix=add_prefix("experts", prefix),
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -107,6 +112,7 @@ class OlmoeAttention(nn.Module): ...@@ -107,6 +112,7 @@ class OlmoeAttention(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 4096, max_position_embeddings: int = 4096,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -138,6 +144,7 @@ class OlmoeAttention(nn.Module): ...@@ -138,6 +144,7 @@ class OlmoeAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("qkv_proj", prefix),
) )
self.q_norm = RMSNorm(hidden_size, eps=1e-5) self.q_norm = RMSNorm(hidden_size, eps=1e-5)
self.k_norm = RMSNorm(hidden_size, eps=1e-5) self.k_norm = RMSNorm(hidden_size, eps=1e-5)
...@@ -146,6 +153,7 @@ class OlmoeAttention(nn.Module): ...@@ -146,6 +153,7 @@ class OlmoeAttention(nn.Module):
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -162,6 +170,7 @@ class OlmoeAttention(nn.Module): ...@@ -162,6 +170,7 @@ class OlmoeAttention(nn.Module):
self.scaling, self.scaling,
layer_id=layer_id, layer_id=layer_id,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
prefix=add_prefix("attn", prefix),
) )
def forward( def forward(
...@@ -186,6 +195,7 @@ class OlmoeDecoderLayer(nn.Module): ...@@ -186,6 +195,7 @@ class OlmoeDecoderLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int = 0, layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -202,6 +212,7 @@ class OlmoeDecoderLayer(nn.Module): ...@@ -202,6 +212,7 @@ class OlmoeDecoderLayer(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
) )
self.mlp = OlmoeMoE( self.mlp = OlmoeMoE(
...@@ -210,6 +221,7 @@ class OlmoeDecoderLayer(nn.Module): ...@@ -210,6 +221,7 @@ class OlmoeDecoderLayer(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
) )
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
...@@ -246,6 +258,7 @@ class OlmoeModel(nn.Module): ...@@ -246,6 +258,7 @@ class OlmoeModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -254,6 +267,7 @@ class OlmoeModel(nn.Module): ...@@ -254,6 +267,7 @@ class OlmoeModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
prefix=add_prefix("embed_tokens", prefix),
) )
self.layers = make_layers( self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
...@@ -261,7 +275,9 @@ class OlmoeModel(nn.Module): ...@@ -261,7 +275,9 @@ class OlmoeModel(nn.Module):
config=config, config=config,
quant_config=quant_config, quant_config=quant_config,
layer_id=idx, layer_id=idx,
prefix=prefix,
), ),
prefix=add_prefix("layers", prefix),
) )
self.norm = RMSNorm(config.hidden_size, eps=1e-5) self.norm = RMSNorm(config.hidden_size, eps=1e-5)
...@@ -294,13 +310,19 @@ class OlmoeForCausalLM(nn.Module): ...@@ -294,13 +310,19 @@ class OlmoeForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OlmoeModel(config, quant_config) self.model = OlmoeModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment