Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1a135a9d
Commit
1a135a9d
authored
Oct 10, 2025
by
maxiao1
Committed by
lizhigong
Oct 10, 2025
Browse files
token split by token adapt to pd separation & p2p can be used async
parent
fc5bfc66
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
35 additions
and
43 deletions
+35
-43
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+19
-21
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
...ibuted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
+6
-3
vllm/envs.py
vllm/envs.py
+4
-0
vllm/two_batch_overlap/v1/model_input_split_v1.py
vllm/two_batch_overlap/v1/model_input_split_v1.py
+2
-2
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
+2
-15
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-2
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
1a135a9d
...
@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context
...
@@ -18,7 +18,7 @@ from vllm.forward_context import get_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_get_done_event
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
...
@@ -36,6 +36,7 @@ class ReqMeta:
...
@@ -36,6 +36,7 @@ class ReqMeta:
token_ids
:
torch
.
Tensor
token_ids
:
torch
.
Tensor
# Slot mappings, should have the same length as token_ids
# Slot mappings, should have the same length as token_ids
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
slot_mapping_device
:
torch
.
Tensor
=
None
@
staticmethod
@
staticmethod
def
make_meta
(
request_id
:
str
,
token_ids
:
list
[
int
],
block_ids
:
list
[
int
],
def
make_meta
(
request_id
:
str
,
token_ids
:
list
[
int
],
block_ids
:
list
[
int
],
...
@@ -273,9 +274,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -273,9 +274,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx)
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
"""
# if envs.VLLM_ENABLE_TBO:
# slot_mapping = slot_mapping.pin_memory().to(device=layer.device, non_blocking=True)
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
...
@@ -287,41 +286,40 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -287,41 +286,40 @@ class P2pNcclConnector(KVConnectorBase_V1):
connector_metadata
=
self
.
_get_connector_metadata
()
connector_metadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
connector_metadata
,
P2pNcclConnectorMetadata
)
assert
isinstance
(
connector_metadata
,
P2pNcclConnectorMetadata
)
if
envs
.
VLLM_ENABLE_TBO
:
if
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
:
send_stream
=
self
.
p2p_nccl_engine
.
send_stream
for
request
in
connector_metadata
.
requests
:
for
request
in
connector_metadata
.
requests
:
request_id
=
request
.
request_id
request_id
=
request
.
request_id
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
slot_mapping
=
request
.
slot_mapping
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
if
request
.
slot_mapping
_device
is
None
:
# tbo_evt = torch.cuda.Event(enable_timing=False)
request
.
slot_mapping_device
=
\
# tbo_evt.record()
request
.
slot_mapping
.
pin_memory
().
to
(
device
=
kv_layer
.
device
,
non_blocking
=
True
)
# with torch.cuda.stream(send_stream):
slot_mapping
=
request
.
slot_mapping_device
# send_stream.wait_event(tbo_evt) # 等 TBO all_reduce_stream 完成本轮
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
slot_mapping
)
# kv_cache.record_stream(send_stream
)
tbo_evt
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
tbo_evt
.
record
()
pp_rank
=
(
self
.
parallel_config
.
rank
//
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
)
%
\
self
.
parallel_config
.
tensor_parallel_size
)
%
\
self
.
parallel_config
.
pipeline_parallel_size
self
.
parallel_config
.
pipeline_parallel_size
if
(
self
.
pp_size
==
1
):
if
(
self
.
pp_size
==
1
):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
kv_cache
,
remote_address
,
tbo_evt
)
elif
(
self
.
pp_size
==
2
):
elif
(
self
.
pp_size
==
2
):
if
(
pp_rank
==
0
):
if
(
pp_rank
==
0
):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
kv_cache
,
remote_address
,
tbo_evt
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
ip
+
":"
+
str
(
port
+
self
.
_rank
+
4
))
kv_cache
,
ip
+
":"
+
str
(
port
+
self
.
_rank
+
4
)
,
tbo_evt
)
else
:
else
:
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
remote_address
)
kv_cache
,
remote_address
,
tbo_evt
)
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
ip
+
":"
+
str
(
port
+
self
.
_rank
-
4
))
kv_cache
,
ip
+
":"
+
str
(
port
+
self
.
_rank
-
4
)
,
tbo_evt
)
elif
(
self
.
pp_size
==
8
):
elif
(
self
.
pp_size
==
8
):
for
i
in
range
(
8
):
for
i
in
range
(
8
):
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
self
.
p2p_nccl_engine
.
send_tensor
(
request_id
+
"#"
+
layer_name
,
kv_cache
,
ip
+
":"
+
str
(
port
+
i
))
kv_cache
,
ip
+
":"
+
str
(
port
+
i
)
,
tbo_evt
)
else
:
else
:
print
(
"Error: only suppprt pp1 pp2 pp8!!!!!!"
)
print
(
"Error: only suppprt pp1 pp2 pp8!!!!!!"
)
else
:
else
:
...
@@ -330,7 +328,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
...
@@ -330,7 +328,7 @@ class P2pNcclConnector(KVConnectorBase_V1):
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
ip
,
port
=
self
.
parse_request_id
(
request_id
,
True
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
remote_address
=
ip
+
":"
+
str
(
port
+
self
.
_rank
)
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
kv_cache
=
extract_kv_from_layer
(
kv_layer
,
request
.
slot_mapping
)
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
pp_rank
=
(
self
.
parallel_config
.
rank
//
self
.
parallel_config
.
tensor_parallel_size
)
%
self
.
parallel_config
.
pipeline_parallel_size
)
%
self
.
parallel_config
.
pipeline_parallel_size
if
(
self
.
pp_size
==
1
):
if
(
self
.
pp_size
==
1
):
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py
View file @
1a135a9d
...
@@ -20,7 +20,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
...
@@ -20,7 +20,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
from
vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool
import
(
# noqa: E501
from
vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool
import
(
# noqa: E501
TensorMemoryPool
)
TensorMemoryPool
)
from
vllm.utils
import
current_stream
,
get_ip
from
vllm.utils
import
current_stream
,
get_ip
from
vllm
.two_batch_overlap.v1.two_batch_overlap_v1
import
all_reduce_stream
as
tbo_all_reduce_stream
from
vllm
import
envs
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.forward_context
import
ForwardContext
from
vllm.forward_context
import
ForwardContext
...
@@ -196,6 +196,7 @@ class P2pNcclEngine:
...
@@ -196,6 +196,7 @@ class P2pNcclEngine:
tensor_id
:
str
,
tensor_id
:
str
,
tensor
:
torch
.
Tensor
,
tensor
:
torch
.
Tensor
,
remote_address
:
typing
.
Optional
[
str
]
=
None
,
remote_address
:
typing
.
Optional
[
str
]
=
None
,
tbo_evt
=
None
,
)
->
bool
:
)
->
bool
:
if
remote_address
is
None
:
if
remote_address
is
None
:
with
self
.
recv_store_cv
:
with
self
.
recv_store_cv
:
...
@@ -207,7 +208,7 @@ class P2pNcclEngine:
...
@@ -207,7 +208,7 @@ class P2pNcclEngine:
return
self
.
_send_sync
(
tensor_id
,
tensor
,
remote_address
)
return
self
.
_send_sync
(
tensor_id
,
tensor
,
remote_address
)
elif
self
.
send_type
==
"PUT_ASYNC"
:
elif
self
.
send_type
==
"PUT_ASYNC"
:
with
self
.
send_queue_cv
:
with
self
.
send_queue_cv
:
self
.
send_queue
.
append
([
tensor_id
,
remote_address
,
tensor
])
self
.
send_queue
.
append
([
tensor_id
,
remote_address
,
tensor
,
tbo_evt
])
self
.
send_queue_cv
.
notify
()
self
.
send_queue_cv
.
notify
()
else
:
# GET
else
:
# GET
with
self
.
send_store_cv
:
with
self
.
send_store_cv
:
...
@@ -391,9 +392,11 @@ class P2pNcclEngine:
...
@@ -391,9 +392,11 @@ class P2pNcclEngine:
with
self
.
send_queue_cv
:
with
self
.
send_queue_cv
:
while
not
self
.
send_queue
:
while
not
self
.
send_queue
:
self
.
send_queue_cv
.
wait
()
self
.
send_queue_cv
.
wait
()
tensor_id
,
remote_address
,
tensor
=
self
.
send_queue
.
popleft
()
tensor_id
,
remote_address
,
tensor
,
tbo_evt
=
self
.
send_queue
.
popleft
()
if
not
self
.
send_queue
:
if
not
self
.
send_queue
:
self
.
send_queue_cv
.
notify
()
self
.
send_queue_cv
.
notify
()
if
(
envs
.
VLLM_ENABLE_TBO
or
envs
.
VLLM_P2P_ASYNC
)
and
tbo_evt
is
not
None
:
self
.
send_stream
.
wait_event
(
tbo_evt
)
self
.
_send_sync
(
tensor_id
,
tensor
,
remote_address
)
self
.
_send_sync
(
tensor_id
,
tensor
,
remote_address
)
def
wait_for_sent
(
self
):
def
wait_for_sent
(
self
):
...
...
vllm/envs.py
View file @
1a135a9d
...
@@ -169,6 +169,7 @@ if TYPE_CHECKING:
...
@@ -169,6 +169,7 @@ if TYPE_CHECKING:
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
VLLM_USE_MERGE_ATTN_STATES_OPT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
USE_FUSED_RMS_QUANT
:
bool
=
False
USE_FUSED_SILU_MUL_QUANT
:
bool
=
False
USE_FUSED_SILU_MUL_QUANT
:
bool
=
False
VLLM_P2P_ASYNC
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
...
@@ -1114,6 +1115,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1114,6 +1115,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"USE_FUSED_SILU_MUL_QUANT"
:
"USE_FUSED_SILU_MUL_QUANT"
:
lambda
:
(
os
.
getenv
(
'USE_FUSED_SILU_MUL_QUANT'
,
'0'
).
lower
()
in
lambda
:
(
os
.
getenv
(
'USE_FUSED_SILU_MUL_QUANT'
,
'0'
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_P2P_ASYNC"
,
"0"
))),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/two_batch_overlap/v1/model_input_split_v1.py
View file @
1a135a9d
...
@@ -323,11 +323,11 @@ def tbo_split_and_execute_model(
...
@@ -323,11 +323,11 @@ def tbo_split_and_execute_model(
)
)
# === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector ===
# === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector ===
#
真实
token
#
real
token
nums
num_tokens_left
=
int
(
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
)
num_tokens_left
=
int
(
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
)
num_tokens_right
=
int
(
input_split
.
scheduler_output_right
.
total_num_scheduled_tokens
)
num_tokens_right
=
int
(
input_split
.
scheduler_output_right
.
total_num_scheduled_tokens
)
#
按左右半批切成两份
#
split intermediate tensors
def
_split_intermediate_tensors
(
it
,
l
,
r
):
def
_split_intermediate_tensors
(
it
,
l
,
r
):
if
it
is
None
:
return
None
,
None
if
it
is
None
:
return
None
,
None
left_tensor_map
,
right_tensor_map
=
{},
{}
left_tensor_map
,
right_tensor_map
=
{},
{}
...
...
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
View file @
1a135a9d
...
@@ -17,7 +17,7 @@ logger = init_logger(__name__)
...
@@ -17,7 +17,7 @@ logger = init_logger(__name__)
tbo_step_stream
=
None
tbo_step_stream
=
None
all_reduce_stream
=
None
all_reduce_stream
=
None
PERSIST_THREADS
=
os
.
getenv
(
'VLLM_TBO_PERSIST_THREADS'
,
'1'
)
not
in
(
'0'
,
'false'
,
'False'
,
'no'
,
'NO'
,
''
)
STOP
=
object
()
STOP
=
object
()
class
TwoBatchOverlap
:
class
TwoBatchOverlap
:
...
@@ -48,7 +48,7 @@ class TwoBatchOverlap:
...
@@ -48,7 +48,7 @@ class TwoBatchOverlap:
self
.
event_right_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
def
init_tbo_thread
(
self
):
def
init_tbo_thread
(
self
):
if
self
.
_threads_started
and
PERSIST_THREADS
:
if
self
.
_threads_started
:
return
return
if
self
.
left_thread
is
None
or
not
self
.
left_thread
.
is_alive
():
if
self
.
left_thread
is
None
or
not
self
.
left_thread
.
is_alive
():
self
.
left_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
self
.
left_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
...
@@ -220,7 +220,6 @@ def tbo_all_reduce_v1(obj):
...
@@ -220,7 +220,6 @@ def tbo_all_reduce_v1(obj):
all_reduce_stream
.
wait_event
(
event_c2t
)
all_reduce_stream
.
wait_event
(
event_c2t
)
output
=
tensor_model_parallel_all_reduce
(
obj
)
output
=
tensor_model_parallel_all_reduce
(
obj
)
event_t2c
.
record
()
event_t2c
.
record
()
#tbo_mark_allreduce_done()
tbo_obj_v1
.
tbo_thread_synchronize
(
tid
)
tbo_obj_v1
.
tbo_thread_synchronize
(
tid
)
tbo_step_stream
.
wait_event
(
event_t2c
)
tbo_step_stream
.
wait_event
(
event_t2c
)
return
output
return
output
...
@@ -281,18 +280,6 @@ def tbo_model_executable_v1(
...
@@ -281,18 +280,6 @@ def tbo_model_executable_v1(
return
hidden_or_intermediate_states
return
hidden_or_intermediate_states
_tbo_done_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
def
tbo_mark_allreduce_done
():
"""Record completion of all_reduce_stream for external synchronization."""
global
all_reduce_stream
,
_tbo_done_event
_tbo_done_event
.
record
(
all_reduce_stream
)
def
tbo_get_done_event
():
"""Return the event recorded by all_reduce_stream."""
return
_tbo_done_event
def
finalize_two_batch_overlap
():
def
finalize_two_batch_overlap
():
global
tbo_obj_v1
global
tbo_obj_v1
if
tbo_obj_v1
is
not
None
:
if
tbo_obj_v1
is
not
None
:
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
1a135a9d
...
@@ -1295,7 +1295,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1295,7 +1295,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
profile
.
StartTracer
()
self
.
_update_states
(
scheduler_output
)
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
if
not
scheduler_output
.
total_num_scheduled_tokens
:
if
not
has_kv_transfer_group
():
if
not
has_kv_transfer_group
():
...
@@ -1574,7 +1574,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1574,7 +1574,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
get_kv_transfer_group
().
clear_connector_metadata
()
get_kv_transfer_group
().
clear_connector_metadata
()
self
.
eplb_step
()
self
.
eplb_step
()
print
(
'###valid_sampled_token_ids'
,
valid_sampled_token_ids
)
return
ModelRunnerOutput
(
return
ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
,
req_ids
=
self
.
input_batch
.
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment