Unverified Commit 1a95f10e authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[5/N] pass the whole config to model (#9983)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 49d2a41a
...@@ -9,8 +9,7 @@ import math ...@@ -9,8 +9,7 @@ import math
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple, from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
Type, cast)
import gguf import gguf
import huggingface_hub import huggingface_hub
...@@ -18,20 +17,17 @@ import numpy as np ...@@ -18,20 +17,17 @@ import numpy as np
import torch import torch
from huggingface_hub import HfApi, hf_hub_download from huggingface_hub import HfApi, hf_hub_download
from torch import nn from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import (CacheConfig, LoadConfig, LoadFormat, LoRAConfig, from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
ModelConfig, MultiModalConfig, ParallelConfig, VllmConfig)
PoolerConfig, SchedulerConfig, VllmConfig)
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ReplicatedLinear, from vllm.model_executor.layers.linear import (ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator) serialize_vllm_model, tensorizer_weights_iterator)
...@@ -43,8 +39,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -43,8 +39,6 @@ from vllm.model_executor.model_loader.weight_utils import (
get_gguf_extra_tensor_names, gguf_quant_weights_iterator, get_gguf_extra_tensor_names, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator) safetensors_weights_iterator)
from vllm.model_executor.models import (has_inner_state, supports_lora,
supports_multimodal)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -94,85 +88,11 @@ def device_loading_context(module: torch.nn.Module, ...@@ -94,85 +88,11 @@ def device_loading_context(module: torch.nn.Module,
logger = init_logger(__name__) logger = init_logger(__name__)
def _get_model_initialization_kwargs(
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}
if supports_lora(model_class):
# lora_config=None is used to disable LoRA
extra_kwargs["lora_config"] = lora_config
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
if supports_multimodal(model_class):
assert multimodal_config is not None
extra_kwargs["multimodal_config"] = multimodal_config
if has_inner_state(model_class) and scheduler_config:
extra_kwargs["scheduler_config"] = scheduler_config
if pooler_config:
extra_kwargs["pooler_config"] = pooler_config
return extra_kwargs
def build_model(model_class: Type[nn.Module],
vllm_config: Optional[VllmConfig],
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
*,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig],
prefix: Optional[str] = None,
pooler_config: Optional[PoolerConfig] = None) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config,
scheduler_config,
pooler_config)
if prefix:
extra_kwargs["prefix"] = prefix
# TODO: unify all the module initialization code
# to only take the `VllmConfig` object as input
from vllm.plugins import set_vllm_config
set_vllm_config(vllm_config)
return model_class(config=hf_config,
cache_config=cache_config,
quant_config=quant_config,
**extra_kwargs)
def _initialize_model(vllm_config: VllmConfig) -> nn.Module: def _initialize_model(vllm_config: VllmConfig) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
model_config = vllm_config.model_config model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
cache_config = vllm_config.cache_config
model_class, _ = get_model_architecture(model_config) model_class, _ = get_model_architecture(model_config)
return model_class(vllm_config=vllm_config)
return build_model(
model_class,
vllm_config,
model_config.hf_config,
cache_config=cache_config,
quant_config=vllm_config.quant_config,
lora_config=lora_config,
multimodal_config=model_config.multimodal_config,
scheduler_config=scheduler_config,
pooler_config=model_config.pooler_config,
)
class BaseModelLoader(ABC): class BaseModelLoader(ABC):
...@@ -486,24 +406,18 @@ class TensorizerLoader(BaseModelLoader): ...@@ -486,24 +406,18 @@ class TensorizerLoader(BaseModelLoader):
device_config = vllm_config.device_config device_config = vllm_config.device_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
lora_config = vllm_config.lora_config
cache_config = vllm_config.cache_config
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0] model_class = get_model_architecture(model_config)[0]
quant_config = vllm_config.quant_config
extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, model_config.multimodal_config)
extra_kwargs["quant_config"] = quant_config
extra_kwargs["cache_config"] = cache_config
tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype tensorizer_config.dtype = model_config.dtype
model = load_with_tensorizer(tensorizer_config, **extra_kwargs) model = load_with_tensorizer(tensorizer_config,
vllm_config=vllm_config)
return model.eval() return model.eval()
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
......
...@@ -17,8 +17,6 @@ from vllm.config import ModelConfig, ParallelConfig ...@@ -17,8 +17,6 @@ from vllm.config import ModelConfig, ParallelConfig
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -268,8 +266,7 @@ class TensorizerAgent: ...@@ -268,8 +266,7 @@ class TensorizerAgent:
in vllm/model_executor/model_loader/weight_utils.py in vllm/model_executor/model_loader/weight_utils.py
""" """
def __init__(self, tensorizer_config: TensorizerConfig, def __init__(self, tensorizer_config: TensorizerConfig, vllm_config):
quant_config: QuantizationConfig, **extra_kwargs):
if tensorizer_error_msg is not None: if tensorizer_error_msg is not None:
raise ImportError( raise ImportError(
"Tensorizer is not installed. Please install tensorizer " "Tensorizer is not installed. Please install tensorizer "
...@@ -279,11 +276,7 @@ class TensorizerAgent: ...@@ -279,11 +276,7 @@ class TensorizerAgent:
self.tensorizer_config = tensorizer_config self.tensorizer_config = tensorizer_config
self.tensorizer_args = ( self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args()) self.tensorizer_config._construct_tensorizer_args())
self.extra_kwargs = extra_kwargs self.vllm_config = vllm_config
if extra_kwargs.get("quant_config") is not None:
self.quant_config = extra_kwargs["quant_config"]
else:
self.quant_config = quant_config
self.model = self._init_model() self.model = self._init_model()
def _init_model(self): def _init_model(self):
...@@ -293,9 +286,7 @@ class TensorizerAgent: ...@@ -293,9 +286,7 @@ class TensorizerAgent:
assert self.tensorizer_config.model_class is not None assert self.tensorizer_config.model_class is not None
with no_init_or_tensor(): with no_init_or_tensor():
return self.tensorizer_config.model_class( return self.tensorizer_config.model_class(
config=model_args, vllm_config=self.vllm_config, )
quant_config=self.quant_config,
**self.extra_kwargs)
def _resize_lora_embeddings(self): def _resize_lora_embeddings(self):
"""Modify LoRA embedding layers to use bigger tensors """Modify LoRA embedding layers to use bigger tensors
......
...@@ -6,7 +6,7 @@ from torch import nn ...@@ -6,7 +6,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -415,14 +415,16 @@ class ArcticModel(nn.Module): ...@@ -415,14 +415,16 @@ class ArcticModel(nn.Module):
class ArcticForCausalLM(nn.Module, SupportsPP): class ArcticForCausalLM(nn.Module, SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: ArcticConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
**kwargs) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.model = ArcticModel(config, cache_config, quant_config) self.model = ArcticModel(config,
cache_config,
quant_config,
prefix=prefix)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.vocab_size, self.vocab_size,
......
...@@ -26,7 +26,7 @@ from transformers import PretrainedConfig ...@@ -26,7 +26,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -332,14 +332,15 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -332,14 +332,15 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
position_embedding: str, prefix: str = "",
cache_config: Optional[CacheConfig] = None, position_embedding: str = "ROPE",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
...@@ -439,17 +440,14 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -439,17 +440,14 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
config = vllm_config.model_config.hf_config
if config.hidden_size == 4096: # baichuan2 7b if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", cache_config, quant_config, super().__init__(vllm_config, prefix, "ROPE")
lora_config)
else: # baichuan 13b, baichuan2 13b else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", cache_config, quant_config, super().__init__(vllm_config, prefix, "ALIBI")
lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
...@@ -459,10 +457,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -459,10 +457,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__(config, "ROPE", cache_config, quant_config, super().__init__(vllm_config, prefix, "ROPE")
lora_config)
...@@ -25,7 +25,7 @@ from transformers import BartConfig ...@@ -25,7 +25,7 @@ from transformers import BartConfig
from transformers.utils import logging from transformers.utils import logging
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -810,13 +810,13 @@ class BartModel(nn.Module): ...@@ -810,13 +810,13 @@ class BartModel(nn.Module):
class BartForConditionalGeneration(nn.Module): class BartForConditionalGeneration(nn.Module):
base_model_prefix = "model" base_model_prefix = "model"
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
# currently all existing BART models have `tie_word_embeddings` enabled # currently all existing BART models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.config = config self.config = config
......
...@@ -6,7 +6,7 @@ from transformers import BertConfig ...@@ -6,7 +6,7 @@ from transformers import BertConfig
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.xformers import XFormersImpl from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import CacheConfig, PoolerConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -384,12 +384,14 @@ class BertEmbeddingModel(nn.Module): ...@@ -384,12 +384,14 @@ class BertEmbeddingModel(nn.Module):
def __init__( def __init__(
self, self,
config: BertConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
self.model = BertModel(config, cache_config, quant_config) self.model = BertModel(config, cache_config, quant_config)
self._pooler = Pooler.from_config_with_defaults( self._pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
......
...@@ -8,7 +8,7 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig, ...@@ -8,7 +8,7 @@ from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
apply_chunking_to_forward) apply_chunking_to_forward)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -483,14 +483,17 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -483,14 +483,17 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(
config: Blip2Config, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
...@@ -513,8 +516,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -513,8 +516,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
......
...@@ -24,7 +24,7 @@ from transformers import BloomConfig ...@@ -24,7 +24,7 @@ from transformers import BloomConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -283,11 +283,13 @@ class BloomForCausalLM(nn.Module, SupportsPP): ...@@ -283,11 +283,13 @@ class BloomForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: BloomConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = BloomModel(config, cache_config, quant_config) self.transformer = BloomModel(config, cache_config, quant_config)
......
...@@ -9,7 +9,7 @@ from torch import nn ...@@ -9,7 +9,7 @@ from torch import nn
from transformers import ChameleonConfig, ChameleonVQVAEConfig from transformers import ChameleonConfig, ChameleonVQVAEConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
...@@ -926,12 +926,14 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -926,12 +926,14 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
def __init__( def __init__(
self, self,
config: ChameleonConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.model = ChameleonModel(config, cache_config, quant_config) self.model = ChameleonModel(config, cache_config, quant_config)
......
...@@ -11,7 +11,7 @@ from torch import nn ...@@ -11,7 +11,7 @@ from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
...@@ -595,14 +595,15 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -595,14 +595,15 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
def __init__( def __init__(
self, self,
config: ChatGLMConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
......
...@@ -28,7 +28,7 @@ from transformers import CohereConfig ...@@ -28,7 +28,7 @@ from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -334,12 +334,14 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -334,12 +334,14 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: CohereConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
# currently all existing command R models have `tie_word_embeddings` # currently all existing command R models have `tie_word_embeddings`
# enabled # enabled
......
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
...@@ -352,11 +352,13 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -352,11 +352,13 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: DbrxConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
if config.tie_word_embeddings: if config.tie_word_embeddings:
raise ValueError( raise ValueError(
......
...@@ -22,13 +22,11 @@ ...@@ -22,13 +22,11 @@
# limitations under the License. # limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights.""" """Inference-only DeciLM model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Tuple from typing import Iterable, Tuple
import torch import torch
from transformers import LlamaConfig
from vllm.config import CacheConfig, LoRAConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
...@@ -55,17 +53,13 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -55,17 +53,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def __init__( def __init__(
self, self,
config: LlamaConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
config = vllm_config.model_config.hf_config
config.num_key_value_heads = max(config.num_key_value_heads_per_layer) config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer") delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config, super().__init__(vllm_config=vllm_config)
cache_config=cache_config,
quant_config=quant_config,
lora_config=lora_config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -27,7 +27,7 @@ from torch import nn ...@@ -27,7 +27,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -385,11 +385,13 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): ...@@ -385,11 +385,13 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekModel(config, cache_config, quant_config) self.model = DeepseekModel(config, cache_config, quant_config)
......
...@@ -28,7 +28,7 @@ from transformers import PretrainedConfig ...@@ -28,7 +28,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -481,11 +481,13 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -481,11 +481,13 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekV2Model(config, self.model = DeepseekV2Model(config,
......
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -12,7 +13,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -12,7 +13,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.eagle import EAGLEConfig
class EAGLE(nn.Module): class EAGLE(nn.Module):
...@@ -34,14 +34,15 @@ class EAGLE(nn.Module): ...@@ -34,14 +34,15 @@ class EAGLE(nn.Module):
in the draft checkpoint (using key token_map). Also, the draft config in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute.""" needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
self.config = config self.config = config
architectures = getattr(self.config.model, "architectures", []) architectures = getattr(self.config.model, "architectures", [])
model_cls, _ = ModelRegistry.resolve_model_cls(architectures) model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
self.model = model_cls(self.config.model, *args, **kwargs) self.model = model_cls(vllm_config, prefix)
self.fc = nn.Linear(config.model.hidden_size * 2, self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size, config.model.hidden_size,
bias=getattr(self.config, "eagle_fc_bias", False)) bias=getattr(self.config, "eagle_fc_bias", False))
......
...@@ -29,7 +29,7 @@ from torch import nn ...@@ -29,7 +29,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -440,12 +440,14 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -440,12 +440,14 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: ExaoneConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
......
...@@ -27,7 +27,7 @@ from transformers import FalconConfig as HF_FalconConfig ...@@ -27,7 +27,7 @@ from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -403,11 +403,13 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -403,11 +403,13 @@ class FalconForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: FalconConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = FalconModel(config, cache_config, quant_config) self.transformer = FalconModel(config, cache_config, quant_config)
......
...@@ -6,7 +6,7 @@ import torch.nn as nn ...@@ -6,7 +6,7 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
...@@ -189,11 +189,11 @@ class Florence2LanguageForConditionalGeneration(nn.Module): ...@@ -189,11 +189,11 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
class Florence2ForConditionalGeneration(nn.Module): class Florence2ForConditionalGeneration(nn.Module):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# TODO(Isotr0py): Add vision backbone # TODO(Isotr0py): Add vision backbone
self.language_model = Florence2LanguageForConditionalGeneration( self.language_model = Florence2LanguageForConditionalGeneration(
......
...@@ -22,14 +22,13 @@ import torch ...@@ -22,14 +22,13 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from PIL import Image from PIL import Image
from transformers import FuyuConfig, FuyuImageProcessor from transformers import FuyuImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -227,12 +226,12 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): ...@@ -227,12 +226,12 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: FuyuConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment