Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a810671a
"vllm/vscode:/vscode.git/clone" did not exist on "6a6108511f251c2b8278a84e4266504c55e1f037"
Commit
a810671a
authored
Jan 08, 2026
by
zhuwenwen
Browse files
Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori
parents
86b5aefe
6a09612b
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
703 additions
and
443 deletions
+703
-443
vllm/compilation/piecewise_backend.py
vllm/compilation/piecewise_backend.py
+1
-2
vllm/config/compilation.py
vllm/config/compilation.py
+9
-11
vllm/config/observability.py
vllm/config/observability.py
+3
-0
vllm/config/parallel.py
vllm/config/parallel.py
+32
-27
vllm/config/vllm.py
vllm/config/vllm.py
+5
-2
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+24
-5
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+6
-1
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+11
-5
vllm/distributed/ec_transfer/ec_connector/example_connector.py
...distributed/ec_transfer/ec_connector/example_connector.py
+4
-1
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+61
-22
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
...connector/v1/lmcache_integration/multi_process_adapter.py
+9
-0
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
...buted/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
+55
-17
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+319
-182
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
...buted/kv_transfer/kv_connector/v1/offloading_connector.py
+2
-21
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+10
-3
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+21
-2
vllm/entrypoints/context.py
vllm/entrypoints/context.py
+60
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+28
-44
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+6
-95
vllm/entrypoints/openai/parser/responses_parser.py
vllm/entrypoints/openai/parser/responses_parser.py
+37
-1
No files found.
vllm/compilation/piecewise_backend.py
View file @
a810671a
...
@@ -170,8 +170,7 @@ class PiecewiseBackend:
...
@@ -170,8 +170,7 @@ class PiecewiseBackend:
range_entry
=
self
.
_find_range_for_shape
(
runtime_shape
)
range_entry
=
self
.
_find_range_for_shape
(
runtime_shape
)
assert
range_entry
is
not
None
,
(
assert
range_entry
is
not
None
,
(
f
"Shape out of considered range:
{
runtime_shape
}
"
f
"Shape:
{
runtime_shape
}
out of considered ranges:
{
self
.
compile_ranges
}
"
"[1, max_num_batched_tokens]"
)
)
self
.
_maybe_compile_for_range_entry
(
range_entry
,
args
)
self
.
_maybe_compile_for_range_entry
(
range_entry
,
args
)
...
...
vllm/config/compilation.py
View file @
a810671a
...
@@ -437,14 +437,14 @@ class CompilationConfig:
...
@@ -437,14 +437,14 @@ class CompilationConfig:
compile_ranges_split_points
:
list
[
int
]
|
None
=
None
compile_ranges_split_points
:
list
[
int
]
|
None
=
None
"""Split points that represent compile ranges for inductor.
"""Split points that represent compile ranges for inductor.
The compile ranges are
The compile ranges are
[1, split_points[0]],
[1, split_points[0]],
[split_points[0] + 1, split_points[1]], ...,
[split_points[0] + 1, split_points[1]], ...,
[split_points[-1] + 1, max_num_batched_tokens].
[split_points[-1] + 1, max_num_batched_tokens].
Compile sizes are also used single element ranges,
Compile sizes are also used single element ranges,
the range is represented as [compile_sizes[i], compile_sizes[i]].
the range is represented as [compile_sizes[i], compile_sizes[i]].
If a range overlaps with the compile size, graph for compile size
If a range overlaps with the compile size, graph for compile size
will be prioritized, i.e. if we have a range [1, 8] and a compile size 4,
will be prioritized, i.e. if we have a range [1, 8] and a compile size 4,
graph for compile size 4 will be compiled and used instead of the graph
graph for compile size 4 will be compiled and used instead of the graph
for range [1, 8].
for range [1, 8].
...
@@ -899,7 +899,7 @@ class CompilationConfig:
...
@@ -899,7 +899,7 @@ class CompilationConfig:
self
.
compute_bs_to_padded_graph_size
()
self
.
compute_bs_to_padded_graph_size
()
def
set_splitting_ops_for_v1
(
def
set_splitting_ops_for_v1
(
self
,
all2all_backend
:
str
|
None
=
None
,
data_parallel_size
:
int
|
None
=
None
self
,
all2all_backend
:
str
,
data_parallel_size
:
int
=
1
):
):
# To compatible with OOT hardware plugin platform (for example vllm-ascend)
# To compatible with OOT hardware plugin platform (for example vllm-ascend)
# which currently only supports sequence parallelism in eager mode.
# which currently only supports sequence parallelism in eager mode.
...
@@ -934,7 +934,7 @@ class CompilationConfig:
...
@@ -934,7 +934,7 @@ class CompilationConfig:
or
self
.
cudagraph_mode
==
CUDAGraphMode
.
FULL_AND_PIECEWISE
or
self
.
cudagraph_mode
==
CUDAGraphMode
.
FULL_AND_PIECEWISE
):
):
logger
.
warning_once
(
logger
.
warning_once
(
"Using piecewise c
ompilation
with empty splitting_ops"
"Using piecewise c
udagraph
with empty splitting_ops"
)
)
if
self
.
cudagraph_mode
==
CUDAGraphMode
.
PIECEWISE
:
if
self
.
cudagraph_mode
==
CUDAGraphMode
.
PIECEWISE
:
logger
.
warning_once
(
logger
.
warning_once
(
...
@@ -956,11 +956,9 @@ class CompilationConfig:
...
@@ -956,11 +956,9 @@ class CompilationConfig:
self
.
splitting_ops
=
[]
self
.
splitting_ops
=
[]
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
backend
=
all2all_backend
or
envs
.
VLLM_ALL2ALL_BACKEND
dp_size
=
data_parallel_size
if
data_parallel_size
is
not
None
else
1
if
(
if
(
backend
==
"deepep_high_throughput"
all2all_
backend
==
"deepep_high_throughput"
and
d
p
_size
>
1
and
d
ata_parallel
_size
>
1
and
self
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
and
self
.
cudagraph_mode
!=
CUDAGraphMode
.
NONE
):
):
# TODO: Piecewise Cuda graph might be enabled
# TODO: Piecewise Cuda graph might be enabled
...
...
vllm/config/observability.py
View file @
a810671a
...
@@ -64,6 +64,9 @@ class ObservabilityConfig:
...
@@ -64,6 +64,9 @@ class ObservabilityConfig:
module in the model and attach informations such as input/output shapes to
module in the model and attach informations such as input/output shapes to
nvtx range markers. Noted that this doesn't work with CUDA graphs enabled."""
nvtx range markers. Noted that this doesn't work with CUDA graphs enabled."""
enable_mfu_metrics
:
bool
=
False
"""Enable Model FLOPs Utilization (MFU) metrics."""
@
cached_property
@
cached_property
def
collect_model_forward_time
(
self
)
->
bool
:
def
collect_model_forward_time
(
self
)
->
bool
:
"""Whether to collect model forward time for the request."""
"""Whether to collect model forward time for the request."""
...
...
vllm/config/parallel.py
View file @
a810671a
...
@@ -36,6 +36,14 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"]
...
@@ -36,6 +36,14 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"]
DistributedExecutorBackend
=
Literal
[
"ray"
,
"mp"
,
"uni"
,
"external_launcher"
]
DistributedExecutorBackend
=
Literal
[
"ray"
,
"mp"
,
"uni"
,
"external_launcher"
]
DataParallelBackend
=
Literal
[
"ray"
,
"mp"
]
DataParallelBackend
=
Literal
[
"ray"
,
"mp"
]
EPLBPolicyOption
=
Literal
[
"default"
]
EPLBPolicyOption
=
Literal
[
"default"
]
All2AllBackend
=
Literal
[
"naive"
,
"pplx"
,
"deepep_high_throughput"
,
"deepep_low_latency"
,
"allgather_reducescatter"
,
"flashinfer_all2allv"
,
]
@
config
@
config
...
@@ -126,24 +134,14 @@ class ParallelConfig:
...
@@ -126,24 +134,14 @@ class ParallelConfig:
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
will have experts [1, 3]. This strategy can help improve load balancing
will have experts [1, 3]. This strategy can help improve load balancing
for grouped expert models with no redundant experts."""
for grouped expert models with no redundant experts."""
all2all_backend
:
(
all2all_backend
:
All2AllBackend
=
"allgather_reducescatter"
Literal
[
"""All2All backend for MoE expert parallel communication. Available options:
"naive"
,
"pplx"
,
- "naive": Naive all2all implementation using broadcasts
\n
"deepep_high_throughput"
,
- "allgather_reducescatter": All2all based on allgather and reducescatter
\n
"deepep_low_latency"
,
- "pplx": Use pplx kernels
\n
"allgather_reducescatter"
,
- "deepep_high_throughput": Use deepep high-throughput kernels
\n
"flashinfer_all2allv"
,
- "deepep_low_latency": Use deepep low-latency kernels
\n
]
|
None
)
=
None
"""All2All backend for MoE expert parallel communication. If not set, uses
the value from VLLM_ALL2ALL_BACKEND environment variable. Available options:
- "naive": Naive all2all implementation using broadcasts
- "allgather_reducescatter": All2all based on allgather and reducescatter
- "pplx": Use pplx kernels
- "deepep_high_throughput": Use deepep high-throughput kernels
- "deepep_low_latency": Use deepep low-latency kernels
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
max_parallel_loading_workers
:
int
|
None
=
None
max_parallel_loading_workers
:
int
|
None
=
None
...
@@ -156,6 +154,8 @@ class ParallelConfig:
...
@@ -156,6 +154,8 @@ class ParallelConfig:
enable_dbo
:
bool
=
False
enable_dbo
:
bool
=
False
"""Enable dual batch overlap for the model executor."""
"""Enable dual batch overlap for the model executor."""
ubatch_size
:
int
=
0
"""Number of ubatch size."""
dbo_decode_token_threshold
:
int
=
32
dbo_decode_token_threshold
:
int
=
32
"""The threshold for dual batch overlap for batches only containing decodes.
"""The threshold for dual batch overlap for batches only containing decodes.
...
@@ -325,6 +325,14 @@ class ParallelConfig:
...
@@ -325,6 +325,14 @@ class ParallelConfig:
including data parallelism."""
including data parallelism."""
return
self
.
world_size
*
self
.
data_parallel_size
return
self
.
world_size
*
self
.
data_parallel_size
@
property
def
use_ubatching
(
self
)
->
bool
:
return
self
.
enable_dbo
or
self
.
ubatch_size
>
1
@
property
def
num_ubatches
(
self
)
->
int
:
return
2
if
self
.
enable_dbo
else
self
.
ubatch_size
def
get_next_dp_init_port
(
self
)
->
int
:
def
get_next_dp_init_port
(
self
)
->
int
:
"""
"""
We might need to initialize process groups in multiple
We might need to initialize process groups in multiple
...
@@ -485,20 +493,17 @@ class ParallelConfig:
...
@@ -485,20 +493,17 @@ class ParallelConfig:
from
vllm.config.utils
import
get_hash_factors
,
hash_factors
from
vllm.config.utils
import
get_hash_factors
,
hash_factors
factors
=
get_hash_factors
(
self
,
ignored_factors
)
factors
=
get_hash_factors
(
self
,
ignored_factors
)
# Explicitly include backend affecting env factor as before
factors
[
"VLLM_ALL2ALL_BACKEND"
]
=
str
(
envs
.
VLLM_ALL2ALL_BACKEND
)
return
hash_factors
(
factors
)
return
hash_factors
(
factors
)
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
# Set all2all_backend from env var if not specified, with deprecation warning
# Set all2all_backend from env var if not specified, with deprecation warning
if
self
.
all2all_backend
is
None
:
if
envs
.
is_set
(
"VLLM_ALL2ALL_BACKEND"
):
logger
.
warning_once
(
"VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
"will be removed in v0.15.0. Please use the "
"--all2all-backend command-line argument instead."
)
self
.
all2all_backend
=
envs
.
VLLM_ALL2ALL_BACKEND
self
.
all2all_backend
=
envs
.
VLLM_ALL2ALL_BACKEND
if
envs
.
is_set
(
"VLLM_ALL2ALL_BACKEND"
):
logger
.
warning_once
(
"VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
"will be removed in a future release. Please use the "
"--all2all-backend command-line argument instead."
)
# Continue with the rest of the initialization
# Continue with the rest of the initialization
self
.
world_size
=
(
self
.
world_size
=
(
...
...
vllm/config/vllm.py
View file @
a810671a
...
@@ -870,9 +870,12 @@ class VllmConfig:
...
@@ -870,9 +870,12 @@ class VllmConfig:
f
"cudagraph_mode=
{
self
.
compilation_config
.
cudagraph_mode
}
"
f
"cudagraph_mode=
{
self
.
compilation_config
.
cudagraph_mode
}
"
)
)
if
self
.
parallel_config
.
enable_dbo
:
if
self
.
parallel_config
.
use_ubatching
:
a2a_backend
=
self
.
parallel_config
.
all2all_backend
a2a_backend
=
self
.
parallel_config
.
all2all_backend
assert
a2a_backend
in
[
"deepep_low_latency"
,
"deepep_high_throughput"
],
(
assert
a2a_backend
in
[
"deepep_low_latency"
,
"deepep_high_throughput"
,
],
(
"Microbatching currently only supports the deepep_low_latency and "
"Microbatching currently only supports the deepep_low_latency and "
f
"deepep_high_throughput all2all backend.
{
a2a_backend
}
is not "
f
"deepep_high_throughput all2all backend.
{
a2a_backend
}
is not "
"supported. To fix use --all2all-backend=deepep_low_latency or "
"supported. To fix use --all2all-backend=deepep_low_latency or "
...
...
vllm/distributed/device_communicators/all2all.py
View file @
a810671a
...
@@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase):
...
@@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
extra_tensors
is
not
None
:
raise
NotImplementedError
(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size
=
self
.
tp_group
.
world_size
if
is_sequence_parallel
else
1
sp_size
=
self
.
tp_group
.
world_size
if
is_sequence_parallel
else
1
dp_metadata
=
get_forward_context
().
dp_metadata
dp_metadata
=
get_forward_context
().
dp_metadata
assert
dp_metadata
is
not
None
assert
dp_metadata
is
not
None
...
@@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
...
@@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
router_logits
=
self
.
naive_multicast
(
router_logits
=
self
.
naive_multicast
(
router_logits
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
router_logits
,
cu_tokens_across_sp_cpu
,
is_sequence_parallel
)
)
return
hidden_states
,
router_logits
return
hidden_states
,
router_logits
def
combine
(
def
combine
(
...
@@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase):
...
@@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
"""
"""
Gather hidden_states and router_logits from all dp ranks.
Gather hidden_states and router_logits from all dp ranks.
"""
"""
...
@@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase):
...
@@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase):
assert
dp_metadata
is
not
None
assert
dp_metadata
is
not
None
sizes
=
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
sizes
=
dp_metadata
.
get_chunk_sizes_across_dp_rank
()
assert
sizes
is
not
None
assert
sizes
is
not
None
dist_group
=
get_ep_group
()
if
is_sequence_parallel
else
get_dp_group
()
dist_group
=
get_ep_group
()
if
is_sequence_parallel
else
get_dp_group
()
assert
sizes
[
dist_group
.
rank_in_group
]
==
hidden_states
.
shape
[
0
]
assert
sizes
[
dist_group
.
rank_in_group
]
==
hidden_states
.
shape
[
0
]
hidden_states
,
router_logits
=
dist_group
.
all_gatherv
(
[
hidden_states
,
router_logits
],
tensors_to_gather
=
[
hidden_states
,
router_logits
]
if
extra_tensors
is
not
None
:
tensors_to_gather
.
extend
(
extra_tensors
)
gathered_tensors
=
dist_group
.
all_gatherv
(
tensors_to_gather
,
dim
=
0
,
dim
=
0
,
sizes
=
sizes
,
sizes
=
sizes
,
)
)
return
hidden_states
,
router_logits
if
extra_tensors
is
not
None
:
return
(
gathered_tensors
[
0
],
gathered_tensors
[
1
],
gathered_tensors
[
2
:])
return
gathered_tensors
[
0
],
gathered_tensors
[
1
]
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
...
@@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
...
@@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
...
@@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/distributed/device_communicators/base_device_communicator.py
View file @
a810671a
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
threading
import
threading
from
typing
import
Any
from
weakref
import
WeakValueDictionary
from
weakref
import
WeakValueDictionary
import
torch
import
torch
...
@@ -68,7 +69,11 @@ class All2AllManagerBase:
...
@@ -68,7 +69,11 @@ class All2AllManagerBase:
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
):
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
Any
:
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise
NotImplementedError
raise
NotImplementedError
def
set_num_sms
(
self
,
num_sms
:
int
):
def
set_num_sms
(
self
,
num_sms
:
int
):
...
...
vllm/distributed/device_communicators/cuda_communicator.py
View file @
a810671a
...
@@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase):
return
output_list
return
output_list
def
dispatch
(
def
dispatch
(
# type: ignore[override]
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
assert
self
.
all2all_manager
is
not
None
assert
self
.
all2all_manager
is
not
None
hidden_states
,
router_logits
=
self
.
all2all_manager
.
dispatch
(
return
self
.
all2all_manager
.
dispatch
(
hidden_states
,
router_logits
,
is_sequence_parallel
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
# type: ignore[call-arg]
)
)
return
hidden_states
,
router_logits
def
combine
(
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
...
...
vllm/distributed/ec_transfer/ec_connector/example_connector.py
View file @
a810671a
...
@@ -73,6 +73,7 @@ class ECExampleConnector(ECConnectorBase):
...
@@ -73,6 +73,7 @@ class ECExampleConnector(ECConnectorBase):
data hashes (`mm_hash`) to encoder cache tensors.
data hashes (`mm_hash`) to encoder cache tensors.
kwargs (dict): Additional keyword arguments for the connector.
kwargs (dict): Additional keyword arguments for the connector.
"""
"""
from
vllm.platforms
import
current_platform
# Get the metadata
# Get the metadata
metadata
:
ECConnectorMetadata
=
self
.
_get_connector_metadata
()
metadata
:
ECConnectorMetadata
=
self
.
_get_connector_metadata
()
...
@@ -91,7 +92,9 @@ class ECExampleConnector(ECConnectorBase):
...
@@ -91,7 +92,9 @@ class ECExampleConnector(ECConnectorBase):
if
mm_data
.
mm_hash
in
encoder_cache
:
if
mm_data
.
mm_hash
in
encoder_cache
:
continue
continue
filename
=
self
.
_generate_filename_debug
(
mm_data
.
mm_hash
)
filename
=
self
.
_generate_filename_debug
(
mm_data
.
mm_hash
)
ec_cache
=
safetensors
.
torch
.
load_file
(
filename
)[
"ec_cache"
].
cuda
()
ec_cache
=
safetensors
.
torch
.
load_file
(
filename
,
device
=
current_platform
.
device_type
)[
"ec_cache"
]
encoder_cache
[
mm_data
.
mm_hash
]
=
ec_cache
encoder_cache
[
mm_data
.
mm_hash
]
=
ec_cache
logger
.
debug
(
"Success load encoder cache for hash %s"
,
mm_data
.
mm_hash
)
logger
.
debug
(
"Success load encoder cache for hash %s"
,
mm_data
.
mm_hash
)
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
a810671a
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
KV cache helper for store.
KV cache helper for store.
"""
"""
from
collections.abc
import
Iterator
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Literal
from
typing
import
TYPE_CHECKING
,
Literal
...
@@ -21,6 +22,8 @@ if TYPE_CHECKING:
...
@@ -21,6 +22,8 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
EngineId
=
str
def
get_kv_connector_cache_layout
():
def
get_kv_connector_cache_layout
():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
...
@@ -201,6 +204,26 @@ def copy_kv_blocks(
...
@@ -201,6 +204,26 @@ def copy_kv_blocks(
copy_fn
(
src_tensor
,
dst_tensor
,
src_indices
,
dst_indices
)
copy_fn
(
src_tensor
,
dst_tensor
,
src_indices
,
dst_indices
)
def
yield_req_data
(
scheduler_output
,
)
->
Iterator
[
tuple
[
str
,
tuple
[
list
[
int
],
...],
bool
]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for
req_data
in
scheduler_output
.
scheduled_new_reqs
:
yield
req_data
.
req_id
,
req_data
.
block_ids
,
False
# cached requests
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
yield
from
zip
(
cached_reqs
.
req_ids
,
cached_reqs
.
new_block_ids
,
(
req_id
in
cached_reqs
.
resumed_req_ids
for
req_id
in
cached_reqs
.
req_ids
),
)
@
dataclass
@
dataclass
class
TpKVTopology
:
class
TpKVTopology
:
"""
"""
...
@@ -209,12 +232,12 @@ class TpKVTopology:
...
@@ -209,12 +232,12 @@ class TpKVTopology:
"""
"""
tp_rank
:
int
tp_rank
:
int
remote_tp_size
:
dict
[
str
,
int
]
remote_tp_size
:
dict
[
EngineId
,
int
]
is_mla
:
bool
is_mla
:
bool
total_num_kv_heads
:
int
total_num_kv_heads
:
int
attn_backend
:
type
[
AttentionBackend
]
attn_backend
:
type
[
AttentionBackend
]
engine_id
:
str
engine_id
:
EngineId
remote_block_size
:
dict
[
str
,
int
]
remote_block_size
:
dict
[
EngineId
,
int
]
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Figure out whether the first dimension of the cache is K/V
# Figure out whether the first dimension of the cache is K/V
...
@@ -256,18 +279,28 @@ class TpKVTopology:
...
@@ -256,18 +279,28 @@ class TpKVTopology:
Calculate the tensor parallel ratio between local and remote TP.
Calculate the tensor parallel ratio between local and remote TP.
We can think of it as the number of local TP workers-per-remote TP
We can think of it as the number of local TP workers-per-remote TP
workers. Local workers will read from the same remote TP worker in
workers. Local workers will read from the same remote TP worker in
groups of size `tp_ratio`.
groups of size `tp_ratio`.If remote tp_size > local tp_size, the
ratio is flipped (remote_size/local_size) and the returned value is
negative.
"""
"""
assert
self
.
tp_size
%
remote_tp_size
==
0
,
(
if
self
.
tp_size
>=
remote_tp_size
:
f
"Local tensor parallel size
{
self
.
tp_size
}
is not divisible "
assert
self
.
tp_size
%
remote_tp_size
==
0
,
(
f
"by remote tensor parallel size
{
remote_tp_size
}
."
f
"Local tensor parallel size
{
self
.
tp_size
}
is not divisible "
f
"by remote tensor parallel size
{
remote_tp_size
}
."
)
return
self
.
tp_size
//
remote_tp_size
assert
remote_tp_size
%
self
.
tp_size
==
0
,
(
f
"Remote tensor parallel size
{
remote_tp_size
}
is not divisible "
f
"by local tensor parallel size
{
self
.
tp_size
}
."
)
)
return
self
.
tp_size
//
remote_tp_size
# P TP > D TP case, return the ratio as negative
return
-
remote_tp_size
//
self
.
tp_size
def
block_size_ratio
(
def
block_size_ratio
(
self
,
self
,
remote_block_size
:
int
,
remote_block_size
:
int
,
)
->
floa
t
:
)
->
in
t
:
"""
"""
Calculate the block size ratio between local and remote TP.
Calculate the block size ratio between local and remote TP.
"""
"""
...
@@ -279,19 +312,19 @@ class TpKVTopology:
...
@@ -279,19 +312,19 @@ class TpKVTopology:
def
tp_ratio_from_engine_id
(
def
tp_ratio_from_engine_id
(
self
,
self
,
remote_engine_id
:
str
,
remote_engine_id
:
EngineId
,
)
->
int
:
)
->
int
:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
tp_ratio
(
remote_tp_size
)
return
self
.
tp_ratio
(
remote_tp_size
)
def
block_size_ratio_from_engine_id
(
def
block_size_ratio_from_engine_id
(
self
,
self
,
remote_engine_id
:
str
,
remote_engine_id
:
EngineId
,
)
->
floa
t
:
)
->
in
t
:
remote_block_size
=
self
.
remote_block_size
[
remote_engine_id
]
remote_block_size
=
self
.
remote_block_size
[
remote_engine_id
]
return
self
.
block_size_ratio
(
remote_block_size
)
return
self
.
block_size_ratio
(
remote_block_size
)
def
is_kv_replicated
(
self
,
engine_id
:
str
)
->
bool
:
def
is_kv_replicated
(
self
,
engine_id
:
EngineId
)
->
bool
:
"""
"""
Whether the KV cache is replicated across TP workers due to the
Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads.
number of TP workers being greater than the number of KV heads.
...
@@ -299,24 +332,30 @@ class TpKVTopology:
...
@@ -299,24 +332,30 @@ class TpKVTopology:
tp_size
=
self
.
remote_tp_size
[
engine_id
]
tp_size
=
self
.
remote_tp_size
[
engine_id
]
return
tp_size
//
self
.
total_num_kv_heads
>=
1
return
tp_size
//
self
.
total_num_kv_heads
>=
1
def
replicates_kv_cache
(
self
,
remote_engine_id
:
str
)
->
bool
:
def
replicates_kv_cache
(
self
,
remote_engine_id
:
EngineId
)
->
bool
:
# MLA is always replicated as the hidden dim can't be split.
# MLA is always replicated as the hidden dim can't be split.
return
self
.
is_mla
or
self
.
is_kv_replicated
(
remote_engine_id
)
return
self
.
is_mla
or
self
.
is_kv_replicated
(
remote_engine_id
)
def
get_target_remote_rank
(
def
get_target_remote_rank
s
(
self
,
self
,
remote_tp_size
:
int
,
remote_tp_size
:
int
,
)
->
int
:
)
->
list
[
int
]
:
"""
"""
Get the remote TP rank (on P) that the current local TP rank
Get the remote TP rank (on P) that the current local TP rank
(on D) will read from.
(on D) will read from. When remote tp_size > local tp_size, we
read from multiple remote ranks.
"""
"""
tp_ratio
=
self
.
tp_ratio
(
remote_tp_size
)
tp_ratio
=
self
.
tp_ratio
(
remote_tp_size
)
return
self
.
tp_rank
//
tp_ratio
if
tp_ratio
>
0
:
return
[
self
.
tp_rank
//
tp_ratio
]
# P TP > D TP case, D reads from |tp_ratio| remote workers.
tp_ratio
=
-
tp_ratio
return
[
self
.
tp_rank
*
tp_ratio
+
i
for
i
in
range
(
tp_ratio
)]
def
get_target_remote_rank_from_engine_id
(
def
get_target_remote_rank
s
_from_engine_id
(
self
,
self
,
remote_engine_id
:
str
,
remote_engine_id
:
EngineId
,
)
->
int
:
)
->
list
[
int
]
:
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
remote_tp_size
=
self
.
remote_tp_size
[
remote_engine_id
]
return
self
.
get_target_remote_rank
(
remote_tp_size
)
return
self
.
get_target_remote_rank
s
(
remote_tp_size
)
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
View file @
a810671a
...
@@ -147,6 +147,14 @@ class LMCacheMPSchedulerAdapter:
...
@@ -147,6 +147,14 @@ class LMCacheMPSchedulerAdapter:
"""
"""
return
self
.
blocks_in_chunk
return
self
.
blocks_in_chunk
def
_cleanup_lookup_result
(
self
,
request_id
:
str
)
->
None
:
"""
Clean up lookup future for a finished request to prevent memory leak.
Args:
request_id: The ID of the finished request.
"""
self
.
lookup_futures
.
pop
(
request_id
,
None
)
# Helper functions
# Helper functions
def
_create_key
(
self
,
block_hash
:
bytes
)
->
IPCCacheEngineKey
:
def
_create_key
(
self
,
block_hash
:
bytes
)
->
IPCCacheEngineKey
:
"""Convert a block hash to an IPC cache engine key"""
"""Convert a block hash to an IPC cache engine key"""
...
@@ -262,6 +270,7 @@ class LMCacheMPWorkerAdapter:
...
@@ -262,6 +270,7 @@ class LMCacheMPWorkerAdapter:
):
):
keys
=
[]
keys
=
[]
block_ids
=
[]
block_ids
=
[]
for
op
in
ops
:
for
op
in
ops
:
keys
.
extend
(
self
.
_block_hashes_to_keys
(
op
.
block_hashes
))
keys
.
extend
(
self
.
_block_hashes_to_keys
(
op
.
block_hashes
))
block_ids
.
extend
(
op
.
block_ids
)
block_ids
.
extend
(
op
.
block_ids
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
View file @
a810671a
...
@@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
...
@@ -24,6 +24,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration import (
)
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
from
vllm.v1.outputs
import
KVConnectorOutput
from
vllm.v1.request
import
RequestStatus
from
vllm.v1.utils
import
ConstantList
from
vllm.v1.utils
import
ConstantList
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -211,7 +212,7 @@ class LMCacheMPRequestTracker:
...
@@ -211,7 +212,7 @@ class LMCacheMPRequestTracker:
"""
"""
self
.
num_stored_blocks
+=
num_new_blocks
self
.
num_stored_blocks
+=
num_new_blocks
def
update
_block_ids
(
def
append
_block_ids
(
self
,
self
,
new_block_ids
:
list
[
int
],
new_block_ids
:
list
[
int
],
):
):
...
@@ -455,10 +456,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
...
@@ -455,10 +456,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
metadata
=
self
.
_get_connector_metadata
()
metadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
metadata
,
LMCacheMPConnectorMetadata
)
assert
isinstance
(
metadata
,
LMCacheMPConnectorMetadata
)
with
torch
.
cuda
.
stream
(
torch
.
cuda
.
current_stream
()):
event
=
torch
.
cuda
.
Event
(
interprocess
=
True
)
event
.
record
()
request_ids
=
[]
request_ids
=
[]
ops
=
[]
ops
=
[]
...
@@ -468,10 +465,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
...
@@ -468,10 +465,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids
.
append
(
meta
.
request_id
)
request_ids
.
append
(
meta
.
request_id
)
ops
.
append
(
meta
.
op
)
ops
.
append
(
meta
.
op
)
if
len
(
request_ids
)
>
0
:
if
len
(
request_ids
)
==
0
:
self
.
worker_adapter
.
batched_submit_retrieve_requests
(
return
request_ids
,
ops
,
event
)
with
torch
.
cuda
.
stream
(
torch
.
cuda
.
current_stream
()):
event
=
torch
.
cuda
.
Event
(
interprocess
=
True
)
event
.
record
()
self
.
worker_adapter
.
batched_submit_retrieve_requests
(
request_ids
,
ops
,
event
)
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
"""
"""
...
@@ -518,10 +519,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
...
@@ -518,10 +519,6 @@ class LMCacheMPConnector(KVConnectorBase_V1):
metadata
=
self
.
_get_connector_metadata
()
metadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
metadata
,
LMCacheMPConnectorMetadata
)
assert
isinstance
(
metadata
,
LMCacheMPConnectorMetadata
)
with
torch
.
cuda
.
stream
(
torch
.
cuda
.
current_stream
()):
event
=
torch
.
cuda
.
Event
(
interprocess
=
True
)
event
.
record
()
request_ids
=
[]
request_ids
=
[]
ops
=
[]
ops
=
[]
for
meta
in
metadata
.
requests
:
for
meta
in
metadata
.
requests
:
...
@@ -530,8 +527,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
...
@@ -530,8 +527,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids
.
append
(
meta
.
request_id
)
request_ids
.
append
(
meta
.
request_id
)
ops
.
append
(
meta
.
op
)
ops
.
append
(
meta
.
op
)
if
len
(
request_ids
)
>
0
:
if
len
(
request_ids
)
==
0
:
self
.
worker_adapter
.
batched_submit_store_requests
(
request_ids
,
ops
,
event
)
return
with
torch
.
cuda
.
stream
(
torch
.
cuda
.
current_stream
()):
event
=
torch
.
cuda
.
Event
(
interprocess
=
True
)
event
.
record
()
self
.
worker_adapter
.
batched_submit_store_requests
(
request_ids
,
ops
,
event
)
def
get_finished
(
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
]
self
,
finished_req_ids
:
set
[
str
]
...
@@ -627,6 +630,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
...
@@ -627,6 +630,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
into account.
into account.
"""
"""
tracker
=
self
.
_get_or_create_request_tracker
(
request
)
tracker
=
self
.
_get_or_create_request_tracker
(
request
)
# TODO: support loading KV for preempted requests in the future
if
request
.
status
==
RequestStatus
.
PREEMPTED
:
return
0
,
False
self
.
scheduler_adapter
.
maybe_submit_lookup_request
(
self
.
scheduler_adapter
.
maybe_submit_lookup_request
(
request
.
request_id
,
convert_block_hashes_to_bytes
(
request
.
block_hashes
)
request
.
request_id
,
convert_block_hashes_to_bytes
(
request
.
block_hashes
)
...
@@ -683,7 +689,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
...
@@ -683,7 +689,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
# No matter we need to retrieve or not, we need to update
# No matter we need to retrieve or not, we need to update
# the block ids into the tracker
# the block ids into the tracker
tracker
.
update
_block_ids
(
block_ids
)
tracker
.
append
_block_ids
(
block_ids
)
# Update the state of the tracker
# Update the state of the tracker
condition
=
tracker
.
needs_retrieve
()
condition
=
tracker
.
needs_retrieve
()
...
@@ -695,6 +701,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
...
@@ -695,6 +701,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
if
condition
if
condition
else
LMCacheMPRequestState
.
READY
else
LMCacheMPRequestState
.
READY
)
)
# Clean up lookup future in scheduler adapter
self
.
scheduler_adapter
.
_cleanup_lookup_result
(
request
.
request_id
)
def
build_connector_meta
(
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
self
,
scheduler_output
:
SchedulerOutput
...
@@ -748,6 +756,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
...
@@ -748,6 +756,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
Optional KVTransferParams to be included in the request outputs
Optional KVTransferParams to be included in the request outputs
returned by the engine.
returned by the engine.
"""
"""
# Clean up request tracker to prevent memory leak
self
.
_cleanup_request_tracker
(
request
.
request_id
)
return
True
,
None
return
True
,
None
def
take_events
(
self
)
->
Iterable
[
"KVCacheEvent"
]:
def
take_events
(
self
)
->
Iterable
[
"KVCacheEvent"
]:
...
@@ -866,7 +876,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
...
@@ -866,7 +876,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
# Update block ids
# Update block ids
new_block_ids
=
reformat_block_ids
(
cached_reqs
.
new_block_ids
[
idx
])
new_block_ids
=
reformat_block_ids
(
cached_reqs
.
new_block_ids
[
idx
])
request_tracker
.
update_block_ids
(
new_block_ids
)
if
request_id
not
in
cached_reqs
.
resumed_req_ids
:
request_tracker
.
append_block_ids
(
new_block_ids
)
# Update new scheduled tokens
# Update new scheduled tokens
num_new_tokens
=
cached_reqs
.
num_computed_tokens
[
idx
]
num_new_tokens
=
cached_reqs
.
num_computed_tokens
[
idx
]
...
@@ -889,7 +900,34 @@ class LMCacheMPConnector(KVConnectorBase_V1):
...
@@ -889,7 +900,34 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self
,
request
:
"Request"
self
,
request
:
"Request"
)
->
LMCacheMPRequestTracker
:
)
->
LMCacheMPRequestTracker
:
request_id
=
request
.
request_id
request_id
=
request
.
request_id
# Remove the old trackers that is created before the preemption
if
(
request
.
status
==
RequestStatus
.
PREEMPTED
and
request_id
in
self
.
request_trackers
):
tracker
=
self
.
request_trackers
[
request_id
]
# NOTE: since this function may be called multiple times
# for a single request (because get_num_new_matched_tokens
# may be called multiple times) for the same request, we
# will only do the remove if the tracker is not in the "fresh"
# state, i.e., PREFETCHING
if
tracker
.
state
!=
LMCacheMPRequestState
.
PREFETCHING
:
self
.
request_trackers
.
pop
(
request_id
)
if
request_id
not
in
self
.
request_trackers
:
if
request_id
not
in
self
.
request_trackers
:
new_tracker
=
LMCacheMPRequestTracker
(
request
)
new_tracker
=
LMCacheMPRequestTracker
(
request
)
self
.
request_trackers
[
request_id
]
=
new_tracker
self
.
request_trackers
[
request_id
]
=
new_tracker
return
self
.
request_trackers
[
request_id
]
return
self
.
request_trackers
[
request_id
]
def
_cleanup_request_tracker
(
self
,
request_id
:
str
)
->
None
:
"""
Clean up request tracker and associated lookup future for a request.
This should be called when a request is finished to prevent memory leak.
"""
# Clean up request tracker
if
self
.
request_trackers
.
pop
(
request_id
,
None
):
logger
.
debug
(
"[KVConnector] Cleaned up request_tracker for request %s"
,
request_id
,
)
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
a810671a
...
@@ -23,7 +23,11 @@ from vllm import envs
...
@@ -23,7 +23,11 @@ from vllm import envs
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.utils
import
TpKVTopology
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
EngineId
,
TpKVTopology
,
yield_req_data
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
CopyBlocksOp
,
CopyBlocksOp
,
KVConnectorBase_V1
,
KVConnectorBase_V1
,
...
@@ -56,7 +60,6 @@ if TYPE_CHECKING:
...
@@ -56,7 +60,6 @@ if TYPE_CHECKING:
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
TransferHandle
=
int
TransferHandle
=
int
EngineId
=
str
ReqId
=
str
ReqId
=
str
#
#
...
@@ -482,7 +485,7 @@ class NixlConnectorScheduler:
...
@@ -482,7 +485,7 @@ class NixlConnectorScheduler:
# New requests are added by update_state_after_alloc in
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
# the scheduler. Used to make metadata passed to Worker.
self
.
_reqs_need_recv
:
dict
[
ReqId
,
tuple
[
Request
,
list
[
int
]]]
=
{}
self
.
_reqs_need_recv
:
dict
[
ReqId
,
tuple
[
Request
,
list
[
int
]]]
=
{}
self
.
_reqs_need_save
:
dict
[
ReqId
,
tuple
[
Request
,
list
[
int
]]
]
=
{}
self
.
_reqs_need_save
:
dict
[
ReqId
,
Request
]
=
{}
# Reqs to send and their expiration time
# Reqs to send and their expiration time
self
.
_reqs_need_send
:
dict
[
ReqId
,
float
]
=
{}
self
.
_reqs_need_send
:
dict
[
ReqId
,
float
]
=
{}
self
.
_reqs_in_batch
:
set
[
ReqId
]
=
set
()
self
.
_reqs_in_batch
:
set
[
ReqId
]
=
set
()
...
@@ -628,16 +631,7 @@ class NixlConnectorScheduler:
...
@@ -628,16 +631,7 @@ class NixlConnectorScheduler:
if
self
.
use_host_buffer
and
params
.
get
(
"do_remote_decode"
):
if
self
.
use_host_buffer
and
params
.
get
(
"do_remote_decode"
):
# NOTE: when accelerator is not directly supported by Nixl,
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
# prefilled blocks need to be saved to host memory before transfer.
self
.
_reqs_need_save
[
request
.
request_id
]
=
request
# save all blocks
block_ids
=
blocks
.
get_block_ids
()[
0
]
# TODO: skip the blocks that are already in the host xfer buffer.
# Currently, the host xfer buffer block is 1-to-1 mapped to device
# kv blocks, so host blocks won't be flushed as long as its device
# block is not overwritten; and it will be safe to skip saving them
# to host xfer buffer.
if
block_ids
:
self
.
_reqs_need_save
[
request
.
request_id
]
=
(
request
,
block_ids
)
elif
params
.
get
(
"do_remote_prefill"
):
elif
params
.
get
(
"do_remote_prefill"
):
if
params
.
get
(
"remote_block_ids"
):
if
params
.
get
(
"remote_block_ids"
):
if
all
(
if
all
(
...
@@ -689,13 +683,32 @@ class NixlConnectorScheduler:
...
@@ -689,13 +683,32 @@ class NixlConnectorScheduler:
kv_transfer_params
=
req
.
kv_transfer_params
,
kv_transfer_params
=
req
.
kv_transfer_params
,
)
)
for
req_id
,
(
req
,
block_ids
)
in
self
.
_reqs_need_save
.
items
():
# NOTE: For the prefill side, there might be a chance that an early added
# request is a chunked prefill, so we need to check if new blocks are added
for
req_id
,
new_block_id_groups
,
_
in
yield_req_data
(
scheduler_output
):
req_to_save
=
self
.
_reqs_need_save
.
get
(
req_id
)
if
req_to_save
is
None
or
new_block_id_groups
is
None
:
continue
req
=
req_to_save
assert
req
.
kv_transfer_params
is
not
None
assert
req
.
kv_transfer_params
is
not
None
meta
.
add_new_req_to_save
(
meta
.
add_new_req_to_save
(
request_id
=
req_id
,
request_id
=
req_id
,
local_block_ids
=
block_id
s
,
local_block_ids
=
new_
block_id
_groups
[
0
]
,
kv_transfer_params
=
req
.
kv_transfer_params
,
kv_transfer_params
=
req
.
kv_transfer_params
,
)
)
assert
scheduler_output
.
num_scheduled_tokens
is
not
None
num_scheduled_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
is_partial
=
(
req
.
num_computed_tokens
+
num_scheduled_tokens
)
<
req
.
num_prompt_tokens
if
not
is_partial
:
# For non-partial prefills, once new req_meta is scheduled, it
# can be removed from _reqs_need_save.
# For partial prefill case, we will retain the request in
# _reqs_need_save until all blocks are scheduled with req_meta.
# Therefore, only pop if `not is_partial`.
self
.
_reqs_need_save
.
pop
(
req_id
)
meta
.
reqs_to_send
=
self
.
_reqs_need_send
meta
.
reqs_to_send
=
self
.
_reqs_need_send
meta
.
reqs_in_batch
=
self
.
_reqs_in_batch
meta
.
reqs_in_batch
=
self
.
_reqs_in_batch
...
@@ -703,7 +716,6 @@ class NixlConnectorScheduler:
...
@@ -703,7 +716,6 @@ class NixlConnectorScheduler:
# Clear the list once workers start the transfers
# Clear the list once workers start the transfers
self
.
_reqs_need_recv
.
clear
()
self
.
_reqs_need_recv
.
clear
()
self
.
_reqs_need_save
.
clear
()
self
.
_reqs_in_batch
=
set
()
self
.
_reqs_in_batch
=
set
()
self
.
_reqs_not_processed
=
set
()
self
.
_reqs_not_processed
=
set
()
self
.
_reqs_need_send
=
{}
self
.
_reqs_need_send
=
{}
...
@@ -749,6 +761,8 @@ class NixlConnectorScheduler:
...
@@ -749,6 +761,8 @@ class NixlConnectorScheduler:
# Also include the case of a P/D Prefill request with immediate
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
# block free (eg abort). Stop tracking this request.
self
.
_reqs_not_processed
.
add
(
request
.
request_id
)
self
.
_reqs_not_processed
.
add
(
request
.
request_id
)
# Clear _reqs_need_save if a request is aborted as partial prefill.
self
.
_reqs_need_save
.
pop
(
request
.
request_id
,
None
)
return
False
,
None
return
False
,
None
# TODO: check whether block_ids actually ever be 0. If not we could
# TODO: check whether block_ids actually ever be 0. If not we could
...
@@ -873,9 +887,10 @@ class NixlConnectorWorker:
...
@@ -873,9 +887,10 @@ class NixlConnectorWorker:
self
.
copy_blocks
:
CopyBlocksOp
|
None
=
None
self
.
copy_blocks
:
CopyBlocksOp
|
None
=
None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self
.
kv_caches_base_addr
:
dict
[
EngineId
,
list
[
int
]]
=
{}
self
.
device_id
:
int
=
0
self
.
device_id
:
int
=
0
# Current rank may pull from multiple remote TP workers.
# EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer
self
.
kv_caches_base_addr
=
defaultdict
[
EngineId
,
dict
[
int
,
list
[
int
]]](
dict
)
# Number of NIXL regions. Currently one region per cache
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
# (so 1 per layer for MLA, otherwise 2 per layer)
...
@@ -883,10 +898,12 @@ class NixlConnectorWorker:
...
@@ -883,10 +898,12 @@ class NixlConnectorWorker:
self
.
num_layers
=
0
self
.
num_layers
=
0
# nixl_prepped_dlist_handle.
# nixl_prepped_dlist_handle.
self
.
src_xfer_side_handle
:
int
=
0
self
.
src_xfer_handles_by_block_size
:
dict
[
int
,
int
]
=
{}
self
.
src_xfer_side_handles
:
dict
[
int
,
int
]
=
{}
# Populated dynamically during handshake based on remote configuration.
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
# Keep track of regions at different tp_ratio values. tp_ratio->handles
self
.
dst_xfer_side_handles
:
dict
[
EngineId
,
int
]
=
{}
self
.
src_xfer_handles_by_tp_ratio
:
dict
[
int
,
list
[
int
]]
=
{}
# Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}.
self
.
dst_xfer_side_handles
=
defaultdict
[
EngineId
,
dict
[
int
,
int
]](
dict
)
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks.
# have the same number of blocks.
...
@@ -977,103 +994,108 @@ class NixlConnectorWorker:
...
@@ -977,103 +994,108 @@ class NixlConnectorWorker:
expected_engine_id
:
str
,
expected_engine_id
:
str
,
)
->
dict
[
int
,
str
]:
)
->
dict
[
int
,
str
]:
"""Do a NIXL handshake with a remote instance."""
"""Do a NIXL handshake with a remote instance."""
# When target instance TP > local TP, we need to perform multiple
start_time
=
time
.
perf_counter
()
# handshakes. Do it in a single background job for simplicity.
# Regardless, only handshake with the remote TP rank(s) that current
# NOTE(rob): we need each rank to have a unique port. This is
# local rank will read from. Note that With homogeneous TP,
# a hack to keep us moving. We will switch when moving to etcd
# this happens to be the same single rank_i.
# or where we have a single ZMQ socket in the scheduler.
p_remote_ranks
=
self
.
kv_topo
.
get_target_remote_ranks
(
remote_tp_size
)
remote_rank_to_agent_name
=
{}
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
p_remote_rank
=
self
.
kv_topo
.
get_target_remote_rank
(
remote_tp_size
)
path
=
make_zmq_path
(
"tcp"
,
host
,
port
)
path
=
make_zmq_path
(
"tcp"
,
host
,
port
)
logger
.
debug
(
"Querying metadata on path: %s at remote tp rank %s"
,
path
,
p_remote_rank
)
# Send query for the request.
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
with
zmq_ctx
(
zmq
.
REQ
,
path
)
as
sock
:
msg
=
msgspec
.
msgpack
.
encode
((
GET_META_MSG
,
p_remote_rank
))
for
remote_rank
in
p_remote_ranks
:
# Set receive timeout to 5 seconds to avoid hanging on dead server
logger
.
debug
(
sock
.
setsockopt
(
zmq
.
RCVTIMEO
,
5000
)
# milliseconds
"Querying metadata on path: %s at remote tp rank %s"
,
sock
.
send
(
msg
)
path
,
handshake_bytes
=
sock
.
recv
()
remote_rank
,
# Decode handshake payload to get compatibility hash
handshake_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlHandshakePayload
)
try
:
handshake_payload
=
handshake_decoder
.
decode
(
handshake_bytes
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
raise
RuntimeError
(
f
"Failed to decode NixlHandshakePayload. This likely indicates "
f
"an incompatibility between connector version. Error:
{
e
}
"
)
from
e
got_metadata_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
)
# Check compatibility hash BEFORE decoding agent metadata
if
(
self
.
enforce_compat_hash
and
handshake_payload
.
compatibility_hash
!=
self
.
compat_hash
):
raise
RuntimeError
(
f
"NIXL compatibility hash mismatch. "
f
"Local:
{
self
.
compat_hash
}
, "
f
"Remote:
{
handshake_payload
.
compatibility_hash
}
. "
f
"Prefill and decode instances have incompatible configurations. "
f
"This may be due to: different vLLM versions, models, dtypes, "
f
"KV cache layouts, attention backends, etc. "
f
"Both instances must use identical configurations."
f
"Disable this check using "
f
'--kv-transfer-config
\'
{{"kv_connector_extra_config": '
f
'{{"enforce_handshake_compat": false}}}}
\'
'
)
)
logger
.
info
(
start_time
=
time
.
perf_counter
()
"NIXL compatibility check passed (hash: %s)"
,
# Send query for the request.
handshake_payload
.
compatibility_hash
,
msg
=
msgspec
.
msgpack
.
encode
((
GET_META_MSG
,
remote_rank
))
)
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock
.
setsockopt
(
zmq
.
RCVTIMEO
,
5000
)
# milliseconds
sock
.
send
(
msg
)
handshake_bytes
=
sock
.
recv
()
# Decode agent metadata
# Decode handshake payload to get compatibility hash
metadata_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
handshake_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlHandshakePayload
)
try
:
try
:
metadata
=
metadata_decoder
.
decode
(
handshake_payload
=
handshake_decoder
.
decode
(
handshake_bytes
)
handshake_payload
.
agent_metadata_bytes
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
raise
RuntimeError
(
f
"Failed to decode NixlHandshakePayload. This likely indicates "
f
"an incompatibility between connector version. Error:
{
e
}
"
)
from
e
got_metadata_time
=
time
.
perf_counter
()
logger
.
debug
(
"NIXL handshake: get metadata took: %s"
,
got_metadata_time
-
start_time
,
)
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
# This should not happen if hash matched
raise
RuntimeError
(
f
"Failed to decode NixlAgentMetadata. Error:
{
e
}
"
)
from
e
# Ensure engine id matches.
# Check compatibility hash BEFORE decoding agent metadata
if
metadata
.
engine_id
!=
expected_engine_id
:
if
(
raise
RuntimeError
(
self
.
enforce_compat_hash
f
"Remote NIXL agent engine ID mismatch. "
and
handshake_payload
.
compatibility_hash
!=
self
.
compat_hash
f
"Expected
{
expected_engine_id
}
,"
):
f
"received
{
metadata
.
engine_id
}
."
raise
RuntimeError
(
)
f
"NIXL compatibility hash mismatch. "
f
"Local:
{
self
.
compat_hash
}
, "
f
"Remote:
{
handshake_payload
.
compatibility_hash
}
. "
f
"Prefill and decode instances have incompatible "
f
"configurations. This may be due to: different vLLM versions,"
f
" models, dtypes, KV cache layouts, attention backends, etc. "
f
"Both instances must use identical configurations."
f
"Disable this check using "
f
'--kv-transfer-config
\'
{{"kv_connector_extra_config": '
f
'{{"enforce_handshake_compat": false}}}}
\'
'
)
# Register Remote agent.
logger
.
info
(
assert
metadata
.
block_size
<=
self
.
block_size
,
(
"NIXL compatibility check passed (hash: %s)"
,
"nP > nD is not supported yet."
handshake_payload
.
compatibility_hash
,
)
)
remote_agent_name
=
self
.
add_remote_agent
(
metadata
,
p_remote_rank
,
remote_tp_size
)
setup_agent_time
=
time
.
perf_counter
()
# Decode agent metadata
logger
.
debug
(
metadata_decoder
=
msgspec
.
msgpack
.
Decoder
(
NixlAgentMetadata
)
"NIXL handshake: add agent took: %s"
,
try
:
setup_agent_time
-
got_metadata_time
,
metadata
=
metadata_decoder
.
decode
(
)
handshake_payload
.
agent_metadata_bytes
)
except
(
msgspec
.
DecodeError
,
msgspec
.
ValidationError
)
as
e
:
# This should not happen if hash matched
raise
RuntimeError
(
f
"Failed to decode NixlAgentMetadata. Error:
{
e
}
"
)
from
e
# Ensure engine id matches.
if
metadata
.
engine_id
!=
expected_engine_id
:
raise
RuntimeError
(
f
"Remote NIXL agent engine ID mismatch. "
f
"Expected
{
expected_engine_id
}
,"
f
"received
{
metadata
.
engine_id
}
."
)
# Ensure engine id matches.
if
metadata
.
engine_id
!=
expected_engine_id
:
raise
RuntimeError
(
f
"Remote NIXL agent engine ID mismatch. "
f
"Expected
{
expected_engine_id
}
,"
f
"received
{
metadata
.
engine_id
}
."
)
setup_agent_time
=
time
.
perf_counter
()
# Remote rank -> agent name.
# Register Remote agent.
return
{
p_remote_rank
:
remote_agent_name
}
remote_agent_name
=
self
.
add_remote_agent
(
metadata
,
remote_rank
,
remote_tp_size
)
logger
.
debug
(
"NIXL handshake: add agent took: %s"
,
setup_agent_time
-
got_metadata_time
,
)
remote_rank_to_agent_name
[
remote_rank
]
=
remote_agent_name
return
remote_rank_to_agent_name
def
initialize_host_xfer_buffer
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
])
->
None
:
def
initialize_host_xfer_buffer
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
])
->
None
:
"""
"""
...
@@ -1283,7 +1305,7 @@ class NixlConnectorWorker:
...
@@ -1283,7 +1305,7 @@ class NixlConnectorWorker:
assert
len
(
self
.
block_len_per_layer
)
==
len
(
seen_base_addresses
)
assert
len
(
self
.
block_len_per_layer
)
==
len
(
seen_base_addresses
)
assert
self
.
num_blocks
!=
0
assert
self
.
num_blocks
!=
0
self
.
kv_caches_base_addr
[
self
.
engine_id
]
=
seen_base_addresses
self
.
kv_caches_base_addr
[
self
.
engine_id
]
[
self
.
tp_rank
]
=
seen_base_addresses
self
.
num_regions
=
len
(
caches_data
)
self
.
num_regions
=
len
(
caches_data
)
self
.
num_layers
=
len
(
xfer_buffers
.
keys
())
self
.
num_layers
=
len
(
xfer_buffers
.
keys
())
...
@@ -1310,9 +1332,9 @@ class NixlConnectorWorker:
...
@@ -1310,9 +1332,9 @@ class NixlConnectorWorker:
# Register local/src descr for NIXL xfer.
# Register local/src descr for NIXL xfer.
self
.
seen_base_addresses
=
seen_base_addresses
self
.
seen_base_addresses
=
seen_base_addresses
self
.
src_xfer_
side_
handle
=
self
.
register_local_xfer_handler
(
self
.
block_size
)
self
.
src_xfer_handle
s_by_block_size
[
self
.
block_size
],
self
.
src_blocks_data
=
(
self
.
register_local_xfer_handler
(
self
.
block_size
)
self
.
src_xfer_side_handles
[
self
.
block_size
]
=
self
.
src_xfer_side_handle
)
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# TODO(mgoin): Hybrid memory allocator is currently disabled for
# models with local attention (Llama 4). Can remove this once enabled.
# models with local attention (Llama 4). Can remove this once enabled.
...
@@ -1340,8 +1362,8 @@ class NixlConnectorWorker:
...
@@ -1340,8 +1362,8 @@ class NixlConnectorWorker:
agent_metadata
=
NixlAgentMetadata
(
agent_metadata
=
NixlAgentMetadata
(
engine_id
=
self
.
engine_id
,
engine_id
=
self
.
engine_id
,
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
agent_metadata
=
self
.
nixl_wrapper
.
get_agent_metadata
(),
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
],
device_id
=
self
.
device_id
,
device_id
=
self
.
device_id
,
kv_caches_base_addr
=
self
.
kv_caches_base_addr
[
self
.
engine_id
][
self
.
tp_rank
],
num_blocks
=
self
.
num_blocks
,
num_blocks
=
self
.
num_blocks
,
block_lens
=
self
.
block_len_per_layer
,
block_lens
=
self
.
block_len_per_layer
,
kv_cache_layout
=
self
.
kv_cache_layout
kv_cache_layout
=
self
.
kv_cache_layout
...
@@ -1359,7 +1381,7 @@ class NixlConnectorWorker:
...
@@ -1359,7 +1381,7 @@ class NixlConnectorWorker:
def
register_local_xfer_handler
(
def
register_local_xfer_handler
(
self
,
self
,
block_size
:
int
,
block_size
:
int
,
)
->
int
:
)
->
tuple
[
int
,
list
[
tuple
[
int
,
int
,
int
]]]
:
"""
"""
Function used for register local xfer handler with local block_size or
Function used for register local xfer handler with local block_size or
Remote block_size.
Remote block_size.
...
@@ -1407,7 +1429,7 @@ class NixlConnectorWorker:
...
@@ -1407,7 +1429,7 @@ class NixlConnectorWorker:
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
# NIXL_INIT_AGENT to be used for preparations of local descs.
# NIXL_INIT_AGENT to be used for preparations of local descs.
return
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
return
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
,
blocks_data
def
add_remote_agent
(
def
add_remote_agent
(
self
,
self
,
...
@@ -1421,10 +1443,12 @@ class NixlConnectorWorker:
...
@@ -1421,10 +1443,12 @@ class NixlConnectorWorker:
In particular, handle both homogeneous and heterogeneous TP. The former
In particular, handle both homogeneous and heterogeneous TP. The former
requires local rank_i to read from remote rank_i.
requires local rank_i to read from remote rank_i.
The latter, assuming D.world_size > P.world_size, requires that two or
The latter, in the case of D.world_size < P.world_size, requires that a
more local TP worker share the xfer from a single TP worker.
local (D) TP worker reads from multiple remote (P) TP workers.
Conversely, assuming D.world_size > P.world_size, two or more local TP
workers will read from a single remote TP worker.
Here's an example (non-MLA
case
):
Here's an example
for the last case described above
(non-MLA):
rank_offset p_remote_tp_rank
rank_offset p_remote_tp_rank
(kv split no)
(kv split no)
...
@@ -1474,9 +1498,6 @@ class NixlConnectorWorker:
...
@@ -1474,9 +1498,6 @@ class NixlConnectorWorker:
nixl_agent_meta
.
agent_metadata
nixl_agent_meta
.
agent_metadata
)
)
# Handle tp_size>num_kv_heads: replicate KV cache.
replicates_kv_cache
=
self
.
kv_topo
.
replicates_kv_cache
(
engine_id
)
# Create dst descs and xfer side handles. TP workers have same #blocks
# Create dst descs and xfer side handles. TP workers have same #blocks
# so we only register once per engine_id.
# so we only register once per engine_id.
# Example:
# Example:
...
@@ -1490,14 +1511,52 @@ class NixlConnectorWorker:
...
@@ -1490,14 +1511,52 @@ class NixlConnectorWorker:
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
self
.
dst_num_blocks
[
engine_id
]
=
nixl_agent_meta
.
num_blocks
# Keep track of remote agent kv caches base addresses.
# Keep track of remote agent kv caches base addresses.
self
.
kv_caches_base_addr
[
engine_id
]
=
nixl_agent_meta
.
kv_caches_base_addr
self
.
kv_caches_base_addr
[
engine_id
][
remote_tp_rank
]
=
(
nixl_agent_meta
.
kv_caches_base_addr
)
self
.
_validate_remote_agent_handshake
(
nixl_agent_meta
,
remote_tp_size
)
self
.
_validate_remote_agent_handshake
(
nixl_agent_meta
,
remote_tp_size
)
#
Number of D TP workers reading from a single P TP worker. This is
#
This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
#
1 when P and D `--tensor-parallel-size` match
.
#
this is the ratio between the two sizes
.
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
engine_id
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
engine_id
)
# Handle tp_size>num_kv_heads: replicate KV cache.
indexes_into_remote
=
(
not
self
.
kv_topo
.
replicates_kv_cache
(
engine_id
)
and
tp_ratio
>
0
)
logger
.
debug
(
"Registering remote agent (%s, rank %s) memory regions with tp_ratio %s"
,
engine_id
,
remote_tp_rank
,
tp_ratio
,
)
### (Optional) Register local agent memory regions. MLA is not split.
if
(
tp_ratio
<
0
and
not
self
.
use_mla
and
tp_ratio
not
in
self
.
src_xfer_handles_by_tp_ratio
):
# Remote tp_size > local tp_size: read from multiple remote ranks.
# Logically "split" own regions into |tp_ratio| chunks. Mind that
# we only do this once per remote tp_size (replica-friendly).
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
]
=
[]
for
i
in
range
(
-
tp_ratio
):
blocks_data
=
[]
for
memory_region
in
self
.
src_blocks_data
:
addr
,
local_block_len
,
own_tp_rank
=
memory_region
# Computing block len layer by layer allows for different
# block sizes to be used.
remote_block_len
=
local_block_len
//
(
-
tp_ratio
)
addr
=
addr
+
i
*
remote_block_len
blocks_data
.
append
((
addr
,
remote_block_len
,
own_tp_rank
))
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
handle
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
"NIXL_INIT_AGENT"
,
descs
)
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
].
append
(
handle
)
### Register remote agent memory regions
### Register remote agent memory regions
blocks_data
=
[]
blocks_data
=
[]
# With homogeneous TP, D pulls the whole kv cache from corresponding
# With homogeneous TP, D pulls the whole kv cache from corresponding
...
@@ -1507,14 +1566,19 @@ class NixlConnectorWorker:
...
@@ -1507,14 +1566,19 @@ class NixlConnectorWorker:
# Register all remote blocks, but only the corresponding kv heads.
# Register all remote blocks, but only the corresponding kv heads.
for
i
,
base_addr
in
enumerate
(
nixl_agent_meta
.
kv_caches_base_addr
):
for
i
,
base_addr
in
enumerate
(
nixl_agent_meta
.
kv_caches_base_addr
):
kv_block_len
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
)
# Read our whole local region size from remote.
remote_kv_block_len
=
kv_block_len
//
block_size_ratio
local_block_len
=
self
.
get_backend_aware_kv_block_len
(
layer_idx
=
i
)
remote_kv_block_len
=
local_block_len
//
block_size_ratio
if
block_size_ratio
>
1
:
if
block_size_ratio
>
1
:
# using remote kv_block_len as transfer unit
# using remote kv_block_len as transfer unit
kv_block_len
=
remote_kv_block_len
local_block_len
=
remote_kv_block_len
if
tp_ratio
<
0
and
not
self
.
use_mla
:
# Remote tp is bigger: read a chunk of local region from remote
local_block_len
=
local_block_len
//
(
-
tp_ratio
)
rank_offset
=
(
rank_offset
=
(
self
.
tp_rank
%
tp_ratio
*
remote_kv_block_len
self
.
tp_rank
%
tp_ratio
*
remote_kv_block_len
if
not
replicates_kv_cach
e
if
indexes_into_remot
e
else
0
else
0
)
)
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
for
block_id
in
range
(
nixl_agent_meta
.
num_blocks
):
...
@@ -1524,7 +1588,7 @@ class NixlConnectorWorker:
...
@@ -1524,7 +1588,7 @@ class NixlConnectorWorker:
# self.block_len == remote_block_len//tp_ratio bytes.
# self.block_len == remote_block_len//tp_ratio bytes.
addr
=
base_addr
+
block_offset
+
rank_offset
addr
=
base_addr
+
block_offset
+
rank_offset
# (addr, len, device id)
# (addr, len, device id)
blocks_data
.
append
((
addr
,
kv
_block_len
,
nixl_agent_meta
.
device_id
))
blocks_data
.
append
((
addr
,
local
_block_len
,
nixl_agent_meta
.
device_id
))
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
if
self
.
kv_topo
.
is_kv_layout_blocks_first
:
# With FlashInfer index V separately to allow head splitting.
# With FlashInfer index V separately to allow head splitting.
...
@@ -1533,7 +1597,7 @@ class NixlConnectorWorker:
...
@@ -1533,7 +1597,7 @@ class NixlConnectorWorker:
addr
=
base_addr
+
block_offset
+
rank_offset
addr
=
base_addr
+
block_offset
+
rank_offset
v_addr
=
addr
+
nixl_agent_meta
.
block_lens
[
i
]
//
2
v_addr
=
addr
+
nixl_agent_meta
.
block_lens
[
i
]
//
2
blocks_data
.
append
(
blocks_data
.
append
(
(
v_addr
,
kv
_block_len
,
nixl_agent_meta
.
device_id
)
(
v_addr
,
local
_block_len
,
nixl_agent_meta
.
device_id
)
)
)
logger
.
debug
(
logger
.
debug
(
...
@@ -1546,15 +1610,15 @@ class NixlConnectorWorker:
...
@@ -1546,15 +1610,15 @@ class NixlConnectorWorker:
# Register with NIXL.
# Register with NIXL.
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
descs
=
self
.
nixl_wrapper
.
get_xfer_descs
(
blocks_data
,
self
.
nixl_memory_type
)
self
.
dst_xfer_side_handles
[
engine_id
]
=
self
.
nixl_wrapper
.
prep_xfer_dlist
(
self
.
dst_xfer_side_handles
[
engine_id
]
[
remote_tp_rank
]
=
(
remote_agent_name
,
descs
self
.
nixl_wrapper
.
prep_xfer_dlist
(
remote_agent_name
,
descs
)
)
)
if
block_size_ratio
>
1
:
if
block_size_ratio
>
1
:
# when prefill with smaller block_size, we need to init a
# when prefill with smaller block_size, we need to init a
# new handler with same block_len to match
# new handler with same block_len to match
self
.
src_xfer_
side_
handles
[
nixl_agent_meta
.
block_size
]
=
(
self
.
src_xfer_handles
_by_block_size
[
nixl_agent_meta
.
block_size
]
=
(
self
.
register_local_xfer_handler
(
nixl_agent_meta
.
block_size
)
self
.
register_local_xfer_handler
(
nixl_agent_meta
.
block_size
)
[
0
]
)
)
return
remote_agent_name
return
remote_agent_name
...
@@ -1574,7 +1638,9 @@ class NixlConnectorWorker:
...
@@ -1574,7 +1638,9 @@ class NixlConnectorWorker:
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
remote_engine_id
remote_engine_id
)
)
assert
tp_ratio
>
0
,
"Decode TP cannot be smaller than prefill TP"
# Num kv_heads > tp_size and P TP > D TP case, not supported
assert
not
(
tp_ratio
<
0
and
self
.
kv_topo
.
is_kv_replicated
(
remote_engine_id
))
assert
not
self
.
_use_pallas
or
tp_ratio
==
1
,
(
assert
not
self
.
_use_pallas
or
tp_ratio
==
1
,
(
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
)
...
@@ -1616,17 +1682,29 @@ class NixlConnectorWorker:
...
@@ -1616,17 +1682,29 @@ class NixlConnectorWorker:
"All remote layers must have the same block size"
"All remote layers must have the same block size"
)
)
assert
(
if
tp_ratio
>
0
:
remote_block_len
# Remote tp is smaller: remote block_len size is bigger
==
(
self
.
block_len_per_layer
[
0
]
*
tp_ratio
)
//
block_size_ratio
assert
(
),
(
remote_block_len
"Remote P worker KV layer cache must be of shape [2, N, "
==
(
self
.
block_len_per_layer
[
0
]
*
tp_ratio
)
//
block_size_ratio
"local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
),
(
)
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
)
# noqa: E501
else
:
assert
block_size_ratio
==
1
,
(
"Different local/remote block sizes are not supported when"
" P TP > D TP."
)
# Remote tp is bigger: remote block_len size is smaller
assert
remote_block_len
==
self
.
block_len_per_layer
[
0
]
//
(
-
tp_ratio
),
(
"Remote P worker KV layer cache must be of shape [2, N, "
"local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
)
# noqa: E501
# TP workers have same #blocks.
# TP workers
that handhshake with same remote
have same #blocks.
assert
self
.
dst_num_blocks
[
remote_engine_id
]
==
nixl_agent_meta
.
num_blocks
assert
self
.
dst_num_blocks
[
remote_engine_id
]
==
nixl_agent_meta
.
num_blocks
# Same number of regions/~layers.
assert
len
(
nixl_agent_meta
.
kv_caches_base_addr
)
==
len
(
self
.
block_len_per_layer
)
assert
len
(
nixl_agent_meta
.
kv_caches_base_addr
)
==
len
(
self
.
block_len_per_layer
)
def
sync_recved_kv_to_device
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
def
sync_recved_kv_to_device
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
...
@@ -1710,7 +1788,7 @@ class NixlConnectorWorker:
...
@@ -1710,7 +1788,7 @@ class NixlConnectorWorker:
)
)
cache
.
index_copy_
(
0
,
indices
,
permuted_blocks
)
cache
.
index_copy_
(
0
,
indices
,
permuted_blocks
)
def
blocksize_post_process
(
self
,
block_ids_per_ratio
:
dict
[
floa
t
,
list
[
list
[
int
]]]):
def
blocksize_post_process
(
self
,
block_ids_per_ratio
:
dict
[
in
t
,
list
[
list
[
int
]]]):
def
_process_local_gt_remote
(
blocks_to_update
,
block_size_ratio
):
def
_process_local_gt_remote
(
blocks_to_update
,
block_size_ratio
):
n_kv_heads
,
block_size
,
head_size
=
blocks_to_update
.
shape
[
1
:]
n_kv_heads
,
block_size
,
head_size
=
blocks_to_update
.
shape
[
1
:]
remote_block_size
=
block_size
//
block_size_ratio
remote_block_size
=
block_size
//
block_size_ratio
...
@@ -1840,7 +1918,7 @@ class NixlConnectorWorker:
...
@@ -1840,7 +1918,7 @@ class NixlConnectorWorker:
notified_req_ids
:
set
[
str
]
=
set
()
notified_req_ids
:
set
[
str
]
=
set
()
for
notifs
in
self
.
nixl_wrapper
.
get_new_notifs
().
values
():
for
notifs
in
self
.
nixl_wrapper
.
get_new_notifs
().
values
():
for
notif
in
notifs
:
for
notif
in
notifs
:
req_id
,
tp_
ratio
=
notif
.
decode
(
"utf-8"
).
rsplit
(
":"
,
1
)
req_id
,
tp_
size
=
notif
.
decode
(
"utf-8"
).
rsplit
(
":"
,
1
)
if
(
if
(
req_id
not
in
self
.
_reqs_to_send
req_id
not
in
self
.
_reqs_to_send
and
req_id
not
in
self
.
_reqs_to_process
and
req_id
not
in
self
.
_reqs_to_process
...
@@ -1853,9 +1931,22 @@ class NixlConnectorWorker:
...
@@ -1853,9 +1931,22 @@ class NixlConnectorWorker:
)
)
continue
continue
# NOTE: `tp_ratio` is the opposite when swapping local<>remote
n_consumers
=
int
(
tp_size
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio
(
n_consumers
)
# Number of reads *per producer* to wait for.
# When remote D TP > local P TP we expect `tp_ratio` reads.
consumers_per_producer
=
(
-
tp_ratio
if
n_consumers
>
self
.
world_size
else
1
)
self
.
consumer_notification_counts_by_req
[
req_id
]
+=
1
self
.
consumer_notification_counts_by_req
[
req_id
]
+=
1
# Wait all consumers (D) to be done reading before freeing.
# Wait all consumers (D) to be done reading before freeing.
if
self
.
consumer_notification_counts_by_req
[
req_id
]
==
int
(
tp_ratio
):
if
(
self
.
consumer_notification_counts_by_req
[
req_id
]
==
consumers_per_producer
):
notified_req_ids
.
add
(
req_id
)
notified_req_ids
.
add
(
req_id
)
del
self
.
consumer_notification_counts_by_req
[
req_id
]
del
self
.
consumer_notification_counts_by_req
[
req_id
]
self
.
_reqs_to_process
.
remove
(
req_id
)
self
.
_reqs_to_process
.
remove
(
req_id
)
...
@@ -1872,7 +1963,7 @@ class NixlConnectorWorker:
...
@@ -1872,7 +1963,7 @@ class NixlConnectorWorker:
"""
"""
done_req_ids
:
set
[
str
]
=
set
()
done_req_ids
:
set
[
str
]
=
set
()
for
req_id
,
handles
in
list
(
transfers
.
items
()):
for
req_id
,
handles
in
list
(
transfers
.
items
()):
in_progress
=
False
in_progress
=
[]
for
handle
in
handles
:
for
handle
in
handles
:
try
:
try
:
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
xfer_state
=
self
.
nixl_wrapper
.
check_xfer_state
(
handle
)
...
@@ -1882,7 +1973,7 @@ class NixlConnectorWorker:
...
@@ -1882,7 +1973,7 @@ class NixlConnectorWorker:
self
.
xfer_stats
.
record_transfer
(
res
)
self
.
xfer_stats
.
record_transfer
(
res
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
elif
xfer_state
==
"PROC"
:
elif
xfer_state
==
"PROC"
:
in_progress
=
True
in_progress
.
append
(
handle
)
continue
continue
else
:
else
:
logger
.
error
(
logger
.
error
(
...
@@ -1892,7 +1983,6 @@ class NixlConnectorWorker:
...
@@ -1892,7 +1983,6 @@ class NixlConnectorWorker:
xfer_state
,
xfer_state
,
)
)
self
.
_handle_failed_transfer
(
req_id
,
handle
)
self
.
_handle_failed_transfer
(
req_id
,
handle
)
in_progress
=
False
except
Exception
:
except
Exception
:
logger
.
exception
(
logger
.
exception
(
"NIXL transfer exception for request %s. "
"NIXL transfer exception for request %s. "
...
@@ -1900,11 +1990,13 @@ class NixlConnectorWorker:
...
@@ -1900,11 +1990,13 @@ class NixlConnectorWorker:
req_id
,
req_id
,
)
)
self
.
_handle_failed_transfer
(
req_id
,
handle
)
self
.
_handle_failed_transfer
(
req_id
,
handle
)
in_progress
=
False
if
not
in_progress
:
if
not
in_progress
:
# Only report request as completed when all transfers are done.
done_req_ids
.
add
(
req_id
)
done_req_ids
.
add
(
req_id
)
del
transfers
[
req_id
]
del
transfers
[
req_id
]
else
:
transfers
[
req_id
]
=
in_progress
return
done_req_ids
return
done_req_ids
def
_handle_failed_transfer
(
self
,
req_id
:
str
,
handle
:
int
):
def
_handle_failed_transfer
(
self
,
req_id
:
str
,
handle
:
int
):
...
@@ -1982,18 +2074,62 @@ class NixlConnectorWorker:
...
@@ -1982,18 +2074,62 @@ class NixlConnectorWorker:
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
def
_read_blocks_for_req
(
self
,
req_id
:
str
,
meta
:
ReqMeta
):
assert
meta
.
remote
is
not
None
assert
meta
.
remote
is
not
None
logger
.
debug
(
remote_ranks
=
self
.
kv_topo
.
get_target_remote_ranks_from_engine_id
(
"Remote agent %s available, calling _read_blocks for req %s"
,
meta
.
remote
.
engine_id
meta
.
remote
.
engine_id
,
req_id
,
)
self
.
_read_blocks
(
request_id
=
req_id
,
dst_engine_id
=
meta
.
remote
.
engine_id
,
remote_request_id
=
meta
.
remote
.
request_id
,
local_block_ids
=
meta
.
local_physical_block_ids
,
remote_block_ids
=
meta
.
remote
.
block_ids
,
)
)
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
meta
.
remote
.
engine_id
)
# D may have to perform multiple reads from different remote ranks.
for
i
,
remote_rank
in
enumerate
(
remote_ranks
):
if
self
.
use_mla
and
tp_ratio
<
0
and
i
>
0
:
# MLA opt: when P TP > D TP, only a single read is executed for
# the first remote rank (cache is duplicated)..
break
remote_block_size
=
self
.
kv_topo
.
remote_block_size
[
meta
.
remote
.
engine_id
]
logger
.
debug
(
"Remote agent %s available, calling _read_blocks"
" on remote rank %s with remote block size %s for req %s"
,
meta
.
remote
.
engine_id
,
remote_rank
,
remote_block_size
,
req_id
,
)
# Get side handles.
if
tp_ratio
<
0
and
not
self
.
use_mla
:
assert
remote_block_size
==
self
.
block_size
# Remote tp_size > local tp_size: we must perform multiple
# reads. Get the memory chunk onto which we will write to.
local_xfer_side_handle
=
self
.
src_xfer_handles_by_tp_ratio
[
tp_ratio
][
i
]
else
:
# Single read from remote, we write to the whole memory region.
# Also handle remote block size different from local block size.
local_xfer_side_handle
=
self
.
src_xfer_handles_by_block_size
[
remote_block_size
]
# Destination handle: remote_engine_id -> remote_rank -> handle.
remote_xfer_side_handle
=
self
.
dst_xfer_side_handles
[
meta
.
remote
.
engine_id
][
remote_rank
]
self
.
_read_blocks
(
request_id
=
req_id
,
dst_engine_id
=
meta
.
remote
.
engine_id
,
remote_request_id
=
meta
.
remote
.
request_id
,
local_block_ids
=
meta
.
local_physical_block_ids
,
remote_block_ids
=
meta
.
remote
.
block_ids
,
remote_rank
=
remote_rank
,
local_xfer_side_handle
=
local_xfer_side_handle
,
remote_xfer_side_handle
=
remote_xfer_side_handle
,
)
if
self
.
use_mla
and
tp_ratio
<
0
:
# ..but we still need to notify the other remote ranks that we
# have the blocks we need so they can update the request state.
notif_id
=
f
"
{
req_id
}
:
{
self
.
world_size
}
"
.
encode
()
remote_agents
=
self
.
_remote_agents
[
meta
.
remote
.
engine_id
]
for
rank_to_notify
,
agent
in
remote_agents
.
items
():
if
rank_to_notify
!=
remote_rank
:
self
.
nixl_wrapper
.
send_notif
(
agent
,
notif_msg
=
notif_id
)
def
_read_blocks
(
def
_read_blocks
(
self
,
self
,
...
@@ -2002,7 +2138,14 @@ class NixlConnectorWorker:
...
@@ -2002,7 +2138,14 @@ class NixlConnectorWorker:
dst_engine_id
:
str
,
dst_engine_id
:
str
,
request_id
:
str
,
request_id
:
str
,
remote_request_id
:
str
,
remote_request_id
:
str
,
remote_rank
:
int
,
local_xfer_side_handle
:
int
,
remote_xfer_side_handle
:
int
,
):
):
"""
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
dst_engine_id
)
block_size_ratio
=
self
.
kv_topo
.
block_size_ratio_from_engine_id
(
dst_engine_id
)
if
block_size_ratio
>
1
:
if
block_size_ratio
>
1
:
local_block_ids
=
self
.
get_mapped_blocks
(
local_block_ids
=
self
.
get_mapped_blocks
(
...
@@ -2031,18 +2174,14 @@ class NixlConnectorWorker:
...
@@ -2031,18 +2174,14 @@ class NixlConnectorWorker:
# saturate IB with heterogeneous TP sizes. We should remove the staging
# saturate IB with heterogeneous TP sizes. We should remove the staging
# blocks until we are ready.
# blocks until we are ready.
# Number of D TP workers that will read from dst P. Propagate
tp_rati
o
# Number of D TP workers that will read from dst P. Propagate
inf
o
# on notification so that dst worker can wait before freeing blocks.
# on notification so that dst worker can wait before freeing blocks.
tp_ratio
=
self
.
kv_topo
.
tp_ratio_from_engine_id
(
dst_engine_id
)
notif_id
=
f
"
{
remote_request_id
}
:
{
self
.
world_size
}
"
.
encode
()
notif_id
=
f
"
{
remote_request_id
}
:
{
tp_ratio
}
"
.
encode
()
# Full prefix cache hit: do not need to read remote blocks,
# Full prefix cache hit: do not need to read remote blocks,
# just notify P worker that we have the blocks we need.
# just notify P worker that we have the blocks we need.
num_local_blocks
=
len
(
local_block_ids
)
num_local_blocks
=
len
(
local_block_ids
)
if
num_local_blocks
==
0
:
if
num_local_blocks
==
0
:
remote_rank
=
self
.
kv_topo
.
get_target_remote_rank_from_engine_id
(
dst_engine_id
)
agent_name
=
self
.
_remote_agents
[
dst_engine_id
][
remote_rank
]
agent_name
=
self
.
_remote_agents
[
dst_engine_id
][
remote_rank
]
try
:
try
:
self
.
nixl_wrapper
.
send_notif
(
agent_name
,
notif_msg
=
notif_id
)
self
.
nixl_wrapper
.
send_notif
(
agent_name
,
notif_msg
=
notif_id
)
...
@@ -2062,13 +2201,6 @@ class NixlConnectorWorker:
...
@@ -2062,13 +2201,6 @@ class NixlConnectorWorker:
if
num_local_blocks
<
num_remote_blocks
:
if
num_local_blocks
<
num_remote_blocks
:
remote_block_ids
=
remote_block_ids
[
-
num_local_blocks
:]
remote_block_ids
=
remote_block_ids
[
-
num_local_blocks
:]
# Get side handles.
remote_block_size
=
self
.
kv_topo
.
remote_block_size
[
dst_engine_id
]
local_xfer_side_handle
=
self
.
src_xfer_side_handles
.
get
(
remote_block_size
,
self
.
src_xfer_side_handle
)
remote_xfer_side_handle
=
self
.
dst_xfer_side_handles
[
dst_engine_id
]
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.
# workers will issue xfers to parts of the P worker remote kv caches.
...
@@ -2230,7 +2362,7 @@ class NixlConnectorWorker:
...
@@ -2230,7 +2362,7 @@ class NixlConnectorWorker:
block_ids_np
,
self
.
_physical_blocks_per_logical_kv_block
,
block_arange
block_ids_np
,
self
.
_physical_blocks_per_logical_kv_block
,
block_arange
).
tolist
()
).
tolist
()
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
):
def
get_backend_aware_kv_block_len
(
self
,
layer_idx
:
int
)
->
int
:
"""
"""
Get the block length for one K/V element (K and V have the same size).
Get the block length for one K/V element (K and V have the same size).
...
@@ -2276,11 +2408,16 @@ class NixlConnectorWorker:
...
@@ -2276,11 +2408,16 @@ class NixlConnectorWorker:
for
handle
in
handles
:
for
handle
in
handles
:
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
nixl_wrapper
.
release_xfer_handle
(
handle
)
self
.
_recving_transfers
.
clear
()
self
.
_recving_transfers
.
clear
()
if
self
.
src_xfer_side_handle
:
for
handle
in
self
.
src_xfer_handles_by_block_size
.
values
():
self
.
nixl_wrapper
.
release_dlist_handle
(
self
.
src_xfer_side_handle
)
self
.
nixl_wrapper
.
release_dlist_handle
(
handle
)
self
.
src_xfer_side_handle
=
0
self
.
src_xfer_handles_by_block_size
.
clear
()
for
dst_xfer_side_handle
in
self
.
dst_xfer_side_handles
.
values
():
for
handles
in
self
.
src_xfer_handles_by_tp_ratio
.
values
():
self
.
nixl_wrapper
.
release_dlist_handle
(
dst_xfer_side_handle
)
for
handle
in
handles
:
self
.
nixl_wrapper
.
release_dlist_handle
(
handle
)
self
.
src_xfer_handles_by_tp_ratio
.
clear
()
for
dst_xfer_side_handles
in
self
.
dst_xfer_side_handles
.
values
():
for
dst_xfer_side_handle
in
dst_xfer_side_handles
.
values
():
self
.
nixl_wrapper
.
release_dlist_handle
(
dst_xfer_side_handle
)
self
.
dst_xfer_side_handles
.
clear
()
self
.
dst_xfer_side_handles
.
clear
()
for
remote_agents
in
self
.
_remote_agents
.
values
():
for
remote_agents
in
self
.
_remote_agents
.
values
():
for
agent_name
in
remote_agents
.
values
():
for
agent_name
in
remote_agents
.
values
():
...
...
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
View file @
a810671a
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Iterable
,
Iterator
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
itertools
import
islice
from
itertools
import
islice
from
typing
import
Any
,
ClassVar
from
typing
import
Any
,
ClassVar
...
@@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
...
@@ -12,6 +12,7 @@ from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from
vllm.attention.layer
import
Attention
from
vllm.attention.layer
import
Attention
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.distributed.kv_events
import
BlockRemoved
,
BlockStored
,
KVCacheEvent
from
vllm.distributed.kv_events
import
BlockRemoved
,
BlockStored
,
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.utils
import
yield_req_data
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
KVConnectorBase_V1
,
KVConnectorRole
,
KVConnectorRole
,
...
@@ -516,23 +517,3 @@ class OffloadingConnectorWorker:
...
@@ -516,23 +517,3 @@ class OffloadingConnectorWorker:
del
self
.
_store_jobs
[
req_id
]
del
self
.
_store_jobs
[
req_id
]
return
finished_sending
,
finished_recving
return
finished_sending
,
finished_recving
def
yield_req_data
(
scheduler_output
,
)
->
Iterator
[
tuple
[
str
,
tuple
[
list
[
int
],
...],
bool
]]:
"""
Yields:
(req_id, new_block_id_groups, preempted)
"""
# new requests
for
req_data
in
scheduler_output
.
scheduled_new_reqs
:
yield
req_data
.
req_id
,
req_data
.
block_ids
,
False
# cached requests
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
yield
from
zip
(
cached_reqs
.
req_ids
,
cached_reqs
.
new_block_ids
,
(
req_id
in
cached_reqs
.
resumed_req_ids
for
req_id
in
cached_reqs
.
req_ids
),
)
vllm/distributed/parallel_state.py
View file @
a810671a
...
@@ -1007,10 +1007,17 @@ class GroupCoordinator:
...
@@ -1007,10 +1007,17 @@ class GroupCoordinator:
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
is_sequence_parallel
:
bool
=
False
,
is_sequence_parallel
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
extra_tensors
:
list
[
torch
.
Tensor
]
|
None
=
None
,
)
->
(
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
):
if
self
.
device_communicator
is
not
None
:
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
dispatch
(
return
self
.
device_communicator
.
dispatch
(
# type: ignore[call-arg]
hidden_states
,
router_logits
,
is_sequence_parallel
hidden_states
,
router_logits
,
is_sequence_parallel
,
extra_tensors
,
)
)
else
:
else
:
return
hidden_states
,
router_logits
return
hidden_states
,
router_logits
...
...
vllm/engine/arg_utils.py
View file @
a810671a
...
@@ -93,6 +93,7 @@ from vllm.transformers_utils.utils import is_cloud_storage
...
@@ -93,6 +93,7 @@ from vllm.transformers_utils.utils import is_cloud_storage
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.mem_constants
import
GiB_bytes
from
vllm.utils.network_utils
import
get_ip
from
vllm.utils.network_utils
import
get_ip
from
vllm.utils.torch_utils
import
resolve_kv_cache_dtype_string
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
from
vllm.v1.sample.logits_processor
import
LogitsProcessor
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -106,6 +107,7 @@ else:
...
@@ -106,6 +107,7 @@ else:
LoadFormats
=
Any
LoadFormats
=
Any
UsageContext
=
Any
UsageContext
=
Any
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# object is used to allow for special typing forms
# object is used to allow for special typing forms
...
@@ -406,8 +408,9 @@ class EngineArgs:
...
@@ -406,8 +408,9 @@ class EngineArgs:
data_parallel_external_lb
:
bool
=
False
data_parallel_external_lb
:
bool
=
False
data_parallel_backend
:
str
=
ParallelConfig
.
data_parallel_backend
data_parallel_backend
:
str
=
ParallelConfig
.
data_parallel_backend
enable_expert_parallel
:
bool
=
ParallelConfig
.
enable_expert_parallel
enable_expert_parallel
:
bool
=
ParallelConfig
.
enable_expert_parallel
all2all_backend
:
str
|
None
=
ParallelConfig
.
all2all_backend
all2all_backend
:
str
=
ParallelConfig
.
all2all_backend
enable_dbo
:
bool
=
ParallelConfig
.
enable_dbo
enable_dbo
:
bool
=
ParallelConfig
.
enable_dbo
ubatch_size
:
int
=
ParallelConfig
.
ubatch_size
dbo_decode_token_threshold
:
int
=
ParallelConfig
.
dbo_decode_token_threshold
dbo_decode_token_threshold
:
int
=
ParallelConfig
.
dbo_decode_token_threshold
dbo_prefill_token_threshold
:
int
=
ParallelConfig
.
dbo_prefill_token_threshold
dbo_prefill_token_threshold
:
int
=
ParallelConfig
.
dbo_prefill_token_threshold
disable_nccl_for_dp_synchronization
:
bool
=
(
disable_nccl_for_dp_synchronization
:
bool
=
(
...
@@ -520,6 +523,7 @@ class EngineArgs:
...
@@ -520,6 +523,7 @@ class EngineArgs:
enable_layerwise_nvtx_tracing
:
bool
=
(
enable_layerwise_nvtx_tracing
:
bool
=
(
ObservabilityConfig
.
enable_layerwise_nvtx_tracing
ObservabilityConfig
.
enable_layerwise_nvtx_tracing
)
)
enable_mfu_metrics
:
bool
=
ObservabilityConfig
.
enable_mfu_metrics
scheduling_policy
:
SchedulerPolicy
=
SchedulerConfig
.
policy
scheduling_policy
:
SchedulerPolicy
=
SchedulerConfig
.
policy
scheduler_cls
:
str
|
type
[
object
]
|
None
=
SchedulerConfig
.
scheduler_cls
scheduler_cls
:
str
|
type
[
object
]
|
None
=
SchedulerConfig
.
scheduler_cls
...
@@ -841,6 +845,10 @@ class EngineArgs:
...
@@ -841,6 +845,10 @@ class EngineArgs:
"--all2all-backend"
,
**
parallel_kwargs
[
"all2all_backend"
]
"--all2all-backend"
,
**
parallel_kwargs
[
"all2all_backend"
]
)
)
parallel_group
.
add_argument
(
"--enable-dbo"
,
**
parallel_kwargs
[
"enable_dbo"
])
parallel_group
.
add_argument
(
"--enable-dbo"
,
**
parallel_kwargs
[
"enable_dbo"
])
parallel_group
.
add_argument
(
"--ubatch-size"
,
**
parallel_kwargs
[
"ubatch_size"
],
)
parallel_group
.
add_argument
(
parallel_group
.
add_argument
(
"--dbo-decode-token-threshold"
,
"--dbo-decode-token-threshold"
,
**
parallel_kwargs
[
"dbo_decode_token_threshold"
],
**
parallel_kwargs
[
"dbo_decode_token_threshold"
],
...
@@ -1035,6 +1043,10 @@ class EngineArgs:
...
@@ -1035,6 +1043,10 @@ class EngineArgs:
"--enable-layerwise-nvtx-tracing"
,
"--enable-layerwise-nvtx-tracing"
,
**
observability_kwargs
[
"enable_layerwise_nvtx_tracing"
],
**
observability_kwargs
[
"enable_layerwise_nvtx_tracing"
],
)
)
observability_group
.
add_argument
(
"--enable-mfu-metrics"
,
**
observability_kwargs
[
"enable_mfu_metrics"
],
)
# Scheduler arguments
# Scheduler arguments
scheduler_kwargs
=
get_kwargs
(
SchedulerConfig
)
scheduler_kwargs
=
get_kwargs
(
SchedulerConfig
)
...
@@ -1356,12 +1368,17 @@ class EngineArgs:
...
@@ -1356,12 +1368,17 @@ class EngineArgs:
f
"dcp_size=
{
self
.
decode_context_parallel_size
}
."
f
"dcp_size=
{
self
.
decode_context_parallel_size
}
."
)
)
# Resolve "auto" kv_cache_dtype to actual value from model config
resolved_cache_dtype
=
resolve_kv_cache_dtype_string
(
self
.
kv_cache_dtype
,
model_config
)
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
gpu_memory_utilization
=
self
.
gpu_memory_utilization
,
gpu_memory_utilization
=
self
.
gpu_memory_utilization
,
kv_cache_memory_bytes
=
self
.
kv_cache_memory_bytes
,
kv_cache_memory_bytes
=
self
.
kv_cache_memory_bytes
,
swap_space
=
self
.
swap_space
,
swap_space
=
self
.
swap_space
,
cache_dtype
=
self
.
kv
_cache_dtype
,
cache_dtype
=
resolved
_cache_dtype
,
is_attention_free
=
model_config
.
is_attention_free
,
is_attention_free
=
model_config
.
is_attention_free
,
num_gpu_blocks_override
=
self
.
num_gpu_blocks_override
,
num_gpu_blocks_override
=
self
.
num_gpu_blocks_override
,
sliding_window
=
sliding_window
,
sliding_window
=
sliding_window
,
...
@@ -1557,6 +1574,7 @@ class EngineArgs:
...
@@ -1557,6 +1574,7 @@ class EngineArgs:
enable_expert_parallel
=
self
.
enable_expert_parallel
,
enable_expert_parallel
=
self
.
enable_expert_parallel
,
all2all_backend
=
self
.
all2all_backend
,
all2all_backend
=
self
.
all2all_backend
,
enable_dbo
=
self
.
enable_dbo
,
enable_dbo
=
self
.
enable_dbo
,
ubatch_size
=
self
.
ubatch_size
,
dbo_decode_token_threshold
=
self
.
dbo_decode_token_threshold
,
dbo_decode_token_threshold
=
self
.
dbo_decode_token_threshold
,
dbo_prefill_token_threshold
=
self
.
dbo_prefill_token_threshold
,
dbo_prefill_token_threshold
=
self
.
dbo_prefill_token_threshold
,
disable_nccl_for_dp_synchronization
=
self
.
disable_nccl_for_dp_synchronization
,
disable_nccl_for_dp_synchronization
=
self
.
disable_nccl_for_dp_synchronization
,
...
@@ -1676,6 +1694,7 @@ class EngineArgs:
...
@@ -1676,6 +1694,7 @@ class EngineArgs:
kv_cache_metrics_sample
=
self
.
kv_cache_metrics_sample
,
kv_cache_metrics_sample
=
self
.
kv_cache_metrics_sample
,
cudagraph_metrics
=
self
.
cudagraph_metrics
,
cudagraph_metrics
=
self
.
cudagraph_metrics
,
enable_layerwise_nvtx_tracing
=
self
.
enable_layerwise_nvtx_tracing
,
enable_layerwise_nvtx_tracing
=
self
.
enable_layerwise_nvtx_tracing
,
enable_mfu_metrics
=
self
.
enable_mfu_metrics
,
)
)
# Compilation config overrides
# Compilation config overrides
...
...
vllm/entrypoints/context.py
View file @
a810671a
...
@@ -2,11 +2,13 @@
...
@@ -2,11 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
asyncio
import
asyncio
import
contextlib
import
contextlib
import
copy
import
json
import
json
import
logging
import
logging
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
contextlib
import
AsyncExitStack
from
contextlib
import
AsyncExitStack
from
dataclasses
import
replace
from
typing
import
TYPE_CHECKING
,
Union
from
typing
import
TYPE_CHECKING
,
Union
from
openai.types.responses.response_function_tool_call_output_item
import
(
from
openai.types.responses.response_function_tool_call_output_item
import
(
...
@@ -164,6 +166,12 @@ class SimpleContext(ConversationContext):
...
@@ -164,6 +166,12 @@ class SimpleContext(ConversationContext):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
last_output
=
None
self
.
last_output
=
None
# Accumulated final output for streaming mode
self
.
_accumulated_text
:
str
=
""
self
.
_accumulated_token_ids
:
list
[
int
]
=
[]
self
.
_accumulated_logprobs
:
list
=
[]
self
.
num_prompt_tokens
=
0
self
.
num_prompt_tokens
=
0
self
.
num_output_tokens
=
0
self
.
num_output_tokens
=
0
self
.
num_cached_tokens
=
0
self
.
num_cached_tokens
=
0
...
@@ -183,6 +191,13 @@ class SimpleContext(ConversationContext):
...
@@ -183,6 +191,13 @@ class SimpleContext(ConversationContext):
self
.
num_cached_tokens
=
output
.
num_cached_tokens
or
0
self
.
num_cached_tokens
=
output
.
num_cached_tokens
or
0
self
.
num_output_tokens
+=
len
(
output
.
outputs
[
0
].
token_ids
or
[])
self
.
num_output_tokens
+=
len
(
output
.
outputs
[
0
].
token_ids
or
[])
# Accumulate text, token_ids, and logprobs for streaming mode
delta_output
=
output
.
outputs
[
0
]
self
.
_accumulated_text
+=
delta_output
.
text
self
.
_accumulated_token_ids
.
extend
(
delta_output
.
token_ids
)
if
delta_output
.
logprobs
is
not
None
:
self
.
_accumulated_logprobs
.
extend
(
delta_output
.
logprobs
)
if
len
(
self
.
input_messages
)
==
0
:
if
len
(
self
.
input_messages
)
==
0
:
output_prompt
=
output
.
prompt
or
""
output_prompt
=
output
.
prompt
or
""
output_prompt_token_ids
=
output
.
prompt_token_ids
or
[]
output_prompt_token_ids
=
output
.
prompt_token_ids
or
[]
...
@@ -194,11 +209,26 @@ class SimpleContext(ConversationContext):
...
@@ -194,11 +209,26 @@ class SimpleContext(ConversationContext):
)
)
self
.
output_messages
.
append
(
self
.
output_messages
.
append
(
ResponseRawMessageAndToken
(
ResponseRawMessageAndToken
(
message
=
output
.
output
s
[
0
]
.
text
,
message
=
delta_
output
.
text
,
tokens
=
output
.
output
s
[
0
]
.
token_ids
,
tokens
=
delta_
output
.
token_ids
,
)
)
)
)
@
property
def
final_output
(
self
)
->
RequestOutput
|
None
:
"""Return the final output, with complete text/token_ids/logprobs."""
if
self
.
last_output
is
not
None
and
self
.
last_output
.
outputs
:
assert
isinstance
(
self
.
last_output
,
RequestOutput
)
final_output
=
copy
.
copy
(
self
.
last_output
)
# copy inner item to avoid modify last_output
final_output
.
outputs
=
[
replace
(
item
)
for
item
in
self
.
last_output
.
outputs
]
final_output
.
outputs
[
0
].
text
=
self
.
_accumulated_text
final_output
.
outputs
[
0
].
token_ids
=
tuple
(
self
.
_accumulated_token_ids
)
if
self
.
_accumulated_logprobs
:
final_output
.
outputs
[
0
].
logprobs
=
self
.
_accumulated_logprobs
return
final_output
return
self
.
last_output
def
append_tool_output
(
self
,
output
)
->
None
:
def
append_tool_output
(
self
,
output
)
->
None
:
raise
NotImplementedError
(
"Should not be called."
)
raise
NotImplementedError
(
"Should not be called."
)
...
@@ -267,12 +297,40 @@ class ParsableContext(ConversationContext):
...
@@ -267,12 +297,40 @@ class ParsableContext(ConversationContext):
self
.
chat_template
=
chat_template
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
=
chat_template_content_format
self
.
chat_template_content_format
=
chat_template_content_format
self
.
input_messages
:
list
[
ResponseRawMessageAndToken
]
=
[]
self
.
output_messages
:
list
[
ResponseRawMessageAndToken
]
=
[]
def
append_output
(
self
,
output
:
RequestOutput
)
->
None
:
def
append_output
(
self
,
output
:
RequestOutput
)
->
None
:
self
.
num_prompt_tokens
=
len
(
output
.
prompt_token_ids
or
[])
self
.
num_prompt_tokens
=
len
(
output
.
prompt_token_ids
or
[])
self
.
num_cached_tokens
=
output
.
num_cached_tokens
or
0
self
.
num_cached_tokens
=
output
.
num_cached_tokens
or
0
self
.
num_output_tokens
+=
len
(
output
.
outputs
[
0
].
token_ids
or
[])
self
.
num_output_tokens
+=
len
(
output
.
outputs
[
0
].
token_ids
or
[])
self
.
parser
.
process
(
output
.
outputs
[
0
])
self
.
parser
.
process
(
output
.
outputs
[
0
])
# only store if enable_response_messages is True, save memory
if
self
.
request
.
enable_response_messages
:
output_prompt
=
output
.
prompt
or
""
output_prompt_token_ids
=
output
.
prompt_token_ids
or
[]
if
len
(
self
.
input_messages
)
==
0
:
self
.
input_messages
.
append
(
ResponseRawMessageAndToken
(
message
=
output_prompt
,
tokens
=
output_prompt_token_ids
,
)
)
else
:
self
.
output_messages
.
append
(
ResponseRawMessageAndToken
(
message
=
output_prompt
,
tokens
=
output_prompt_token_ids
,
)
)
self
.
output_messages
.
append
(
ResponseRawMessageAndToken
(
message
=
output
.
outputs
[
0
].
text
,
tokens
=
output
.
outputs
[
0
].
token_ids
,
)
)
def
append_tool_output
(
self
,
output
:
list
[
ResponseInputOutputItem
])
->
None
:
def
append_tool_output
(
self
,
output
:
list
[
ResponseInputOutputItem
])
->
None
:
self
.
parser
.
response_messages
.
extend
(
output
)
self
.
parser
.
response_messages
.
extend
(
output
)
...
...
vllm/entrypoints/llm.py
View file @
a810671a
...
@@ -18,6 +18,7 @@ from vllm.beam_search import (
...
@@ -18,6 +18,7 @@ from vllm.beam_search import (
create_sort_beams_key_function
,
create_sort_beams_key_function
,
)
)
from
vllm.config
import
(
from
vllm.config
import
(
AttentionConfig
,
CompilationConfig
,
CompilationConfig
,
PoolerConfig
,
PoolerConfig
,
ProfilerConfig
,
ProfilerConfig
,
...
@@ -175,6 +176,10 @@ class LLM:
...
@@ -175,6 +176,10 @@ class LLM:
compilation_config: Either an integer or a dictionary. If it is an
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the mode of compilation optimization. If it
integer, it is used as the mode of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
is a dictionary, it can specify the full compilation configuration.
attention_config: Configuration for attention mechanisms. Can be a
dictionary or an AttentionConfig instance. If a dictionary, it will
be converted to an AttentionConfig. Allows specifying the attention
backend and other attention-related settings.
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
**kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
Note:
Note:
...
@@ -213,6 +218,7 @@ class LLM:
...
@@ -213,6 +218,7 @@ class LLM:
|
StructuredOutputsConfig
|
StructuredOutputsConfig
|
None
=
None
,
|
None
=
None
,
profiler_config
:
dict
[
str
,
Any
]
|
ProfilerConfig
|
None
=
None
,
profiler_config
:
dict
[
str
,
Any
]
|
ProfilerConfig
|
None
=
None
,
attention_config
:
dict
[
str
,
Any
]
|
AttentionConfig
|
None
=
None
,
kv_cache_memory_bytes
:
int
|
None
=
None
,
kv_cache_memory_bytes
:
int
|
None
=
None
,
compilation_config
:
int
|
dict
[
str
,
Any
]
|
CompilationConfig
|
None
=
None
,
compilation_config
:
int
|
dict
[
str
,
Any
]
|
CompilationConfig
|
None
=
None
,
logits_processors
:
list
[
str
|
type
[
LogitsProcessor
]]
|
None
=
None
,
logits_processors
:
list
[
str
|
type
[
LogitsProcessor
]]
|
None
=
None
,
...
@@ -252,51 +258,28 @@ class LLM:
...
@@ -252,51 +258,28 @@ class LLM:
if
hf_overrides
is
None
:
if
hf_overrides
is
None
:
hf_overrides
=
{}
hf_overrides
=
{}
if
compilation_config
is
not
None
:
def
_make_config
(
value
:
Any
,
cls
:
type
[
_R
])
->
_R
:
if
isinstance
(
compilation_config
,
int
):
"""Convert dict/None/instance to a config instance."""
compilation_config_instance
=
CompilationConfig
(
if
value
is
None
:
mode
=
CompilationMode
(
compilation_config
)
return
cls
()
)
if
isinstance
(
value
,
dict
):
elif
isinstance
(
compilation_config
,
dict
):
return
cls
(
**
{
k
:
v
for
k
,
v
in
value
.
items
()
if
is_init_field
(
cls
,
k
)})
# type: ignore[arg-type]
compilation_config_instance
=
CompilationConfig
(
return
value
**
{
k
:
v
if
isinstance
(
compilation_config
,
int
):
for
k
,
v
in
compilation_config
.
items
()
compilation_config_instance
=
CompilationConfig
(
if
is_init_field
(
CompilationConfig
,
k
)
mode
=
CompilationMode
(
compilation_config
)
}
)
)
else
:
compilation_config_instance
=
compilation_config
else
:
compilation_config_instance
=
CompilationConfig
()
if
structured_outputs_config
is
not
None
:
if
isinstance
(
structured_outputs_config
,
dict
):
structured_outputs_instance
=
StructuredOutputsConfig
(
**
{
k
:
v
for
k
,
v
in
structured_outputs_config
.
items
()
if
is_init_field
(
StructuredOutputsConfig
,
k
)
}
)
else
:
structured_outputs_instance
=
structured_outputs_config
else
:
structured_outputs_instance
=
StructuredOutputsConfig
()
if
profiler_config
is
not
None
:
if
isinstance
(
profiler_config
,
dict
):
profiler_config_instance
=
ProfilerConfig
(
**
{
k
:
v
for
k
,
v
in
profiler_config
.
items
()
if
is_init_field
(
ProfilerConfig
,
k
)
}
)
else
:
profiler_config_instance
=
profiler_config
else
:
else
:
profiler_config_instance
=
ProfilerConfig
()
compilation_config_instance
=
_make_config
(
compilation_config
,
CompilationConfig
)
structured_outputs_instance
=
_make_config
(
structured_outputs_config
,
StructuredOutputsConfig
)
profiler_config_instance
=
_make_config
(
profiler_config
,
ProfilerConfig
)
attention_config_instance
=
_make_config
(
attention_config
,
AttentionConfig
)
# warn about single-process data parallel usage.
# warn about single-process data parallel usage.
_dp_size
=
int
(
kwargs
.
get
(
"data_parallel_size"
,
1
))
_dp_size
=
int
(
kwargs
.
get
(
"data_parallel_size"
,
1
))
...
@@ -341,6 +324,7 @@ class LLM:
...
@@ -341,6 +324,7 @@ class LLM:
pooler_config
=
pooler_config
,
pooler_config
=
pooler_config
,
structured_outputs_config
=
structured_outputs_instance
,
structured_outputs_config
=
structured_outputs_instance
,
profiler_config
=
profiler_config_instance
,
profiler_config
=
profiler_config_instance
,
attention_config
=
attention_config_instance
,
compilation_config
=
compilation_config_instance
,
compilation_config
=
compilation_config_instance
,
logits_processors
=
logits_processors
,
logits_processors
=
logits_processors
,
**
kwargs
,
**
kwargs
,
...
...
vllm/entrypoints/openai/api_server.py
View file @
a810671a
...
@@ -17,21 +17,20 @@ from argparse import Namespace
...
@@ -17,21 +17,20 @@ from argparse import Namespace
from
collections.abc
import
AsyncGenerator
,
AsyncIterator
,
Awaitable
from
collections.abc
import
AsyncGenerator
,
AsyncIterator
,
Awaitable
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Annotated
,
Any
,
Literal
from
typing
import
Annotated
,
Any
import
model_hosting_container_standards.sagemaker
as
sagemaker_standards
import
model_hosting_container_standards.sagemaker
as
sagemaker_standards
import
pydantic
import
pydantic
import
uvloop
import
uvloop
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
Form
,
HTTPException
,
Query
,
Request
from
fastapi
import
APIRouter
,
Depends
,
FastAPI
,
Form
,
HTTPException
,
Request
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.exceptions
import
RequestValidationError
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.responses
import
JSONResponse
,
Response
,
StreamingResponse
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
starlette.concurrency
import
iterate_in_threadpool
from
starlette.concurrency
import
iterate_in_threadpool
from
starlette.datastructures
import
URL
,
Headers
,
MutableHeaders
,
State
from
starlette.datastructures
import
URL
,
Headers
,
MutableHeaders
,
State
from
starlette.types
import
ASGIApp
,
Message
,
Receive
,
Scope
,
Send
from
starlette.types
import
ASGIApp
,
Message
,
Receive
,
Scope
,
Send
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.protocol
import
EngineClient
from
vllm.engine.protocol
import
EngineClient
from
vllm.entrypoints.anthropic.protocol
import
(
from
vllm.entrypoints.anthropic.protocol
import
(
...
@@ -639,97 +638,6 @@ async def create_translations(
...
@@ -639,97 +638,6 @@ async def create_translations(
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
return
StreamingResponse
(
content
=
generator
,
media_type
=
"text/event-stream"
)
if
envs
.
VLLM_SERVER_DEV_MODE
:
logger
.
warning
(
"SECURITY WARNING: Development endpoints are enabled! "
"This should NOT be used in production!"
)
PydanticVllmConfig
=
pydantic
.
TypeAdapter
(
VllmConfig
)
@
router
.
get
(
"/server_info"
)
async
def
show_server_info
(
raw_request
:
Request
,
config_format
:
Annotated
[
Literal
[
"text"
,
"json"
],
Query
()]
=
"text"
,
):
vllm_config
:
VllmConfig
=
raw_request
.
app
.
state
.
vllm_config
server_info
=
{
"vllm_config"
:
str
(
vllm_config
)
if
config_format
==
"text"
else
PydanticVllmConfig
.
dump_python
(
vllm_config
,
mode
=
"json"
,
fallback
=
str
)
# fallback=str is needed to handle e.g. torch.dtype
}
return
JSONResponse
(
content
=
server_info
)
@
router
.
post
(
"/reset_prefix_cache"
)
async
def
reset_prefix_cache
(
raw_request
:
Request
,
reset_running_requests
:
bool
=
Query
(
default
=
False
),
reset_external
:
bool
=
Query
(
default
=
False
),
):
"""
Reset the local prefix cache.
Optionally, if the query parameter `reset_external=true`
also resets the external (connector-managed) prefix cache.
Note that we currently do not check if the prefix cache
is successfully reset in the API server.
Example:
POST /reset_prefix_cache?reset_external=true
"""
logger
.
info
(
"Resetting prefix cache..."
)
await
engine_client
(
raw_request
).
reset_prefix_cache
(
reset_running_requests
,
reset_external
)
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/reset_mm_cache"
)
async
def
reset_mm_cache
(
raw_request
:
Request
):
"""
Reset the multi-modal cache. Note that we currently do not check if the
multi-modal cache is successfully reset in the API server.
"""
logger
.
info
(
"Resetting multi-modal cache..."
)
await
engine_client
(
raw_request
).
reset_mm_cache
()
return
Response
(
status_code
=
200
)
@
router
.
post
(
"/collective_rpc"
)
async
def
collective_rpc
(
raw_request
:
Request
):
try
:
body
=
await
raw_request
.
json
()
except
json
.
JSONDecodeError
as
e
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
detail
=
f
"JSON decode error:
{
e
}
"
,
)
from
e
method
=
body
.
get
(
"method"
)
if
method
is
None
:
raise
HTTPException
(
status_code
=
HTTPStatus
.
BAD_REQUEST
.
value
,
detail
=
"Missing 'method' in request body"
,
)
# For security reason, only serialized string args/kwargs are passed.
# User-defined `method` is responsible for deserialization if needed.
args
:
list
[
str
]
=
body
.
get
(
"args"
,
[])
kwargs
:
dict
[
str
,
str
]
=
body
.
get
(
"kwargs"
,
{})
timeout
:
float
|
None
=
body
.
get
(
"timeout"
)
results
=
await
engine_client
(
raw_request
).
collective_rpc
(
method
=
method
,
timeout
=
timeout
,
args
=
tuple
(
args
),
kwargs
=
kwargs
)
if
results
is
None
:
return
Response
(
status_code
=
200
)
response
:
list
[
Any
]
=
[]
for
result
in
results
:
if
result
is
None
or
isinstance
(
result
,
dict
|
list
):
response
.
append
(
result
)
else
:
response
.
append
(
str
(
result
))
return
JSONResponse
(
content
=
{
"results"
:
response
})
def
load_log_config
(
log_config_file
:
str
|
None
)
->
dict
|
None
:
def
load_log_config
(
log_config_file
:
str
|
None
)
->
dict
|
None
:
if
not
log_config_file
:
if
not
log_config_file
:
return
None
return
None
...
@@ -1174,6 +1082,9 @@ async def init_app_state(
...
@@ -1174,6 +1082,9 @@ async def init_app_state(
if
"generate"
in
supported_tasks
if
"generate"
in
supported_tasks
else
None
else
None
)
)
# Warm up chat template processing to avoid first-request latency
if
state
.
openai_serving_chat
is
not
None
:
await
state
.
openai_serving_chat
.
warmup
()
state
.
openai_serving_completion
=
(
state
.
openai_serving_completion
=
(
OpenAIServingCompletion
(
OpenAIServingCompletion
(
engine_client
,
engine_client
,
...
...
vllm/entrypoints/openai/parser/responses_parser.py
View file @
a810671a
...
@@ -3,7 +3,11 @@
...
@@ -3,7 +3,11 @@
import
logging
import
logging
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
openai.types.responses.response_function_tool_call
import
ResponseFunctionToolCall
from
openai.types.responses
import
ResponseFunctionToolCall
,
ResponseOutputItem
from
openai.types.responses.response_function_tool_call_output_item
import
(
ResponseFunctionToolCallOutputItem
,
)
from
openai.types.responses.response_output_item
import
McpCall
from
openai.types.responses.response_output_message
import
ResponseOutputMessage
from
openai.types.responses.response_output_message
import
ResponseOutputMessage
from
openai.types.responses.response_output_text
import
ResponseOutputText
from
openai.types.responses.response_output_text
import
ResponseOutputText
from
openai.types.responses.response_reasoning_item
import
(
from
openai.types.responses.response_reasoning_item
import
(
...
@@ -11,6 +15,7 @@ from openai.types.responses.response_reasoning_item import (
...
@@ -11,6 +15,7 @@ from openai.types.responses.response_reasoning_item import (
ResponseReasoningItem
,
ResponseReasoningItem
,
)
)
from
vllm.entrypoints.constants
import
MCP_PREFIX
from
vllm.entrypoints.openai.protocol
import
ResponseInputOutputItem
,
ResponsesRequest
from
vllm.entrypoints.openai.protocol
import
ResponseInputOutputItem
,
ResponsesRequest
from
vllm.outputs
import
CompletionOutput
from
vllm.outputs
import
CompletionOutput
from
vllm.reasoning.abs_reasoning_parsers
import
ReasoningParser
from
vllm.reasoning.abs_reasoning_parsers
import
ReasoningParser
...
@@ -111,6 +116,37 @@ class ResponsesParser:
...
@@ -111,6 +116,37 @@ class ResponsesParser:
return
self
return
self
def
make_response_output_items_from_parsable_context
(
self
,
)
->
list
[
ResponseOutputItem
]:
"""Given a list of sentences, construct ResponseOutput Items."""
response_messages
=
self
.
response_messages
[
self
.
num_init_messages
:]
output_messages
:
list
[
ResponseOutputItem
]
=
[]
for
message
in
response_messages
:
if
not
isinstance
(
message
,
ResponseFunctionToolCallOutputItem
):
output_messages
.
append
(
message
)
else
:
if
len
(
output_messages
)
==
0
:
raise
ValueError
(
"Cannot have a FunctionToolCallOutput before FunctionToolCall."
)
if
isinstance
(
output_messages
[
-
1
],
ResponseFunctionToolCall
):
mcp_message
=
McpCall
(
id
=
f
"
{
MCP_PREFIX
}{
random_uuid
()
}
"
,
arguments
=
output_messages
[
-
1
].
arguments
,
name
=
output_messages
[
-
1
].
name
,
server_label
=
output_messages
[
-
1
].
name
,
# TODO: store the server label
type
=
"mcp_call"
,
status
=
"completed"
,
output
=
message
.
output
,
# TODO: support error output
)
output_messages
[
-
1
]
=
mcp_message
return
output_messages
def
get_responses_parser_for_simple_context
(
def
get_responses_parser_for_simple_context
(
*
,
*
,
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment