Unverified Commit 1a95f10e authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[5/N] pass the whole config to model (#9983)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 49d2a41a
...@@ -9,14 +9,14 @@ import torch.nn as nn ...@@ -9,14 +9,14 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk from mistral_common.protocol.instruct.messages import ImageChunk
from PIL import Image from PIL import Image
from transformers import PixtralVisionConfig, PretrainedConfig from transformers import PixtralVisionConfig
from transformers.models.pixtral.image_processing_pixtral import ( from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens) _num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import ( from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid) PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_and_mul_fn from vllm.model_executor.layers.activation import get_act_and_mul_fn
...@@ -152,13 +152,14 @@ def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -152,13 +152,14 @@ def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__(self, def __init__(
config: PretrainedConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
...@@ -174,8 +175,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -174,8 +175,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# init MistralForCausalLM # init MistralForCausalLM
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
self.vision_encoder = VisionTransformer(self.vision_args) self.vision_encoder = VisionTransformer(self.vision_args)
......
...@@ -20,7 +20,7 @@ from transformers import PretrainedConfig ...@@ -20,7 +20,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
...@@ -867,13 +867,14 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): ...@@ -867,13 +867,14 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, ) -> None:
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.quant_config = quant_config self.quant_config = quant_config
...@@ -1064,17 +1065,13 @@ class QWenLMHeadModel(QWenBaseModel, SupportsLoRA): ...@@ -1064,17 +1065,13 @@ class QWenLMHeadModel(QWenBaseModel, SupportsLoRA):
def __new__( def __new__(
cls, cls,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None, ) -> None:
quant_config: Optional[QuantizationConfig] = None, config = vllm_config.model_config.hf_config
lora_config: Optional[LoRAConfig] = None,
):
# Initialize VL # Initialize VL
if hasattr(config, "visual"): if hasattr(config, "visual"):
return QWenVL(config, multimodal_config, cache_config, return QWenVL(vllm_config)
quant_config, lora_config)
# Initialize LLM # Initialize LLM
else: else:
return QWenLLM(config, multimodal_config, cache_config, return QWenLLM(vllm_config)
quant_config, lora_config)
...@@ -29,7 +29,7 @@ from transformers import Qwen2Config ...@@ -29,7 +29,7 @@ from transformers import Qwen2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -405,12 +405,14 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -405,12 +405,14 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: Qwen2Config, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
# TODO (@robertgshaw2): see if this can be moved out # TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")): and hasattr(config, "max_window_layers")):
...@@ -423,8 +425,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -423,8 +425,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
config.num_hidden_layers, config.num_hidden_layers,
)) ))
super().__init__()
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
......
...@@ -26,16 +26,14 @@ import librosa ...@@ -26,16 +26,14 @@ import librosa
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import Qwen2AudioConfig, Qwen2AudioEncoder from transformers import Qwen2AudioEncoder
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -266,13 +264,16 @@ def input_mapper_for_qwen2_audio( ...@@ -266,13 +264,16 @@ def input_mapper_for_qwen2_audio(
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__(self, def __init__(
config: Qwen2AudioConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
......
...@@ -8,14 +8,11 @@ from typing import Iterable, List, Optional, Tuple ...@@ -8,14 +8,11 @@ from typing import Iterable, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import Qwen2Config
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
...@@ -48,12 +45,15 @@ class Qwen2ForSequenceClassification(nn.Module): ...@@ -48,12 +45,15 @@ class Qwen2ForSequenceClassification(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen2Config, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
# TODO (@robertgshaw2): see if this can be moved out # TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")): and hasattr(config, "max_window_layers")):
...@@ -66,8 +66,6 @@ class Qwen2ForSequenceClassification(nn.Module): ...@@ -66,8 +66,6 @@ class Qwen2ForSequenceClassification(nn.Module):
config.num_hidden_layers, config.num_hidden_layers,
)) ))
super().__init__()
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
......
...@@ -30,7 +30,7 @@ from transformers import PretrainedConfig ...@@ -30,7 +30,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -379,11 +379,13 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): ...@@ -379,11 +379,13 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2MoeModel(config, cache_config, quant_config) self.model = Qwen2MoeModel(config, cache_config, quant_config)
......
...@@ -7,14 +7,12 @@ from typing import Iterable, List, Optional, Tuple, Union ...@@ -7,14 +7,12 @@ from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Qwen2Config
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
...@@ -59,12 +57,15 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP): ...@@ -59,12 +57,15 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: Qwen2Config, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
# TODO (@robertgshaw2): see if this can be moved out # TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")): and hasattr(config, "max_window_layers")):
...@@ -77,8 +78,6 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP): ...@@ -77,8 +78,6 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
config.num_hidden_layers, config.num_hidden_layers,
)) ))
super().__init__()
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
......
...@@ -40,7 +40,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( ...@@ -40,7 +40,7 @@ from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.attention.selector import _Backend from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils from vllm.distributed import utils as dist_utils
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
...@@ -966,15 +966,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -966,15 +966,16 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
def __init__(self, def __init__(
config: Qwen2VLConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None, ) -> None:
lora_config: Optional[LoRAConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
"Qwen2-VL currently does not support prefix caching" "Qwen2-VL currently does not support prefix caching"
......
...@@ -29,7 +29,7 @@ from transformers import PretrainedConfig ...@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -411,13 +411,14 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -411,13 +411,14 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
......
...@@ -25,7 +25,7 @@ from torch import nn ...@@ -25,7 +25,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -247,11 +247,13 @@ class StablelmForCausalLM(nn.Module, SupportsPP): ...@@ -247,11 +247,13 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = StableLMEpochModel(config, cache_config, quant_config) self.model = StableLMEpochModel(config, cache_config, quant_config)
......
...@@ -25,7 +25,7 @@ from transformers import Starcoder2Config ...@@ -25,7 +25,7 @@ from transformers import Starcoder2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -245,11 +245,15 @@ class Starcoder2Model(nn.Module): ...@@ -245,11 +245,15 @@ class Starcoder2Model(nn.Module):
class Starcoder2ForCausalLM(nn.Module, SupportsPP): class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def __init__(self, def __init__(
config: Starcoder2Config, self,
cache_config: Optional[CacheConfig] = None, vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None): prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.model = Starcoder2Model(config, self.model = Starcoder2Model(config,
cache_config, cache_config,
......
...@@ -15,12 +15,11 @@ from transformers.models.whisper import WhisperFeatureExtractor ...@@ -15,12 +15,11 @@ from transformers.models.whisper import WhisperFeatureExtractor
from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.models.whisper.modeling_whisper import WhisperEncoder
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -340,12 +339,14 @@ class ModifiedWhisperEncoder(WhisperEncoder): ...@@ -340,12 +339,14 @@ class ModifiedWhisperEncoder(WhisperEncoder):
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) @INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(
config: UltravoxConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional["QuantizationConfig"] = None): ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multi_modal_config = multimodal_config self.multi_modal_config = multimodal_config
assert self.multi_modal_config assert self.multi_modal_config
...@@ -361,10 +362,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -361,10 +362,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
)) ))
self.multi_modal_projector = UltravoxProjector(config) self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config, vllm_config, prefix="language_model")
cache_config,
quant_config,
prefix="language_model")
if config.text_model_id is not None: if config.text_model_id is not None:
self.secondary_weights.append( self.secondary_weights.append(
DefaultModelLoader.Source(model_or_path=config.text_model_id, DefaultModelLoader.Source(model_or_path=config.text_model_id,
......
...@@ -11,11 +11,8 @@ from transformers import PretrainedConfig ...@@ -11,11 +11,8 @@ from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.selector import (_Backend, backend_name_to_enum, from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend) get_global_forced_attn_backend)
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig, from vllm.config import VllmConfig
SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors from vllm.multimodal.base import MultiModalPlaceholderMap, NestedTensors
...@@ -236,12 +233,7 @@ class AutoWeightsLoader: ...@@ -236,12 +233,7 @@ class AutoWeightsLoader:
def init_vllm_registered_model( def init_vllm_registered_model(
hf_config: PretrainedConfig, hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig], vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig],
*,
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
prefix: str = "", prefix: str = "",
) -> nn.Module: ) -> nn.Module:
""" """
...@@ -249,18 +241,11 @@ def init_vllm_registered_model( ...@@ -249,18 +241,11 @@ def init_vllm_registered_model(
based on the arguments passed to the outer vLLM model. based on the arguments passed to the outer vLLM model.
""" """
model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures) model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
import copy
copied_config = copy.deepcopy(vllm_config)
copied_config.model_config.hf_config = hf_config
return build_model( return model_class(vllm_config=copied_config, prefix=prefix)
model_class,
None,
hf_config,
cache_config,
quant_config,
lora_config=lora_config,
multimodal_config=multimodal_config,
scheduler_config=scheduler_config,
prefix=prefix,
)
@overload @overload
......
...@@ -27,7 +27,7 @@ from transformers import PretrainedConfig ...@@ -27,7 +27,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -225,13 +225,14 @@ class XverseModel(nn.Module): ...@@ -225,13 +225,14 @@ class XverseModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
...@@ -316,13 +317,16 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -316,13 +317,16 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
......
...@@ -61,15 +61,3 @@ def set_compilation_config(config: Optional[CompilationConfig]): ...@@ -61,15 +61,3 @@ def set_compilation_config(config: Optional[CompilationConfig]):
def get_compilation_config() -> Optional[CompilationConfig]: def get_compilation_config() -> Optional[CompilationConfig]:
return _compilation_config return _compilation_config
_vllm_config: Optional[VllmConfig] = None
def set_vllm_config(config: Optional[VllmConfig]):
global _vllm_config
_vllm_config = config
def get_vllm_config() -> Optional[VllmConfig]:
return _vllm_config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment