Unverified Commit f89d18ff authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[6/N] pass whole config to inner model (#10205)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent f0f2e563
...@@ -17,7 +17,7 @@ from vllm.model_executor.models.qwen2 import Qwen2Model ...@@ -17,7 +17,7 @@ from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from .utils import AutoWeightsLoader from .utils import AutoWeightsLoader, maybe_prefix
class Qwen2ForSequenceClassification(nn.Module): class Qwen2ForSequenceClassification(nn.Module):
...@@ -43,11 +43,7 @@ class Qwen2ForSequenceClassification(nn.Module): ...@@ -43,11 +43,7 @@ class Qwen2ForSequenceClassification(nn.Module):
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
...@@ -70,7 +66,8 @@ class Qwen2ForSequenceClassification(nn.Module): ...@@ -70,7 +66,8 @@ class Qwen2ForSequenceClassification(nn.Module):
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2Model(config, cache_config, quant_config) self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.score = RowParallelLinear(config.hidden_size, self.score = RowParallelLinear(config.hidden_size,
config.num_labels, config.num_labels,
......
...@@ -54,7 +54,8 @@ from vllm.utils import print_warning_once ...@@ -54,7 +54,8 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):
...@@ -315,14 +316,13 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -315,14 +316,13 @@ class Qwen2MoeDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class Qwen2MoeModel(nn.Module): class Qwen2MoeModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -377,18 +377,14 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): ...@@ -377,18 +377,14 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2MoeModel(config, cache_config, quant_config) self.model = Qwen2MoeModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
......
...@@ -18,7 +18,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput ...@@ -18,7 +18,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .qwen2 import Qwen2Model from .qwen2 import Qwen2Model
from .utils import AutoWeightsLoader from .utils import AutoWeightsLoader, maybe_prefix
class ReLU(nn.Module): class ReLU(nn.Module):
...@@ -55,11 +55,7 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP): ...@@ -55,11 +55,7 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
...@@ -82,7 +78,8 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP): ...@@ -82,7 +78,8 @@ class Qwen2ForRewardModel(nn.Module, SupportsPP):
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Qwen2Model(config, cache_config, quant_config) self.model = Qwen2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.score = nn.Sequential( self.score = nn.Sequential(
ColumnParallelLinear(config.hidden_size, ColumnParallelLinear(config.hidden_size,
......
...@@ -70,7 +70,7 @@ from vllm.transformers_utils.processor import cached_get_processor ...@@ -70,7 +70,7 @@ from vllm.transformers_utils.processor import cached_get_processor
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (PPMissingLayer, get_vit_attn_backend, from .utils import (PPMissingLayer, get_vit_attn_backend,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory) make_empty_intermediate_tensors_factory, maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -966,11 +966,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -966,11 +966,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
...@@ -986,13 +982,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -986,13 +982,11 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6), norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=self._maybe_ignore_quant_config(quant_config), quant_config=self._maybe_ignore_quant_config(quant_config),
prefix="visual", prefix=maybe_prefix(prefix, "visual"),
) )
self.model = Qwen2Model(config, self.model = Qwen2Model(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config,
prefix="model")
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
if config.tie_word_embeddings: if config.tie_word_embeddings:
...@@ -1001,7 +995,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1001,7 +995,8 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
prefix="lm_head") prefix=maybe_prefix(
prefix, "lm_head"))
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
......
...@@ -29,7 +29,7 @@ from transformers import PretrainedConfig ...@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -53,7 +53,8 @@ from vllm.sequence import IntermediateTensors ...@@ -53,7 +53,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class SolarMLP(nn.Module): class SolarMLP(nn.Module):
...@@ -266,15 +267,14 @@ class SolarDecoderLayer(nn.Module): ...@@ -266,15 +267,14 @@ class SolarDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class SolarModel(nn.Module): class SolarModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size * lora_vocab = ((lora_config.lora_extra_vocab_size *
...@@ -409,25 +409,17 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -409,25 +409,17 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj": ("gate_up_proj", 1), "up_proj": ("gate_up_proj", 1),
} }
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.model = SolarModel( self.model = SolarModel(
config, vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"),
quant_config,
lora_config=lora_config,
prefix="model",
) )
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
......
...@@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors ...@@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):
...@@ -193,12 +194,13 @@ class StablelmDecoderLayer(nn.Module): ...@@ -193,12 +194,13 @@ class StablelmDecoderLayer(nn.Module):
class StableLMEpochModel(nn.Module): class StableLMEpochModel(nn.Module):
def __init__(self, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '') -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
...@@ -245,18 +247,14 @@ class StableLMEpochModel(nn.Module): ...@@ -245,18 +247,14 @@ class StableLMEpochModel(nn.Module):
class StablelmForCausalLM(nn.Module, SupportsPP): class StablelmForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = StableLMEpochModel(config, cache_config, quant_config) self.model = StableLMEpochModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
......
...@@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors ...@@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class Starcoder2Attention(nn.Module): class Starcoder2Attention(nn.Module):
...@@ -195,12 +196,13 @@ class Starcoder2DecoderLayer(nn.Module): ...@@ -195,12 +196,13 @@ class Starcoder2DecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class Starcoder2Model(nn.Module): class Starcoder2Model(nn.Module):
def __init__(self, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -245,19 +247,13 @@ class Starcoder2Model(nn.Module): ...@@ -245,19 +247,13 @@ class Starcoder2Model(nn.Module):
class Starcoder2ForCausalLM(nn.Module, SupportsPP): class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.model = Starcoder2Model(config, self.model = Starcoder2Model(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config=quant_config)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if config.tie_word_embeddings: if config.tie_word_embeddings:
......
...@@ -34,7 +34,7 @@ from vllm.utils import is_list_of ...@@ -34,7 +34,7 @@ from vllm.utils import is_list_of
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings_from_map) merge_multimodal_embeddings_from_map)
_AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_PLACEHOLDER_TOKEN = 128002
...@@ -339,11 +339,7 @@ class ModifiedWhisperEncoder(WhisperEncoder): ...@@ -339,11 +339,7 @@ class ModifiedWhisperEncoder(WhisperEncoder):
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) @INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
...@@ -354,6 +350,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -354,6 +350,8 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
self.secondary_weights = [] self.secondary_weights = []
self.audio_tower = ModifiedWhisperEncoder(config.audio_config) self.audio_tower = ModifiedWhisperEncoder(config.audio_config)
if config.audio_model_id is not None: if config.audio_model_id is not None:
# this prefix is not for initialization, but for loading weights
# note the trailing dot
self.secondary_weights.append( self.secondary_weights.append(
DefaultModelLoader.Source( DefaultModelLoader.Source(
model_or_path=config.audio_model_id, model_or_path=config.audio_model_id,
...@@ -362,8 +360,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -362,8 +360,12 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
)) ))
self.multi_modal_projector = UltravoxProjector(config) self.multi_modal_projector = UltravoxProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, vllm_config, prefix="language_model") config.text_config,
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "language_model"))
if config.text_model_id is not None: if config.text_model_id is not None:
# this prefix is not for initialization, but for loading weights
# note the trailing dot
self.secondary_weights.append( self.secondary_weights.append(
DefaultModelLoader.Source(model_or_path=config.text_model_id, DefaultModelLoader.Source(model_or_path=config.text_model_id,
revision=None, revision=None,
......
...@@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors ...@@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class XverseMLP(nn.Module): class XverseMLP(nn.Module):
...@@ -223,11 +224,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -223,11 +224,7 @@ class XverseDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class XverseModel(nn.Module): class XverseModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
...@@ -315,15 +312,10 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -315,15 +312,10 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
} }
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
...@@ -331,7 +323,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -331,7 +323,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = XverseModel(config, cache_config, quant_config) self.model = XverseModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment