Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ddfed314
Unverified
Commit
ddfed314
authored
Jun 16, 2025
by
Driss Guessous
Committed by
GitHub
Jun 17, 2025
Browse files
Fixes IMA for TP w/ flex-attention (#19712)
Signed-off-by:
drisspg
<
drisspguessous@gmail.com
>
parent
5b3ad5ec
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
10 deletions
+2
-10
tests/kernels/test_flex_attention.py
tests/kernels/test_flex_attention.py
+0
-2
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+2
-8
No files found.
tests/kernels/test_flex_attention.py
View file @
ddfed314
...
...
@@ -51,7 +51,6 @@ def test_flex_attention_vs_default_backend(monkeypatch):
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_ATTENTION_BACKEND"
,
"FLEX_ATTENTION"
)
m
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
set_seed
(
seed
)
...
...
@@ -66,7 +65,6 @@ def test_flex_attention_vs_default_backend(monkeypatch):
# Run with default backend
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
m
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
set_seed
(
seed
)
llm_default
=
LLM
(
model_name
,
...
...
vllm/v1/attention/backends/flex_attention.py
View file @
ddfed314
...
...
@@ -13,7 +13,6 @@ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
...
...
@@ -237,17 +236,13 @@ class FlexAttentionMetadata:
def
build_block_mask
(
self
)
->
BlockMask
:
assert
self
.
mask_mod
is
not
None
# FIXME: With TP>1, create_block_mask_compiled will raise
# CUDA error: an illegal memory access was encountered
create_block_mask_fn
=
(
create_block_mask_compiled
if
get_tensor_model_parallel_world_size
()
==
1
else
create_block_mask
)
return
create_block_mask_fn
(
return
create_block_mask_compiled
(
self
.
mask_mod
,
None
,
None
,
self
.
num_actual_tokens
,
self
.
total_cache_tokens
,
device
=
self
.
block_table
.
device
,
)
def
__post_init__
(
self
):
...
...
@@ -429,7 +424,6 @@ class FlexAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size]
"""
assert
output
is
not
None
,
"Output tensor must be provided."
if
output_scale
is
not
None
:
raise
NotImplementedError
(
"fused output quantization is not yet supported"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment