Unverified Commit f89d18ff authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[6/N] pass whole config to inner model (#10205)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent f0f2e563
...@@ -34,7 +34,8 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig ...@@ -34,7 +34,8 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -364,14 +365,13 @@ class ArcticDecoderLayer(nn.Module): ...@@ -364,14 +365,13 @@ class ArcticDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class ArcticModel(nn.Module): class ArcticModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: ArcticConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
...@@ -418,13 +418,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP): ...@@ -418,13 +418,10 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.model = ArcticModel(config, self.model = ArcticModel(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config,
prefix=prefix)
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.vocab_size, self.vocab_size,
......
...@@ -253,13 +253,18 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -253,13 +253,18 @@ class BaiChuanDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class BaiChuanModel(nn.Module): class BaiChuanModel(nn.Module):
def __init__(self, def __init__(
config: PretrainedConfig, self,
position_embedding: str, vllm_config: VllmConfig,
cache_config: Optional[CacheConfig] = None, prefix: str = "",
quant_config: Optional[QuantizationConfig] = None, position_embedding: str = "ROPE",
prefix: str = ""): ) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -332,21 +337,22 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -332,21 +337,22 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
def __init__( def __init__(
self, self,
*,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str = "", prefix: str = "",
position_embedding: str = "ROPE", position_embedding: str = "ROPE",
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, cache_config, self.model = BaiChuanModel(vllm_config=vllm_config,
quant_config) prefix=prefix,
position_embedding=position_embedding)
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
...@@ -438,16 +444,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -438,16 +444,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
NOTE: the class name has a lower case 'c'. NOTE: the class name has a lower case 'c'.
""" """
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
if config.hidden_size == 4096: # baichuan2 7b if config.hidden_size == 4096: # baichuan2 7b
super().__init__(vllm_config, prefix, "ROPE") super().__init__(vllm_config=vllm_config,
prefix=prefix,
position_embedding="ROPE")
else: # baichuan 13b, baichuan2 13b else: # baichuan 13b, baichuan2 13b
super().__init__(vllm_config, prefix, "ALIBI") super().__init__(vllm_config=vllm_config,
prefix=prefix,
position_embedding="ALIBI")
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
...@@ -455,9 +461,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -455,9 +461,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
NOTE: the class name has an upper case 'C'. NOTE: the class name has an upper case 'C'.
""" """
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self, super().__init__(vllm_config=vllm_config,
vllm_config: VllmConfig, prefix=prefix,
prefix: str = "", position_embedding="ROPE")
):
super().__init__(vllm_config, prefix, "ROPE")
...@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -739,13 +741,14 @@ class BartModel(nn.Module): ...@@ -739,13 +741,14 @@ class BartModel(nn.Module):
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight" "encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
] ]
def __init__(self, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -810,20 +813,16 @@ class BartModel(nn.Module): ...@@ -810,20 +813,16 @@ class BartModel(nn.Module):
class BartForConditionalGeneration(nn.Module): class BartForConditionalGeneration(nn.Module):
base_model_prefix = "model" base_model_prefix = "model"
def __init__(self, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
# currently all existing BART models have `tie_word_embeddings` enabled # currently all existing BART models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.config = config self.config = config
self.model = BartModel(config, self.model = BartModel(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
if lora_config: if lora_config:
......
...@@ -21,6 +21,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -21,6 +21,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from .utils import maybe_prefix
class BertEmbedding(nn.Module): class BertEmbedding(nn.Module):
...@@ -309,12 +311,13 @@ class BertOutput(nn.Module): ...@@ -309,12 +311,13 @@ class BertOutput(nn.Module):
class BertModel(nn.Module): class BertModel(nn.Module):
def __init__(self, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embeddings = BertEmbedding(config) self.embeddings = BertEmbedding(config)
self.encoder = BertEncoder(config, self.encoder = BertEncoder(config,
cache_config, cache_config,
...@@ -382,17 +385,11 @@ class BertEmbeddingModel(nn.Module): ...@@ -382,17 +385,11 @@ class BertEmbeddingModel(nn.Module):
_pooler: An instance of Pooler used for pooling operations. _pooler: An instance of Pooler used for pooling operations.
""" """
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
pooler_config = vllm_config.model_config.pooler_config pooler_config = vllm_config.model_config.pooler_config
self.model = BertModel(config, cache_config, quant_config) self.model = BertModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults( self._pooler = Pooler.from_config_with_defaults(
pooler_config, pooler_config,
pooling_type=PoolingType.CLS, pooling_type=PoolingType.CLS,
......
...@@ -23,7 +23,7 @@ from .blip import (BlipVisionModel, dummy_image_for_blip, ...@@ -23,7 +23,7 @@ from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens) get_max_blip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, init_vllm_registered_model, from .utils import (AutoWeightsLoader, init_vllm_registered_model,
merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
# We use this internally as placeholders since there is no image token # We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo # defined on the HuggingFace repo
...@@ -483,11 +483,7 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs): ...@@ -483,11 +483,7 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
...@@ -517,7 +513,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -517,7 +513,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, config.text_config,
vllm_config=vllm_config, vllm_config=vllm_config,
prefix="language_model") prefix=maybe_prefix(prefix, "language_model"))
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
......
...@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors ...@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
...@@ -221,14 +222,13 @@ class BloomBlock(nn.Module): ...@@ -221,14 +222,13 @@ class BloomBlock(nn.Module):
@support_torch_compile @support_torch_compile
class BloomModel(nn.Module): class BloomModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: BloomConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
# Embedding + LN Embedding # Embedding + LN Embedding
...@@ -288,11 +288,12 @@ class BloomForCausalLM(nn.Module, SupportsPP): ...@@ -288,11 +288,12 @@ class BloomForCausalLM(nn.Module, SupportsPP):
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = BloomModel(config, cache_config, quant_config) self.transformer = BloomModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings self.lm_head = self.transformer.word_embeddings
else: else:
......
...@@ -37,7 +37,8 @@ from vllm.utils import print_warning_once ...@@ -37,7 +37,8 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal, SupportsPP from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
# These configs are not part of the model config but the preprocessor # These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now. # and processor files, so we hardcode them in the model file for now.
...@@ -831,14 +832,13 @@ class ChameleonImageVocabularyMapping: ...@@ -831,14 +832,13 @@ class ChameleonImageVocabularyMapping:
class ChameleonModel(nn.Module): class ChameleonModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: ChameleonConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -924,19 +924,14 @@ class ChameleonModel(nn.Module): ...@@ -924,19 +924,14 @@ class ChameleonModel(nn.Module):
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP): SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
self.config = config self.config = config
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.model = ChameleonModel(config, cache_config, quant_config) self.model = ChameleonModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, self.unpadded_vocab_size,
......
...@@ -39,7 +39,8 @@ from vllm.transformers_utils.configs import ChatGLMConfig ...@@ -39,7 +39,8 @@ from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -481,14 +482,13 @@ class GLMTransformer(nn.Module): ...@@ -481,14 +482,13 @@ class GLMTransformer(nn.Module):
class ChatGLMModel(nn.Module): class ChatGLMModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.embedding = VocabParallelEmbedding(config.padded_vocab_size, self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
...@@ -600,7 +600,6 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -600,7 +600,6 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
multimodal_config = vllm_config.model_config.multimodal_config multimodal_config = vllm_config.model_config.multimodal_config
...@@ -611,7 +610,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, ...@@ -611,7 +610,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
self.quant_config = quant_config self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length", self.max_position_embeddings = getattr(config, "max_sequence_length",
8192) 8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config) self.transformer = ChatGLMModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.transformer.output_layer.weight = ( self.transformer.output_layer.weight = (
self.transformer.embedding.weight) self.transformer.embedding.weight)
......
...@@ -28,7 +28,7 @@ from transformers import CohereConfig ...@@ -28,7 +28,7 @@ from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...@@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors ...@@ -49,7 +49,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@torch.compile @torch.compile
...@@ -253,15 +254,14 @@ class CohereDecoderLayer(nn.Module): ...@@ -253,15 +254,14 @@ class CohereDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class CohereModel(nn.Module): class CohereModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: CohereConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
lora_vocab = (lora_config.lora_extra_vocab_size * lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0 (lora_config.max_loras or 1)) if lora_config else 0
...@@ -332,14 +332,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -332,14 +332,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {"embed_tokens": "input_embeddings"} embedding_modules = {"embed_tokens": "input_embeddings"}
embedding_padding_modules = [] embedding_padding_modules = []
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
self.config = config self.config = config
...@@ -353,10 +348,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -353,10 +348,8 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, config.vocab_size,
scale=config.logit_scale) scale=config.logit_scale)
self.model = CohereModel(config, self.model = CohereModel(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config,
lora_config=lora_config)
self.sampler = get_sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
......
...@@ -25,7 +25,8 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig ...@@ -25,7 +25,8 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class DbrxRouter(nn.Module): class DbrxRouter(nn.Module):
...@@ -294,14 +295,13 @@ class DbrxBlock(nn.Module): ...@@ -294,14 +295,13 @@ class DbrxBlock(nn.Module):
class DbrxModel(nn.Module): class DbrxModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.d_model, config.d_model,
...@@ -357,7 +357,6 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -357,7 +357,6 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
if config.tie_word_embeddings: if config.tie_word_embeddings:
...@@ -365,7 +364,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -365,7 +364,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
"tie_word_embeddings is not supported for Dbrx models.") "tie_word_embeddings is not supported for Dbrx models.")
self.quant_config = quant_config self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, cache_config, quant_config) self.transformer = DbrxModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.d_model, config.d_model,
......
...@@ -51,11 +51,7 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -51,11 +51,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
instead. instead.
""" """
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
config.num_key_value_heads = max(config.num_key_value_heads_per_layer) config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer") delattr(config, "num_key_value_heads_per_layer")
......
...@@ -50,7 +50,8 @@ from vllm.sequence import IntermediateTensors ...@@ -50,7 +50,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class DeepseekMLP(nn.Module): class DeepseekMLP(nn.Module):
...@@ -326,14 +327,13 @@ class DeepseekModel(nn.Module): ...@@ -326,14 +327,13 @@ class DeepseekModel(nn.Module):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -383,18 +383,14 @@ class DeepseekModel(nn.Module): ...@@ -383,18 +383,14 @@ class DeepseekModel(nn.Module):
class DeepseekForCausalLM(nn.Module, SupportsPP): class DeepseekForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekModel(config, cache_config, quant_config) self.model = DeepseekModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
......
...@@ -51,7 +51,8 @@ from vllm.sequence import IntermediateTensors ...@@ -51,7 +51,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
...@@ -408,14 +409,13 @@ class DeepseekV2Model(nn.Module): ...@@ -408,14 +409,13 @@ class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
...@@ -479,21 +479,14 @@ class DeepseekV2Model(nn.Module): ...@@ -479,21 +479,14 @@ class DeepseekV2Model(nn.Module):
class DeepseekV2ForCausalLM(nn.Module, SupportsPP): class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekV2Model(config, self.model = DeepseekV2Model(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"))
quant_config,
prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
......
...@@ -14,6 +14,8 @@ from vllm.model_executor.models import ModelRegistry ...@@ -14,6 +14,8 @@ from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
class EAGLE(nn.Module): class EAGLE(nn.Module):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
...@@ -42,7 +44,8 @@ class EAGLE(nn.Module): ...@@ -42,7 +44,8 @@ class EAGLE(nn.Module):
architectures = getattr(self.config.model, "architectures", []) architectures = getattr(self.config.model, "architectures", [])
model_cls, _ = ModelRegistry.resolve_model_cls(architectures) model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
self.model = model_cls(vllm_config, prefix) self.model = model_cls(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.fc = nn.Linear(config.model.hidden_size * 2, self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size, config.model.hidden_size,
bias=getattr(self.config, "eagle_fc_bias", False)) bias=getattr(self.config, "eagle_fc_bias", False))
......
...@@ -29,7 +29,7 @@ from torch import nn ...@@ -29,7 +29,7 @@ from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -54,7 +54,8 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig ...@@ -54,7 +54,8 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class ExaoneGatedMLP(nn.Module): class ExaoneGatedMLP(nn.Module):
...@@ -314,15 +315,14 @@ class ExaoneDecoderLayer(nn.Module): ...@@ -314,15 +315,14 @@ class ExaoneDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class ExaoneModel(nn.Module): class ExaoneModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: ExaoneConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = ((lora_config.lora_extra_vocab_size * lora_vocab = ((lora_config.lora_extra_vocab_size *
...@@ -438,14 +438,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -438,14 +438,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"c_fc_1": ("gate_up_proj", 1), "c_fc_1": ("gate_up_proj", 1),
} }
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
...@@ -453,11 +448,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -453,11 +448,8 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config self.lora_config = lora_config
self.transformer = ExaoneModel( self.transformer = ExaoneModel(
config, vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(prefix, "model"),
quant_config,
lora_config=lora_config,
prefix="model",
) )
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
......
...@@ -48,7 +48,8 @@ from vllm.transformers_utils.configs import RWConfig ...@@ -48,7 +48,8 @@ from vllm.transformers_utils.configs import RWConfig
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
FalconConfig = Union[HF_FalconConfig, RWConfig] FalconConfig = Union[HF_FalconConfig, RWConfig]
...@@ -332,14 +333,13 @@ class FalconDecoderLayer(nn.Module): ...@@ -332,14 +333,13 @@ class FalconDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class FalconModel(nn.Module): class FalconModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
...@@ -408,11 +408,12 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -408,11 +408,12 @@ class FalconForCausalLM(nn.Module, SupportsPP):
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = FalconModel(config, cache_config, quant_config) self.transformer = FalconModel(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "transformer"))
# only Falcon-11B doesn't share lm_head weight with word embeddings # only Falcon-11B doesn't share lm_head weight with word embeddings
# and previous Falcon model doesn't have tie_word_embeddings config # and previous Falcon model doesn't have tie_word_embeddings config
# so we set tie_word_embeddings to True by default # so we set tie_word_embeddings to True by default
......
...@@ -3,13 +3,10 @@ from typing import Iterable, List, Optional, Tuple ...@@ -3,13 +3,10 @@ from typing import Iterable, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
...@@ -23,11 +20,13 @@ from .utils import AutoWeightsLoader ...@@ -23,11 +20,13 @@ from .utils import AutoWeightsLoader
class Florence2LanguageModel(nn.Module): class Florence2LanguageModel(nn.Module):
def __init__(self, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -93,15 +92,14 @@ class Florence2LanguageModel(nn.Module): ...@@ -93,15 +92,14 @@ class Florence2LanguageModel(nn.Module):
class Florence2LanguageForConditionalGeneration(nn.Module): class Florence2LanguageForConditionalGeneration(nn.Module):
def __init__(self, def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
self.config = config self.config = config
self.model = Florence2LanguageModel(config, self.model = Florence2LanguageModel(vllm_config=vllm_config,
cache_config=cache_config, prefix=prefix)
quant_config=quant_config)
embed_scale = math.sqrt( embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0 config.d_model) if config.scale_embedding else 1.0
...@@ -189,17 +187,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module): ...@@ -189,17 +187,15 @@ class Florence2LanguageForConditionalGeneration(nn.Module):
class Florence2ForConditionalGeneration(nn.Module): class Florence2ForConditionalGeneration(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
# TODO(Isotr0py): Add vision backbone # TODO(Isotr0py): Add vision backbone
self.language_model = Florence2LanguageForConditionalGeneration( self.language_model = Florence2LanguageForConditionalGeneration(
config=config.text_config, vllm_config=vllm_config.with_hf_config(config.text_config),
cache_config=cache_config, prefix=prefix,
quant_config=quant_config) )
@property @property
def sampler(self): def sampler(self):
......
...@@ -258,14 +258,13 @@ class GemmaDecoderLayer(nn.Module): ...@@ -258,14 +258,13 @@ class GemmaDecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class GemmaModel(nn.Module): class GemmaModel(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
...@@ -372,14 +371,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -372,14 +371,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
...@@ -389,9 +383,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -389,9 +383,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.lora_config = lora_config self.lora_config = lora_config
self.quant_config = quant_config self.quant_config = quant_config
self.model = GemmaModel(config, self.model = GemmaModel(vllm_config=vllm_config,
cache_config,
quant_config,
prefix=maybe_prefix(prefix, "model")) prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler() self.sampler = get_sampler()
......
...@@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput ...@@ -43,7 +43,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -243,11 +244,7 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -243,11 +244,7 @@ class Gemma2DecoderLayer(nn.Module):
@support_torch_compile @support_torch_compile
class Gemma2Model(nn.Module): class Gemma2Model(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
...@@ -399,13 +396,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -399,13 +396,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"up_proj": ("gate_up_proj", 1), "up_proj": ("gate_up_proj", 1),
} }
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
del lora_config # Unused. del lora_config # Unused.
...@@ -414,7 +406,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -414,7 +406,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# currently all existing Gemma models have `tie_word_embeddings` enabled # currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings assert config.tie_word_embeddings
self.quant_config = quant_config self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config) self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=config.final_logit_softcapping) config.vocab_size, soft_cap=config.final_logit_softcapping)
self.sampler = get_sampler() self.sampler = get_sampler()
...@@ -471,14 +464,11 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP): ...@@ -471,14 +464,11 @@ class Gemma2EmbeddingModel(nn.Module, SupportsPP):
_pooler: An instance of Pooler used for pooling operations. _pooler: An instance of Pooler used for pooling operations.
""" """
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__() super().__init__()
self.model = Gemma2Model(vllm_config, prefix) self.model = Gemma2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self._pooler = Pooler.from_config_with_defaults( self._pooler = Pooler.from_config_with_defaults(
vllm_config.model_config.pooler_config, vllm_config.model_config.pooler_config,
pooling_type=PoolingType.LAST, pooling_type=PoolingType.LAST,
......
...@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors ...@@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
class GPT2Attention(nn.Module): class GPT2Attention(nn.Module):
...@@ -184,14 +185,13 @@ class GPT2Block(nn.Module): ...@@ -184,14 +185,13 @@ class GPT2Block(nn.Module):
@support_torch_compile @support_torch_compile
class GPT2Model(nn.Module): class GPT2Model(nn.Module):
def __init__( def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self,
config: GPT2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config self.config = config
assert not config.add_cross_attention assert not config.add_cross_attention
assert not config.scale_attn_by_inverse_layer_idx assert not config.scale_attn_by_inverse_layer_idx
...@@ -247,14 +247,12 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): ...@@ -247,14 +247,12 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
): ):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPT2Model(config, self.transformer = GPT2Model(vllm_config=vllm_config,
cache_config, prefix=maybe_prefix(
quant_config, prefix, "transformer"))
prefix="transformer")
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment