Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf56cf78
Unverified
Commit
cf56cf78
authored
Sep 21, 2025
by
Isotr0py
Committed by
GitHub
Sep 21, 2025
Browse files
[V1] Add sliding window support to Flex Attention backend (#24089)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
7ed82d19
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
229 additions
and
69 deletions
+229
-69
tests/v1/attention/test_attention_backends.py
tests/v1/attention/test_attention_backends.py
+157
-51
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+72
-18
No files found.
tests/v1/attention/test_attention_backends.py
View file @
cf56cf78
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for v1 attention backends without GPUModelRunner dependency."""
"""Tests for v1 attention backends without GPUModelRunner dependency."""
from
functools
import
partial
from
typing
import
Optional
,
Union
import
pytest
import
pytest
import
torch
import
torch
from
torch.nn.attention.flex_attention
import
create_block_mask
,
flex_attention
from
tests.v1.attention.utils
import
(
BatchSpec
,
_Backend
,
from
tests.v1.attention.utils
import
(
BatchSpec
,
_Backend
,
create_common_attn_metadata
,
create_common_attn_metadata
,
create_standard_kv_cache_spec
,
create_standard_kv_cache_spec
,
create_vllm_config
,
create_vllm_config
,
get_attention_backend
)
get_attention_backend
)
from
vllm.config
import
ModelConfig
from
vllm.platforms
import
current_platform
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
cdiv
,
is_torch_equal_or_newer
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
cdiv
,
is_torch_equal_or_newer
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
set_kv_cache_layout
)
set_kv_cache_layout
)
...
@@ -183,13 +188,19 @@ class MockAttentionLayer:
...
@@ -183,13 +188,19 @@ class MockAttentionLayer:
self
.
_v_scale_float
=
1.0
self
.
_v_scale_float
=
1.0
def
run_attention_backend
(
backend
:
_Backend
,
kv_cache_spec
:
FullAttentionSpec
,
def
run_attention_backend
(
layer_names
:
list
[
str
],
vllm_config
,
backend
:
_Backend
,
device
:
torch
.
device
,
kv_cache_spec
:
FullAttentionSpec
,
common_attn_metadata
:
CommonAttentionMetadata
,
layer_names
:
list
[
str
],
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
vllm_config
,
value
:
torch
.
Tensor
,
device
:
torch
.
device
,
kv_cache
:
torch
.
Tensor
)
->
torch
.
Tensor
:
common_attn_metadata
:
CommonAttentionMetadata
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""Run attention computation using the specified backend's AttentionImpl."""
"""Run attention computation using the specified backend's AttentionImpl."""
# Handle special case for FLEX_ATTENTION_SLOW
# Handle special case for FLEX_ATTENTION_SLOW
...
@@ -253,7 +264,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
...
@@ -253,7 +264,7 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
scale
=
scale
,
scale
=
scale
,
num_kv_heads
=
num_kv_heads
,
num_kv_heads
=
num_kv_heads
,
alibi_slopes
=
None
,
alibi_slopes
=
None
,
sliding_window
=
None
,
sliding_window
=
sliding_window
,
kv_cache_dtype
=
"auto"
,
kv_cache_dtype
=
"auto"
,
)
)
...
@@ -275,13 +286,16 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
...
@@ -275,13 +286,16 @@ def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
return
output
return
output
@
pytest
.
mark
.
parametrize
(
"batch_spec_name"
,
[
def
_test_backend_correctness
(
"small_decode"
,
"small_prefill"
,
"mixed_small"
,
"medium_decode"
,
batch_spec
:
BatchSpec
,
"medium_prefill"
,
"mixed_medium"
,
"large_decode"
,
"large_prefill"
,
model
:
str
,
"single_decode"
,
"single_prefill"
backend_to_test
:
list
[
Union
[
_Backend
,
str
]],
])
mask_mod
,
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"meta-llama/Meta-Llama-3-8B"
])
*
,
def
test_backend_correctness
(
batch_spec_name
:
str
,
model
:
str
):
block_size
:
int
=
16
,
atol
:
float
=
1e-2
,
rtol
:
float
=
1e-2
,
):
"""
"""
Test that all backends produce similar outputs to a reference implementation
Test that all backends produce similar outputs to a reference implementation
using torch.nn.functional.scaled_dot_product_attention.
using torch.nn.functional.scaled_dot_product_attention.
...
@@ -297,9 +311,10 @@ def test_backend_correctness(batch_spec_name: str, model: str):
...
@@ -297,9 +311,10 @@ def test_backend_correctness(batch_spec_name: str, model: str):
simulated paged KV cache.
simulated paged KV cache.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
"""
"""
batch_spec
=
BATCH_SPECS
[
batch_spec_name
]
current_platform
.
seed_everything
(
42
)
vllm_config
=
create_vllm_config
(
model_name
=
model
,
vllm_config
=
create_vllm_config
(
model_name
=
model
,
max_model_len
=
max
(
batch_spec
.
seq_lens
),
max_model_len
=
max
(
batch_spec
.
seq_lens
),
block_size
=
block_size
,
num_gpu_blocks
=
8192
)
num_gpu_blocks
=
8192
)
device
=
torch
.
device
(
"cuda:0"
)
device
=
torch
.
device
(
"cuda:0"
)
...
@@ -314,6 +329,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
...
@@ -314,6 +329,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
num_kv_heads
=
vllm_config
.
model_config
.
get_num_kv_heads
(
num_kv_heads
=
vllm_config
.
model_config
.
get_num_kv_heads
(
vllm_config
.
parallel_config
)
vllm_config
.
parallel_config
)
head_size
=
vllm_config
.
model_config
.
get_head_size
()
head_size
=
vllm_config
.
model_config
.
get_head_size
()
sliding_window
=
vllm_config
.
model_config
.
get_sliding_window
()
dtype
=
_convert_dtype_to_torch
(
vllm_config
.
model_config
.
dtype
)
dtype
=
_convert_dtype_to_torch
(
vllm_config
.
model_config
.
dtype
)
block_size
=
vllm_config
.
cache_config
.
block_size
block_size
=
vllm_config
.
cache_config
.
block_size
scale
=
1.0
/
(
head_size
**
0.5
)
scale
=
1.0
/
(
head_size
**
0.5
)
...
@@ -361,22 +377,21 @@ def test_backend_correctness(batch_spec_name: str, model: str):
...
@@ -361,22 +377,21 @@ def test_backend_correctness(batch_spec_name: str, model: str):
# Create causal mask: query token i attends to positions 0 to
# Create causal mask: query token i attends to positions 0 to
# (context_len + i)
# (context_len + i)
kv_len
=
s_len
kv_len
=
s_len
offset
=
context_len
attn_mask
=
torch
.
full
((
q_len
,
kv_len
),
final_mask_mod
=
partial
(
mask_mod
,
context_len
=
context_len
)
float
(
'-inf'
),
block_mask
=
create_block_mask
(
final_mask_mod
,
device
=
device
,
B
=
None
,
dtype
=
dtype
)
H
=
None
,
for
i
in
range
(
q_len
):
Q_LEN
=
q_len
,
attn_mask
[
i
,
:
offset
+
i
+
1
]
=
0.0
KV_LEN
=
kv_len
,
device
=
device
)
sdpa_out_i
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
sdpa_out_i
=
flex_attention
(
q_sdpa_in
,
q_sdpa_in
,
k_sdpa_in
,
k_sdpa_in
,
v_sdpa_in
,
v_sdpa_in
,
block_mask
=
block_mask
,
attn_mask
=
attn_mask
,
scale
=
scale
,
scale
=
scale
,
enable_gqa
=
True
)
enable_gqa
=
True
)
# Convert back to (L, H, D)
all_sdpa_outputs
.
append
(
sdpa_out_i
.
transpose
(
1
,
2
).
squeeze
(
0
))
all_sdpa_outputs
.
append
(
sdpa_out_i
.
transpose
(
1
,
2
).
squeeze
(
0
))
# Inputs for vLLM backends are just the new tokens
# Inputs for vLLM backends are just the new tokens
...
@@ -412,7 +427,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
...
@@ -412,7 +427,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
# 4. Run vLLM backends and compare
# 4. Run vLLM backends and compare
# Note: flex_attention has known Triton kernel compatibility issues
# Note: flex_attention has known Triton kernel compatibility issues
# with test infrastructures
# with test infrastructures
for
backend_name
in
BACKENDS_TO_TEST
:
for
backend_name
in
backend_to_test
:
# FlashAttentionm + FlexAttention:
# FlashAttentionm + FlexAttention:
# [2, num_blocks, block_size, num_kv_heads, head_size]
# [2, num_blocks, block_size, num_kv_heads, head_size]
# FlashInfer:
# FlashInfer:
...
@@ -427,12 +442,19 @@ def test_backend_correctness(batch_spec_name: str, model: str):
...
@@ -427,12 +442,19 @@ def test_backend_correctness(batch_spec_name: str, model: str):
2
,
3
).
contiguous
().
transpose
(
2
,
3
)
2
,
3
).
contiguous
().
transpose
(
2
,
3
)
set_kv_cache_layout
(
"HND"
)
set_kv_cache_layout
(
"HND"
)
backend_output
=
run_attention_backend
(
backend_name
,
kv_cache_spec
,
backend_output
=
run_attention_backend
(
[
"placeholder"
],
vllm_config
,
backend_name
,
device
,
common_attn_metadata
,
kv_cache_spec
,
query_vllm
,
key_vllm
,
[
"placeholder"
],
value_vllm
,
vllm_config
,
kv_cache_for_backend
)
device
,
common_attn_metadata
,
query_vllm
,
key_vllm
,
value_vllm
,
kv_cache_for_backend
,
sliding_window
=
sliding_window
,
)
# Check shape and dtype consistency
# Check shape and dtype consistency
assert
backend_output
.
shape
==
sdpa_output
.
shape
,
(
assert
backend_output
.
shape
==
sdpa_output
.
shape
,
(
...
@@ -446,18 +468,102 @@ def test_backend_correctness(batch_spec_name: str, model: str):
...
@@ -446,18 +468,102 @@ def test_backend_correctness(batch_spec_name: str, model: str):
f
"[
{
backend_name
}
] produced non-finite values"
)
f
"[
{
backend_name
}
] produced non-finite values"
)
# Check numerical similarity
# Check numerical similarity
rtol
=
1e-2
def
error_msg
(
msg
:
str
,
backend_name
:
str
):
atol
=
5e-3
return
(
f
"[
{
backend_name
}
] output differs from SDPA baseline. "
f
"
{
msg
}
"
)
max_diff
=
torch
.
max
(
torch
.
abs
(
backend_output
-
sdpa_output
)).
item
()
max_rel_diff
=
torch
.
max
(
torch
.
testing
.
assert_close
(
backend_output
,
torch
.
abs
(
backend_output
-
sdpa_output
)
/
torch
.
abs
(
sdpa_output
)).
item
()
all_close
=
torch
.
allclose
(
backend_output
,
sdpa_output
,
sdpa_output
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
,
msg
=
partial
(
error_msg
,
backend_name
=
backend_name
))
assert
all_close
,
(
f
"[
{
backend_name
}
] output differs from SDPA baseline. "
@
pytest
.
mark
.
parametrize
(
"batch_spec_name"
,
[
f
"Max diff:
{
max_diff
:.
6
f
}
, max rel diff:
{
max_rel_diff
:.
6
f
}
)"
)
"small_decode"
,
"small_prefill"
,
"mixed_small"
,
"medium_decode"
,
\ No newline at end of file
"medium_prefill"
,
"mixed_medium"
,
"large_decode"
,
"large_prefill"
,
"single_decode"
,
"single_prefill"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"meta-llama/Meta-Llama-3-8B"
])
def
test_causal_backend_correctness
(
batch_spec_name
:
str
,
model
:
str
):
"""Test backend's correctness with causal attention."""
def
causal_mask_mod
(
b
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
kv_idx
:
torch
.
Tensor
,
*
,
context_len
:
int
,
):
return
(
q_idx
+
context_len
)
>=
kv_idx
batch_spec
=
BATCH_SPECS
[
batch_spec_name
]
LARGE_BLOCK_BACKENDS
=
([
_Backend
.
FLEX_ATTENTION
]
if
is_torch_equal_or_newer
(
"2.9.0.dev0"
)
else
[])
SMALL_BLOCK_BACKENDS
=
[
x
for
x
in
BACKENDS_TO_TEST
if
x
not
in
LARGE_BLOCK_BACKENDS
]
_test_backend_correctness
(
batch_spec
,
model
,
SMALL_BLOCK_BACKENDS
,
causal_mask_mod
)
# Fast FlexAttention needs to run with block_size=128
if
LARGE_BLOCK_BACKENDS
:
_test_backend_correctness
(
batch_spec
,
model
,
LARGE_BLOCK_BACKENDS
,
causal_mask_mod
,
block_size
=
128
)
SLIDING_WINDOW_BACKENDS_TO_TEST
=
[
_Backend
.
FLASH_ATTN_VLLM_V1
,
_Backend
.
FLEX_ATTENTION
,
_Backend
.
TRITON_ATTN_VLLM_V1
,
"FLEX_ATTENTION_SLOW"
]
@
pytest
.
mark
.
parametrize
(
"batch_spec_name"
,
[
"small_decode"
,
"small_prefill"
,
"mixed_medium"
,
"large_decode"
,
"large_prefill"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"microsoft/Phi-tiny-MoE-instruct"
])
def
test_sliding_window_backend_correctness
(
batch_spec_name
:
str
,
model
:
str
):
"""Test backend's correctness with sliding window attention."""
def
sliding_window_mask_mod
(
b
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
kv_idx
:
torch
.
Tensor
,
*
,
context_len
:
int
,
sliding_window
:
int
,
):
causal_mask
=
q_idx
+
context_len
>=
kv_idx
window_mask
=
q_idx
+
context_len
-
kv_idx
<
sliding_window
return
causal_mask
&
window_mask
batch_spec
=
BATCH_SPECS
[
batch_spec_name
]
model_config
=
ModelConfig
(
model
=
model
,
max_model_len
=
max
(
batch_spec
.
seq_lens
))
sliding_window
=
model_config
.
get_sliding_window
()
sliding_window_mask_mod_fn
=
partial
(
sliding_window_mask_mod
,
sliding_window
=
sliding_window
)
LARGE_BLOCK_BACKENDS
=
([
_Backend
.
FLEX_ATTENTION
]
if
is_torch_equal_or_newer
(
"2.9.0.dev0"
)
else
[])
SMALL_BLOCK_BACKENDS
=
[
x
for
x
in
SLIDING_WINDOW_BACKENDS_TO_TEST
if
x
not
in
LARGE_BLOCK_BACKENDS
]
_test_backend_correctness
(
batch_spec
,
model
,
SMALL_BLOCK_BACKENDS
,
sliding_window_mask_mod_fn
)
# Fast FlexAttention needs to run with block_size=128
if
LARGE_BLOCK_BACKENDS
:
_test_backend_correctness
(
batch_spec
,
model
,
LARGE_BLOCK_BACKENDS
,
sliding_window_mask_mod_fn
,
block_size
=
128
)
vllm/v1/attention/backends/flex_attention.py
View file @
cf56cf78
...
@@ -9,7 +9,7 @@ import torch
...
@@ -9,7 +9,7 @@ import torch
import
torch._dynamo.decorators
import
torch._dynamo.decorators
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.attention.flex_attention
import
(
BlockMask
,
_mask_mod_signature
,
from
torch.nn.attention.flex_attention
import
(
BlockMask
,
_mask_mod_signature
,
_score_mod_signature
,
_score_mod_signature
,
and_masks
,
create_block_mask
,
create_block_mask
,
flex_attention
)
flex_attention
)
...
@@ -292,6 +292,7 @@ class FlexAttentionMetadata:
...
@@ -292,6 +292,7 @@ class FlexAttentionMetadata:
q_block_size
:
int
=
16
q_block_size
:
int
=
16
kv_block_size
:
int
=
16
kv_block_size
:
int
=
16
transformed_score_mod
:
Optional
[
_score_mod_signature
]
=
None
transformed_score_mod
:
Optional
[
_score_mod_signature
]
=
None
sliding_window
:
Optional
[
int
]
=
None
def
_convert_physical_to_logical
(
def
_convert_physical_to_logical
(
self
,
self
,
...
@@ -380,6 +381,53 @@ class FlexAttentionMetadata:
...
@@ -380,6 +381,53 @@ class FlexAttentionMetadata:
return
final_mask_mod
return
final_mask_mod
def
get_sliding_window_mask_mod
(
self
)
->
_mask_mod_signature
:
"""Creates the sliding window mask_mod function for FlexAttention.
Note that the sliding window mask here is bidirectional, we need
to mask it with the bidirectional/causal mask for encoder/decoder.
"""
if
self
.
sliding_window
is
None
:
raise
ValueError
(
"sliding_window must be set for sliding window attention"
)
def
sliding_window_mask_mod
(
b
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
kv_idx
:
torch
.
Tensor
):
return
torch
.
abs
(
q_idx
-
kv_idx
)
<
self
.
sliding_window
def
final_mask_mod
(
b
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
q_idx
:
torch
.
Tensor
,
physical_kv_idx
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
(
is_valid
,
logical_q_idx
,
logical_kv_idx
)
=
self
.
_convert_physical_to_logical
(
self
.
doc_ids
,
q_idx
,
physical_kv_idx
)
return
torch
.
where
(
is_valid
,
sliding_window_mask_mod
(
b
,
h
,
logical_q_idx
,
logical_kv_idx
),
False
,
)
return
final_mask_mod
if
self
.
causal
else
sliding_window_mask_mod
def
get_mask_mod
(
self
):
# Stage-1: initialize the base mask_mod
# (causal mask for decoder or bidirectional mask for encoder)
if
self
.
causal
:
mask_mod
=
self
.
get_causal_mask_mod
()
else
:
mask_mod
=
self
.
get_bidirectional_mask_mod
()
# stage-2: add external mask_mod for special attention during
# forwarding runtime to create the combined mask_mod.
if
self
.
sliding_window
is
not
None
:
# Add sliding window mask for sliding window attention
sliding_window_mask_mod
=
self
.
get_sliding_window_mask_mod
()
mask_mod
=
and_masks
(
mask_mod
,
sliding_window_mask_mod
)
return
mask_mod
def
get_transformed_score_mod
(
self
)
->
Optional
[
_score_mod_signature
]:
def
get_transformed_score_mod
(
self
)
->
Optional
[
_score_mod_signature
]:
"""Creates the transformed score_mod function for FlexAttention.
"""Creates the transformed score_mod function for FlexAttention.
...
@@ -472,12 +520,9 @@ class FlexAttentionMetadata:
...
@@ -472,12 +520,9 @@ class FlexAttentionMetadata:
return
BlockMask
.
from_kv_blocks
(
**
block_mask_kwargs
)
return
BlockMask
.
from_kv_blocks
(
**
block_mask_kwargs
)
def
build_block_mask
(
self
)
->
BlockMask
:
def
build_block_mask
(
self
)
->
BlockMask
:
if
self
.
causal
:
mask_mod
=
self
.
get_mask_mod
()
mask_mod
=
self
.
get_causal_mask_mod
()
kv_len
=
(
self
.
total_cache_tokens
kv_len
=
self
.
total_cache_tokens
if
self
.
causal
else
self
.
num_actual_tokens
)
else
:
mask_mod
=
self
.
get_bidirectional_mask_mod
()
kv_len
=
self
.
num_actual_tokens
return
create_block_mask_compiled
(
return
create_block_mask_compiled
(
mask_mod
,
mask_mod
,
None
,
None
,
...
@@ -498,11 +543,7 @@ class FlexAttentionMetadata:
...
@@ -498,11 +543,7 @@ class FlexAttentionMetadata:
self
.
doc_ids
=
_offsets_to_doc_ids_tensor
(
self
.
query_start_loc
)
self
.
doc_ids
=
_offsets_to_doc_ids_tensor
(
self
.
query_start_loc
)
self
.
num_blocks
=
self
.
total_cache_tokens
//
self
.
block_size
self
.
num_blocks
=
self
.
total_cache_tokens
//
self
.
block_size
if
self
.
causal
:
self
.
mask_mod
=
self
.
get_mask_mod
()
self
.
mask_mod
=
self
.
get_causal_mask_mod
()
else
:
self
.
mask_mod
=
self
.
get_bidirectional_mask_mod
()
self
.
transformed_score_mod
=
self
.
get_transformed_score_mod
()
self
.
transformed_score_mod
=
self
.
get_transformed_score_mod
()
if
self
.
direct_build
and
self
.
causal
:
if
self
.
direct_build
and
self
.
causal
:
...
@@ -607,7 +648,7 @@ class FlexAttentionMetadataBuilder(
...
@@ -607,7 +648,7 @@ class FlexAttentionMetadataBuilder(
class
FlexAttentionImpl
(
AttentionImpl
):
class
FlexAttentionImpl
(
AttentionImpl
):
sliding_window
:
Optional
[
tuple
[
int
,
int
]
]
sliding_window
:
Optional
[
int
]
alibi_slopes
:
Optional
[
torch
.
Tensor
]
alibi_slopes
:
Optional
[
torch
.
Tensor
]
logits_soft_cap
:
Optional
[
float
]
logits_soft_cap
:
Optional
[
float
]
...
@@ -641,11 +682,9 @@ class FlexAttentionImpl(AttentionImpl):
...
@@ -641,11 +682,9 @@ class FlexAttentionImpl(AttentionImpl):
"FlexAttention does not support alibi slopes yet."
)
"FlexAttention does not support alibi slopes yet."
)
else
:
else
:
self
.
alibi_slopes
=
None
self
.
alibi_slopes
=
None
if
sliding_window
is
not
None
:
raise
NotImplementedError
(
self
.
sliding_window
=
sliding_window
"FlexAttention does not support sliding window yet."
)
else
:
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
logits_soft_cap
=
logits_soft_cap
self
.
logits_soft_cap
=
logits_soft_cap
if
self
.
logits_soft_cap
is
not
None
:
if
self
.
logits_soft_cap
is
not
None
:
...
@@ -712,6 +751,21 @@ class FlexAttentionImpl(AttentionImpl):
...
@@ -712,6 +751,21 @@ class FlexAttentionImpl(AttentionImpl):
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
if
attn_metadata
.
sliding_window
!=
self
.
sliding_window
:
attn_metadata
.
sliding_window
=
self
.
sliding_window
if
attn_metadata
.
direct_build
:
# TODO: Support skipping the computation of sliding window
# in direct block mask building code path.
logger
.
warning_once
(
"Using direct block mask building with sliding window, "
"which is suboptimal now. Performance may be degraded."
)
# update mask mod in attention metadata
attn_metadata
.
mask_mod
=
attn_metadata
.
get_mask_mod
()
attn_metadata
.
block_mask
=
(
attn_metadata
.
_build_block_mask_direct
())
else
:
attn_metadata
.
block_mask
=
attn_metadata
.
build_block_mask
()
if
not
attn_metadata
.
causal
:
if
not
attn_metadata
.
causal
:
assert
self
.
attn_type
==
AttentionType
.
ENCODER_ONLY
assert
self
.
attn_type
==
AttentionType
.
ENCODER_ONLY
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment