Unverified Commit f89d18ff authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[6/N] pass whole config to inner model (#10205)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent f0f2e563
......@@ -28,7 +28,7 @@ from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -48,7 +48,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class MixtralMoE(nn.Module):
......@@ -248,15 +249,14 @@ class MixtralDecoderLayer(nn.Module):
@support_torch_compile
class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
......@@ -332,24 +332,16 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = MixtralModel(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
......@@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class MixtralMLP(nn.Module):
......@@ -293,14 +294,13 @@ class MixtralDecoderLayer(nn.Module):
class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
......@@ -350,18 +350,14 @@ class MixtralModel(nn.Module):
class MixtralForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = MixtralModel(config, cache_config, quant_config)
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
......
......@@ -33,7 +33,7 @@ from transformers.models.mllama.processing_mllama import (
import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
InputContext, TokenInputs, token_inputs)
......@@ -56,6 +56,7 @@ from vllm.utils import is_list_of
from .clip import CLIPMLP
from .interfaces import SupportsMultiModal
from .llama import LlamaDecoderLayer, LlamaMLP
from .utils import maybe_prefix
logger = init_logger(__name__)
MLLAMA_IMAGE_TOKEN_ID = 128256
......@@ -939,15 +940,13 @@ class MllamaTextModel(nn.Module):
config_class = config_mllama.MllamaTextConfig
base_model_prefix = "model"
def __init__(
self,
config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.text_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
......@@ -1029,18 +1028,14 @@ class MllamaForCausalLM(nn.Module):
"MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
]
def __init__(
self,
config: config_mllama.MllamaTextConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.text_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.model = MllamaTextModel(config,
cache_config,
quant_config,
self.model = MllamaTextModel(vllm_config=vllm_config,
prefix=f"{prefix}.model")
self.lm_head = ParallelLMHead(
config.vocab_size,
......@@ -1108,14 +1103,9 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
......@@ -1127,12 +1117,11 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.vision_model = MllamaVisionModel(config.vision_config,
quant_config,
prefix="vision_model")
prefix=maybe_prefix(
prefix, "vision_model"))
self.language_model = MllamaForCausalLM(
config.text_config,
cache_config=cache_config,
quant_config=quant_config,
prefix="language_model",
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
self.multi_modal_projector = ColumnParallelLinear(
config.vision_config.vision_output_dim,
......@@ -1140,7 +1129,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
bias=True,
quant_config=quant_config,
gather_output=True,
prefix="multi_modal_projector",
prefix=maybe_prefix(prefix, "multi_modal_projector"),
)
self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size)
......
......@@ -44,7 +44,8 @@ from vllm.transformers_utils.processor import get_processor
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (get_vit_attn_backend,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS = [-2, -9]
......@@ -716,14 +717,13 @@ class MolmoVisionBackbone(nn.Module):
@support_torch_compile
class MolmoModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embedding_size = config.embedding_size or config.vocab_size
......@@ -1024,14 +1024,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
......@@ -1040,7 +1035,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
vision_config = VisionBackboneConfig()
self.vision_backbone = MolmoVisionBackbone(config, vision_config,
quant_config)
self.model = MolmoModel(config, cache_config, quant_config)
self.model = MolmoModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if self.config.weight_tying:
self.lm_head = self.model.transformer.wte
......
......@@ -26,7 +26,8 @@ from vllm.transformers_utils.configs.mpt import MPTConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
def _get_alibi_slopes(
......@@ -207,14 +208,13 @@ class MPTBlock(nn.Module):
@support_torch_compile
class MPTModel(nn.Module):
def __init__(
self,
config: MPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm"
......@@ -267,20 +267,16 @@ class MPTModel(nn.Module):
class MPTForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
assert config.tie_word_embeddings
self.quant_config = quant_config
self.transformer = MPTModel(config, cache_config, quant_config)
self.transformer = MPTModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "transformer"))
self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
......
......@@ -27,7 +27,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -47,7 +47,8 @@ from vllm.transformers_utils.configs import NemotronConfig
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
# The architecture is pretty similar to Llama, with these changes:
# - There is no gate_proj, just up_proj
......@@ -293,15 +294,14 @@ class NemotronDecoderLayer(nn.Module):
@support_torch_compile
class NemotronModel(nn.Module):
def __init__(
self,
config: NemotronConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
......@@ -401,14 +401,9 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"v_proj": ("qkv_proj", 2),
}
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
assert isinstance(config, NemotronConfig)
......@@ -416,11 +411,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config
self.lora_config = lora_config
self.model = NemotronModel(config,
cache_config,
quant_config,
lora_config=lora_config,
prefix="model")
self.model = NemotronModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
......
......@@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OlmoAttention(nn.Module):
......@@ -224,12 +225,13 @@ class OlmoDecoderLayer(nn.Module):
@support_torch_compile
class OlmoModel(nn.Module):
def __init__(self,
config: OlmoConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
......@@ -291,17 +293,13 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
Extremely barebones HF model wrapper.
"""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.model = OlmoModel(config, cache_config, quant_config)
self.model = OlmoModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
......
......@@ -38,7 +38,8 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OlmoeMoE(nn.Module):
......@@ -243,14 +244,13 @@ class OlmoeDecoderLayer(nn.Module):
@support_torch_compile
class OlmoeModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
......@@ -309,18 +309,14 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = OlmoeModel(config, cache_config, quant_config)
self.model = OlmoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
......
......@@ -293,14 +293,13 @@ class OPTDecoder(nn.Module):
@support_torch_compile
class OPTModel(nn.Module):
def __init__(
self,
config: OPTConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.decoder = OPTDecoder(config,
cache_config,
quant_config,
......@@ -342,21 +341,14 @@ class OPTForCausalLM(nn.Module, SupportsPP):
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
]
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__()
self.config = config
self.quant_config = quant_config
self.model = OPTModel(config,
cache_config,
quant_config,
self.model = OPTModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if self.config.tie_word_embeddings:
self.lm_head = self.model.decoder.embed_tokens
......
......@@ -29,7 +29,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class OrionMLP(nn.Module):
......@@ -208,14 +209,13 @@ class OrionDecoderLayer(nn.Module):
@support_torch_compile
class OrionModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
......@@ -268,18 +268,14 @@ class OrionModel(nn.Module):
class OrionForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = OrionModel(config, cache_config, quant_config)
self.model = OrionModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
......
......@@ -20,7 +20,7 @@ from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings)
maybe_prefix, merge_multimodal_embeddings)
logger = init_logger(__name__)
......@@ -131,11 +131,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
......@@ -145,7 +141,8 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config,
prefix="vision_tower")
prefix=maybe_prefix(
prefix, "vision_tower"))
self.multi_modal_projector = PaliGemmaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
projection_dim=config.vision_config.projection_dim)
......@@ -155,7 +152,7 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix="language_model")
prefix=maybe_prefix(prefix, "language_model"))
logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale
......
......@@ -45,7 +45,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class PersimmonMLP(nn.Module):
......@@ -212,12 +213,13 @@ class PersimmonDecoderLayer(nn.Module):
@support_torch_compile
class PersimmonModel(nn.Module):
def __init__(self,
config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
......@@ -265,20 +267,13 @@ class PersimmonModel(nn.Module):
class PersimmonForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
self.model = PersimmonModel(config,
cache_config=cache_config,
quant_config=quant_config)
self.model = PersimmonModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=False)
......
......@@ -60,7 +60,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class PhiAttention(nn.Module):
......@@ -196,12 +197,13 @@ class PhiLayer(nn.Module):
@support_torch_compile
class PhiModel(nn.Module):
def __init__(self,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
......@@ -277,14 +279,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {}
embedding_padding_modules = []
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
......@@ -294,7 +291,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.quant_config = quant_config
self.model = PhiModel(config, cache_config, quant_config)
self.model = PhiModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
......
......@@ -24,7 +24,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
def load_column_parallel_weight(param: torch.nn.Parameter,
......@@ -299,14 +300,13 @@ class Phi3SmallDecoderLayer(nn.Module):
class Phi3SmallModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
......@@ -363,18 +363,14 @@ class Phi3SmallModel(nn.Module):
class Phi3SmallForCausalLM(nn.Module, SupportsPP):
_tied_weights_keys = ["lm_head.weight"]
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = Phi3SmallModel(config, cache_config, quant_config)
self.model = Phi3SmallModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.vocab_size = config.vocab_size
self.mup_width_multiplier = config.mup_width_multiplier
self.lm_head = ParallelLMHead(
......
......@@ -45,7 +45,7 @@ from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__)
......@@ -525,11 +525,7 @@ def input_processor_for_phi3v(ctx: InputContext,
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
......@@ -544,12 +540,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
prefix="model.embed_tokens",
prefix=maybe_prefix(prefix, "model.embed_tokens"),
)
# TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
config, quant_config, prefix="model.vision_embed_tokens")
config,
quant_config,
prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
# The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model"
......
......@@ -28,7 +28,7 @@ from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear,
......@@ -48,7 +48,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class PhiMoEConfig(PretrainedConfig):
......@@ -432,15 +433,14 @@ class PhiMoEDecoderLayer(nn.Module):
@support_torch_compile
class PhiMoEModel(nn.Module):
def __init__(
self,
config: PhiMoEConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0)
......@@ -529,23 +529,15 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = PhiMoEModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.model = PhiMoEModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......
......@@ -38,7 +38,7 @@ from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import init_vllm_registered_model
from .utils import init_vllm_registered_model, maybe_prefix
try:
from xformers import ops as xops
......@@ -152,11 +152,7 @@ def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
......@@ -176,7 +172,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model(
config.text_config,
vllm_config=vllm_config,
prefix="language_model")
prefix=maybe_prefix(prefix, "language_model"))
self.vision_encoder = VisionTransformer(self.vision_args)
self.vision_language_adapter = VisionLanguageAdapter(
......
......@@ -50,7 +50,8 @@ from vllm.utils import is_list_of
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__)
......@@ -552,14 +553,13 @@ class QWenBlock(nn.Module):
@support_torch_compile
class QWenModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
......@@ -865,20 +865,17 @@ def dummy_data_for_qwen(
class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.quant_config = quant_config
self.transformer = QWenModel(config, cache_config, quant_config)
self.transformer = QWenModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
......
......@@ -240,14 +240,13 @@ class Qwen2DecoderLayer(nn.Module):
@support_torch_compile
class Qwen2Model(nn.Module):
def __init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
......@@ -403,11 +402,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
......@@ -429,9 +424,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(config,
cache_config,
quant_config,
self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if config.tie_word_embeddings:
......
......@@ -264,14 +264,9 @@ def input_mapper_for_qwen2_audio(
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
......@@ -283,8 +278,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
self.quant_config = quant_config
self.language_model = Qwen2Model(config.text_config, cache_config,
quant_config)
self.language_model = Qwen2Model(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=prefix)
self.unpadded_vocab_size = config.text_config.vocab_size
if config.text_config.tie_word_embeddings:
self.lm_head = self.language_model.embed_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment