Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1173804d
Unverified
Commit
1173804d
authored
Jun 16, 2025
by
Isotr0py
Committed by
GitHub
Jun 16, 2025
Browse files
[Bugfix] Fix TP inference for Flex attention backend (#19657)
Signed-off-by:
Isotr0py
<
2037008807@qq.com
>
parent
4d542402
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
54 additions
and
2 deletions
+54
-2
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+35
-1
vllm/v1/attention/backends/flex_attention.py
vllm/v1/attention/backends/flex_attention.py
+7
-1
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+2
-0
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+5
-0
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+5
-0
No files found.
tests/v1/engine/test_engine_core.py
View file @
1173804d
...
@@ -19,7 +19,7 @@ from vllm.v1.executor.abstract import Executor, UniProcExecutor
...
@@ -19,7 +19,7 @@ from vllm.v1.executor.abstract import Executor, UniProcExecutor
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
...utils
import
create_new_process_for_each_test
from
...utils
import
create_new_process_for_each_test
,
multi_gpu_test
if
not
current_platform
.
is_cuda
():
if
not
current_platform
.
is_cuda
():
pytest
.
skip
(
reason
=
"V1 currently only supported on CUDA."
,
pytest
.
skip
(
reason
=
"V1 currently only supported on CUDA."
,
...
@@ -378,3 +378,37 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
...
@@ -378,3 +378,37 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Odd steps schedules a new batch.
# Odd steps schedules a new batch.
assert
output
is
None
assert
output
is
None
step
+=
1
step
+=
1
@
multi_gpu_test
(
num_gpus
=
2
)
def
test_engine_core_tp
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""
Test engine can initialize worker in tp properly
"""
with
monkeypatch
.
context
()
as
m
:
m
.
setenv
(
"VLLM_USE_V1"
,
"1"
)
"""Setup the EngineCore."""
engine_args
=
EngineArgs
(
model
=
MODEL_NAME
,
tensor_parallel_size
=
2
,
# Reduce startup time.
enforce_eager
=
True
,
)
vllm_config
=
engine_args
.
create_engine_config
()
executor_class
=
Executor
.
get_class
(
vllm_config
)
with
set_default_torch_num_threads
(
1
):
engine_core
=
EngineCore
(
vllm_config
=
vllm_config
,
executor_class
=
executor_class
,
log_stats
=
True
)
def
get_worker_cache_config_field
(
worker
,
key
:
str
):
return
getattr
(
worker
.
cache_config
,
key
)
num_gpu_blocks
=
engine_core
.
collective_rpc
(
get_worker_cache_config_field
,
args
=
(
"num_gpu_blocks"
,
))
num_cpu_blocks
=
engine_core
.
collective_rpc
(
get_worker_cache_config_field
,
args
=
(
"num_cpu_blocks"
,
))
assert
all
(
x
is
not
None
for
x
in
num_gpu_blocks
)
assert
all
(
x
is
not
None
for
x
in
num_cpu_blocks
)
vllm/v1/attention/backends/flex_attention.py
View file @
1173804d
...
@@ -13,6 +13,7 @@ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
...
@@ -13,6 +13,7 @@ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
is_quantized_kv_cache
)
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
...
@@ -236,7 +237,12 @@ class FlexAttentionMetadata:
...
@@ -236,7 +237,12 @@ class FlexAttentionMetadata:
def
build_block_mask
(
self
)
->
BlockMask
:
def
build_block_mask
(
self
)
->
BlockMask
:
assert
self
.
mask_mod
is
not
None
assert
self
.
mask_mod
is
not
None
return
create_block_mask_compiled
(
# FIXME: With TP>1, create_block_mask_compiled will raise
# CUDA error: an illegal memory access was encountered
create_block_mask_fn
=
(
create_block_mask_compiled
if
get_tensor_model_parallel_world_size
()
==
1
else
create_block_mask
)
return
create_block_mask_fn
(
self
.
mask_mod
,
self
.
mask_mod
,
None
,
None
,
None
,
None
,
...
...
vllm/v1/engine/core.py
View file @
1173804d
...
@@ -84,6 +84,8 @@ class EngineCore:
...
@@ -84,6 +84,8 @@ class EngineCore:
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
vllm_config
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
vllm_config
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
self
.
collective_rpc
(
"initialize_cache"
,
args
=
(
num_gpu_blocks
,
num_cpu_blocks
))
self
.
structured_output_manager
=
StructuredOutputManager
(
vllm_config
)
self
.
structured_output_manager
=
StructuredOutputManager
(
vllm_config
)
...
...
vllm/v1/worker/gpu_worker.py
View file @
1173804d
...
@@ -112,6 +112,11 @@ class Worker(WorkerBase):
...
@@ -112,6 +112,11 @@ class Worker(WorkerBase):
buffer
.
data
.
copy_
(
self
.
_sleep_saved_buffers
[
name
].
data
)
buffer
.
data
.
copy_
(
self
.
_sleep_saved_buffers
[
name
].
data
)
self
.
_sleep_saved_buffers
=
{}
self
.
_sleep_saved_buffers
=
{}
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
def
init_device
(
self
):
def
init_device
(
self
):
if
self
.
device_config
.
device
.
type
==
"cuda"
:
if
self
.
device_config
.
device
.
type
==
"cuda"
:
# torch.distributed.all_reduce does not free the input tensor until
# torch.distributed.all_reduce does not free the input tensor until
...
...
vllm/v1/worker/tpu_worker.py
View file @
1173804d
...
@@ -93,6 +93,11 @@ class TPUWorker:
...
@@ -93,6 +93,11 @@ class TPUWorker:
if
self
.
model_config
.
seed
is
None
:
if
self
.
model_config
.
seed
is
None
:
self
.
model_config
.
seed
=
0
self
.
model_config
.
seed
=
0
def
initialize_cache
(
self
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
)
->
None
:
self
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
def
init_device
(
self
):
def
init_device
(
self
):
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
os
.
environ
[
"PJRT_DEVICE"
]
=
"TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment