Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c1208d0
Unverified
Commit
6c1208d0
authored
Nov 20, 2024
by
Pavani Majety
Committed by
GitHub
Nov 20, 2024
Browse files
[Core] Add Sliding Window Support with Flashinfer (#10462)
Signed-off-by:
Pavani Majety
<
pmajety@nvidia.com
>
parent
388ee3de
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
7 deletions
+18
-7
tests/core/block/e2e/test_correctness_sliding_window.py
tests/core/block/e2e/test_correctness_sliding_window.py
+10
-2
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+8
-5
No files found.
tests/core/block/e2e/test_correctness_sliding_window.py
View file @
6c1208d0
...
@@ -3,6 +3,7 @@ from typing import List
...
@@ -3,6 +3,7 @@ from typing import List
import
pytest
import
pytest
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
.conftest
import
get_text_from_llm_generator
from
.conftest
import
get_text_from_llm_generator
...
@@ -28,8 +29,9 @@ BLOCK_SIZE = 16
...
@@ -28,8 +29,9 @@ BLOCK_SIZE = 16
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"XFORMERS"
])
def
test_sliding_window_retrival
(
baseline_llm_generator
,
test_llm_generator
,
def
test_sliding_window_retrival
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
seed
):
batch_size
,
seed
,
backend
,
monkeypatch
):
"""
"""
The test does a bunch of assignments "x1 = 10
\n
x2 = 33
\n
..." and then
The test does a bunch of assignments "x1 = 10
\n
x2 = 33
\n
..." and then
asks for value of one of them (which is outside the sliding window).
asks for value of one of them (which is outside the sliding window).
...
@@ -38,6 +40,8 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
...
@@ -38,6 +40,8 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
Additionally, we compare the results of the v1 and v2 managers.
Additionally, we compare the results of the v1 and v2 managers.
"""
"""
override_backend_env_variable
(
monkeypatch
,
backend
)
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
max_tokens
=
1024
,
max_tokens
=
1024
,
ignore_eos
=
True
,
ignore_eos
=
True
,
...
@@ -84,7 +88,9 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
...
@@ -84,7 +88,9 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator,
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"enable_chunked_prefill"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"enable_chunked_prefill"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_sliding_window_chunked_prefill
(
test_llm_generator
,
batch_size
,
seed
):
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"XFORMERS"
])
def
test_sliding_window_chunked_prefill
(
test_llm_generator
,
batch_size
,
seed
,
backend
,
monkeypatch
):
"""
"""
This is similar to test_sliding_window_retrival, however, it doesn't
This is similar to test_sliding_window_retrival, however, it doesn't
compare against the v1 block manager since v1 doesn't support
compare against the v1 block manager since v1 doesn't support
...
@@ -93,6 +99,8 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
...
@@ -93,6 +99,8 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed):
The results with and without chunked prefill are not the same due to
The results with and without chunked prefill are not the same due to
numerical instabilities.
numerical instabilities.
"""
"""
override_backend_env_variable
(
monkeypatch
,
backend
)
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
max_tokens
=
10
,
max_tokens
=
10
,
ignore_eos
=
True
,
ignore_eos
=
True
,
...
...
vllm/attention/backends/flashinfer.py
View file @
6c1208d0
...
@@ -757,9 +757,8 @@ class FlashInferImpl(AttentionImpl):
...
@@ -757,9 +757,8 @@ class FlashInferImpl(AttentionImpl):
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
alibi_slopes
=
alibi_slopes
if
sliding_window
is
not
None
:
self
.
sliding_window
=
((
sliding_window
-
1
,
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
logits_soft_cap
=
logits_soft_cap
self
.
logits_soft_cap
=
logits_soft_cap
...
@@ -865,6 +864,8 @@ def unified_flash_infer(
...
@@ -865,6 +864,8 @@ def unified_flash_infer(
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
query
.
shape
[
0
]
==
num_prefill_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
assert
decode_query
.
shape
[
0
]
==
num_decode_tokens
window_left
=
window_size
[
0
]
if
window_size
is
not
None
else
-
1
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
prefill_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
decode_output
:
Optional
[
torch
.
Tensor
]
=
None
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
...
@@ -895,7 +896,8 @@ def unified_flash_infer(
...
@@ -895,7 +896,8 @@ def unified_flash_infer(
logits_soft_cap
=
logits_soft_cap
,
logits_soft_cap
=
logits_soft_cap
,
causal
=
True
,
causal
=
True
,
k_scale
=
k_scale
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
v_scale
=
v_scale
,
window_left
=
window_left
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
...
@@ -905,7 +907,8 @@ def unified_flash_infer(
...
@@ -905,7 +907,8 @@ def unified_flash_infer(
sm_scale
=
softmax_scale
,
sm_scale
=
softmax_scale
,
logits_soft_cap
=
logits_soft_cap
,
logits_soft_cap
=
logits_soft_cap
,
k_scale
=
k_scale
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
v_scale
=
v_scale
,
window_left
=
window_left
)
if
prefill_output
is
None
and
decode_output
is
not
None
:
if
prefill_output
is
None
and
decode_output
is
not
None
:
# Decode only batch.
# Decode only batch.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment