Unverified Commit 1a95f10e authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[5/N] pass the whole config to model (#9983)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 49d2a41a
...@@ -6,7 +6,7 @@ from torch import nn ...@@ -6,7 +6,7 @@ from torch import nn
from transformers import MambaConfig from transformers import MambaConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -132,12 +132,14 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): ...@@ -132,12 +132,14 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def __init__( def __init__(
self, self,
config: MambaConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> None: ) -> None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
scheduler_config = vllm_config.scheduler_config
assert not cache_config.enable_prefix_caching, \ assert not cache_config.enable_prefix_caching, \
"Mamba does not support prefix caching" "Mamba does not support prefix caching"
......
...@@ -3,13 +3,13 @@ from typing import Iterable, List, Optional, Tuple ...@@ -3,13 +3,13 @@ from typing import Iterable, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.transformers_utils.configs.medusa import MedusaConfig
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
...@@ -44,7 +44,8 @@ class Medusa(nn.Module): ...@@ -44,7 +44,8 @@ class Medusa(nn.Module):
in the draft checkpoint (using key token_map). Also, the draft config in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute.""" needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, config: MedusaConfig, **_) -> None: def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
super().__init__() super().__init__()
self.config = config self.config = config
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
......
...@@ -29,7 +29,7 @@ from transformers import PretrainedConfig ...@@ -29,7 +29,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -463,12 +463,14 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -463,12 +463,14 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
......
...@@ -34,7 +34,7 @@ from transformers import PretrainedConfig ...@@ -34,7 +34,7 @@ from transformers import PretrainedConfig
from typing_extensions import NotRequired from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
...@@ -385,11 +385,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -385,11 +385,13 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
): ):
config = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__() super().__init__()
# All MiniCPM-V models disable `tie_word_embeddings` but # All MiniCPM-V models disable `tie_word_embeddings` but
# `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot
...@@ -701,12 +703,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel): ...@@ -701,12 +703,10 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__(config, multimodal_config, cache_config, quant_config) super().__init__(vllm_config)
assert self.version == (2, 0) assert self.version == (2, 0)
def init_llm( def init_llm(
...@@ -867,13 +867,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): ...@@ -867,13 +867,10 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__(config, multimodal_config, cache_config, quant_config) super().__init__(vllm_config)
assert self.version == (2, 5) assert self.version == (2, 5)
def init_llm( def init_llm(
...@@ -1017,12 +1014,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1017,12 +1014,10 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: MultiModalConfig, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__(config, multimodal_config, cache_config, quant_config) super().__init__(vllm_config)
assert self.version == (2, 6) assert self.version == (2, 6)
def init_llm( def init_llm(
...@@ -1141,12 +1136,8 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1141,12 +1136,8 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
def __new__(cls, def __new__(cls, vllm_config: VllmConfig, prefix: str = ""):
config: PretrainedConfig, config = vllm_config.model_config.hf_config
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
if not hasattr(config, "version"): if not hasattr(config, "version"):
if config.hidden_size == 2304 and config.query_num == 64: if config.hidden_size == 2304 and config.query_num == 64:
version = (2, 0) version = (2, 0)
...@@ -1160,5 +1151,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA): ...@@ -1160,5 +1151,4 @@ class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA):
if instance_class is None: if instance_class is None:
raise ValueError( raise ValueError(
"Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6")
return instance_class(config, multimodal_config, cache_config, return instance_class(vllm_config, prefix=prefix)
quant_config)
...@@ -28,7 +28,7 @@ from transformers import MixtralConfig ...@@ -28,7 +28,7 @@ from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -334,13 +334,14 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -334,13 +334,14 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: MixtralConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
......
...@@ -29,7 +29,7 @@ from torch import nn ...@@ -29,7 +29,7 @@ from torch import nn
from transformers import MixtralConfig from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
...@@ -352,11 +352,13 @@ class MixtralForCausalLM(nn.Module, SupportsPP): ...@@ -352,11 +352,13 @@ class MixtralForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: MixtralConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = MixtralModel(config, cache_config, quant_config) self.model = MixtralModel(config, cache_config, quant_config)
......
...@@ -33,7 +33,7 @@ from transformers.models.mllama.processing_mllama import ( ...@@ -33,7 +33,7 @@ from transformers.models.mllama.processing_mllama import (
import vllm.distributed.parallel_state as ps import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs,
InputContext, TokenInputs, token_inputs) InputContext, TokenInputs, token_inputs)
...@@ -1108,12 +1108,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1108,12 +1108,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
"up_proj": ("gate_up_proj", 1), "up_proj": ("gate_up_proj", 1),
} }
def __init__(self, def __init__(
config: config_mllama.MllamaConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None): ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.vocab_size = config.text_config.vocab_size self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles self.max_num_tiles = config.vision_config.max_num_tiles
......
...@@ -3,8 +3,7 @@ import re ...@@ -3,8 +3,7 @@ import re
from array import array from array import array
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import (Any, Iterable, List, Mapping, Optional, Tuple, TypedDict, from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union
Union)
import torch import torch
from einops import rearrange from einops import rearrange
...@@ -16,7 +15,7 @@ from transformers import PretrainedConfig ...@@ -16,7 +15,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import _Backend from vllm.attention.selector import _Backend
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
split_tensor_along_last_dim, split_tensor_along_last_dim,
...@@ -1027,13 +1026,14 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -1027,13 +1026,14 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
multimodal_config: Optional[MultiModalConfig] = None, prefix: str = "",
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[Mapping[str, Any]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
......
...@@ -7,7 +7,7 @@ import torch.nn as nn ...@@ -7,7 +7,7 @@ import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -269,11 +269,13 @@ class MPTForCausalLM(nn.Module, SupportsPP): ...@@ -269,11 +269,13 @@ class MPTForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: MPTConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None, ) -> None:
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.quant_config = quant_config self.quant_config = quant_config
......
...@@ -27,7 +27,7 @@ from torch import nn ...@@ -27,7 +27,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -403,13 +403,14 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -403,13 +403,14 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: NemotronConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
assert isinstance(config, NemotronConfig) assert isinstance(config, NemotronConfig)
self.config = config self.config = config
......
...@@ -28,7 +28,7 @@ from transformers import OlmoConfig ...@@ -28,7 +28,7 @@ from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -291,11 +291,15 @@ class OlmoForCausalLM(nn.Module, SupportsPP): ...@@ -291,11 +291,15 @@ class OlmoForCausalLM(nn.Module, SupportsPP):
Extremely barebones HF model wrapper. Extremely barebones HF model wrapper.
""" """
def __init__(self, def __init__(
config: OlmoConfig, self,
cache_config: Optional[CacheConfig] = None, vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None): prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.model = OlmoModel(config, cache_config, quant_config) self.model = OlmoModel(config, cache_config, quant_config)
if config.tie_word_embeddings: if config.tie_word_embeddings:
......
...@@ -18,7 +18,7 @@ from transformers import PretrainedConfig ...@@ -18,7 +18,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -311,11 +311,13 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): ...@@ -311,11 +311,13 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OlmoeModel(config, cache_config, quant_config) self.model = OlmoeModel(config, cache_config, quant_config)
......
...@@ -24,7 +24,7 @@ from transformers import OPTConfig ...@@ -24,7 +24,7 @@ from transformers import OPTConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -344,11 +344,13 @@ class OPTForCausalLM(nn.Module, SupportsPP): ...@@ -344,11 +344,13 @@ class OPTForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: OPTConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
......
...@@ -11,7 +11,7 @@ from transformers import PretrainedConfig ...@@ -11,7 +11,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -270,11 +270,13 @@ class OrionForCausalLM(nn.Module, SupportsPP): ...@@ -270,11 +270,13 @@ class OrionForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = OrionModel(config, cache_config, quant_config) self.model = OrionModel(config, cache_config, quant_config)
......
...@@ -6,13 +6,11 @@ from torch import nn ...@@ -6,13 +6,11 @@ from torch import nn
from transformers import PaliGemmaConfig from transformers import PaliGemmaConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.gemma import GemmaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
...@@ -21,7 +19,8 @@ from vllm.sequence import IntermediateTensors ...@@ -21,7 +19,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens) dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import AutoWeightsLoader, merge_multimodal_embeddings from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -132,13 +131,15 @@ class PaliGemmaMultiModalProjector(nn.Module): ...@@ -132,13 +131,15 @@ class PaliGemmaMultiModalProjector(nn.Module):
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__(self, def __init__(
config: PaliGemmaConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
...@@ -150,9 +151,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -150,9 +151,10 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
projection_dim=config.vision_config.projection_dim) projection_dim=config.vision_config.projection_dim)
self.quant_config = quant_config self.quant_config = quant_config
self.language_model = GemmaForCausalLM(config.text_config, config.text_config.architectures = ["GemmaForCausalLM"]
cache_config, self.language_model = init_vllm_registered_model(
quant_config, config.text_config,
vllm_config=vllm_config,
prefix="language_model") prefix="language_model")
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.language_model.logits_processor.scale *= logit_scale self.language_model.logits_processor.scale *= logit_scale
......
...@@ -27,7 +27,7 @@ from transformers import PersimmonConfig ...@@ -27,7 +27,7 @@ from transformers import PersimmonConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -265,11 +265,15 @@ class PersimmonModel(nn.Module): ...@@ -265,11 +265,15 @@ class PersimmonModel(nn.Module):
class PersimmonForCausalLM(nn.Module, SupportsPP): class PersimmonForCausalLM(nn.Module, SupportsPP):
def __init__(self, def __init__(
config: PersimmonConfig, self,
cache_config: Optional[CacheConfig] = None, vllm_config: VllmConfig,
quant_config: Optional[QuantizationConfig] = None): prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.model = PersimmonModel(config, self.model = PersimmonModel(config,
......
...@@ -42,7 +42,7 @@ from transformers import PhiConfig ...@@ -42,7 +42,7 @@ from transformers import PhiConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...@@ -279,13 +279,14 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -279,13 +279,14 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: PhiConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None, ) -> None:
lora_config: Optional[LoRAConfig] = None,
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
# lm_head use bias, cannot share word embeddings # lm_head use bias, cannot share word embeddings
assert not config.tie_word_embeddings assert not config.tie_word_embeddings
......
...@@ -6,7 +6,7 @@ from torch import nn ...@@ -6,7 +6,7 @@ from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -365,12 +365,13 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP): ...@@ -365,12 +365,13 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config: PretrainedConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None, ) -> None:
lora_config: Optional[LoRAConfig] = None,
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Phi3SmallModel(config, cache_config, quant_config) self.model = Phi3SmallModel(config, cache_config, quant_config)
......
...@@ -25,8 +25,7 @@ from PIL import Image ...@@ -25,8 +25,7 @@ from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import (CacheConfig, ModelConfig, MultiModalConfig, from vllm.config import ModelConfig, VllmConfig
PoolerConfig)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs) InputContext, token_inputs)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -526,14 +525,16 @@ def input_processor_for_phi3v(ctx: InputContext, ...@@ -526,14 +525,16 @@ def input_processor_for_phi3v(ctx: InputContext,
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(
config: PretrainedConfig, self,
multimodal_config: MultiModalConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None, ) -> None:
pooler_config: Optional[PoolerConfig] = None) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.image_token_id = _IMAGE_TOKEN_ID self.image_token_id = _IMAGE_TOKEN_ID
...@@ -552,8 +553,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -552,8 +553,8 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# The prefix is empty intentionally because default prefix of # The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model" # LlamaForCausalLM is "model"
self.language_model = LlamaForCausalLM(config, cache_config, self.language_model = LlamaForCausalLM(vllm_config=vllm_config,
quant_config) prefix="")
# The same model class supports both language generation and embedding # The same model class supports both language generation and embedding
# because the architecture name is the same # because the architecture name is the same
......
...@@ -28,7 +28,7 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -28,7 +28,7 @@ from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
...@@ -531,13 +531,14 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -531,13 +531,14 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
config: PhiMoEConfig, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment