Unverified Commit f89d18ff authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[6/N] pass whole config to inner model (#10205)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent f0f2e563
...@@ -28,7 +28,7 @@ from transformers import MixtralConfig ...@@ -28,7 +28,7 @@ from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -48,7 +48,8 @@ from vllm.sequence import IntermediateTensors ...@@ -48,7 +48,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
...@@ -248,15 +249,14 @@ class MixtralDecoderLayer(nn.Module): ...@@ -248,15 +249,14 @@ class MixtralDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class MixtralModel(nn.Module): class MixtralModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
...@@ -332,24 +332,16 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -332,24 +332,16 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
} }
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.model = MixtralModel(config, self.model = MixtralModel(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config,
lora_config=lora_config,
prefix="model")
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
...@@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors ...@@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class MixtralMLP(nn.Module): class MixtralMLP(nn.Module):
...@@ -293,14 +294,13 @@ class MixtralDecoderLayer(nn.Module): ...@@ -293,14 +294,13 @@ class MixtralDecoderLayer(nn.Module):
class MixtralModel(nn.Module): class MixtralModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -350,18 +350,14 @@ class MixtralModel(nn.Module): ...@@ -350,18 +350,14 @@ class MixtralModel(nn.Module):
class MixtralForCausalLM(nn.Module, SupportsPP): class MixtralForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = MixtralModel(config, cache_config, quant_config) self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
......
...@@ -33,7 +33,7 @@ from transformers.models.mllama.processing_mllama import ( ...@@ -33,7 +33,7 @@ from transformers.models.mllama.processing_mllama import (
import vllm.distributed.parallel_state as ps import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
InputContext, TokenInputs, token_inputs) InputContext, TokenInputs, token_inputs)
...@@ -56,6 +56,7 @@ from vllm.utils import is_list_of ...@@ -56,6 +56,7 @@ from vllm.utils import is_list_of
from .clip import CLIPMLP from .clip import CLIPMLP
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal
from .llama import LlamaDecoderLayer, LlamaMLP from .llama import LlamaDecoderLayer, LlamaMLP
from .utils import maybe_prefix
logger = init_logger(__name__) logger = init_logger(__name__)
MLLAMA_IMAGE_TOKEN_ID = 128256 MLLAMA_IMAGE_TOKEN_ID = 128256
...@@ -939,15 +940,13 @@ class MllamaTextModel(nn.Module): ...@@ -939,15 +940,13 @@ class MllamaTextModel(nn.Module):
config_class = config_mllama.MllamaTextConfig config_class = config_mllama.MllamaTextConfig
base_model_prefix = "model" base_model_prefix = "model"
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config.text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
...@@ -1029,18 +1028,14 @@ class MllamaForCausalLM(nn.Module): ...@@ -1029,18 +1028,14 @@ class MllamaForCausalLM(nn.Module):
"MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
] ]
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config.text_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.model = MllamaTextModel(config, self.model = MllamaTextModel(vllm_config=vllm_config,
cache_config,
quant_config,
prefix=f"{prefix}.model") prefix=f"{prefix}.model")
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
...@@ -1108,14 +1103,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1108,14 +1103,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
"up_proj": ("gate_up_proj", 1), "up_proj": ("gate_up_proj", 1),
} }
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size self.hidden_size = config.text_config.hidden_size
...@@ -1127,12 +1117,11 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1127,12 +1117,11 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.vision_model = MllamaVisionModel(config.vision_config, self.vision_model = MllamaVisionModel(config.vision_config,
quant_config, quant_config,
prefix="vision_model") prefix=maybe_prefix(
prefix, "vision_model"))
self.language_model = MllamaForCausalLM( self.language_model = MllamaForCausalLM(
config.text_config, vllm_config=vllm_config,
cache_config=cache_config, prefix=maybe_prefix(prefix, "language_model"),
quant_config=quant_config,
prefix="language_model",
) )
self.multi_modal_projector = ColumnParallelLinear( self.multi_modal_projector = ColumnParallelLinear(
config.vision_config.vision_output_dim, config.vision_config.vision_output_dim,
...@@ -1140,7 +1129,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1140,7 +1129,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
gather_output=True, gather_output=True,
prefix="multi_modal_projector", prefix=maybe_prefix(prefix, "multi_modal_projector"),
) )
self.logits_processor = LogitsProcessor(config.output_hidden_states, self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size) config.text_config.vocab_size)
......
...@@ -44,7 +44,8 @@ from vllm.transformers_utils.processor import get_processor ...@@ -44,7 +44,8 @@ from vllm.transformers_utils.processor import get_processor
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (get_vit_attn_backend, from .utils import (get_vit_attn_backend,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
# TODO: hard-coded for now. Consider making it configurable. # TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9] VIT_LAYERS = [-2, -9]
...@@ -716,14 +717,13 @@ class MolmoVisionBackbone(nn.Module): ...@@ -716,14 +717,13 @@ class MolmoVisionBackbone(nn.Module):
@support_torch_compile @support_torch_compile
class MolmoModel(nn.Module): class MolmoModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.embedding_size = config.embedding_size or config.vocab_size self.embedding_size = config.embedding_size or config.vocab_size
...@@ -1024,14 +1024,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -1024,14 +1024,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
...@@ -1040,7 +1035,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1040,7 +1035,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
vision_config = VisionBackboneConfig() vision_config = VisionBackboneConfig()
self.vision_backbone = MolmoVisionBackbone(config, vision_config, self.vision_backbone = MolmoVisionBackbone(config, vision_config,
quant_config) quant_config)
self.model = MolmoModel(config, cache_config, quant_config) self.model = MolmoModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if self.config.weight_tying: if self.config.weight_tying:
self.lm_head = self.model.transformer.wte self.lm_head = self.model.transformer.wte
......
...@@ -26,7 +26,8 @@ from vllm.transformers_utils.configs.mpt import MPTConfig ...@@ -26,7 +26,8 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
def _get_alibi_slopes( def _get_alibi_slopes(
...@@ -207,14 +208,13 @@ class MPTBlock(nn.Module): ...@@ -207,14 +208,13 @@ class MPTBlock(nn.Module):
@support_torch_compile @support_torch_compile
class MPTModel(nn.Module): class MPTModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
assert config.embedding_fraction == 1.0 assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm" assert config.norm_type == "low_precision_layernorm"
...@@ -267,20 +267,16 @@ class MPTModel(nn.Module): ...@@ -267,20 +267,16 @@ class MPTModel(nn.Module):
class MPTForCausalLM(nn.Module, SupportsPP): class MPTForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = MPTModel(config, cache_config, quant_config) self.transformer = MPTModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "transformer"))
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()
......
...@@ -27,7 +27,7 @@ from torch import nn ...@@ -27,7 +27,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -47,7 +47,8 @@ from vllm.transformers_utils.configs import NemotronConfig ...@@ -47,7 +47,8 @@ from vllm.transformers_utils.configs import NemotronConfig
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
# The architecture is pretty similar to Llama, with these changes: # The architecture is pretty similar to Llama, with these changes:
# - There is no gate_proj, just up_proj # - There is no gate_proj, just up_proj
...@@ -293,15 +294,14 @@ class NemotronDecoderLayer(nn.Module): ...@@ -293,15 +294,14 @@ class NemotronDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class NemotronModel(nn.Module): class NemotronModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: NemotronConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
...@@ -401,14 +401,9 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -401,14 +401,9 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"v_proj": ("qkv_proj", 2), "v_proj": ("qkv_proj", 2),
} }
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
assert isinstance(config, NemotronConfig) assert isinstance(config, NemotronConfig)
...@@ -416,11 +411,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -416,11 +411,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.model = NemotronModel(config, self.model = NemotronModel(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config,
lora_config=lora_config,
prefix="model")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
......
...@@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors ...@@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OlmoAttention(nn.Module): class OlmoAttention(nn.Module):
...@@ -224,12 +225,13 @@ class OlmoDecoderLayer(nn.Module): ...@@ -224,12 +225,13 @@ class OlmoDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class OlmoModel(nn.Module): class OlmoModel(nn.Module):
def __init__(self, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
...@@ -291,17 +293,13 @@ class OlmoForCausalLM(nn.Module, SupportsPP): ...@@ -291,17 +293,13 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
Extremely barebones HF model wrapper. Extremely barebones HF model wrapper.
""" """
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.model = OlmoModel(config, cache_config, quant_config) self.model = OlmoModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
......
...@@ -38,7 +38,8 @@ from vllm.utils import print_warning_once ...@@ -38,7 +38,8 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OlmoeMoE(nn.Module): class OlmoeMoE(nn.Module):
...@@ -243,14 +244,13 @@ class OlmoeDecoderLayer(nn.Module): ...@@ -243,14 +244,13 @@ class OlmoeDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class OlmoeModel(nn.Module): class OlmoeModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -309,18 +309,14 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): ...@@ -309,18 +309,14 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OlmoeModel(config, cache_config, quant_config) self.model = OlmoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
......
...@@ -293,14 +293,13 @@ class OPTDecoder(nn.Module): ...@@ -293,14 +293,13 @@ class OPTDecoder(nn.Module):
@support_torch_compile @support_torch_compile
class OPTModel(nn.Module): class OPTModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.decoder = OPTDecoder(config, self.decoder = OPTDecoder(config,
cache_config, cache_config,
quant_config, quant_config,
...@@ -342,21 +341,14 @@ class OPTForCausalLM(nn.Module, SupportsPP): ...@@ -342,21 +341,14 @@ class OPTForCausalLM(nn.Module, SupportsPP):
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2." ".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
] ]
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OPTModel(config, self.model = OPTModel(vllm_config=vllm_config,
cache_config,
quant_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.model.decoder.embed_tokens self.lm_head = self.model.decoder.embed_tokens
......
...@@ -29,7 +29,8 @@ from vllm.sequence import IntermediateTensors ...@@ -29,7 +29,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OrionMLP(nn.Module): class OrionMLP(nn.Module):
...@@ -208,14 +209,13 @@ class OrionDecoderLayer(nn.Module): ...@@ -208,14 +209,13 @@ class OrionDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class OrionModel(nn.Module): class OrionModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -268,18 +268,14 @@ class OrionModel(nn.Module): ...@@ -268,18 +268,14 @@ class OrionModel(nn.Module):
class OrionForCausalLM(nn.Module, SupportsPP): class OrionForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OrionModel(config, cache_config, quant_config) self.model = OrionModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
......
...@@ -20,7 +20,7 @@ from .interfaces import SupportsMultiModal, SupportsPP ...@@ -20,7 +20,7 @@ from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens) dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -131,11 +131,7 @@ class PaliGemmaMultiModalProjector(nn.Module): ...@@ -131,11 +131,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
...@@ -145,7 +141,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -145,7 +141,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.vision_tower = SiglipVisionModel(config.vision_config, self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config, quant_config,
prefix="vision_tower") prefix=maybe_prefix(
prefix, "vision_tower"))
self.multi_modal_projector = PaliGemmaMultiModalProjector( self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
projection_dim=config.vision_config.projection_dim) projection_dim=config.vision_config.projection_dim)
...@@ -155,7 +152,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -155,7 +152,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix="language_model") prefix=maybe_prefix(prefix, "language_model"))
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale self.language_model.logits_processor.scale *= logit_scale
......
...@@ -45,7 +45,8 @@ from vllm.sequence import IntermediateTensors ...@@ -45,7 +45,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class PersimmonMLP(nn.Module): class PersimmonMLP(nn.Module):
...@@ -212,12 +213,13 @@ class PersimmonDecoderLayer(nn.Module): ...@@ -212,12 +213,13 @@ class PersimmonDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class PersimmonModel(nn.Module): class PersimmonModel(nn.Module):
def __init__(self, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
...@@ -265,20 +267,13 @@ class PersimmonModel(nn.Module): ...@@ -265,20 +267,13 @@ class PersimmonModel(nn.Module):
class PersimmonForCausalLM(nn.Module, SupportsPP): class PersimmonForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.model = PersimmonModel(config, self.model = PersimmonModel(vllm_config=vllm_config,
cache_config=cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
bias=False) bias=False)
......
...@@ -60,7 +60,8 @@ from vllm.sequence import IntermediateTensors ...@@ -60,7 +60,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class PhiAttention(nn.Module): class PhiAttention(nn.Module):
...@@ -196,12 +197,13 @@ class PhiLayer(nn.Module): ...@@ -196,12 +197,13 @@ class PhiLayer(nn.Module):
@support_torch_compile @support_torch_compile
class PhiModel(nn.Module): class PhiModel(nn.Module):
def __init__(self, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
...@@ -277,14 +279,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -277,14 +279,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
...@@ -294,7 +291,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -294,7 +291,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.quant_config = quant_config self.quant_config = quant_config
self.model = PhiModel(config, cache_config, quant_config) self.model = PhiModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
......
...@@ -24,7 +24,8 @@ from vllm.sequence import IntermediateTensors ...@@ -24,7 +24,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
def load_column_parallel_weight(param: torch.nn.Parameter, def load_column_parallel_weight(param: torch.nn.Parameter,
...@@ -299,14 +300,13 @@ class Phi3SmallDecoderLayer(nn.Module): ...@@ -299,14 +300,13 @@ class Phi3SmallDecoderLayer(nn.Module):
class Phi3SmallModel(nn.Module): class Phi3SmallModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
...@@ -363,18 +363,14 @@ class Phi3SmallModel(nn.Module): ...@@ -363,18 +363,14 @@ class Phi3SmallModel(nn.Module):
class Phi3SmallForCausalLM(nn.Module, SupportsPP): class Phi3SmallForCausalLM(nn.Module, SupportsPP):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Phi3SmallModel(config, cache_config, quant_config) self.model = Phi3SmallModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.mup_width_multiplier = config.mup_width_multiplier self.mup_width_multiplier = config.mup_width_multiplier
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
......
...@@ -45,7 +45,7 @@ from vllm.utils import is_list_of ...@@ -45,7 +45,7 @@ from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -525,11 +525,7 @@ def input_processor_for_phi3v(ctx: InputContext, ...@@ -525,11 +525,7 @@ def input_processor_for_phi3v(ctx: InputContext,
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
...@@ -544,12 +540,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -544,12 +540,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
quant_config=quant_config, quant_config=quant_config,
prefix="model.embed_tokens", prefix=maybe_prefix(prefix, "model.embed_tokens"),
) )
# TODO: Optionally initializes this for supporting input embeddings. # TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding( self.vision_embed_tokens = Phi3HDImageEmbedding(
config, quant_config, prefix="model.vision_embed_tokens") config,
quant_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
# The prefix is empty intentionally because default prefix of # The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model" # LlamaForCausalLM is "model"
......
...@@ -28,7 +28,7 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -28,7 +28,7 @@ from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
...@@ -48,7 +48,8 @@ from vllm.sequence import IntermediateTensors ...@@ -48,7 +48,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class PhiMoEConfig(PretrainedConfig): class PhiMoEConfig(PretrainedConfig):
...@@ -432,15 +433,14 @@ class PhiMoEDecoderLayer(nn.Module): ...@@ -432,15 +433,14 @@ class PhiMoEDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class PhiMoEModel(nn.Module): class PhiMoEModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: PhiMoEConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size * lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0) (lora_config.max_loras or 1)) if lora_config else 0)
...@@ -529,23 +529,15 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -529,23 +529,15 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
} }
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.model = PhiMoEModel(config, self.model = PhiMoEModel(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
...@@ -38,7 +38,7 @@ from vllm.transformers_utils.processor import cached_get_processor ...@@ -38,7 +38,7 @@ from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import init_vllm_registered_model from .utils import init_vllm_registered_model, maybe_prefix
try: try:
from xformers import ops as xops from xformers import ops as xops
...@@ -152,11 +152,7 @@ def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -152,11 +152,7 @@ def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
...@@ -176,7 +172,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -176,7 +172,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix="language_model") prefix=maybe_prefix(prefix, "language_model"))
self.vision_encoder = VisionTransformer(self.vision_args) self.vision_encoder = VisionTransformer(self.vision_args)
self.vision_language_adapter = VisionLanguageAdapter( self.vision_language_adapter = VisionLanguageAdapter(
......
...@@ -50,7 +50,8 @@ from vllm.utils import is_list_of ...@@ -50,7 +50,8 @@ from vllm.utils import is_list_of
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, is_pp_missing_parameter, from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -552,14 +553,13 @@ class QWenBlock(nn.Module): ...@@ -552,14 +553,13 @@ class QWenBlock(nn.Module):
@support_torch_compile @support_torch_compile
class QWenModel(nn.Module): class QWenModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -865,20 +865,17 @@ def dummy_data_for_qwen( ...@@ -865,20 +865,17 @@ def dummy_data_for_qwen(
class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = QWenModel(config, cache_config, quant_config) self.transformer = QWenModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
......
...@@ -240,14 +240,13 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -240,14 +240,13 @@ class Qwen2DecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class Qwen2Model(nn.Module): class Qwen2Model(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -403,11 +402,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -403,11 +402,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj": ("gate_up_proj", 1), "up_proj": ("gate_up_proj", 1),
} }
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
...@@ -429,9 +424,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -429,9 +424,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2Model(config, self.model = Qwen2Model(vllm_config=vllm_config,
cache_config,
quant_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
if config.tie_word_embeddings: if config.tie_word_embeddings:
......
...@@ -264,14 +264,9 @@ def input_mapper_for_qwen2_audio( ...@@ -264,14 +264,9 @@ def input_mapper_for_qwen2_audio(
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
...@@ -283,8 +278,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -283,8 +278,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
self.quant_config = quant_config self.quant_config = quant_config
self.language_model = Qwen2Model(config.text_config, cache_config, self.language_model = Qwen2Model(
quant_config) vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=prefix)
self.unpadded_vocab_size = config.text_config.vocab_size self.unpadded_vocab_size = config.text_config.vocab_size
if config.text_config.tie_word_embeddings: if config.text_config.tie_word_embeddings:
self.lm_head = self.language_model.embed_tokens self.lm_head = self.language_model.embed_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment