Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4fa3e333
Unverified
Commit
4fa3e333
authored
Oct 20, 2024
by
Chen Zhang
Committed by
GitHub
Oct 20, 2024
Browse files
[Kernel] Support sliding window in flash attention backend (#9403)
parent
962d2c63
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
41 additions
and
61 deletions
+41
-61
tests/kernels/test_attention_selector.py
tests/kernels/test_attention_selector.py
+15
-20
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+16
-13
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+5
-8
vllm/attention/layer.py
vllm/attention/layer.py
+3
-4
vllm/attention/selector.py
vllm/attention/selector.py
+2
-8
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+0
-1
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+0
-1
vllm/worker/cpu_worker.py
vllm/worker/cpu_worker.py
+0
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+0
-1
vllm/worker/openvino_model_runner.py
vllm/worker/openvino_model_runner.py
+0
-1
vllm/worker/openvino_worker.py
vllm/worker/openvino_worker.py
+0
-1
vllm/worker/tpu_model_runner.py
vllm/worker/tpu_model_runner.py
+0
-1
vllm/worker/xpu_model_runner.py
vllm/worker/xpu_model_runner.py
+0
-1
No files found.
tests/kernels/test_attention_selector.py
View file @
4fa3e333
...
@@ -20,21 +20,21 @@ def test_env(name: str, device: str, monkeypatch):
...
@@ -20,21 +20,21 @@ def test_env(name: str, device: str, monkeypatch):
if
device
==
"cpu"
:
if
device
==
"cpu"
:
with
patch
(
"vllm.attention.selector.is_cpu"
,
return_value
=
True
):
with
patch
(
"vllm.attention.selector.is_cpu"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
16
,
False
)
False
)
assert
backend
.
name
==
"TORCH_SDPA"
assert
backend
.
name
==
"TORCH_SDPA"
elif
device
==
"hip"
:
elif
device
==
"hip"
:
with
patch
(
"vllm.attention.selector.is_hip"
,
return_value
=
True
):
with
patch
(
"vllm.attention.selector.is_hip"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
16
,
False
)
False
)
assert
backend
.
name
==
"ROCM_FLASH"
assert
backend
.
name
==
"ROCM_FLASH"
elif
device
==
"openvino"
:
elif
device
==
"openvino"
:
with
patch
(
"vllm.attention.selector.is_openvino"
,
return_value
=
True
):
with
patch
(
"vllm.attention.selector.is_openvino"
,
return_value
=
True
):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
16
,
False
)
False
)
assert
backend
.
name
==
"OPENVINO"
assert
backend
.
name
==
"OPENVINO"
else
:
else
:
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
16
,
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
False
)
False
)
assert
backend
.
name
==
name
assert
backend
.
name
==
name
...
@@ -46,37 +46,32 @@ def test_flash_attn(monkeypatch):
...
@@ -46,37 +46,32 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch
# Unsupported CUDA arch
with
patch
(
"torch.cuda.get_device_capability"
,
return_value
=
(
7
,
5
)):
with
patch
(
"torch.cuda.get_device_capability"
,
return_value
=
(
7
,
5
)):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported data type
# Unsupported data type
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float8_e4m3fn
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float8_e4m3fn
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
# Unsupported kv cache data type
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
"fp8"
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
"fp8"
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported block size
# Unsupported block size
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
None
,
8
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
8
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported sliding window
backend
=
which_attn_to_use
(
16
,
1
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# flash-attn is not installed
# flash-attn is not installed
with
patch
.
dict
(
'sys.modules'
,
{
'vllm_flash_attn'
:
None
}):
with
patch
.
dict
(
'sys.modules'
,
{
'vllm_flash_attn'
:
None
}):
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Unsupported head size
# Unsupported head size
backend
=
which_attn_to_use
(
17
,
None
,
torch
.
float16
,
None
,
16
,
False
)
backend
=
which_attn_to_use
(
17
,
torch
.
float16
,
None
,
16
,
False
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
# Attention-free models should bypass env and use PlaceholderAttention
backend
=
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
torch
.
float16
,
16
,
backend
=
which_attn_to_use
(
16
,
torch
.
float16
,
torch
.
float16
,
16
,
True
)
True
)
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
assert
backend
.
name
!=
STR_FLASH_ATTN_VAL
...
@@ -84,4 +79,4 @@ def test_invalid_env(monkeypatch):
...
@@ -84,4 +79,4 @@ def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
override_backend_env_variable
(
monkeypatch
,
STR_INVALID_VAL
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
which_attn_to_use
(
16
,
None
,
torch
.
float16
,
None
,
16
,
False
)
which_attn_to_use
(
16
,
torch
.
float16
,
None
,
16
,
False
)
tests/kernels/test_flash_attn.py
View file @
4fa3e333
...
@@ -78,6 +78,7 @@ def ref_paged_attn(
...
@@ -78,6 +78,7 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
def
test_flash_attn_with_paged_kv
(
kv_lens
:
List
[
int
],
kv_lens
:
List
[
int
],
...
@@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv(
block_size
:
int
,
block_size
:
int
,
soft_cap
:
Optional
[
float
],
soft_cap
:
Optional
[
float
],
num_blocks
:
int
,
num_blocks
:
int
,
sliding_window
:
Optional
[
int
],
)
->
None
:
)
->
None
:
torch
.
set_default_device
(
"cuda"
)
torch
.
set_default_device
(
"cuda"
)
seed_everything
(
0
)
seed_everything
(
0
)
...
@@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv(
...
@@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv(
assert
num_query_heads
%
num_kv_heads
==
0
assert
num_query_heads
%
num_kv_heads
==
0
max_kv_len
=
max
(
kv_lens
)
max_kv_len
=
max
(
kv_lens
)
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
window_size
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
query
=
torch
.
randn
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
)
key_cache
=
torch
.
randn
(
num_blocks
,
key_cache
=
torch
.
randn
(
num_blocks
,
...
@@ -121,10 +125,10 @@ def test_flash_attn_with_paged_kv(
...
@@ -121,10 +125,10 @@ def test_flash_attn_with_paged_kv(
block_table
=
block_tables
,
block_table
=
block_tables
,
cache_seqlens
=
kv_lens_tensor
,
cache_seqlens
=
kv_lens_tensor
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
softcap
=
soft_cap
if
soft_cap
is
not
None
else
0
,
window_size
=
window_size
,
).
squeeze
(
1
)
).
squeeze
(
1
)
ref_output
=
ref_paged_attn
(
ref_output
=
ref_paged_attn
(
query
=
query
,
query
=
query
,
key_cache
=
key_cache
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
value_cache
=
value_cache
,
query_lens
=
[
1
]
*
num_seqs
,
query_lens
=
[
1
]
*
num_seqs
,
...
@@ -132,7 +136,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -132,7 +136,7 @@ def test_flash_attn_with_paged_kv(
block_tables
=
block_tables
,
block_tables
=
block_tables
,
scale
=
scale
,
scale
=
scale
,
soft_cap
=
soft_cap
,
soft_cap
=
soft_cap
,
)
sliding_window
=
sliding_window
)
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
@@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
10.0
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
...
@@ -166,8 +170,7 @@ def test_varlen_with_paged_kv(
...
@@ -166,8 +170,7 @@ def test_varlen_with_paged_kv(
assert
num_query_heads
%
num_kv_heads
==
0
assert
num_query_heads
%
num_kv_heads
==
0
max_query_len
=
max
(
query_lens
)
max_query_len
=
max
(
query_lens
)
max_kv_len
=
max
(
kv_lens
)
max_kv_len
=
max
(
kv_lens
)
window_size
=
((
sliding_window
,
window_size
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
sliding_window
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
(
-
1
,
-
1
))
scale
=
head_size
**-
0.5
scale
=
head_size
**-
0.5
...
...
vllm/attention/backends/flash_attn.py
View file @
4fa3e333
...
@@ -524,8 +524,8 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -524,8 +524,8 @@ class FlashAttentionImpl(AttentionImpl):
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
self
.
alibi_slopes
=
alibi_slopes
self
.
sliding_window
=
((
sliding_window
,
sliding_window
)
self
.
sliding_window
=
((
sliding_window
-
1
,
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
if
logits_soft_cap
is
None
:
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
...
@@ -535,12 +535,6 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -535,12 +535,6 @@ class FlashAttentionImpl(AttentionImpl):
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
sliding_window
is
not
None
:
# NOTE(woosuk): flash-attn's sliding window does not work with
# paged KV cache.
raise
ValueError
(
"Sliding window is not supported in FlashAttention."
)
support_head_sizes
=
FlashAttentionBackend
.
get_supported_head_sizes
()
support_head_sizes
=
FlashAttentionBackend
.
get_supported_head_sizes
()
if
head_size
not
in
support_head_sizes
:
if
head_size
not
in
support_head_sizes
:
raise
ValueError
(
raise
ValueError
(
...
@@ -704,6 +698,7 @@ def unified_flash_attention(
...
@@ -704,6 +698,7 @@ def unified_flash_attention(
max_seqlen_k
=
max_seq_len
,
max_seqlen_k
=
max_seq_len
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
block_table
=
prefill_meta
.
block_tables
,
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
...
@@ -725,6 +720,7 @@ def unified_flash_attention(
...
@@ -725,6 +720,7 @@ def unified_flash_attention(
max_seqlen_k
=
decode_meta
.
max_decode_seq_len
,
max_seqlen_k
=
decode_meta
.
max_decode_seq_len
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
block_table
=
decode_meta
.
block_tables
,
block_table
=
decode_meta
.
block_tables
,
...
@@ -739,6 +735,7 @@ def unified_flash_attention(
...
@@ -739,6 +735,7 @@ def unified_flash_attention(
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
cache_seqlens
=
decode_meta
.
seq_lens_tensor
,
softmax_scale
=
softmax_scale
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
causal
=
True
,
window_size
=
window_size
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
softcap
=
logits_soft_cap
,
softcap
=
logits_soft_cap
,
).
squeeze
(
1
)
).
squeeze
(
1
)
...
...
vllm/attention/layer.py
View file @
4fa3e333
...
@@ -78,10 +78,9 @@ class Attention(nn.Module):
...
@@ -78,10 +78,9 @@ class Attention(nn.Module):
# During model initialization, the default dtype is set as the model
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
# weight and activation dtype.
dtype
=
torch
.
get_default_dtype
()
dtype
=
torch
.
get_default_dtype
()
attn_backend
=
get_attn_backend
(
head_size
,
sliding_window
,
dtype
,
attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
kv_cache_dtype
,
kv_cache_dtype
,
block_size
,
block_size
,
is_attention_free
,
is_attention_free
,
blocksparse_params
blocksparse_params
is
not
None
)
is
not
None
)
impl_cls
=
attn_backend
.
get_impl_cls
()
impl_cls
=
attn_backend
.
get_impl_cls
()
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
self
.
impl
=
impl_cls
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
...
...
vllm/attention/selector.py
View file @
4fa3e333
...
@@ -90,7 +90,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
...
@@ -90,7 +90,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
@
lru_cache
(
maxsize
=
None
)
@
lru_cache
(
maxsize
=
None
)
def
get_attn_backend
(
def
get_attn_backend
(
head_size
:
int
,
head_size
:
int
,
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
block_size
:
int
,
...
@@ -105,8 +104,8 @@ def get_attn_backend(
...
@@ -105,8 +104,8 @@ def get_attn_backend(
BlocksparseFlashAttentionBackend
)
BlocksparseFlashAttentionBackend
)
return
BlocksparseFlashAttentionBackend
return
BlocksparseFlashAttentionBackend
backend
=
which_attn_to_use
(
head_size
,
sliding_window
,
dtyp
e
,
backend
=
which_attn_to_use
(
head_size
,
dtype
,
kv_cache_dtype
,
block_siz
e
,
kv_cache_dtype
,
block_size
,
is_attention_free
)
is_attention_free
)
if
backend
==
_Backend
.
FLASH_ATTN
:
if
backend
==
_Backend
.
FLASH_ATTN
:
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
FlashAttentionBackend
)
FlashAttentionBackend
)
...
@@ -155,7 +154,6 @@ def get_attn_backend(
...
@@ -155,7 +154,6 @@ def get_attn_backend(
def
which_attn_to_use
(
def
which_attn_to_use
(
head_size
:
int
,
head_size
:
int
,
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
kv_cache_dtype
:
Optional
[
str
],
kv_cache_dtype
:
Optional
[
str
],
block_size
:
int
,
block_size
:
int
,
...
@@ -243,10 +241,6 @@ def which_attn_to_use(
...
@@ -243,10 +241,6 @@ def which_attn_to_use(
"Cannot use FlashAttention-2 backend for block size not "
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16."
)
"divisible by 16."
)
selected_backend
=
_Backend
.
XFORMERS
selected_backend
=
_Backend
.
XFORMERS
elif
sliding_window
is
not
None
:
logger
.
info
(
"Cannot use FlashAttention-2 backend due to sliding window."
)
selected_backend
=
_Backend
.
XFORMERS
# FlashAttn is valid for the model, checking if the package is installed.
# FlashAttn is valid for the model, checking if the package is installed.
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
if
selected_backend
==
_Backend
.
FLASH_ATTN
:
...
...
vllm/worker/cache_engine.py
View file @
4fa3e333
...
@@ -53,7 +53,6 @@ class CacheEngine:
...
@@ -53,7 +53,6 @@ class CacheEngine:
# Get attention backend.
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
model_config
.
get_sliding_window
(),
model_config
.
dtype
,
model_config
.
dtype
,
cache_config
.
cache_dtype
,
cache_config
.
cache_dtype
,
self
.
block_size
,
self
.
block_size
,
...
...
vllm/worker/cpu_model_runner.py
View file @
4fa3e333
...
@@ -420,7 +420,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
...
@@ -420,7 +420,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
self
.
block_size
=
cache_config
.
block_size
self
.
block_size
=
cache_config
.
block_size
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
block_size
,
...
...
vllm/worker/cpu_worker.py
View file @
4fa3e333
...
@@ -57,7 +57,6 @@ class CPUCacheEngine:
...
@@ -57,7 +57,6 @@ class CPUCacheEngine:
# Get attention backend.
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
model_config
.
dtype
,
cache_config
.
cache_dtype
,
cache_config
.
cache_dtype
,
self
.
block_size
,
self
.
block_size
,
...
...
vllm/worker/model_runner.py
View file @
4fa3e333
...
@@ -1011,7 +1011,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1011,7 +1011,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
block_size
,
...
...
vllm/worker/openvino_model_runner.py
View file @
4fa3e333
...
@@ -75,7 +75,6 @@ class OpenVINOModelRunner:
...
@@ -75,7 +75,6 @@ class OpenVINOModelRunner:
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
block_size
,
...
...
vllm/worker/openvino_worker.py
View file @
4fa3e333
...
@@ -71,7 +71,6 @@ class OpenVINOCacheEngine:
...
@@ -71,7 +71,6 @@ class OpenVINOCacheEngine:
# Get attention backend.
# Get attention backend.
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
head_size
,
self
.
head_size
,
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
model_config
.
dtype
,
self
.
cache_config
.
cache_dtype
,
self
.
cache_config
.
cache_dtype
,
self
.
block_size
,
self
.
block_size
,
...
...
vllm/worker/tpu_model_runner.py
View file @
4fa3e333
...
@@ -114,7 +114,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
...
@@ -114,7 +114,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
model_config
.
dtype
,
self
.
cache_config
.
cache_dtype
,
self
.
cache_config
.
cache_dtype
,
self
.
block_size
,
self
.
block_size
,
...
...
vllm/worker/xpu_model_runner.py
View file @
4fa3e333
...
@@ -374,7 +374,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
...
@@ -374,7 +374,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
self
.
attn_backend
=
get_attn_backend
(
self
.
attn_backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
get_sliding_window
(),
self
.
model_config
.
dtype
,
self
.
model_config
.
dtype
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
self
.
block_size
,
self
.
block_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment