Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c1208d0
Unverified
Commit
6c1208d0
authored
Nov 20, 2024
by
Pavani Majety
Committed by
GitHub
Nov 20, 2024
Browse files
[Core] Add Sliding Window Support with Flashinfer (#10462)
Signed-off-by:
Pavani Majety
<
pmajety@nvidia.com
>
parent
388ee3de
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
7 deletions
+18
-7
tests/core/block/e2e/test_correctness_sliding_window.py
tests/core/block/e2e/test_correctness_sliding_window.py
+10
-2
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+8
-5
No files found.
tests/core/block/e2e/test_correctness_sliding_window.py
View file @
6c1208d0
...
...
@@ -3,6 +3,7 @@ from typing import List
import
pytest
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm
import
LLM
,
SamplingParams
from
.conftest
import
get_text_from_llm_generator
...
...
@@ -28,8 +29,9 @@ BLOCK_SIZE = 16
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"XFORMERS"
])
def
test_sliding_window_retrival
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
seed
):
batch_size
,
seed
,
backend
,
monkeypatch
):
"""
The test does a bunch of assignments "x1 = 10
\n
x2 = 33
\n
..." and then
asks for value of one of them (which is outside the sliding window).
...
...
@@ -38,6 +40,8 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
Additionally, we compare the results of the v1 and v2 managers.
"""
override_backend_env_variable
(
monkeypatch
,
backend
)
sampling_params
=
SamplingParams
(
max_tokens
=
1024
,
ignore_eos
=
True
,
...
...
@@ -84,7 +88,9 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"enable_chunked_prefill"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_sliding_window_chunked_prefill
(
test_llm_generator
,
batch_size
,
seed
):
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"XFORMERS"
])
def
test_sliding_window_chunked_prefill
(
test_llm_generator
,
batch_size
,
seed
,
backend
,
monkeypatch
):
"""
This is similar to test_sliding_window_retrival, however, it doesn't
compare against the v1 block manager since v1 doesn't support
...
...
@@ -93,6 +99,8 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
The results with and without chunked prefill are not the same due to
numerical instabilities.
"""
override_backend_env_variable
(
monkeypatch
,
backend
)
sampling_params
=
SamplingParams
(
max_tokens
=
10
,
ignore_eos
=
True
,
...
...
vllm/attention/backends/flashinfer.py
View file @
6c1208d0
...
...
@@ -757,9 +757,8 @@ class FlashInferImpl(AttentionImpl):
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
if
sliding_window
is
not
None
:
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
sliding_window
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
logits_soft_cap
=
logits_soft_cap
...
...
@@ -865,6 +864,8 @@ def unified_flash_infer(
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
window_left
=
window_size
[
0
]
if
window_size
is
not
None
else
-
1
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
...
...
@@ -895,7 +896,8 @@ def unified_flash_infer(
logits_soft_cap
=
logits_soft_cap
,
causal
=
True
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
v_scale
=
v_scale
,
window_left
=
window_left
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
...
...
@@ -905,7 +907,8 @@ def unified_flash_infer(
sm_scale
=
softmax_scale
,
logits_soft_cap
=
logits_soft_cap
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
v_scale
=
v_scale
,
window_left
=
window_left
)
if
prefill_output
is
None
and
decode_output
is
not
None
:
# Decode only batch.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment