Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9ef9173c
"vscode:/vscode.git/clone" did not exist on "18e85452979d2f974f2c193d159816a893fbc253"
Unverified
Commit
9ef9173c
authored
Jun 05, 2025
by
Nicolò Lucchesi
Committed by
GitHub
Jun 05, 2025
Browse files
[P/D][NixlConnector] Enable FlashInfer backend (#19090)
parent
85e2b7bb
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
15 deletions
+51
-15
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+50
-15
vllm/platforms/interface.py
vllm/platforms/interface.py
+1
-0
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
9ef9173c
...
@@ -15,6 +15,7 @@ import torch
...
@@ -15,6 +15,7 @@ import torch
import
zmq
import
zmq
from
vllm
import
envs
from
vllm
import
envs
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
)
...
@@ -22,6 +23,7 @@ from vllm.distributed.parallel_state import (
...
@@ -22,6 +23,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tp_group
)
get_tp_group
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
_Backend
from
vllm.utils
import
make_zmq_path
,
make_zmq_socket
,
round_down
from
vllm.utils
import
make_zmq_path
,
make_zmq_socket
,
round_down
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.request
import
RequestStatus
from
vllm.v1.request
import
RequestStatus
...
@@ -57,6 +59,7 @@ class NixlAgentMetadata(
...
@@ -57,6 +59,7 @@ class NixlAgentMetadata(
num_blocks
:
int
num_blocks
:
int
tp_size
:
int
tp_size
:
int
block_len
:
int
block_len
:
int
attn_backend_name
:
str
@
dataclass
@
dataclass
...
@@ -384,11 +387,25 @@ class NixlConnectorWorker:
...
@@ -384,11 +387,25 @@ class NixlConnectorWorker:
self
.
vllm_config
=
vllm_config
self
.
vllm_config
=
vllm_config
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
# TODO(mgoin): remove this once we have hybrid memory allocator
# TODO(mgoin): remove this once we have hybrid memory allocator
# Optimization for models with local attention (Llama 4)
# Optimization for models with local attention (Llama 4)
# List of block window sizes for each layer for local attention
# List of block window sizes for each layer for local attention
self
.
block_window_per_layer
:
list
[
Optional
[
int
]]
=
[]
self
.
block_window_per_layer
:
list
[
Optional
[
int
]]
=
[]
self
.
use_mla
=
self
.
model_config
.
use_mla
backend
=
get_attn_backend
(
self
.
model_config
.
get_head_size
(),
self
.
model_config
.
dtype
,
self
.
cache_config
.
cache_dtype
,
self
.
block_size
,
self
.
model_config
.
is_attention_free
,
use_mla
=
self
.
use_mla
)
self
.
backend_name
=
backend
.
get_name
()
attn_backend
=
backend_name_to_enum
(
self
.
backend_name
)
self
.
_use_flashinfer
=
attn_backend
==
_Backend
.
FLASHINFER_VLLM_V1
logger
.
debug
(
"Detected attention backend %s"
,
self
.
backend_name
)
self
.
_tp_size
:
dict
[
str
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
self
.
_tp_size
:
dict
[
str
,
int
]
=
{
self
.
engine_id
:
self
.
world_size
}
# With heterogeneous TP, P must wait for all assigned D TP workers to
# With heterogeneous TP, P must wait for all assigned D TP workers to
...
@@ -472,12 +489,16 @@ class NixlConnectorWorker:
...
@@ -472,12 +489,16 @@ class NixlConnectorWorker:
kv_elem_size
=
first_kv_cache
.
element_size
()
kv_elem_size
=
first_kv_cache
.
element_size
()
# TODO(tms): Find a more robust way to detect and handle MLA
# TODO(tms): Find a more robust way to detect and handle MLA
self
.
use_mla
=
len
(
first_kv_cache
.
shape
)
==
3
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
# KV memory layout is HND, as opposed to the default NHD. Note that it
# KV memory layout is HND, as opposed to the default NHD. Note that it
# will only affects the strides. For MLA instead, we make require no
# will only affects the strides. For MLA instead, we make require no
# such thing and resort to the standard layout.
# such thing and resort to the standard layout.
if
self
.
use_mla
:
use_mla
=
len
(
first_kv_cache
.
shape
)
==
3
assert
use_mla
==
self
.
use_mla
# TODO (NickLucche) not compatible with hybrid allocator. Enforce check
# once it goes live, as a single kv layout is expected for xfers.
if
use_mla
:
# MLA case.
# MLA case.
self
.
num_blocks
=
first_kv_cache
.
shape
[
0
]
self
.
num_blocks
=
first_kv_cache
.
shape
[
0
]
block_rank
=
2
# [block_size, latent_dim]
block_rank
=
2
# [block_size, latent_dim]
...
@@ -485,11 +506,16 @@ class NixlConnectorWorker:
...
@@ -485,11 +506,16 @@ class NixlConnectorWorker:
block_size
,
kv_latent_dim
=
block_shape
block_size
,
kv_latent_dim
=
block_shape
self
.
slot_size_bytes
=
kv_elem_size
*
kv_latent_dim
self
.
slot_size_bytes
=
kv_elem_size
*
kv_latent_dim
else
:
else
:
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
# [2 (k and v), num_blocks, ...]
if
self
.
_use_flashinfer
:
# FlashInfer swaps 2<->num_blocks dimensions.
self
.
num_blocks
=
first_kv_cache
.
shape
[
0
]
block_rank
=
4
# [2, block_size, kv_heads, head_dim]
else
:
self
.
num_blocks
=
first_kv_cache
.
shape
[
1
]
self
.
num_blocks
=
first_kv_cache
.
shape
[
1
]
block_rank
=
3
# [block_size, kv_heads, head_dim]
block_rank
=
3
# [block_size, kv_heads, head_dim]
block_shape
=
first_kv_cache
.
shape
[
-
block_rank
:]
block_shape
=
first_kv_cache
.
shape
[
-
block_rank
:]
block_size
,
n_kv_heads
,
head_dim
=
block_shape
block_size
,
n_kv_heads
,
head_dim
=
block_shape
[
-
3
:]
# head size in bytes.
# head size in bytes.
self
.
slot_size_bytes
=
kv_elem_size
*
n_kv_heads
*
head_dim
self
.
slot_size_bytes
=
kv_elem_size
*
n_kv_heads
*
head_dim
assert
block_size
==
self
.
block_size
assert
block_size
==
self
.
block_size
...
@@ -497,12 +523,10 @@ class NixlConnectorWorker:
...
@@ -497,12 +523,10 @@ class NixlConnectorWorker:
# hybrid attn, etc
# hybrid attn, etc
# block size in bytes
# block size in bytes
self
.
block_len
=
kv_elem_size
*
math
.
prod
(
block_shape
)
self
.
block_len
=
kv_elem_size
*
math
.
prod
(
block_shape
)
logger
.
info
(
logger
.
debug
(
"Registering KV_Caches. use_mla: %s, shape %s"
,
"Registering KV_Caches: use_mla: %s, num_blocks: %s, "
self
.
use_mla
,
first_kv_cache
.
shape
)
"block_shape: %s, per_layer_kv_cache_shape: %s"
,
use_mla
,
logger
.
debug
(
"num_blocks: %s, block_shape: %s"
,
self
.
num_blocks
,
self
.
num_blocks
,
block_shape
,
first_kv_cache
.
shape
)
block_shape
)
logger
.
debug
(
"Per layer kv cache size: %s"
,
first_kv_cache
.
shape
)
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
self
.
dst_num_blocks
[
self
.
engine_id
]
=
self
.
num_blocks
self
.
kv_caches
=
kv_caches
self
.
kv_caches
=
kv_caches
kv_caches_base_addr
=
[]
kv_caches_base_addr
=
[]
...
@@ -514,9 +538,12 @@ class NixlConnectorWorker:
...
@@ -514,9 +538,12 @@ class NixlConnectorWorker:
# are non-contiguous (it's not locally guaranteed that they will be)
# are non-contiguous (it's not locally guaranteed that they will be)
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# Disadvantage is that the encoded NixlAgentMetadata is now larger
# (roughly 8KB vs 5KB).
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are transferred in the same tensor
# to better exploit the memory layout (ie num_blocks is the first dim).
for
cache_or_caches
in
kv_caches
.
values
():
for
cache_or_caches
in
kv_caches
.
values
():
# Normalize to always be a list of caches
# Normalize to always be a list of caches
cache_list
=
[
cache_or_caches
]
if
self
.
use_mla
else
cache_or_caches
cache_list
=
[
cache_or_caches
]
if
use_mla
or
self
.
_use_flashinfer
\
else
cache_or_caches
for
cache
in
cache_list
:
for
cache
in
cache_list
:
base_addr
=
cache
.
data_ptr
()
base_addr
=
cache
.
data_ptr
()
region_len
=
self
.
num_blocks
*
self
.
block_len
region_len
=
self
.
num_blocks
*
self
.
block_len
...
@@ -581,7 +608,8 @@ class NixlConnectorWorker:
...
@@ -581,7 +608,8 @@ class NixlConnectorWorker:
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
num_blocks
=
self
.
num_blocks
,
num_blocks
=
self
.
num_blocks
,
tp_size
=
self
.
world_size
,
tp_size
=
self
.
world_size
,
block_len
=
self
.
block_len
)
block_len
=
self
.
block_len
,
attn_backend_name
=
self
.
backend_name
)
ready_event
=
threading
.
Event
()
ready_event
=
threading
.
Event
()
self
.
_nixl_handshake_listener_t
=
threading
.
Thread
(
self
.
_nixl_handshake_listener_t
=
threading
.
Thread
(
target
=
self
.
_nixl_handshake_listener
,
target
=
self
.
_nixl_handshake_listener
,
...
@@ -641,6 +669,10 @@ class NixlConnectorWorker:
...
@@ -641,6 +669,10 @@ class NixlConnectorWorker:
assert
self
.
_tp_size
[
engine_id
]
==
nixl_agent_meta
.
tp_size
assert
self
.
_tp_size
[
engine_id
]
==
nixl_agent_meta
.
tp_size
else
:
else
:
self
.
_tp_size
[
engine_id
]
=
nixl_agent_meta
.
tp_size
self
.
_tp_size
[
engine_id
]
=
nixl_agent_meta
.
tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
assert
nixl_agent_meta
.
attn_backend_name
==
self
.
backend_name
self
.
_remote_agents
[
engine_id
][
self
.
_remote_agents
[
engine_id
][
remote_tp_rank
]
=
self
.
nixl_wrapper
.
add_remote_agent
(
remote_tp_rank
]
=
self
.
nixl_wrapper
.
add_remote_agent
(
nixl_agent_meta
.
agent_metadata
)
nixl_agent_meta
.
agent_metadata
)
...
@@ -659,13 +691,16 @@ class NixlConnectorWorker:
...
@@ -659,13 +691,16 @@ class NixlConnectorWorker:
else
:
else
:
remote_block_size
=
nixl_agent_meta
.
block_len
//
(
remote_block_size
=
nixl_agent_meta
.
block_len
//
(
self
.
slot_size_bytes
*
tp_ratio
)
self
.
slot_size_bytes
*
tp_ratio
)
if
self
.
_use_flashinfer
:
# Account for joint KV in FlashInfer.
remote_block_size
//=
2
assert
nixl_agent_meta
.
block_len
==
self
.
block_len
*
tp_ratio
,
(
assert
nixl_agent_meta
.
block_len
==
self
.
block_len
*
tp_ratio
,
(
"Remote P worker KV layer cache must be of shape [2, N, "
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
)
)
assert
self
.
block_size
==
remote_block_size
,
"Remote P worker with "
assert
self
.
block_size
==
remote_block_size
,
"Remote P worker with "
\
"different block size is not supported"
"different block size is not supported"
assert
self
.
num_blocks
>=
nixl_agent_meta
.
num_blocks
assert
self
.
num_blocks
>=
nixl_agent_meta
.
num_blocks
...
...
vllm/platforms/interface.py
View file @
9ef9173c
...
@@ -47,6 +47,7 @@ class _Backend(enum.Enum):
...
@@ -47,6 +47,7 @@ class _Backend(enum.Enum):
ROCM_AITER_MLA_VLLM_V1
=
enum
.
auto
()
ROCM_AITER_MLA_VLLM_V1
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
FLASHINFER_VLLM_V1
=
enum
.
auto
()
TRITON_MLA
=
enum
.
auto
()
# Supported by V1
TRITON_MLA
=
enum
.
auto
()
# Supported by V1
TRITON_MLA_VLLM_V1
=
enum
.
auto
()
TRITON_MLA_VLLM_V1
=
enum
.
auto
()
FLASHMLA_VLLM_V1
=
enum
.
auto
()
FLASHMLA_VLLM_V1
=
enum
.
auto
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment