Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3e9ce609
"vscode:/vscode.git/clone" did not exist on "81f09cfd80a5a2e1572ee79facd60bb823923367"
Unverified
Commit
3e9ce609
authored
May 28, 2025
by
wang.yuqi
Committed by
GitHub
May 27, 2025
Browse files
[Bugfix] Fix nomic max_model_len (#18755)
parent
794ae1f5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
242 additions
and
3 deletions
+242
-3
examples/offline_inference/context_extension.py
examples/offline_inference/context_extension.py
+46
-0
tests/models/language/pooling/test_nomic_max_model_len.py
tests/models/language/pooling/test_nomic_max_model_len.py
+130
-0
vllm/config.py
vllm/config.py
+14
-0
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+52
-3
No files found.
examples/offline_inference/context_extension.py
0 → 100644
View file @
3e9ce609
# SPDX-License-Identifier: Apache-2.0
from
vllm
import
LLM
,
SamplingParams
rope_theta
=
1000000
original_max_position_embeddings
=
32768
factor
=
4.0
# Use yarn to extend context
hf_overrides
=
{
"rope_theta"
:
rope_theta
,
"rope_scaling"
:
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
,
"original_max_position_embeddings"
:
original_max_position_embeddings
,
},
"max_model_len"
:
int
(
original_max_position_embeddings
*
factor
),
}
llm
=
LLM
(
model
=
"Qwen/Qwen3-0.6B"
,
hf_overrides
=
hf_overrides
)
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
128
,
)
conversation
=
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful assistant"
},
{
"role"
:
"user"
,
"content"
:
"Hello"
},
{
"role"
:
"assistant"
,
"content"
:
"Hello! How can I assist you today?"
},
]
outputs
=
llm
.
chat
(
conversation
,
sampling_params
,
use_tqdm
=
False
)
def
print_outputs
(
outputs
):
print
(
"
\n
Generated Outputs:
\n
"
+
"-"
*
80
)
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
\n
"
)
print
(
f
"Generated text:
{
generated_text
!
r
}
"
)
print
(
"-"
*
80
)
print_outputs
(
outputs
)
tests/models/language/pooling/test_nomic_max_model_len.py
0 → 100644
View file @
3e9ce609
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: SIM117
import
pytest
from
...utils
import
EmbedModelInfo
MODELS
=
[
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v1"
),
#EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5"),
#EmbedModelInfo("nomic-ai/CodeRankEmbed"),
EmbedModelInfo
(
"nomic-ai/nomic-embed-text-v2-moe"
),
#EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long"),
]
rope_theta
=
1000
factor
=
4.0
original_max_position_embeddings
=
2048
max_model_len
=
int
(
original_max_position_embeddings
*
factor
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
def
test_default
(
model_info
,
vllm_runner
):
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
max_model_len
=
None
)
as
vllm_model
:
model_config
=
vllm_model
.
model
.
llm_engine
.
model_config
if
model_info
.
name
==
"nomic-ai/nomic-embed-text-v2-moe"
:
# For nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
assert
model_config
.
max_model_len
==
512
else
:
assert
(
model_config
.
max_model_len
==
original_max_position_embeddings
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
def
test_set_max_model_len_legal
(
model_info
,
vllm_runner
):
# set max_model_len <= 512
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
max_model_len
=
256
)
as
vllm_model
:
model_config
=
vllm_model
.
model
.
llm_engine
.
model_config
assert
model_config
.
max_model_len
==
256
# set 512 < max_model_len <= 2048
if
model_info
.
name
==
"nomic-ai/nomic-embed-text-v2-moe"
:
# For nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
with
pytest
.
raises
(
ValueError
):
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
max_model_len
=
1024
):
pass
else
:
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
max_model_len
=
1024
)
as
vllm_model
:
model_config
=
vllm_model
.
model
.
llm_engine
.
model_config
assert
model_config
.
max_model_len
==
1024
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
def
test_set_max_model_len_illegal
(
model_info
,
vllm_runner
):
# set max_model_len > 2048
with
pytest
.
raises
(
ValueError
):
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
max_model_len
=
4096
):
pass
# set max_model_len > 2048 by hf_overrides
hf_overrides
=
{
"max_model_len"
:
4096
}
with
pytest
.
raises
(
ValueError
):
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
max_model_len
=
None
,
hf_overrides
=
hf_overrides
):
pass
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
def
test_use_rope_scaling_legal
(
model_info
,
vllm_runner
):
hf_overrides
=
{
"rope_theta"
:
rope_theta
,
"rope_scaling"
:
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
,
"original_max_position_embeddings"
:
original_max_position_embeddings
},
"max_model_len"
:
max_model_len
}
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
max_model_len
=
None
,
hf_overrides
=
hf_overrides
):
pass
@
pytest
.
mark
.
parametrize
(
"model_info"
,
MODELS
)
def
test_use_rope_scaling_illegal
(
model_info
,
vllm_runner
):
hf_overrides
=
{
"rope_theta"
:
rope_theta
,
"rope_scaling"
:
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
,
"original_max_position_embeddings"
:
original_max_position_embeddings
}
}
# illegal max_model_len
with
pytest
.
raises
(
ValueError
):
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
max_model_len
=
max_model_len
+
1
,
hf_overrides
=
hf_overrides
):
pass
hf_overrides
=
{
"rope_theta"
:
rope_theta
,
"rope_scaling"
:
{
"rope_type"
:
"yarn"
,
"factor"
:
factor
,
"original_max_position_embeddings"
:
original_max_position_embeddings
},
"max_model_len"
:
max_model_len
+
1
}
# illegal max_model_len by hf_overrides
with
pytest
.
raises
(
ValueError
):
with
vllm_runner
(
model_info
.
name
,
task
=
"embed"
,
max_model_len
=
None
,
hf_overrides
=
hf_overrides
):
pass
vllm/config.py
View file @
3e9ce609
...
@@ -571,6 +571,7 @@ class ModelConfig:
...
@@ -571,6 +571,7 @@ class ModelConfig:
sliding_window
=
None
sliding_window
=
None
self
.
original_max_model_len
=
self
.
max_model_len
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
max_model_len
=
_get_and_verify_max_len
(
hf_config
=
self
.
hf_text_config
,
hf_config
=
self
.
hf_text_config
,
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
...
@@ -4471,6 +4472,19 @@ class VllmConfig:
...
@@ -4471,6 +4472,19 @@ class VllmConfig:
self
.
compilation_config
.
init_with_cudagraph_sizes
(
self
.
compilation_config
.
init_with_cudagraph_sizes
(
batch_size_capture_list
)
batch_size_capture_list
)
def
recalculate_max_model_len
(
self
,
max_model_len
:
int
):
model_config
=
self
.
model_config
max_model_len
=
_get_and_verify_max_len
(
hf_config
=
model_config
.
hf_text_config
,
max_model_len
=
max_model_len
,
disable_sliding_window
=
model_config
.
disable_sliding_window
,
sliding_window_len
=
model_config
.
get_hf_config_sliding_window
(),
spec_target_max_model_len
=
model_config
.
spec_target_max_model_len
,
encoder_config
=
model_config
.
encoder_config
)
self
.
model_config
.
max_model_len
=
max_model_len
self
.
scheduler_config
.
max_model_len
=
max_model_len
self
.
compute_hash
()
def
__str__
(
self
):
def
__str__
(
self
):
return
(
return
(
f
"model=
{
self
.
model_config
.
model
!
r
}
,"
f
"model=
{
self
.
model_config
.
model
!
r
}
,"
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
3e9ce609
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
copy
import
deepcopy
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
...
@@ -10,6 +11,7 @@ from vllm.attention import Attention, AttentionType
...
@@ -10,6 +11,7 @@ from vllm.attention import Attention, AttentionType
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
(
get_act_and_mul_fn
,
from
vllm.model_executor.layers.activation
import
(
get_act_and_mul_fn
,
get_act_fn
)
get_act_fn
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -27,6 +29,8 @@ from vllm.model_executor.models.interfaces import SupportsQuant
...
@@ -27,6 +29,8 @@ from vllm.model_executor.models.interfaces import SupportsQuant
from
vllm.model_executor.models.utils
import
WeightsMapper
from
vllm.model_executor.models.utils
import
WeightsMapper
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
logger
=
init_logger
(
__name__
)
class
BertWithRopeEmbedding
(
nn
.
Module
):
class
BertWithRopeEmbedding
(
nn
.
Module
):
...
@@ -513,10 +517,11 @@ class NomicBertModel(BertWithRope):
...
@@ -513,10 +517,11 @@ class NomicBertModel(BertWithRope):
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
head_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
rotary_emb_dim
=
head_dim
*
config
.
rotary_emb_fraction
rotary_emb_dim
=
head_dim
*
config
.
rotary_emb_fraction
max_trained_positions
=
getattr
(
config
,
"max_trained_positions"
,
2048
)
config
.
rotary_kwargs
=
{
config
.
rotary_kwargs
=
{
"head_size"
:
head_dim
,
"head_size"
:
head_dim
,
"rotary_dim"
:
rotary_emb_dim
,
"rotary_dim"
:
rotary_emb_dim
,
"max_position"
:
config
.
max_trained_positions
,
"max_position"
:
max_trained_positions
,
"base"
:
getattr
(
config
,
"rope_theta"
,
config
.
rotary_emb_base
),
"base"
:
getattr
(
config
,
"rope_theta"
,
config
.
rotary_emb_base
),
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
)
"rope_scaling"
:
getattr
(
config
,
"rope_scaling"
,
None
)
}
}
...
@@ -525,8 +530,52 @@ class NomicBertModel(BertWithRope):
...
@@ -525,8 +530,52 @@ class NomicBertModel(BertWithRope):
# than max_trained_positions 2048, the results are consistent
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_scaling.
# The context extension uses vllm style rope_theta and rope_scaling.
# See #17785
# See #17785 #18755
if
(
not
vllm_config
.
model_config
.
hf_overrides
and
vllm_config
.
model_config
.
original_max_model_len
is
None
):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before
=
vllm_config
.
model_config
.
max_model_len
max_model_len
=
min
(
vllm_config
.
model_config
.
max_model_len
,
max_trained_positions
)
vllm_config
.
recalculate_max_model_len
(
max_model_len
)
logger
.
warning
(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html"
,
max_model_len_before
,
vllm_config
.
model_config
.
max_model_len
)
else
:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
model_config
=
vllm_config
.
model_config
hf_text_config
=
model_config
.
hf_text_config
if
isinstance
(
model_config
.
hf_overrides
,
dict
):
# hf_overrides_kw
max_model_len
=
model_config
.
hf_overrides
.
get
(
"max_model_len"
,
vllm_config
.
model_config
.
max_model_len
)
else
:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
max_model_len
=
vllm_config
.
model_config
.
max_model_len
# reset hf_text_config for recalculate_max_model_len.
if
hasattr
(
hf_text_config
,
"max_model_len"
):
delattr
(
hf_text_config
,
"max_model_len"
)
hf_text_config
.
max_position_embeddings
=
max_trained_positions
hf_text_config
.
rope_scaling
=
config
.
rotary_kwargs
[
"rope_scaling"
]
# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
encoder_config
=
deepcopy
(
model_config
.
encoder_config
)
encoder_config
.
pop
(
"max_seq_length"
,
None
)
model_config
.
encoder_config
=
encoder_config
vllm_config
.
recalculate_max_model_len
(
max_model_len
)
return
config
return
config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment