Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8e838a89
Commit
8e838a89
authored
Dec 17, 2025
by
niuhb
Browse files
fall back tbo version
parent
ffcc47b7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
328 additions
and
46 deletions
+328
-46
vllm/two_batch_overlap/v1/model_input_split_v1.py
vllm/two_batch_overlap/v1/model_input_split_v1.py
+36
-43
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
+292
-3
No files found.
vllm/two_batch_overlap/v1/model_input_split_v1.py
View file @
8e838a89
...
@@ -159,7 +159,7 @@ def prepare_tbo_atten_metadata(
...
@@ -159,7 +159,7 @@ def prepare_tbo_atten_metadata(
# The block_table for RIGHT starts from (req_offset-1).
# The block_table for RIGHT starts from (req_offset-1).
# Align both offsets to that, and re-build the seq_lens for row-0.
# Align both offsets to that, and re-build the seq_lens for row-0.
seq_len_offset
=
req_offset
-
1
seq_len_offset
=
req_offset
-
1
query_start_offset
=
req_offset
query_start_offset
=
req_offset
-
1
# row-0 is the split request (global row index = req_offset-1):
# row-0 is the split request (global row index = req_offset-1):
base_hist
=
runner
.
input_batch
.
num_computed_tokens_cpu
[
req_offset
-
1
].
item
()
base_hist
=
runner
.
input_batch
.
num_computed_tokens_cpu
[
req_offset
-
1
].
item
()
...
@@ -180,7 +180,7 @@ def prepare_tbo_atten_metadata(
...
@@ -180,7 +180,7 @@ def prepare_tbo_atten_metadata(
else
:
else
:
# RIGHT without split-in-req: natural positions
# RIGHT without split-in-req: natural positions
seq_len_offset
=
req_offset
seq_len_offset
=
req_offset
query_start_offset
=
req_offset
+
1
query_start_offset
=
req_offset
seq_lens_cpu_local
=
torch
.
as_tensor
(
default_seq_lens
,
device
=
runner
.
seq_lens_cpu
.
device
)
seq_lens_cpu_local
=
torch
.
as_tensor
(
default_seq_lens
,
device
=
runner
.
seq_lens_cpu
.
device
)
# Copy query_start_loc into global GPU buffer window
# Copy query_start_loc into global GPU buffer window
...
@@ -201,10 +201,8 @@ def prepare_tbo_atten_metadata(
...
@@ -201,10 +201,8 @@ def prepare_tbo_atten_metadata(
runner
.
seq_lens
[
seq_len_offset
+
num_reqs
:].
fill_
(
0
)
runner
.
seq_lens
[
seq_len_offset
+
num_reqs
:].
fill_
(
0
)
# Build common metadata (pass CLONES to avoid aliasing between threads)
# Build common metadata (pass CLONES to avoid aliasing between threads)
# query_start_loc = runner.query_start_loc[query_start_offset: query_start_offset + num_reqs + 1].clone()
query_start_loc
=
runner
.
query_start_loc
[
query_start_offset
:
query_start_offset
+
num_reqs
+
1
].
clone
()
# seq_lens = runner.seq_lens[seq_len_offset : seq_len_offset + num_reqs].clone()
seq_lens
=
runner
.
seq_lens
[
seq_len_offset
:
seq_len_offset
+
num_reqs
].
clone
()
query_start_loc
=
runner
.
query_start_loc
[
query_start_offset
:
query_start_offset
+
num_reqs
+
1
]
seq_lens
=
runner
.
seq_lens
[
seq_len_offset
:
seq_len_offset
+
num_reqs
]
common_attn_metadata
=
CommonAttentionMetadata
(
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
seq_lens
=
seq_lens
,
...
@@ -306,6 +304,8 @@ def tbo_split_and_execute_model(
...
@@ -306,6 +304,8 @@ def tbo_split_and_execute_model(
intermediate_tensors
:
Optional
[
IntermediateTensors
],
intermediate_tensors
:
Optional
[
IntermediateTensors
],
skip_cuda_graphs
:
bool
,
skip_cuda_graphs
:
bool
,
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
"###############enter tbo"
)
# If below TBO threshold, run the normal single-batch path (supports decode/prefill as-is).
# If below TBO threshold, run the normal single-batch path (supports decode/prefill as-is).
# Two-batch overlap path
# Two-batch overlap path
split_scheduler_output
(
runner
,
scheduler_output
)
split_scheduler_output
(
runner
,
scheduler_output
)
...
@@ -320,44 +320,37 @@ def tbo_split_and_execute_model(
...
@@ -320,44 +320,37 @@ def tbo_split_and_execute_model(
)
)
# === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector ===
# === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector ===
#
real
token
nums
#
真实
token
num_tokens_left
=
int
(
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
)
real_L
=
int
(
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
)
num_tokens_right
=
int
(
input_split
.
scheduler_output_right
.
total_num_scheduled_tokens
)
real_R
=
int
(
input_split
.
scheduler_output_right
.
total_num_scheduled_tokens
)
#
split intermediate tensors
#
按左右半批切成两份
def
_split_i
ntermediate_tensors
(
it
,
l
,
r
):
def
_split_i
t
(
it
,
l
,
r
):
if
it
is
None
:
return
None
,
None
if
it
is
None
:
return
None
,
None
l
eft_tensor_map
,
right_tensor_map
=
{},
{}
l
m
,
rm
=
{},
{}
for
name
,
tensor
in
it
.
tensors
.
items
():
for
k
,
v
in
it
.
tensors
.
items
():
vl
,
vr
=
torch
.
split
(
tensor
[:
l
+
r
],
[
l
,
r
],
dim
=
0
)
vl
,
vr
=
torch
.
split
(
v
[:
l
+
r
],
[
l
,
r
],
dim
=
0
)
l
eft_tensor_map
[
name
],
right_tensor_map
[
name
]
=
vl
,
vr
l
m
[
k
],
rm
[
k
]
=
vl
,
vr
return
IntermediateTensors
(
l
eft_tensor_map
),
IntermediateTensors
(
r
ight_tensor_map
)
return
IntermediateTensors
(
l
m
),
IntermediateTensors
(
r
m
)
intermediate_tensors_left
,
intermediate_tensors_right
=
_split_i
ntermediate_tensors
(
intermediate_tensors_left
,
intermediate_tensors_right
=
_split_i
t
(
intermediate_tensors
,
num_tokens_left
,
num_tokens_right
intermediate_tensors
,
real_L
,
real_R
)
)
with
set_forward_context
(
attn_metadata
,
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
model_output
=
tbo_model_executable_v1
(
num_tokens_across_dp
=
num_tokens_across_dp
,
runner
,
skip_cuda_graphs
=
True
):
attn_metadata_left
,
attn_metadata_right
,
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
num_input_tokens_left
,
num_input_tokens_right
,
model_output
=
tbo_model_executable_v1
(
num_tokens_across_dp
,
runner
,
input_ids
,
positions
,
attn_metadata_left
,
(
intermediate_tensors_left
,
intermediate_tensors_right
),
attn_metadata_right
,
inputs_embeds
,
num_input_tokens_left
,
)
num_input_tokens_right
,
num_tokens_across_dp
,
runner
.
maybe_wait_for_kv_save
()
input_ids
,
finished_sending
,
finished_recving
=
runner
.
get_finished_kv_transfers
(
scheduler_output
)
positions
,
(
intermediate_tensors_left
,
intermediate_tensors_right
),
return
model_output
,
finished_sending
,
finished_recving
inputs_embeds
)
runner
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
runner
.
get_finished_kv_transfers
(
scheduler_output
))
return
model_output
,
finished_sending
,
finished_recving
\ No newline at end of file
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
View file @
8e838a89
# import os
# import queue
# import threading
# import torch
# from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
# from vllm.distributed.parallel_state import get_tp_group
# from vllm.forward_context import set_forward_context
# from vllm.multimodal.inputs import MultiModalKwargs
# from vllm.sequence import IntermediateTensors
# from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
# from vllm.logger import init_logger
# from vllm.profiler.prof import profile
# from vllm import envs
# logger = init_logger(__name__)
# tbo_step_stream = None
# all_reduce_stream = None
# STOP = object()
# class TwoBatchOverlap:
# def __init__(self):
# global tbo_step_stream, all_reduce_stream
# self.model_input_left_queue = queue.Queue()
# self.model_input_right_queue = queue.Queue()
# self.states_left_queue = queue.Queue()
# self.states_right_queue = queue.Queue()
# self.left_thread = None
# self.right_thread = None
# self.left_tid = 0
# self.right_tid = 0
# self._stop_evt = threading.Event()
# self._threads_started = False
# self.sem_left = threading.Semaphore(0)
# self.sem_right = threading.Semaphore(0)
# self.left_first = False
# self.tbo_running = False
# self.tbo_in_capture = False
# if tbo_step_stream is None:
# tbo_step_stream = torch.cuda.Stream()
# all_reduce_stream = torch.cuda.Stream()
# self.step_event = torch.cuda.Event(enable_timing=False)
# self.event_left_c2t = torch.cuda.Event(enable_timing=False)
# self.event_right_c2t = torch.cuda.Event(enable_timing=False)
# self.event_left_t2c = torch.cuda.Event(enable_timing=False)
# self.event_right_t2c = torch.cuda.Event(enable_timing=False)
# def init_tbo_thread(self):
# if self._threads_started:
# return
# if self.left_thread is None or not self.left_thread.is_alive():
# self.left_thread = threading.Thread(target=self.thread_two_batch_overlap,
# args=(self.model_input_left_queue,), daemon=True)
# self.left_thread.start()
# if self.right_thread is None or not self.right_thread.is_alive():
# self.right_thread = threading.Thread(target=self.thread_two_batch_overlap,
# args=(self.model_input_right_queue,), daemon=True)
# self.right_thread.start()
# self._threads_started = True
# def shutdown(self, timeout=5.0):
# self._stop_evt.set()
# try:
# self.model_input_left_queue.put(STOP)
# self.model_input_right_queue.put(STOP)
# except Exception:
# pass
# if self.left_thread is not None:
# self.left_thread.join(timeout=timeout)
# self.left_thread = None
# if self.right_thread is not None:
# self.right_thread.join(timeout=timeout)
# self.right_thread = None
# @torch.inference_mode()
# def thread_two_batch_overlap(self, q):
# is_left_thread = False
# tid = threading.get_ident()
# if q is self.model_input_left_queue:
# self.left_tid = tid
# is_left_thread = True
# init_tbo_forward_context(True, self.left_tid)
# else:
# self.right_tid = tid
# init_tbo_forward_context(False, self.right_tid)
# while not self._stop_evt.is_set():
# item = q.get()
# if item is STOP:
# break
# with torch.cuda.stream(tbo_step_stream):
# self.tbo_thread_synchronize(tid)
# if is_left_thread:
# attn_metadata = self.attn_metadata_left
# num_input_tokens = self.num_input_tokens_left
# input_ids = self.input_ids_left
# positions = self.positions_left
# else:
# attn_metadata = self.attn_metadata_right
# num_input_tokens = self.num_input_tokens_right
# input_ids = self.input_ids_right
# positions = self.positions_right
# # Select per-thread tensors (left/right) with backward-compatible fallback
# if is_left_thread:
# intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
# else:
# intermediate_tensors = getattr(self, 'intermediate_tensors_right', None)
# if intermediate_tensors is None:
# intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
# with set_forward_context(attn_metadata,
# self.model_runner.vllm_config,
# num_tokens=num_input_tokens,
# num_tokens_across_dp=self.num_tokens_across_dp,
# skip_cuda_graphs=True,
# ):
# model_output = self.model_runner.model(
# input_ids=input_ids,
# positions=positions,
# intermediate_tensors=intermediate_tensors,
# inputs_embeds=self.inputs_embeds,
# )
# if is_left_thread:
# self.sem_right.release()
# self.states_left_queue.put(model_output)
# else:
# self.states_right_queue.put(model_output)
# def tbo_thread_synchronize(self, tid):
# if tid == self.left_tid:
# if not self.left_first:
# self.sem_right.release()
# self.left_first = False
# self.sem_left.acquire()
# return self.event_left_c2t, self.event_left_t2c
# else:
# self.sem_left.release()
# self.sem_right.acquire()
# return self.event_right_c2t, self.event_right_t2c
# def set_model_input(self,
# model_runner,
# attn_metadata_left,
# attn_metadata_right,
# num_input_tokens_left,
# num_input_tokens_right,
# input_ids_left,
# input_ids_right,
# positions_left,
# positions_right,
# num_tokens_across_dp,
# intermediate_tensors,
# inputs_embeds,
# ):
# self.model_runner = model_runner
# self.attn_metadata_left = attn_metadata_left
# self.attn_metadata_right = attn_metadata_right
# self.num_input_tokens_left = num_input_tokens_left
# self.num_input_tokens_right = num_input_tokens_right
# self.input_ids_left = input_ids_left
# self.input_ids_right = input_ids_right
# self.positions_left = positions_left
# self.positions_right = positions_right
# self.num_tokens_across_dp = num_tokens_across_dp
# self.inputs_embeds = inputs_embeds
# if isinstance(intermediate_tensors, tuple):
# self.intermediate_tensors_left, self.intermediate_tensors_right = intermediate_tensors
# else:
# self.intermediate_tensors_left = intermediate_tensors
# self.intermediate_tensors_right = None
# self.model_input_left_queue.put(None)
# self.model_input_right_queue.put(None)
# def get_model_output(self):
# states_left = self.states_left_queue.get()
# states_right = self.states_right_queue.get()
# return states_left, states_right
# tbo_obj_v1 = None
# def is_enable_tbo_v1():
# global tbo_obj_v1
# return tbo_obj_v1 is not None
# def init_two_batch_overlap():
# global tbo_obj_v1
# if tbo_obj_v1 is None:
# tbo_obj_v1 = TwoBatchOverlap()
# tbo_obj_v1.init_tbo_thread()
# def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
# from vllm.attention.layer import maybe_save_kv_layer_to_connector
# if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
# tid = threading.get_ident()
# if tid == tbo_obj_v1.left_tid:
# return
# maybe_save_kv_layer_to_connector(layer_name, kv_cache)
# def tbo_all_reduce_v1(obj):
# if envs.VLLM_ENABLE_TBO and tbo_obj_v1 is not None and tbo_obj_v1.tbo_running:
# tid = threading.get_ident()
# if tid == tbo_obj_v1.left_tid:
# event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c
# else:
# event_c2t, event_t2c = tbo_obj_v1.event_right_c2t, tbo_obj_v1.event_right_t2c
# event_c2t.record()
# with torch.cuda.stream(all_reduce_stream):
# all_reduce_stream.wait_event(event_c2t)
# output = tensor_model_parallel_all_reduce(obj)
# event_t2c.record()
# tbo_obj_v1.tbo_thread_synchronize(tid)
# tbo_step_stream.wait_event(event_t2c)
# return output
# return tensor_model_parallel_all_reduce(obj)
# def merge_model_output(states_left, states_right):
# if isinstance(states_left, IntermediateTensors):
# output_map = {}
# for key in states_left.tensors:
# output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
# output = IntermediateTensors(output_map)
# else:
# output = torch.concat([states_left, states_right], dim=0)
# return output
# def tbo_model_executable_v1(
# model_runner,
# attn_metadata_left,
# attn_metadata_right,
# num_input_tokens_left,
# num_input_tokens_right,
# num_tokens_across_dp,
# input_ids,
# positions,
# intermediate_tensors,
# inputs_embeds,
# ):
# init_two_batch_overlap()
# tbo_obj_v1.tbo_running = True
# tbo_obj_v1.left_first = True
# tbo_obj_v1.step_event.record()
# current_stream = torch.cuda.current_stream()
# num_total_tokens = num_input_tokens_left + num_input_tokens_right
# with torch.cuda.stream(tbo_step_stream):
# tbo_step_stream.wait_event(tbo_obj_v1.step_event)
# tokens_split = [num_input_tokens_left, num_input_tokens_right]
# input_ids_left, input_ids_right = torch.split(input_ids[:num_total_tokens], tokens_split, dim=0)
# positions_left, positions_right = torch.split(positions[:num_total_tokens], tokens_split, dim=0)
# tbo_obj_v1.set_model_input(model_runner,
# attn_metadata_left,
# attn_metadata_right,
# num_input_tokens_left,
# num_input_tokens_right,
# input_ids_left,
# input_ids_right,
# positions_left,
# positions_right,
# num_tokens_across_dp,
# intermediate_tensors,
# inputs_embeds,
# )
# model_output_left, model_output_right = tbo_obj_v1.get_model_output()
# hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right)
# tbo_obj_v1.tbo_running = False
# tbo_obj_v1.step_event.record()
# current_stream.wait_event(tbo_obj_v1.step_event)
# return hidden_or_intermediate_states
# def finalize_two_batch_overlap():
# global tbo_obj_v1
# if tbo_obj_v1 is not None:
# try:
# tbo_obj_v1.shutdown()
# finally:
# tbo_obj_v1 = None
import
os
import
os
import
queue
import
queue
import
threading
import
threading
...
@@ -17,7 +306,7 @@ logger = init_logger(__name__)
...
@@ -17,7 +306,7 @@ logger = init_logger(__name__)
tbo_step_stream
=
None
tbo_step_stream
=
None
all_reduce_stream
=
None
all_reduce_stream
=
None
PERSIST_THREADS
=
os
.
getenv
(
'VLLM_TBO_PERSIST_THREADS'
,
'1'
)
not
in
(
'0'
,
'false'
,
'False'
,
'no'
,
'NO'
,
''
)
STOP
=
object
()
STOP
=
object
()
class
TwoBatchOverlap
:
class
TwoBatchOverlap
:
...
@@ -48,7 +337,7 @@ class TwoBatchOverlap:
...
@@ -48,7 +337,7 @@ class TwoBatchOverlap:
self
.
event_right_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
event_right_t2c
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
def
init_tbo_thread
(
self
):
def
init_tbo_thread
(
self
):
if
self
.
_threads_started
:
if
self
.
_threads_started
and
PERSIST_THREADS
:
return
return
if
self
.
left_thread
is
None
or
not
self
.
left_thread
.
is_alive
():
if
self
.
left_thread
is
None
or
not
self
.
left_thread
.
is_alive
():
self
.
left_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
self
.
left_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
...
@@ -286,4 +575,4 @@ def finalize_two_batch_overlap():
...
@@ -286,4 +575,4 @@ def finalize_two_batch_overlap():
try
:
try
:
tbo_obj_v1
.
shutdown
()
tbo_obj_v1
.
shutdown
()
finally
:
finally
:
tbo_obj_v1
=
None
tbo_obj_v1
=
None
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment