Unverified Commit 91373a0d authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix `head_dim` not existing in all model configs (Transformers backend) (#14141)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 848a6438
...@@ -25,7 +25,6 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS ...@@ -25,7 +25,6 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention from vllm.attention import Attention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
...@@ -128,10 +127,12 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): ...@@ -128,10 +127,12 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = model_config.get_vocab_size()
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = model_config.get_vocab_size()
self.model: PreTrainedModel = AutoModel.from_config( self.model: PreTrainedModel = AutoModel.from_config(
self.config, self.config,
...@@ -145,15 +146,17 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): ...@@ -145,15 +146,17 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
self.apply_base_model_tp_plan(self.model) self.apply_base_model_tp_plan(self.model)
# Attention modifications (assumes 1 attention op per hidden layer) # Attention modifications (assumes 1 attention op per hidden layer)
tp_size = get_tensor_model_parallel_world_size() num_heads = model_config.get_num_attention_heads(parallel_config)
head_size = model_config.get_head_size()
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.attention_instances = [ self.attention_instances = [
Attention( Attention(
num_heads=divide(config.num_attention_heads, tp_size), num_heads=num_heads,
head_size=config.head_dim, head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by # NOTE: We use Llama scale as default, if it's set by
# Transformers, it's updated in vllm_flash_attention_forward # Transformers, it's updated in vllm_flash_attention_forward
scale=config.head_dim**-0.5, scale=head_size**-0.5,
num_kv_heads=divide(config.num_key_value_heads, tp_size), num_kv_heads=num_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=f"{i}.attn") for i in range(config.num_hidden_layers) prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
...@@ -163,7 +166,7 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): ...@@ -163,7 +166,7 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
self.replace_vocab_embed_class(self.model) self.replace_vocab_embed_class(self.model)
# ForCausalLM modifications # ForCausalLM modifications
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(self.vocab_size,
config.hidden_size, config.hidden_size,
quant_config=self.quant_config, quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head")) prefix=maybe_prefix(prefix, "lm_head"))
...@@ -172,7 +175,7 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): ...@@ -172,7 +175,7 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale) self.vocab_size, logit_scale)
self.sampler = get_sampler() self.sampler = get_sampler()
def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""): def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
...@@ -203,12 +206,12 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA): ...@@ -203,12 +206,12 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
new_module = VocabParallelEmbedding( new_module = VocabParallelEmbedding(
self.vocab_size, self.vocab_size,
self.config.hidden_size, self.config.hidden_size,
org_num_embeddings=self.config.vocab_size, org_num_embeddings=self.vocab_size,
quant_config=None, quant_config=None,
) )
log_replacement("input embedding", self.model.get_input_embeddings(), log_replacement("input embedding", self.model.get_input_embeddings(),
new_module) new_module)
self.model.set_input_embeddings(new_module) module.set_input_embeddings(new_module)
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment