Unverified Commit 1a95f10e authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[5/N] pass the whole config to model (#9983)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 49d2a41a
......@@ -22,7 +22,7 @@ from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
......@@ -374,13 +374,14 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__(
self,
config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
......
......@@ -21,7 +21,7 @@ from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
......@@ -245,12 +245,13 @@ class Gemma2Model(nn.Module):
def __init__(
self,
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embed_tokens = VocabParallelEmbedding(
......@@ -400,11 +401,13 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__(
self,
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
del lora_config # Unused.
super().__init__()
self.config = config
......@@ -470,14 +473,14 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
def __init__(
self,
pooler_config: Optional[PoolerConfig] = None,
**kwargs,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
self.model = Gemma2Model(**kwargs)
self.model = Gemma2Model(vllm_config, prefix)
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
......
......@@ -24,7 +24,7 @@ from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
......@@ -242,11 +242,13 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
def __init__(
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = GPT2Model(config,
......
......@@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -260,12 +260,14 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
......
......@@ -23,7 +23,7 @@ from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -231,11 +231,13 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
assert not config.tie_word_embeddings
......
......@@ -23,7 +23,7 @@ from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -244,11 +244,13 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
......
......@@ -28,7 +28,7 @@ from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -372,12 +372,14 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__(
self,
config: GraniteConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
......
......@@ -28,7 +28,7 @@ from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -335,12 +335,14 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__(
self,
config: GraniteMoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
......
......@@ -26,7 +26,7 @@ from transformers import PretrainedConfig as Idefics3Config
from transformers import ProcessorMixin as Idefics3ImageProcessor
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
......@@ -615,13 +615,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(
self,
config: Idefics3Config,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
......
......@@ -11,9 +11,8 @@ from vllm.utils import supports_kw
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -39,10 +38,8 @@ class VllmModel(Protocol[C_co, T_co]):
def __init__(
self,
config: C_co,
*,
cache_config: Optional["CacheConfig"],
quant_config: Optional["QuantizationConfig"],
vllm_config: "VllmConfig",
prefix: str = "",
) -> None:
...
......@@ -58,20 +55,7 @@ class VllmModel(Protocol[C_co, T_co]):
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
model_init = model.__init__
vllm_kws = ("cache_config", "quant_config")
missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_init, kw))
if missing_kws and (isinstance(model, type)
and issubclass(model, nn.Module)):
logger.warning(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s",
model,
missing_kws,
)
return len(missing_kws) == 0
return supports_kw(model_init, "vllm_config")
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
......
......@@ -7,7 +7,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
......@@ -319,12 +319,13 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.model = InternLM2Model(config,
......
......@@ -5,7 +5,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
......@@ -161,11 +161,12 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__(config, cache_config, quant_config)
self.model = InternLM2VEModel(config,
cache_config,
......
......@@ -16,7 +16,7 @@ from PIL import Image
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.quantization import (AWQConfig,
......@@ -410,13 +410,13 @@ input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self,
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self._patch_quant_config(config, quant_config)
......@@ -440,8 +440,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model = init_vllm_registered_model(
config.text_config,
cache_config,
quant_config,
vllm_config=vllm_config,
prefix="language_model")
self.mlp1 = self._init_mlp1(config)
......
......@@ -26,7 +26,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
......@@ -288,11 +288,13 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
def __init__(
self,
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config)
......
......@@ -7,7 +7,7 @@ from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
......@@ -350,12 +350,14 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
def __init__(
self,
config: JambaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \
"Jamba currently does not support prefix caching"
......
......@@ -28,7 +28,7 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -494,15 +494,15 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
vllm_config: VllmConfig,
prefix: str = "",
pooler_config: Optional[PoolerConfig] = None,
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config
self.lora_config = lora_config
......@@ -654,12 +654,22 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
def __init__(
self,
pooler_config: Optional[PoolerConfig] = None,
**kwargs,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
self.model = LlamaModel(**kwargs)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.model = LlamaModel(config,
cache_config,
quant_config,
lora_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
......
......@@ -9,7 +9,7 @@ from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
PretrainedConfig, SiglipVisionConfig)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.activation import get_act_fn
......@@ -258,13 +258,13 @@ def init_vision_tower_for_llava(
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self,
config: LlavaConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
......@@ -290,8 +290,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model = init_vllm_registered_model(
config.text_config,
cache_config,
quant_config,
vllm_config=vllm_config,
prefix="language_model")
self.make_empty_intermediate_tensors = (
......
......@@ -11,11 +11,10 @@ from transformers.models.llava_next.modeling_llava_next import (
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext)
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -282,13 +281,12 @@ def input_processor_for_llava_next(ctx: InputContext,
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self,
config: LlavaNextConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> None:
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
......@@ -308,8 +306,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model(
config.text_config,
cache_config,
quant_config,
vllm_config=vllm_config,
prefix="language_model")
# The same model class supports both language generation and embedding
......
......@@ -10,11 +10,10 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
SiglipVisionConfig)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -254,12 +253,11 @@ class LlavaNextMultiModalProjector(nn.Module):
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self,
config: LlavaNextVideoConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
......@@ -277,8 +275,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config,
cache_config,
quant_config,
vllm_config=vllm_config,
prefix="language_model")
self.make_empty_intermediate_tensors = (
......
......@@ -14,11 +14,10 @@ from transformers.models.llava_onevision.modeling_llava_onevision import (
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -405,12 +404,11 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self,
config: LlavaOnevisionConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
......@@ -424,8 +422,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config,
cache_config,
quant_config,
vllm_config=vllm_config,
prefix="language_model")
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment