Unverified Commit f89d18ff authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[6/N] pass whole config to inner model (#10205)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent f0f2e563
...@@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig ...@@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -189,15 +189,14 @@ class GPTBigCodeBlock(nn.Module): ...@@ -189,15 +189,14 @@ class GPTBigCodeBlock(nn.Module):
@support_torch_compile @support_torch_compile
class GPTBigCodeModel(nn.Module): class GPTBigCodeModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
assert not config.add_cross_attention assert not config.add_cross_attention
...@@ -265,7 +264,6 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -265,7 +264,6 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
...@@ -273,8 +271,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -273,8 +271,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config, self.transformer = GPTBigCodeModel(vllm_config=vllm_config,
lora_config) prefix=prefix)
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
else: else:
......
...@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors ...@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class GPTJAttention(nn.Module): class GPTJAttention(nn.Module):
...@@ -177,14 +178,13 @@ class GPTJBlock(nn.Module): ...@@ -177,14 +178,13 @@ class GPTJBlock(nn.Module):
@support_torch_compile @support_torch_compile
class GPTJModel(nn.Module): class GPTJModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.embed_dim = config.n_embd self.embed_dim = config.n_embd
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
...@@ -236,12 +236,13 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -236,12 +236,13 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
assert not config.tie_word_embeddings assert not config.tie_word_embeddings
self.transformer = GPTJModel(config, cache_config, quant_config) self.transformer = GPTJModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.n_embd, config.n_embd,
......
...@@ -41,7 +41,8 @@ from vllm.sequence import IntermediateTensors ...@@ -41,7 +41,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
...@@ -189,14 +190,13 @@ class GPTNeoXLayer(nn.Module): ...@@ -189,14 +190,13 @@ class GPTNeoXLayer(nn.Module):
@support_torch_compile @support_torch_compile
class GPTNeoXModel(nn.Module): class GPTNeoXModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.embed_in = VocabParallelEmbedding( self.embed_in = VocabParallelEmbedding(
...@@ -249,11 +249,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -249,11 +249,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config) self.gpt_neox = GPTNeoXModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "gpt_neox"))
self.embed_out = ParallelLMHead( self.embed_out = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
......
...@@ -28,7 +28,7 @@ from transformers import GraniteConfig ...@@ -28,7 +28,7 @@ from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -52,7 +52,8 @@ from vllm.platforms import current_platform ...@@ -52,7 +52,8 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import (PPMissingLayer, is_pp_missing_parameter, make_layers,
maybe_prefix)
class GraniteMLP(nn.Module): class GraniteMLP(nn.Module):
...@@ -257,15 +258,14 @@ class GraniteDecoderLayer(nn.Module): ...@@ -257,15 +258,14 @@ class GraniteDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class GraniteModel(nn.Module): class GraniteModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: GraniteConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
...@@ -370,25 +370,17 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -370,25 +370,17 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj": ("gate_up_proj", 1), "up_proj": ("gate_up_proj", 1),
} }
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.model = GraniteModel(config, self.model = GraniteModel(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config,
lora_config=lora_config,
prefix="model")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
......
...@@ -28,7 +28,7 @@ from transformers.models.granitemoe import GraniteMoeConfig ...@@ -28,7 +28,7 @@ from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -47,7 +47,7 @@ from vllm.sequence import IntermediateTensors ...@@ -47,7 +47,7 @@ from vllm.sequence import IntermediateTensors
from . import mixtral from . import mixtral
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import make_layers from .utils import make_layers, maybe_prefix
class GraniteMoeMoE(nn.Module): class GraniteMoeMoE(nn.Module):
...@@ -247,15 +247,14 @@ class GraniteMoeDecoderLayer(nn.Module): ...@@ -247,15 +247,14 @@ class GraniteMoeDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class GraniteMoeModel(nn.Module): class GraniteMoeModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: GraniteMoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
...@@ -333,25 +332,17 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -333,25 +332,17 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
} }
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.model = GraniteMoeModel(config, self.model = GraniteMoeModel(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config,
lora_config=lora_config,
prefix="model")
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
...@@ -22,17 +22,15 @@ import torch.utils.checkpoint ...@@ -22,17 +22,15 @@ import torch.utils.checkpoint
from PIL import Image from PIL import Image
from torch import nn from torch import nn
# Temporary solution for transformers below 4.46.0. # Temporary solution for transformers below 4.46.0.
from transformers import PretrainedConfig as Idefics3Config
from transformers import ProcessorMixin as Idefics3ImageProcessor from transformers import ProcessorMixin as Idefics3ImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -48,7 +46,8 @@ from .idefics2_vision_model import ( ...@@ -48,7 +46,8 @@ from .idefics2_vision_model import (
# yapf: enable # yapf: enable
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .llama import LlamaModel from .llama import LlamaModel
from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -417,13 +416,13 @@ class Idefics3Connector(nn.Module): ...@@ -417,13 +416,13 @@ class Idefics3Connector(nn.Module):
class Idefics3Model(nn.Module): class Idefics3Model(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: Idefics3Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.padding_idx = self.config.text_config.pad_token_id self.padding_idx = self.config.text_config.pad_token_id
self.vocab_size = self.config.text_config.vocab_size self.vocab_size = self.config.text_config.vocab_size
...@@ -613,22 +612,18 @@ class Idefics3Model(nn.Module): ...@@ -613,22 +612,18 @@ class Idefics3Model(nn.Module):
@INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3) @INPUT_REGISTRY.register_input_processor(input_processor_for_idefics3)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal): class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.model = Idefics3Model(config, cache_config, quant_config) self.model = Idefics3Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.image_token_id = self.config.image_token_id self.image_token_id = self.config.image_token_id
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
......
...@@ -250,14 +250,13 @@ class InternLMDecoderLayer(nn.Module): ...@@ -250,14 +250,13 @@ class InternLMDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class InternLM2Model(nn.Module): class InternLM2Model(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -317,20 +316,13 @@ class InternLM2Model(nn.Module): ...@@ -317,20 +316,13 @@ class InternLM2Model(nn.Module):
class InternLM2ForCausalLM(nn.Module, SupportsPP): class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = InternLM2Model(config, self.model = InternLM2Model(vllm_config=vllm_config,
cache_config,
quant_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.output = ParallelLMHead(config.vocab_size, self.output = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
......
...@@ -104,14 +104,13 @@ class InternLM2VEDecoderLayer(nn.Module): ...@@ -104,14 +104,13 @@ class InternLM2VEDecoderLayer(nn.Module):
class InternLM2VEModel(InternLM2Model): class InternLM2VEModel(InternLM2Model):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self, super().__init__(vllm_config=vllm_config, prefix=prefix)
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, config = vllm_config.model_config.hf_config
quant_config: Optional[QuantizationConfig] = None, cache_config = vllm_config.cache_config
prefix: str = "", quant_config = vllm_config.quant_config
) -> None:
super().__init__(config, cache_config, quant_config)
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: InternLM2VEDecoderLayer( lambda prefix: InternLM2VEDecoderLayer(
...@@ -159,12 +158,8 @@ class InternLM2VEModel(InternLM2Model): ...@@ -159,12 +158,8 @@ class InternLM2VEModel(InternLM2Model):
class InternLM2VEForCausalLM(InternLM2ForCausalLM): class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self, super().__init__(vllm_config=vllm_config, prefix=prefix)
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__(vllm_config, prefix=prefix)
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
......
...@@ -35,7 +35,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, ...@@ -35,7 +35,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches) get_clip_num_patches)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
IMG_START = '<img>' IMG_START = '<img>'
IMG_END = '</img>' IMG_END = '</img>'
...@@ -435,13 +435,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -435,13 +435,13 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
config, config,
quant_config=quant_config, quant_config=quant_config,
is_mono=self.is_mono, is_mono=self.is_mono,
prefix="vision_model", prefix=maybe_prefix(prefix, "vision_model"),
) )
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix="language_model") prefix=maybe_prefix(prefix, "language_model"))
self.mlp1 = self._init_mlp1(config) self.mlp1 = self._init_mlp1(config)
......
...@@ -44,7 +44,8 @@ from vllm.transformers_utils.configs import JAISConfig ...@@ -44,7 +44,8 @@ from vllm.transformers_utils.configs import JAISConfig
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class SwiGLUActivation(nn.Module): class SwiGLUActivation(nn.Module):
...@@ -215,14 +216,13 @@ class JAISBlock(nn.Module): ...@@ -215,14 +216,13 @@ class JAISBlock(nn.Module):
@support_torch_compile @support_torch_compile
class JAISModel(nn.Module): class JAISModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
assert not config.add_cross_attention assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx assert not config.scale_attn_by_inverse_layer_idx
...@@ -293,11 +293,12 @@ class JAISLMHeadModel(nn.Module, SupportsPP): ...@@ -293,11 +293,12 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config) self.transformer = JAISModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
else: else:
......
...@@ -7,7 +7,7 @@ from transformers import JambaConfig ...@@ -7,7 +7,7 @@ from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -29,6 +29,7 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, ...@@ -29,6 +29,7 @@ from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size) _get_graph_batch_size)
from .interfaces import HasInnerState, SupportsLoRA from .interfaces import HasInnerState, SupportsLoRA
from .utils import maybe_prefix
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -258,14 +259,14 @@ ALL_DECODER_LAYER_TYPES = { ...@@ -258,14 +259,14 @@ ALL_DECODER_LAYER_TYPES = {
class JambaModel(nn.Module): class JambaModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: JambaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size * lora_vocab = ((lora_config.lora_extra_vocab_size *
...@@ -348,14 +349,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): ...@@ -348,14 +349,9 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
} }
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
...@@ -364,10 +360,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): ...@@ -364,10 +360,8 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
super().__init__() super().__init__()
self.config = config self.config = config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.model = JambaModel(config, self.model = JambaModel(vllm_config=vllm_config,
cache_config=cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config=quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
...@@ -28,7 +28,7 @@ from transformers import LlamaConfig ...@@ -28,7 +28,7 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -271,15 +271,14 @@ class LlamaDecoderLayer(nn.Module): ...@@ -271,15 +271,14 @@ class LlamaDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
...@@ -492,24 +491,16 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -492,24 +491,16 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"norm": "model.norm" "norm": "model.norm"
} }
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.model = LlamaModel(config, self.model = LlamaModel(vllm_config=vllm_config,
cache_config,
quant_config,
lora_config=lora_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
...@@ -652,23 +643,12 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -652,23 +643,12 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
} }
embedding_padding_modules = [] embedding_padding_modules = []
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self.model = LlamaModel(config, self.model = LlamaModel(vllm_config=vllm_config,
cache_config,
quant_config,
lora_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults( self._pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
......
...@@ -32,7 +32,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, ...@@ -32,7 +32,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens, dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip) input_processor_for_siglip)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict): class LlavaImagePixelInputs(TypedDict):
...@@ -282,7 +282,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -282,7 +282,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
config, config,
quant_config, quant_config,
require_post_norm=False, require_post_norm=False,
prefix="vision_tower") prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
...@@ -291,7 +291,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -291,7 +291,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix="language_model") prefix=maybe_prefix(prefix, "language_model"))
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
......
...@@ -31,7 +31,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, ...@@ -31,7 +31,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size, dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn, from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
init_vllm_registered_model) init_vllm_registered_model, maybe_prefix)
class LlavaNextImagePixelInputs(TypedDict): class LlavaNextImagePixelInputs(TypedDict):
...@@ -296,7 +296,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -296,7 +296,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
config, config,
quant_config, quant_config,
require_post_norm=False, require_post_norm=False,
prefix="vision_tower") prefix=maybe_prefix(prefix, "vision_tower"))
self.image_newline = nn.Parameter( self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(
...@@ -307,7 +307,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -307,7 +307,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix="language_model") prefix=maybe_prefix(prefix, "language_model"))
# The same model class supports both language generation and embedding # The same model class supports both language generation and embedding
# because the architecture name is the same # because the architecture name is the same
......
...@@ -29,7 +29,7 @@ from .llava import init_vision_tower_for_llava ...@@ -29,7 +29,7 @@ from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip) dummy_seq_data_for_siglip)
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
# For profile run # For profile run
_MAX_FRAMES_PER_VIDEO = 32 _MAX_FRAMES_PER_VIDEO = 32
...@@ -267,7 +267,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -267,7 +267,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
config, config,
quant_config, quant_config,
require_post_norm=False, require_post_norm=False,
prefix="vision_tower") prefix=maybe_prefix(prefix, "vision_tower"))
self.vision_resampler = LlavaNextVideoPooler(config) self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector( self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
...@@ -276,7 +276,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -276,7 +276,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix="language_model") prefix=maybe_prefix(prefix, "language_model"))
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors) self.language_model.model.make_empty_intermediate_tensors)
......
...@@ -35,7 +35,7 @@ from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, ...@@ -35,7 +35,7 @@ from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size, dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
# Result in the max possible feature size (2x2 grid of 336x336px tiles) # Result in the max possible feature size (2x2 grid of 336x336px tiles)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
...@@ -418,12 +418,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -418,12 +418,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
config, config,
quant_config, quant_config,
require_post_norm=False, require_post_norm=False,
prefix="vision_tower") prefix=maybe_prefix(prefix, "vision_tower"))
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix="language_model") prefix=maybe_prefix(prefix, "language_model"))
self.image_newline = nn.Parameter( self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))
......
...@@ -6,7 +6,7 @@ from torch import nn ...@@ -6,7 +6,7 @@ from torch import nn
from transformers import MambaConfig from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -26,6 +26,8 @@ from vllm.sequence import IntermediateTensors ...@@ -26,6 +26,8 @@ from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size) _get_graph_batch_size)
from .utils import maybe_prefix
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
...@@ -73,14 +75,14 @@ class MambaDecoderLayer(nn.Module): ...@@ -73,14 +75,14 @@ class MambaDecoderLayer(nn.Module):
class MambaModel(nn.Module): class MambaModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size * lora_vocab = ((lora_config.lora_extra_vocab_size *
...@@ -130,14 +132,9 @@ class MambaModel(nn.Module): ...@@ -130,14 +132,9 @@ class MambaModel(nn.Module):
class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
...@@ -146,10 +143,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -146,10 +143,8 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
super().__init__() super().__init__()
self.config = config self.config = config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.backbone = MambaModel(config, self.backbone = MambaModel(vllm_config=vllm_config,
cache_config=cache_config, prefix=maybe_prefix(prefix, "backbone"))
quant_config=quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
...@@ -29,7 +29,7 @@ from transformers import PretrainedConfig ...@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -53,7 +53,8 @@ from vllm.sequence import IntermediateTensors ...@@ -53,7 +53,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class MiniCPMMoE(nn.Module): class MiniCPMMoE(nn.Module):
...@@ -351,15 +352,14 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -351,15 +352,14 @@ class MiniCPMDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class MiniCPMModel(nn.Module): class MiniCPMModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.cache_config = cache_config self.cache_config = cache_config
self.quant_config = quant_config self.quant_config = quant_config
...@@ -461,24 +461,22 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -461,24 +461,22 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
} }
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.prefix = prefix
self.vllm_config = vllm_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.cache_config = cache_config self.cache_config = cache_config
self.quant_config = quant_config self.quant_config = quant_config
self.num_experts = getattr(self.config, "num_experts", 0) self.num_experts = getattr(self.config, "num_experts", 0)
self._init_model() self._init_model(vllm_config=vllm_config, prefix=prefix)
unpadded_vocab_size = config.vocab_size unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size unpadded_vocab_size += lora_config.lora_extra_vocab_size
...@@ -502,11 +500,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -502,11 +500,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def _init_model(self): def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = MiniCPMModel(config=self.config, self.model = MiniCPMModel(vllm_config=vllm_config,
cache_config=self.cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config=self.quant_config,
lora_config=self.lora_config)
def forward( def forward(
self, self,
......
...@@ -28,7 +28,7 @@ from torch import nn ...@@ -28,7 +28,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -40,7 +40,7 @@ from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, ...@@ -40,7 +40,7 @@ from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM, MiniCPMForCausalLM,
MiniCPMModel) MiniCPMModel)
from .utils import make_layers from .utils import make_layers, maybe_prefix
class MiniCPM3Attention(nn.Module): class MiniCPM3Attention(nn.Module):
...@@ -238,8 +238,6 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM): ...@@ -238,8 +238,6 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
# `embedding_modules` and `embedding_padding_modules` # `embedding_modules` and `embedding_padding_modules`
# are inherited from MiniCPMForCausalLM # are inherited from MiniCPMForCausalLM
def _init_model(self): def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = MiniCPM3Model(config=self.config, self.model = MiniCPM3Model(vllm_config=vllm_config,
cache_config=self.cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config=self.quant_config,
lora_config=self.lora_config)
...@@ -34,7 +34,7 @@ from transformers import PretrainedConfig ...@@ -34,7 +34,7 @@ from transformers import PretrainedConfig
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -59,7 +59,7 @@ from vllm.sequence import IntermediateTensors, SequenceData ...@@ -59,7 +59,7 @@ from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import is_pp_missing_parameter from .utils import is_pp_missing_parameter, maybe_prefix
_KEYS_TO_MODIFY_MAPPING = { _KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head", "llm.lm_head": "lm_head",
...@@ -390,7 +390,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -390,7 +390,6 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
): ):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
super().__init__() super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but # All MiniCPM-V models disable `tie_word_embeddings` but
...@@ -401,11 +400,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -401,11 +400,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.version = get_version_by_config(self.config) self.version = get_version_by_config(self.config)
self.llm = self.init_llm(config, self.llm = self.init_llm(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "llm"))
quant_config, self.vpm = self.init_vision_module(config,
prefix="llm") quant_config,
self.vpm = self.init_vision_module(config, quant_config, prefix="vpm") prefix=maybe_prefix(prefix, "vpm"))
param_dtype = torch.get_default_dtype() param_dtype = torch.get_default_dtype()
self.vpm.to(dtype=param_dtype) self.vpm.to(dtype=param_dtype)
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
...@@ -414,13 +413,15 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -414,13 +413,15 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
self.resampler = self.init_resampler(self.embed_dim, self.resampler = self.init_resampler(self.embed_dim,
self.vision_dim, self.vision_dim,
quant_config=quant_config, quant_config=quant_config,
prefix="resampler") prefix=maybe_prefix(
prefix, "resampler"))
self.resampler.to(device="cuda", dtype=param_dtype) self.resampler.to(device="cuda", dtype=param_dtype)
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm # TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix="llm.lm_head") prefix=maybe_prefix(
prefix, "llm.lm_head"))
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()
...@@ -661,9 +662,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -661,9 +662,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def init_llm( def init_llm(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> nn.Module: ) -> nn.Module:
raise NotImplementedError raise NotImplementedError
...@@ -711,16 +710,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel): ...@@ -711,16 +710,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
def init_llm( def init_llm(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> nn.Module: ) -> nn.Module:
return LLMWrapper(MiniCPMModel(vllm_config=vllm_config, prefix=prefix),
return LLMWrapper(MiniCPMModel(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
name="model") name="model")
def init_vision_module( def init_vision_module(
...@@ -875,15 +868,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -875,15 +868,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
def init_llm( def init_llm(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> nn.Module: ) -> nn.Module:
return LLMWrapper(LlamaModel(config, return LLMWrapper(LlamaModel(vllm_config=vllm_config, prefix=prefix),
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
name="model") name="model")
def init_vision_module( def init_vision_module(
...@@ -1022,16 +1010,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1022,16 +1010,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
def init_llm( def init_llm(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> nn.Module: ) -> nn.Module:
return LLMWrapper(Qwen2Model(vllm_config=vllm_config, prefix=prefix),
return LLMWrapper(Qwen2Model(config,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
name="model") name="model")
def init_vision_module( def init_vision_module(
...@@ -1151,4 +1133,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1151,4 +1133,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
if instance_class is None: if instance_class is None:
raise ValueError( raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(vllm_config, prefix=prefix) return instance_class(vllm_config=vllm_config, prefix=prefix)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment