Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4aba6e3d
Unverified
Commit
4aba6e3d
authored
Nov 22, 2024
by
youkaichao
Committed by
GitHub
Nov 22, 2024
Browse files
[core] gemma2 full context length support (#10584)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
978b3974
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
24 deletions
+55
-24
tests/basic_correctness/test_basic_correctness.py
tests/basic_correctness/test_basic_correctness.py
+18
-7
vllm/attention/layer.py
vllm/attention/layer.py
+10
-2
vllm/config.py
vllm/config.py
+20
-9
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+7
-6
No files found.
tests/basic_correctness/test_basic_correctness.py
View file @
4aba6e3d
...
...
@@ -14,11 +14,12 @@ from vllm import LLM
from
vllm.platforms
import
current_platform
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
from
..conftest
import
VllmRunner
from
..models.utils
import
check_outputs_equal
from
..utils
import
multi_gpu_test
MODELS
=
[
"
facebook/opt-125m
"
,
"
google/gemma-2-2b-it
"
,
"meta-llama/Llama-3.2-1B"
,
]
...
...
@@ -42,8 +43,6 @@ def test_vllm_gc_ed():
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
,
True
])
def
test_models
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
backend
:
str
,
dtype
:
str
,
...
...
@@ -54,15 +53,27 @@ def test_models(
if
backend
==
"FLASHINFER"
and
current_platform
.
is_rocm
():
pytest
.
skip
(
"Flashinfer does not support ROCm/HIP."
)
if
backend
==
"XFORMERS"
and
model
==
"google/gemma-2-2b-it"
:
pytest
.
skip
(
"XFORMERS does not support gemma2 with full context length."
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
# 5042 tokens for gemma2
# gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window
prompt
=
"The following numbers of the sequence "
+
", "
.
join
(
str
(
i
)
for
i
in
range
(
1024
))
+
" are:"
example_prompts
=
[
prompt
]
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
with
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
,
gpu_memory_utilization
=
0.7
)
as
vllm_model
:
with
VllmRunner
(
model
,
max_model_len
=
8192
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
,
gpu_memory_utilization
=
0.7
)
as
vllm_model
:
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
check_outputs_equal
(
...
...
vllm/attention/layer.py
View file @
4aba6e3d
...
...
@@ -40,18 +40,26 @@ class Attention(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
per_layer_sliding_window
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
if
per_layer_sliding_window
is
not
None
:
# per-layer sliding window
sliding_window
=
per_layer_sliding_window
elif
cache_config
is
not
None
:
# model-level sliding window
sliding_window
=
cache_config
.
sliding_window
else
:
sliding_window
=
None
if
cache_config
is
not
None
:
kv_cache_dtype
=
cache_config
.
cache_dtype
block_size
=
cache_config
.
block_size
sliding_window
=
cache_config
.
sliding_window
is_attention_free
=
cache_config
.
is_attention_free
else
:
kv_cache_dtype
=
"auto"
block_size
=
16
sliding_window
=
None
is_attention_free
=
False
if
num_kv_heads
is
None
:
num_kv_heads
=
num_heads
...
...
vllm/config.py
View file @
4aba6e3d
...
...
@@ -233,15 +233,26 @@ class ModelConfig:
(
self
.
hf_text_config
.
model_type
in
[
"gemma2"
]))
if
(
not
self
.
disable_sliding_window
and
has_interleaved_attention
):
sliding_window_len_min
=
get_min_sliding_window
(
self
.
hf_text_config
.
sliding_window
)
print_warning_once
(
f
"
{
self
.
hf_text_config
.
model_type
}
has interleaved attention, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f
"(
{
sliding_window_len_min
}
)."
)
self
.
disable_sliding_window
=
True
if
envs
.
VLLM_ATTENTION_BACKEND
==
"XFORMERS"
:
sliding_window_len_min
=
get_min_sliding_window
(
self
.
hf_text_config
.
sliding_window
)
print_warning_once
(
f
"
{
self
.
hf_text_config
.
model_type
}
has interleaved "
"attention, which is currently not supported by the "
"XFORMERS backend. Disabling sliding window and capping "
"the max length to the sliding window size "
f
"(
{
sliding_window_len_min
}
)."
)
self
.
disable_sliding_window
=
True
else
:
# for a model with interleaved attention,
# the scheduler and the model treat it as full attention
# (i.e., not dropping any tokens outside the window).
# only the attention layer itself is aware of the sliding
# window, and use the window size to compute the attention.
self
.
hf_text_config
.
interleaved_sliding_window
=
sliding_window
delattr
(
self
.
hf_text_config
,
"sliding_window"
)
sliding_window
=
None
self
.
max_model_len
=
_get_and_verify_max_len
(
hf_config
=
self
.
hf_text_config
,
...
...
vllm/model_executor/models/gemma2.py
View file @
4aba6e3d
...
...
@@ -143,12 +143,12 @@ class Gemma2Attention(nn.Module):
is_neox_style
=
True
,
)
#
FIXME(woosuk): While Gemma 2 uses sliding window attention for every
#
odd layer, vLLM currently ignores it and uses global attention for
# all layers.
use_sliding_window
=
(
layer_idx
%
2
==
1
and
config
.
sliding_window
i
s
not
None
)
del
use_sliding_window
# Unused.
#
reference:
#
https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
use_sliding_window
=
(
layer_idx
%
2
==
0
and
config
.
interleaved_sliding_window
is
not
None
)
sliding_window
=
config
.
interleaved_
sliding_window
i
f
\
use_sliding_window
else
None
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
...
...
@@ -156,6 +156,7 @@ class Gemma2Attention(nn.Module):
cache_config
=
cache_config
,
quant_config
=
quant_config
,
logits_soft_cap
=
attn_logits_soft_cap
,
per_layer_sliding_window
=
sliding_window
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment