Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9b73a2f4
Unverified
Commit
9b73a2f4
authored
Aug 21, 2024
by
Nick Hill
Committed by
GitHub
Aug 22, 2024
Browse files
[Spec Decoding] Use target model max length as default for draft model (#7706)
parent
6925cdbe
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
1 deletion
+10
-1
vllm/config.py
vllm/config.py
+10
-1
No files found.
vllm/config.py
View file @
9b73a2f4
...
...
@@ -127,6 +127,7 @@ class ModelConfig:
rope_theta
:
Optional
[
float
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
max_model_len
:
Optional
[
int
]
=
None
,
spec_target_max_model_len
:
Optional
[
int
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
quantization_param_path
:
Optional
[
str
]
=
None
,
enforce_eager
:
Optional
[
bool
]
=
None
,
...
...
@@ -210,7 +211,8 @@ class ModelConfig:
hf_config
=
self
.
hf_text_config
,
max_model_len
=
max_model_len
,
disable_sliding_window
=
self
.
disable_sliding_window
,
sliding_window_len
=
self
.
get_hf_config_sliding_window
())
sliding_window_len
=
self
.
get_hf_config_sliding_window
(),
spec_target_max_model_len
=
spec_target_max_model_len
)
self
.
served_model_name
=
get_served_model_name
(
model
,
served_model_name
)
self
.
multimodal_config
=
self
.
_init_multimodal_config
(
...
...
@@ -1134,6 +1136,7 @@ class SpeculativeConfig:
code_revision
=
draft_code_revision
,
tokenizer_revision
=
target_model_config
.
tokenizer_revision
,
max_model_len
=
None
,
spec_target_max_model_len
=
target_model_config
.
max_model_len
,
quantization
=
draft_quantization
,
enforce_eager
=
target_model_config
.
enforce_eager
,
max_seq_len_to_capture
=
target_model_config
.
...
...
@@ -1563,6 +1566,7 @@ def _get_and_verify_max_len(
max_model_len
:
Optional
[
int
],
disable_sliding_window
:
bool
,
sliding_window_len
:
Optional
[
int
],
spec_target_max_model_len
:
Optional
[
int
]
=
None
,
)
->
int
:
"""Get and verify the model's maximum length."""
derived_max_model_len
=
float
(
"inf"
)
...
...
@@ -1605,6 +1609,11 @@ def _get_and_verify_max_len(
# If max_model_len is specified, we use it.
return
max_model_len
if
spec_target_max_model_len
is
not
None
:
# If this is a speculative draft model, we use the max model len
# from the target model.
return
spec_target_max_model_len
default_max_len
=
2048
logger
.
warning
(
"The model's config.json does not contain any of the following "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment