Unverified Commit 1a95f10e authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[5/N] pass the whole config to model (#9983)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 49d2a41a
...@@ -22,7 +22,7 @@ from transformers import GemmaConfig ...@@ -22,7 +22,7 @@ from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
...@@ -374,13 +374,14 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -374,13 +374,14 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: GemmaConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled # currently all existing Gemma models have `tie_word_embeddings` enabled
......
...@@ -21,7 +21,7 @@ from transformers import Gemma2Config ...@@ -21,7 +21,7 @@ from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
...@@ -245,12 +245,13 @@ class Gemma2Model(nn.Module): ...@@ -245,12 +245,13 @@ class Gemma2Model(nn.Module):
def __init__( def __init__(
self, self,
config: Gemma2Config, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
...@@ -400,11 +401,13 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -400,11 +401,13 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: Gemma2Config, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
del lora_config # Unused. del lora_config # Unused.
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -470,14 +473,14 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP): ...@@ -470,14 +473,14 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
pooler_config: Optional[PoolerConfig] = None, vllm_config: VllmConfig,
**kwargs, prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = Gemma2Model(**kwargs) self.model = Gemma2Model(vllm_config, prefix)
self._pooler = Pooler.from_config_with_defaults( self._pooler = Pooler.from_config_with_defaults(
pooler_config, vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST, pooling_type=PoolingType.LAST,
normalize=True, normalize=True,
softmax=False) softmax=False)
......
...@@ -24,7 +24,7 @@ from transformers import GPT2Config ...@@ -24,7 +24,7 @@ from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size) get_pp_group, get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -242,11 +242,13 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): ...@@ -242,11 +242,13 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: GPT2Config, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPT2Model(config, self.transformer = GPT2Model(config,
......
...@@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig ...@@ -25,7 +25,7 @@ from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -260,12 +260,14 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -260,12 +260,14 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
......
...@@ -23,7 +23,7 @@ from transformers import GPTJConfig ...@@ -23,7 +23,7 @@ from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -231,11 +231,13 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -231,11 +231,13 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: GPTJConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
assert not config.tie_word_embeddings assert not config.tie_word_embeddings
......
...@@ -23,7 +23,7 @@ from transformers import GPTNeoXConfig ...@@ -23,7 +23,7 @@ from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -244,11 +244,13 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -244,11 +244,13 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: GPTNeoXConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config) self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
......
...@@ -28,7 +28,7 @@ from transformers import GraniteConfig ...@@ -28,7 +28,7 @@ from transformers import GraniteConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -372,12 +372,14 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -372,12 +372,14 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: GraniteConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
......
...@@ -28,7 +28,7 @@ from transformers.models.granitemoe import GraniteMoeConfig ...@@ -28,7 +28,7 @@ from transformers.models.granitemoe import GraniteMoeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -335,12 +335,14 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -335,12 +335,14 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: GraniteMoeConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
......
...@@ -26,7 +26,7 @@ from transformers import PretrainedConfig as Idefics3Config ...@@ -26,7 +26,7 @@ from transformers import PretrainedConfig as Idefics3Config
from transformers import ProcessorMixin as Idefics3ImageProcessor from transformers import ProcessorMixin as Idefics3ImageProcessor
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -615,13 +615,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -615,13 +615,16 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__( def __init__(
self, self,
config: Idefics3Config, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
......
...@@ -11,9 +11,8 @@ from vllm.utils import supports_kw ...@@ -11,9 +11,8 @@ from vllm.utils import supports_kw
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.pooler import PoolerOutput from vllm.model_executor.layers.pooler import PoolerOutput
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -39,10 +38,8 @@ class VllmModel(Protocol[C_co, T_co]): ...@@ -39,10 +38,8 @@ class VllmModel(Protocol[C_co, T_co]):
def __init__( def __init__(
self, self,
config: C_co, vllm_config: "VllmConfig",
*, prefix: str = "",
cache_config: Optional["CacheConfig"],
quant_config: Optional["QuantizationConfig"],
) -> None: ) -> None:
... ...
...@@ -58,20 +55,7 @@ class VllmModel(Protocol[C_co, T_co]): ...@@ -58,20 +55,7 @@ class VllmModel(Protocol[C_co, T_co]):
def _check_vllm_model_init(model: Union[Type[object], object]) -> bool: def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
model_init = model.__init__ model_init = model.__init__
vllm_kws = ("cache_config", "quant_config") return supports_kw(model_init, "vllm_config")
missing_kws = tuple(kw for kw in vllm_kws
if not supports_kw(model_init, kw))
if missing_kws and (isinstance(model, type)
and issubclass(model, nn.Module)):
logger.warning(
"The model (%s) is missing "
"vLLM-specific keywords from its initializer: %s",
model,
missing_kws,
)
return len(missing_kws) == 0
def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
......
...@@ -7,7 +7,7 @@ from transformers import PretrainedConfig ...@@ -7,7 +7,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
...@@ -319,12 +319,13 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP): ...@@ -319,12 +319,13 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = InternLM2Model(config, self.model = InternLM2Model(config,
......
...@@ -5,7 +5,7 @@ from torch import nn ...@@ -5,7 +5,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -161,11 +161,12 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM): ...@@ -161,11 +161,12 @@ class InternLM2VEForCausalLM(InternLM2ForCausalLM):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__(config, cache_config, quant_config) super().__init__(config, cache_config, quant_config)
self.model = InternLM2VEModel(config, self.model = InternLM2VEModel(config,
cache_config, cache_config,
......
...@@ -16,7 +16,7 @@ from PIL import Image ...@@ -16,7 +16,7 @@ from PIL import Image
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.quantization import (AWQConfig, from vllm.model_executor.layers.quantization import (AWQConfig,
...@@ -410,13 +410,13 @@ input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) ...@@ -410,13 +410,13 @@ input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor) @INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: PretrainedConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self._patch_quant_config(config, quant_config) self._patch_quant_config(config, quant_config)
...@@ -440,8 +440,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -440,8 +440,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
self.mlp1 = self._init_mlp1(config) self.mlp1 = self._init_mlp1(config)
......
...@@ -26,7 +26,7 @@ from torch import nn ...@@ -26,7 +26,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -288,11 +288,13 @@ class JAISLMHeadModel(nn.Module, SupportsPP): ...@@ -288,11 +288,13 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: JAISConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = JAISModel(config, cache_config, quant_config) self.transformer = JAISModel(config, cache_config, quant_config)
......
...@@ -7,7 +7,7 @@ from transformers import JambaConfig ...@@ -7,7 +7,7 @@ from transformers import JambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -350,12 +350,14 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): ...@@ -350,12 +350,14 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA):
def __init__( def __init__(
self, self,
config: JambaConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None: ) -> None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
"Jamba currently does not support prefix caching" "Jamba currently does not support prefix caching"
......
...@@ -28,7 +28,7 @@ from transformers import LlamaConfig ...@@ -28,7 +28,7 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -494,15 +494,15 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -494,15 +494,15 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: LlamaConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "", prefix: str = "",
pooler_config: Optional[PoolerConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
...@@ -654,12 +654,22 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): ...@@ -654,12 +654,22 @@ class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
pooler_config: Optional[PoolerConfig] = None, vllm_config: VllmConfig,
**kwargs, prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.model = LlamaModel(**kwargs) config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.model = LlamaModel(config,
cache_config,
quant_config,
lora_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults( self._pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
pooling_type=PoolingType.LAST, pooling_type=PoolingType.LAST,
......
...@@ -9,7 +9,7 @@ from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig, ...@@ -9,7 +9,7 @@ from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
PretrainedConfig, SiglipVisionConfig) PretrainedConfig, SiglipVisionConfig)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext) InputContext)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -258,13 +258,13 @@ def init_vision_tower_for_llava( ...@@ -258,13 +258,13 @@ def init_vision_tower_for_llava(
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: LlavaConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
...@@ -290,8 +290,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -290,8 +290,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
......
...@@ -11,11 +11,10 @@ from transformers.models.llava_next.modeling_llava_next import ( ...@@ -11,11 +11,10 @@ from transformers.models.llava_next.modeling_llava_next import (
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext) InputContext)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -282,13 +281,12 @@ def input_processor_for_llava_next(ctx: InputContext, ...@@ -282,13 +281,12 @@ def input_processor_for_llava_next(ctx: InputContext,
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: LlavaNextConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
...@@ -308,8 +306,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -308,8 +306,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
# The same model class supports both language generation and embedding # The same model class supports both language generation and embedding
......
...@@ -10,11 +10,10 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig, ...@@ -10,11 +10,10 @@ from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
SiglipVisionConfig) SiglipVisionConfig)
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -254,12 +253,11 @@ class LlavaNextMultiModalProjector(nn.Module): ...@@ -254,12 +253,11 @@ class LlavaNextMultiModalProjector(nn.Module):
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: LlavaNextVideoConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
...@@ -277,8 +275,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -277,8 +275,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
......
...@@ -14,11 +14,10 @@ from transformers.models.llava_onevision.modeling_llava_onevision import ( ...@@ -14,11 +14,10 @@ from transformers.models.llava_onevision.modeling_llava_onevision import (
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -405,12 +404,11 @@ class LlavaOnevisionMultiModalProjector(nn.Module): ...@@ -405,12 +404,11 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__(self, def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config: LlavaOnevisionConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
...@@ -424,8 +422,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -424,8 +422,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
cache_config, vllm_config=vllm_config,
quant_config,
prefix="language_model") prefix="language_model")
self.image_newline = nn.Parameter( self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment