Unverified Commit 85e1a6f3 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

Update model_loader deps and qqq quantization deps (#2220) (#2318)


Co-authored-by: default avatarHandH1998 <1335248067@qq.com>
parent 33deca81
...@@ -27,7 +27,6 @@ from vllm.distributed import ( ...@@ -27,7 +27,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import fused_moe from sglang.srt.layers.fused_moe_triton import fused_moe
...@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class DeepseekMLP(nn.Module): class DeepseekMLP(nn.Module):
...@@ -184,7 +184,6 @@ class DeepseekAttention(nn.Module): ...@@ -184,7 +184,6 @@ class DeepseekAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -261,7 +260,6 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -261,7 +260,6 @@ class DeepseekDecoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int, layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -277,7 +275,6 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -277,7 +275,6 @@ class DeepseekDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
) )
if ( if (
...@@ -330,7 +327,6 @@ class DeepseekModel(nn.Module): ...@@ -330,7 +327,6 @@ class DeepseekModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -343,9 +339,7 @@ class DeepseekModel(nn.Module): ...@@ -343,9 +339,7 @@ class DeepseekModel(nn.Module):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
DeepseekDecoderLayer( DeepseekDecoderLayer(config, layer_id, quant_config=quant_config)
config, layer_id, cache_config, quant_config=quant_config
)
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
...@@ -373,13 +367,12 @@ class DeepseekForCausalLM(nn.Module): ...@@ -373,13 +367,12 @@ class DeepseekForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekModel(config, cache_config, quant_config) self.model = DeepseekModel(config, quant_config)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
) )
......
...@@ -28,7 +28,6 @@ from vllm.distributed import ( ...@@ -28,7 +28,6 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.fused_moe_triton import FusedMoE
...@@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -48,6 +47,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available(): if is_flashinfer_available():
...@@ -189,7 +189,6 @@ class DeepseekV2Attention(nn.Module): ...@@ -189,7 +189,6 @@ class DeepseekV2Attention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id=None, layer_id=None,
) -> None: ) -> None:
...@@ -337,7 +336,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -337,7 +336,6 @@ class DeepseekV2AttentionMLA(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id=None, layer_id=None,
use_dp=False, use_dp=False,
...@@ -568,7 +566,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -568,7 +566,6 @@ class DeepseekV2DecoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int, layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -599,7 +596,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -599,7 +596,6 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
use_dp=self.enable_dp_attention, use_dp=self.enable_dp_attention,
...@@ -619,7 +615,6 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -619,7 +615,6 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
) )
...@@ -685,7 +680,6 @@ class DeepseekV2Model(nn.Module): ...@@ -685,7 +680,6 @@ class DeepseekV2Model(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -702,7 +696,6 @@ class DeepseekV2Model(nn.Module): ...@@ -702,7 +696,6 @@ class DeepseekV2Model(nn.Module):
DeepseekV2DecoderLayer( DeepseekV2DecoderLayer(
config, config,
layer_id, layer_id,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
...@@ -733,13 +726,12 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -733,13 +726,12 @@ class DeepseekV2ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekV2Model(config, cache_config, quant_config) self.model = DeepseekV2Model(config, quant_config)
if global_server_args_dict["enable_dp_attention"]: if global_server_args_dict["enable_dp_attention"]:
self.lm_head = ReplicatedLinear( self.lm_head = ReplicatedLinear(
config.hidden_size, config.hidden_size,
......
...@@ -22,7 +22,6 @@ import torch ...@@ -22,7 +22,6 @@ import torch
from torch import nn from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class ExaoneGatedMLP(nn.Module): class ExaoneGatedMLP(nn.Module):
...@@ -293,7 +293,6 @@ class ExaoneForCausalLM(nn.Module): ...@@ -293,7 +293,6 @@ class ExaoneForCausalLM(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -21,10 +21,8 @@ from typing import Iterable, Optional, Tuple ...@@ -21,10 +21,8 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -38,6 +36,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -38,6 +36,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):
...@@ -278,10 +277,7 @@ class GemmaForCausalLM(nn.Module): ...@@ -278,10 +277,7 @@ class GemmaForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
cache_config=None,
) -> None: ) -> None:
del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
......
...@@ -20,12 +20,8 @@ from typing import Iterable, Optional, Set, Tuple, Union ...@@ -20,12 +20,8 @@ from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -38,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -38,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers from sglang.srt.utils import make_layers
...@@ -106,7 +103,6 @@ class Gemma2Attention(nn.Module): ...@@ -106,7 +103,6 @@ class Gemma2Attention(nn.Module):
head_dim: int, head_dim: int,
max_position_embeddings: int, max_position_embeddings: int,
rope_theta: float, rope_theta: float,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -191,7 +187,6 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -191,7 +187,6 @@ class Gemma2DecoderLayer(nn.Module):
self, self,
layer_id: int, layer_id: int,
config: PretrainedConfig, config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -205,7 +200,6 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -205,7 +200,6 @@ class Gemma2DecoderLayer(nn.Module):
head_dim=config.head_dim, head_dim=config.head_dim,
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta, rope_theta=config.rope_theta,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
) )
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -258,7 +252,6 @@ class Gemma2Model(nn.Module): ...@@ -258,7 +252,6 @@ class Gemma2Model(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -273,7 +266,6 @@ class Gemma2Model(nn.Module): ...@@ -273,7 +266,6 @@ class Gemma2Model(nn.Module):
lambda idx, prefix: Gemma2DecoderLayer( lambda idx, prefix: Gemma2DecoderLayer(
layer_id=idx, layer_id=idx,
config=config, config=config,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
), ),
prefix="", prefix="",
...@@ -342,15 +334,12 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -342,15 +334,12 @@ class Gemma2ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config) self.model = Gemma2Model(config, quant_config)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
......
...@@ -29,7 +29,6 @@ class Gemma2ForSequenceClassification(nn.Module): ...@@ -29,7 +29,6 @@ class Gemma2ForSequenceClassification(nn.Module):
self, self,
config: Gemma2Config, config: Gemma2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -22,11 +22,9 @@ from typing import Iterable, List, Optional, Tuple ...@@ -22,11 +22,9 @@ from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import GPT2Config from transformers import GPT2Config
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
# from sglang.srt.layers.activation import get_act_fn # from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -39,6 +37,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -39,6 +37,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class GPT2Attention(nn.Module): class GPT2Attention(nn.Module):
...@@ -47,7 +46,6 @@ class GPT2Attention(nn.Module): ...@@ -47,7 +46,6 @@ class GPT2Attention(nn.Module):
self, self,
layer_id: int, layer_id: int,
config: GPT2Config, config: GPT2Config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -140,7 +138,6 @@ class GPT2Block(nn.Module): ...@@ -140,7 +138,6 @@ class GPT2Block(nn.Module):
self, self,
layer_id: int, layer_id: int,
config: GPT2Config, config: GPT2Config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -150,7 +147,7 @@ class GPT2Block(nn.Module): ...@@ -150,7 +147,7 @@ class GPT2Block(nn.Module):
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention( self.attn = GPT2Attention(
layer_id, config, cache_config, quant_config, prefix=f"{prefix}.attn" layer_id, config, quant_config, prefix=f"{prefix}.attn"
) )
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp") self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
...@@ -182,7 +179,6 @@ class GPT2Model(nn.Module): ...@@ -182,7 +179,6 @@ class GPT2Model(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
...@@ -196,7 +192,7 @@ class GPT2Model(nn.Module): ...@@ -196,7 +192,7 @@ class GPT2Model(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
GPT2Block(i, config, cache_config, quant_config) GPT2Block(i, config, quant_config)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -226,15 +222,12 @@ class GPT2LMHeadModel(nn.Module): ...@@ -226,15 +222,12 @@ class GPT2LMHeadModel(nn.Module):
def __init__( def __init__(
self, self,
config: GPT2Config, config: GPT2Config,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPT2Model( self.transformer = GPT2Model(config, quant_config, prefix="transformer")
config, cache_config, quant_config, prefix="transformer"
)
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
...@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple ...@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import GPTBigCodeConfig from transformers import GPTBigCodeConfig
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
...@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
...@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module): ...@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module):
self, self,
layer_id: int, layer_id: int,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module): ...@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module):
self, self,
layer_id: int, layer_id: int,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module): ...@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module):
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config) self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config) self.mlp = GPTBigMLP(inner_dim, config, quant_config)
...@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module): ...@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module):
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
assert not config.add_cross_attention assert not config.add_cross_attention
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
lora_vocab = ( lora_vocab = 0
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab self.vocab_size = config.vocab_size + lora_vocab
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
...@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList( self.h = nn.ModuleList(
[ [
GPTBigCodeBlock(i, config, cache_config, quant_config) GPTBigCodeBlock(i, config, quant_config)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPTBigCodeModel( self.transformer = GPTBigCodeModel(config, quant_config)
config, cache_config, quant_config, lora_config
)
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
......
...@@ -24,7 +24,6 @@ from torch import nn ...@@ -24,7 +24,6 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -43,6 +42,8 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -43,6 +42,8 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.loader import DefaultModelLoader
from sglang.srt.model_loader.weight_utils import default_weight_loader
class Grok1MoE(nn.Module): class Grok1MoE(nn.Module):
...@@ -285,7 +286,6 @@ class Grok1ForCausalLM(nn.Module): ...@@ -285,7 +286,6 @@ class Grok1ForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -21,7 +21,6 @@ from torch import nn ...@@ -21,7 +21,6 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class InternLM2MLP(nn.Module): class InternLM2MLP(nn.Module):
...@@ -251,7 +251,6 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -251,7 +251,6 @@ class InternLM2ForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -29,7 +29,6 @@ class InternLM2ForRewardModel(nn.Module): ...@@ -29,7 +29,6 @@ class InternLM2ForRewardModel(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -24,7 +24,6 @@ from torch import nn ...@@ -24,7 +24,6 @@ from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -44,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -44,6 +43,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import make_layers from sglang.srt.utils import make_layers
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -300,7 +300,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -300,7 +300,6 @@ class LlamaForCausalLM(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -17,11 +17,11 @@ from typing import Iterable, Optional, Tuple ...@@ -17,11 +17,11 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
...@@ -30,7 +30,6 @@ class LlamaForClassification(nn.Module): ...@@ -30,7 +30,6 @@ class LlamaForClassification(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -3,10 +3,10 @@ from typing import Iterable, Tuple ...@@ -3,10 +3,10 @@ from typing import Iterable, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.model_executor.model_runner import ForwardBatch from sglang.srt.model_executor.model_runner import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaModel from sglang.srt.models.llama import LlamaModel
...@@ -15,7 +15,6 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -15,7 +15,6 @@ class LlamaEmbeddingModel(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config=None, quant_config=None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
......
...@@ -21,6 +21,7 @@ from transformers import LlamaConfig ...@@ -21,6 +21,7 @@ from transformers import LlamaConfig
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
...@@ -29,7 +30,6 @@ class LlamaForSequenceClassification(nn.Module): ...@@ -29,7 +30,6 @@ class LlamaForSequenceClassification(nn.Module):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -84,9 +84,8 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific ...@@ -84,9 +84,8 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__(config, quant_config, cache_config) super().__init__(config, quant_config)
self.weights = self.Weights(config.hidden_size, self.num_labels) self.weights = self.Weights(config.hidden_size, self.num_labels)
@torch.no_grad() @torch.no_grad()
......
...@@ -29,7 +29,6 @@ from transformers import ( ...@@ -29,7 +29,6 @@ from transformers import (
SiglipVisionModel, SiglipVisionModel,
) )
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
...@@ -39,6 +38,7 @@ from sglang.srt.mm_utils import ( ...@@ -39,6 +38,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
...@@ -451,7 +451,6 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM): ...@@ -451,7 +451,6 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
self, self,
config: LlavaConfig, config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -473,7 +472,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM): ...@@ -473,7 +472,6 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
self, self,
config: LlavaConfig, config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -506,7 +504,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM): ...@@ -506,7 +504,6 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
self, self,
config: LlavaConfig, config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
......
...@@ -20,11 +20,11 @@ import torch ...@@ -20,11 +20,11 @@ import torch
from torch import nn from torch import nn
from transformers import CLIPVisionModel, LlavaConfig from transformers import CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
...@@ -33,7 +33,6 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -33,7 +33,6 @@ class LlavaVidForCausalLM(nn.Module):
self, self,
config: LlavaConfig, config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -20,7 +20,6 @@ import torch ...@@ -20,7 +20,6 @@ import torch
from torch import nn from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -37,6 +36,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -37,6 +36,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class MiniCPMMLP(nn.Module): class MiniCPMMLP(nn.Module):
...@@ -275,7 +275,6 @@ class MiniCPMForCausalLM(nn.Module): ...@@ -275,7 +275,6 @@ class MiniCPMForCausalLM(nn.Module):
self, self,
config, config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import ( ...@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available(): if is_flashinfer_available():
...@@ -105,7 +105,6 @@ class MiniCPM3Attention(nn.Module): ...@@ -105,7 +105,6 @@ class MiniCPM3Attention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id=None, layer_id=None,
) -> None: ) -> None:
...@@ -249,7 +248,6 @@ class MiniCPM3AttentionMLA(nn.Module): ...@@ -249,7 +248,6 @@ class MiniCPM3AttentionMLA(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id=None, layer_id=None,
) -> None: ) -> None:
...@@ -406,7 +404,6 @@ class MiniCPM3DecoderLayer(nn.Module): ...@@ -406,7 +404,6 @@ class MiniCPM3DecoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
layer_id: int, layer_id: int,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -430,7 +427,6 @@ class MiniCPM3DecoderLayer(nn.Module): ...@@ -430,7 +427,6 @@ class MiniCPM3DecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
) )
...@@ -449,7 +445,6 @@ class MiniCPM3DecoderLayer(nn.Module): ...@@ -449,7 +445,6 @@ class MiniCPM3DecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
) )
...@@ -498,7 +493,6 @@ class MiniCPM3Model(nn.Module): ...@@ -498,7 +493,6 @@ class MiniCPM3Model(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -512,9 +506,7 @@ class MiniCPM3Model(nn.Module): ...@@ -512,9 +506,7 @@ class MiniCPM3Model(nn.Module):
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
MiniCPM3DecoderLayer( MiniCPM3DecoderLayer(config, i, quant_config=quant_config)
config, i, cache_config=cache_config, quant_config=quant_config
)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
] ]
) )
...@@ -549,7 +541,6 @@ class MiniCPM3ForCausalLM(nn.Module): ...@@ -549,7 +541,6 @@ class MiniCPM3ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -557,9 +548,7 @@ class MiniCPM3ForCausalLM(nn.Module): ...@@ -557,9 +548,7 @@ class MiniCPM3ForCausalLM(nn.Module):
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config self.quant_config = quant_config
self.model = MiniCPM3Model( self.model = MiniCPM3Model(config, quant_config=quant_config)
config, cache_config=cache_config, quant_config=quant_config
)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if not self.config.tie_word_embeddings: if not self.config.tie_word_embeddings:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
......
...@@ -23,7 +23,6 @@ from torch import nn ...@@ -23,7 +23,6 @@ from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
...@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
...@@ -291,7 +291,6 @@ class MixtralForCausalLM(nn.Module): ...@@ -291,7 +291,6 @@ class MixtralForCausalLM(nn.Module):
self, self,
config: MixtralConfig, config: MixtralConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment