Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
415f76a9
"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "3dcb3e8b9838cbbef83ce326b1a35b31a3cf14f2"
Unverified
Commit
415f76a9
authored
Oct 16, 2024
by
Patrick von Platen
Committed by
GitHub
Oct 16, 2024
Browse files
Support mistral interleaved attn (#9414)
parent
cf1d62a6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
10 deletions
+28
-10
vllm/config.py
vllm/config.py
+28
-10
No files found.
vllm/config.py
View file @
415f76a9
...
@@ -173,14 +173,20 @@ class ModelConfig:
...
@@ -173,14 +173,20 @@ class ModelConfig:
if
self
.
enforce_eager
is
None
:
if
self
.
enforce_eager
is
None
:
self
.
enforce_eager
=
False
self
.
enforce_eager
=
False
if
(
not
self
.
disable_sliding_window
sliding_window
=
getattr
(
self
.
hf_text_config
,
"sliding_window"
,
None
)
and
self
.
hf_text_config
.
model_type
==
"gemma2"
has_interleaved_attention
=
(
sliding_window
is
not
None
)
and
(
and
self
.
hf_text_config
.
sliding_window
is
not
None
):
isinstance
(
sliding_window
,
list
)
or
(
self
.
hf_text_config
.
model_type
in
[
"gemma2"
]))
if
(
not
self
.
disable_sliding_window
and
has_interleaved_attention
):
sliding_window_len_min
=
get_min_sliding_window
(
self
.
hf_text_config
.
sliding_window
)
print_warning_once
(
print_warning_once
(
"Gemma 2 uses sliding window attention for every odd layer
, "
f
"
{
self
.
hf_text_config
.
model_type
}
has interleaved attention
, "
"which is currently not supported by vLLM. Disabling sliding "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
"window and capping the max length to the sliding window size "
f
"(
{
self
.
hf_text_config
.
sliding_window
}
)."
)
f
"(
{
sliding_window
_len_min
}
)."
)
self
.
disable_sliding_window
=
True
self
.
disable_sliding_window
=
True
self
.
max_model_len
=
_get_and_verify_max_len
(
self
.
max_model_len
=
_get_and_verify_max_len
(
...
@@ -431,7 +437,8 @@ class ModelConfig:
...
@@ -431,7 +437,8 @@ class ModelConfig:
"pipeline parallelism currently. Disabling it."
)
"pipeline parallelism currently. Disabling it."
)
self
.
use_async_output_proc
=
False
self
.
use_async_output_proc
=
False
def
get_hf_config_sliding_window
(
self
)
->
Optional
[
int
]:
def
get_hf_config_sliding_window
(
self
)
->
Union
[
Optional
[
int
],
List
[
Optional
[
int
]]]:
"""Get the sliding window size, or None if disabled."""
"""Get the sliding window size, or None if disabled."""
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
...
@@ -442,7 +449,7 @@ class ModelConfig:
...
@@ -442,7 +449,7 @@ class ModelConfig:
return
None
return
None
return
getattr
(
self
.
hf_text_config
,
"sliding_window"
,
None
)
return
getattr
(
self
.
hf_text_config
,
"sliding_window"
,
None
)
def
get_sliding_window
(
self
)
->
Optional
[
int
]:
def
get_sliding_window
(
self
)
->
Optional
[
Union
[
int
,
List
[
Optional
[
int
]]]
]:
"""Get the sliding window size, or None if disabled.
"""Get the sliding window size, or None if disabled.
"""
"""
# If user disables sliding window, return None.
# If user disables sliding window, return None.
...
@@ -1689,7 +1696,7 @@ def _get_and_verify_max_len(
...
@@ -1689,7 +1696,7 @@ def _get_and_verify_max_len(
hf_config
:
PretrainedConfig
,
hf_config
:
PretrainedConfig
,
max_model_len
:
Optional
[
int
],
max_model_len
:
Optional
[
int
],
disable_sliding_window
:
bool
,
disable_sliding_window
:
bool
,
sliding_window_len
:
Optional
[
int
],
sliding_window_len
:
Optional
[
Union
[
int
,
List
[
Optional
[
int
]]]
],
spec_target_max_model_len
:
Optional
[
int
]
=
None
,
spec_target_max_model_len
:
Optional
[
int
]
=
None
,
)
->
int
:
)
->
int
:
"""Get and verify the model's maximum length."""
"""Get and verify the model's maximum length."""
...
@@ -1722,9 +1729,12 @@ def _get_and_verify_max_len(
...
@@ -1722,9 +1729,12 @@ def _get_and_verify_max_len(
# If sliding window is manually disabled, max_length should be less
# If sliding window is manually disabled, max_length should be less
# than the sliding window length in the model config.
# than the sliding window length in the model config.
if
disable_sliding_window
and
sliding_window_len
is
not
None
:
if
disable_sliding_window
and
sliding_window_len
is
not
None
:
sliding_window_len_min
=
get_min_sliding_window
(
sliding_window_len
)
max_len_key
=
"sliding_window"
\
max_len_key
=
"sliding_window"
\
if
sliding_window_len
<
derived_max_model_len
else
max_len_key
if
sliding_window_len_min
<
derived_max_model_len
else
max_len_key
derived_max_model_len
=
min
(
derived_max_model_len
,
sliding_window_len
)
derived_max_model_len
=
min
(
derived_max_model_len
,
sliding_window_len_min
)
# If none of the keys were found in the config, use a default and
# If none of the keys were found in the config, use a default and
# log a warning.
# log a warning.
...
@@ -1805,6 +1815,14 @@ def _get_and_verify_max_len(
...
@@ -1805,6 +1815,14 @@ def _get_and_verify_max_len(
return
int
(
max_model_len
)
return
int
(
max_model_len
)
def
get_min_sliding_window
(
sliding_window
:
Union
[
int
,
List
[
Optional
[
int
]]])
->
int
:
if
isinstance
(
sliding_window
,
list
):
return
min
(
s
for
s
in
sliding_window
if
s
is
not
None
)
return
sliding_window
def
get_served_model_name
(
model
:
str
,
def
get_served_model_name
(
model
:
str
,
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]):
served_model_name
:
Optional
[
Union
[
str
,
List
[
str
]]]):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment