Unverified Commit 6b16fff0 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Bugfix] Fix Jais2ForCausalLM (#31198)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent f1c2c201
...@@ -48,7 +48,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor ...@@ -48,7 +48,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
...@@ -167,7 +166,6 @@ class Jais2Attention(nn.Module): ...@@ -167,7 +166,6 @@ class Jais2Attention(nn.Module):
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
rope_parameters=getattr(config, "rope_parameters", None), rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
...@@ -304,17 +302,12 @@ class Jais2Model(nn.Module): ...@@ -304,17 +302,12 @@ class Jais2Model(nn.Module):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) self.vocab_size = config.vocab_size
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or ( if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank config.tie_word_embeddings and get_pp_group().is_last_rank
...@@ -456,29 +449,15 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -456,29 +449,15 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config self.config = config
self.lora_config = lora_config
self.model = self._init_model( self.model = self._init_model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
) )
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
self.unpadded_vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config, quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"), prefix=maybe_prefix(prefix, "lm_head"),
) )
...@@ -487,7 +466,7 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -487,7 +466,7 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale config.vocab_size, scale=logit_scale
) )
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment