Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6b16fff0
Unverified
Commit
6b16fff0
authored
Dec 23, 2025
by
Jee Jee Li
Committed by
GitHub
Dec 23, 2025
Browse files
[Bugfix] Fix Jais2ForCausalLM (#31198)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
f1c2c201
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
25 deletions
+4
-25
vllm/model_executor/models/jais2.py
vllm/model_executor/models/jais2.py
+4
-25
No files found.
vllm/model_executor/models/jais2.py
View file @
6b16fff0
...
@@ -48,7 +48,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -48,7 +48,6 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
)
)
...
@@ -167,7 +166,6 @@ class Jais2Attention(nn.Module):
...
@@ -167,7 +166,6 @@ class Jais2Attention(nn.Module):
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
rope_parameters
=
getattr
(
config
,
"rope_parameters"
,
None
),
rope_parameters
=
getattr
(
config
,
"rope_parameters"
,
None
),
is_neox_style
=
is_neox_style
,
is_neox_style
=
is_neox_style
,
...
@@ -304,17 +302,12 @@ class Jais2Model(nn.Module):
...
@@ -304,17 +302,12 @@ class Jais2Model(nn.Module):
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
self
.
vocab_size
=
config
.
vocab_size
if
lora_config
else
0
)
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
org_vocab_size
=
config
.
vocab_size
if
get_pp_group
().
is_first_rank
or
(
if
get_pp_group
().
is_first_rank
or
(
config
.
tie_word_embeddings
and
get_pp_group
().
is_last_rank
config
.
tie_word_embeddings
and
get_pp_group
().
is_last_rank
...
@@ -456,29 +449,15 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -456,29 +449,15 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
super
().
__init__
()
super
().
__init__
()
config
=
vllm_config
.
model_config
.
hf_config
config
=
vllm_config
.
model_config
.
hf_config
quant_config
=
vllm_config
.
quant_config
quant_config
=
vllm_config
.
quant_config
lora_config
=
vllm_config
.
lora_config
self
.
config
=
config
self
.
config
=
config
self
.
lora_config
=
lora_config
self
.
model
=
self
.
_init_model
(
self
.
model
=
self
.
_init_model
(
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
vllm_config
=
vllm_config
,
prefix
=
maybe_prefix
(
prefix
,
"model"
)
)
)
if
get_pp_group
().
is_last_rank
:
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
self
.
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
self
.
lm_head
=
ParallelLMHead
(
self
.
unpadded_
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
),
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
)
...
@@ -487,7 +466,7 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -487,7 +466,7 @@ class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
config
.
vocab_size
,
scale
=
logit_scale
)
)
else
:
else
:
self
.
lm_head
=
PPMissingLayer
()
self
.
lm_head
=
PPMissingLayer
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment