Commit 8e838a89 authored by niuhb's avatar niuhb
Browse files

fall back tbo version

parent ffcc47b7
...@@ -159,7 +159,7 @@ def prepare_tbo_atten_metadata( ...@@ -159,7 +159,7 @@ def prepare_tbo_atten_metadata(
# The block_table for RIGHT starts from (req_offset-1). # The block_table for RIGHT starts from (req_offset-1).
# Align both offsets to that, and re-build the seq_lens for row-0. # Align both offsets to that, and re-build the seq_lens for row-0.
seq_len_offset = req_offset - 1 seq_len_offset = req_offset - 1
query_start_offset = req_offset query_start_offset = req_offset - 1
# row-0 is the split request (global row index = req_offset-1): # row-0 is the split request (global row index = req_offset-1):
base_hist = runner.input_batch.num_computed_tokens_cpu[req_offset - 1].item() base_hist = runner.input_batch.num_computed_tokens_cpu[req_offset - 1].item()
...@@ -180,7 +180,7 @@ def prepare_tbo_atten_metadata( ...@@ -180,7 +180,7 @@ def prepare_tbo_atten_metadata(
else: else:
# RIGHT without split-in-req: natural positions # RIGHT without split-in-req: natural positions
seq_len_offset = req_offset seq_len_offset = req_offset
query_start_offset = req_offset + 1 query_start_offset = req_offset
seq_lens_cpu_local = torch.as_tensor(default_seq_lens, device=runner.seq_lens_cpu.device) seq_lens_cpu_local = torch.as_tensor(default_seq_lens, device=runner.seq_lens_cpu.device)
# Copy query_start_loc into global GPU buffer window # Copy query_start_loc into global GPU buffer window
...@@ -201,10 +201,8 @@ def prepare_tbo_atten_metadata( ...@@ -201,10 +201,8 @@ def prepare_tbo_atten_metadata(
runner.seq_lens[seq_len_offset + num_reqs:].fill_(0) runner.seq_lens[seq_len_offset + num_reqs:].fill_(0)
# Build common metadata (pass CLONES to avoid aliasing between threads) # Build common metadata (pass CLONES to avoid aliasing between threads)
# query_start_loc = runner.query_start_loc[query_start_offset: query_start_offset + num_reqs + 1].clone() query_start_loc = runner.query_start_loc[query_start_offset: query_start_offset + num_reqs + 1].clone()
# seq_lens = runner.seq_lens[seq_len_offset : seq_len_offset + num_reqs].clone() seq_lens = runner.seq_lens[seq_len_offset : seq_len_offset + num_reqs].clone()
query_start_loc = runner.query_start_loc[query_start_offset: query_start_offset + num_reqs + 1]
seq_lens = runner.seq_lens[seq_len_offset : seq_len_offset + num_reqs]
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
seq_lens=seq_lens, seq_lens=seq_lens,
...@@ -306,6 +304,8 @@ def tbo_split_and_execute_model( ...@@ -306,6 +304,8 @@ def tbo_split_and_execute_model(
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
skip_cuda_graphs: bool, skip_cuda_graphs: bool,
) -> Union[ModelRunnerOutput, IntermediateTensors]: ) -> Union[ModelRunnerOutput, IntermediateTensors]:
if torch.distributed.get_rank() == 0:
print("###############enter tbo")
# If below TBO threshold, run the normal single-batch path (supports decode/prefill as-is). # If below TBO threshold, run the normal single-batch path (supports decode/prefill as-is).
# Two-batch overlap path # Two-batch overlap path
split_scheduler_output(runner, scheduler_output) split_scheduler_output(runner, scheduler_output)
...@@ -320,44 +320,37 @@ def tbo_split_and_execute_model( ...@@ -320,44 +320,37 @@ def tbo_split_and_execute_model(
) )
# === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector === # === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector ===
# real token nums # 真实 token
num_tokens_left = int(input_split.scheduler_output_left.total_num_scheduled_tokens) real_L = int(input_split.scheduler_output_left.total_num_scheduled_tokens)
num_tokens_right = int(input_split.scheduler_output_right.total_num_scheduled_tokens) real_R = int(input_split.scheduler_output_right.total_num_scheduled_tokens)
# split intermediate tensors # 按左右半批切成两份
def _split_intermediate_tensors(it, l, r): def _split_it(it, l, r):
if it is None: return None, None if it is None: return None, None
left_tensor_map, right_tensor_map = {}, {} lm, rm = {}, {}
for name, tensor in it.tensors.items(): for k, v in it.tensors.items():
vl, vr = torch.split(tensor[:l + r], [l, r], dim=0) vl, vr = torch.split(v[:l + r], [l, r], dim=0)
left_tensor_map[name], right_tensor_map[name] = vl, vr lm[k], rm[k] = vl, vr
return IntermediateTensors(left_tensor_map), IntermediateTensors(right_tensor_map) return IntermediateTensors(lm), IntermediateTensors(rm)
intermediate_tensors_left, intermediate_tensors_right = _split_intermediate_tensors( intermediate_tensors_left, intermediate_tensors_right = _split_it(
intermediate_tensors, num_tokens_left, num_tokens_right intermediate_tensors, real_L, real_R
) )
with set_forward_context(attn_metadata,
runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=True):
runner.maybe_setup_kv_connector(scheduler_output) runner.maybe_setup_kv_connector(scheduler_output)
model_output = tbo_model_executable_v1( model_output = tbo_model_executable_v1(
runner, runner,
attn_metadata_left, attn_metadata_left, attn_metadata_right,
attn_metadata_right,
num_input_tokens_left, num_input_tokens_left,
num_input_tokens_right, num_input_tokens_right,
num_tokens_across_dp, num_tokens_across_dp,
input_ids, input_ids, positions,
positions,
(intermediate_tensors_left, intermediate_tensors_right), (intermediate_tensors_left, intermediate_tensors_right),
inputs_embeds) inputs_embeds,
)
runner.maybe_wait_for_kv_save() runner.maybe_wait_for_kv_save()
finished_sending, finished_recving = ( finished_sending, finished_recving = runner.get_finished_kv_transfers(scheduler_output)
runner.get_finished_kv_transfers(scheduler_output))
return model_output, finished_sending, finished_recving return model_output, finished_sending, finished_recving
# import os
# import queue
# import threading
# import torch
# from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
# from vllm.distributed.parallel_state import get_tp_group
# from vllm.forward_context import set_forward_context
# from vllm.multimodal.inputs import MultiModalKwargs
# from vllm.sequence import IntermediateTensors
# from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
# from vllm.logger import init_logger
# from vllm.profiler.prof import profile
# from vllm import envs
# logger = init_logger(__name__)
# tbo_step_stream = None
# all_reduce_stream = None
# STOP = object()
# class TwoBatchOverlap:
# def __init__(self):
# global tbo_step_stream, all_reduce_stream
# self.model_input_left_queue = queue.Queue()
# self.model_input_right_queue = queue.Queue()
# self.states_left_queue = queue.Queue()
# self.states_right_queue = queue.Queue()
# self.left_thread = None
# self.right_thread = None
# self.left_tid = 0
# self.right_tid = 0
# self._stop_evt = threading.Event()
# self._threads_started = False
# self.sem_left = threading.Semaphore(0)
# self.sem_right = threading.Semaphore(0)
# self.left_first = False
# self.tbo_running = False
# self.tbo_in_capture = False
# if tbo_step_stream is None:
# tbo_step_stream = torch.cuda.Stream()
# all_reduce_stream = torch.cuda.Stream()
# self.step_event = torch.cuda.Event(enable_timing=False)
# self.event_left_c2t = torch.cuda.Event(enable_timing=False)
# self.event_right_c2t = torch.cuda.Event(enable_timing=False)
# self.event_left_t2c = torch.cuda.Event(enable_timing=False)
# self.event_right_t2c = torch.cuda.Event(enable_timing=False)
# def init_tbo_thread(self):
# if self._threads_started:
# return
# if self.left_thread is None or not self.left_thread.is_alive():
# self.left_thread = threading.Thread(target=self.thread_two_batch_overlap,
# args=(self.model_input_left_queue,), daemon=True)
# self.left_thread.start()
# if self.right_thread is None or not self.right_thread.is_alive():
# self.right_thread = threading.Thread(target=self.thread_two_batch_overlap,
# args=(self.model_input_right_queue,), daemon=True)
# self.right_thread.start()
# self._threads_started = True
# def shutdown(self, timeout=5.0):
# self._stop_evt.set()
# try:
# self.model_input_left_queue.put(STOP)
# self.model_input_right_queue.put(STOP)
# except Exception:
# pass
# if self.left_thread is not None:
# self.left_thread.join(timeout=timeout)
# self.left_thread = None
# if self.right_thread is not None:
# self.right_thread.join(timeout=timeout)
# self.right_thread = None
# @torch.inference_mode()
# def thread_two_batch_overlap(self, q):
# is_left_thread = False
# tid = threading.get_ident()
# if q is self.model_input_left_queue:
# self.left_tid = tid
# is_left_thread = True
# init_tbo_forward_context(True, self.left_tid)
# else:
# self.right_tid = tid
# init_tbo_forward_context(False, self.right_tid)
# while not self._stop_evt.is_set():
# item = q.get()
# if item is STOP:
# break
# with torch.cuda.stream(tbo_step_stream):
# self.tbo_thread_synchronize(tid)
# if is_left_thread:
# attn_metadata = self.attn_metadata_left
# num_input_tokens = self.num_input_tokens_left
# input_ids = self.input_ids_left
# positions = self.positions_left
# else:
# attn_metadata = self.attn_metadata_right
# num_input_tokens = self.num_input_tokens_right
# input_ids = self.input_ids_right
# positions = self.positions_right
# # Select per-thread tensors (left/right) with backward-compatible fallback
# if is_left_thread:
# intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
# else:
# intermediate_tensors = getattr(self, 'intermediate_tensors_right', None)
# if intermediate_tensors is None:
# intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
# with set_forward_context(attn_metadata,
# self.model_runner.vllm_config,
# num_tokens=num_input_tokens,
# num_tokens_across_dp=self.num_tokens_across_dp,
# skip_cuda_graphs=True,
# ):
# model_output = self.model_runner.model(
# input_ids=input_ids,
# positions=positions,
# intermediate_tensors=intermediate_tensors,
# inputs_embeds=self.inputs_embeds,
# )
# if is_left_thread:
# self.sem_right.release()
# self.states_left_queue.put(model_output)
# else:
# self.states_right_queue.put(model_output)
# def tbo_thread_synchronize(self, tid):
# if tid == self.left_tid:
# if not self.left_first:
# self.sem_right.release()
# self.left_first = False
# self.sem_left.acquire()
# return self.event_left_c2t, self.event_left_t2c
# else:
# self.sem_left.release()
# self.sem_right.acquire()
# return self.event_right_c2t, self.event_right_t2c
# def set_model_input(self,
# model_runner,
# attn_metadata_left,
# attn_metadata_right,
# num_input_tokens_left,
# num_input_tokens_right,
# input_ids_left,
# input_ids_right,
# positions_left,
# positions_right,
# num_tokens_across_dp,
# intermediate_tensors,
# inputs_embeds,
# ):
# self.model_runner = model_runner
# self.attn_metadata_left = attn_metadata_left
# self.attn_metadata_right = attn_metadata_right
# self.num_input_tokens_left = num_input_tokens_left
# self.num_input_tokens_right = num_input_tokens_right
# self.input_ids_left = input_ids_left
# self.input_ids_right = input_ids_right
# self.positions_left = positions_left
# self.positions_right = positions_right
# self.num_tokens_across_dp = num_tokens_across_dp
# self.inputs_embeds = inputs_embeds
# if isinstance(intermediate_tensors, tuple):
# self.intermediate_tensors_left, self.intermediate_tensors_right = intermediate_tensors
# else:
# self.intermediate_tensors_left = intermediate_tensors
# self.intermediate_tensors_right = None
# self.model_input_left_queue.put(None)
# self.model_input_right_queue.put(None)
# def get_model_output(self):
# states_left = self.states_left_queue.get()
# states_right = self.states_right_queue.get()
# return states_left, states_right
# tbo_obj_v1 = None
# def is_enable_tbo_v1():
# global tbo_obj_v1
# return tbo_obj_v1 is not None
# def init_two_batch_overlap():
# global tbo_obj_v1
# if tbo_obj_v1 is None:
# tbo_obj_v1 = TwoBatchOverlap()
# tbo_obj_v1.init_tbo_thread()
# def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
# from vllm.attention.layer import maybe_save_kv_layer_to_connector
# if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
# tid = threading.get_ident()
# if tid == tbo_obj_v1.left_tid:
# return
# maybe_save_kv_layer_to_connector(layer_name, kv_cache)
# def tbo_all_reduce_v1(obj):
# if envs.VLLM_ENABLE_TBO and tbo_obj_v1 is not None and tbo_obj_v1.tbo_running:
# tid = threading.get_ident()
# if tid == tbo_obj_v1.left_tid:
# event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c
# else:
# event_c2t, event_t2c = tbo_obj_v1.event_right_c2t, tbo_obj_v1.event_right_t2c
# event_c2t.record()
# with torch.cuda.stream(all_reduce_stream):
# all_reduce_stream.wait_event(event_c2t)
# output = tensor_model_parallel_all_reduce(obj)
# event_t2c.record()
# tbo_obj_v1.tbo_thread_synchronize(tid)
# tbo_step_stream.wait_event(event_t2c)
# return output
# return tensor_model_parallel_all_reduce(obj)
# def merge_model_output(states_left, states_right):
# if isinstance(states_left, IntermediateTensors):
# output_map = {}
# for key in states_left.tensors:
# output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
# output = IntermediateTensors(output_map)
# else:
# output = torch.concat([states_left, states_right], dim=0)
# return output
# def tbo_model_executable_v1(
# model_runner,
# attn_metadata_left,
# attn_metadata_right,
# num_input_tokens_left,
# num_input_tokens_right,
# num_tokens_across_dp,
# input_ids,
# positions,
# intermediate_tensors,
# inputs_embeds,
# ):
# init_two_batch_overlap()
# tbo_obj_v1.tbo_running = True
# tbo_obj_v1.left_first = True
# tbo_obj_v1.step_event.record()
# current_stream = torch.cuda.current_stream()
# num_total_tokens = num_input_tokens_left + num_input_tokens_right
# with torch.cuda.stream(tbo_step_stream):
# tbo_step_stream.wait_event(tbo_obj_v1.step_event)
# tokens_split = [num_input_tokens_left, num_input_tokens_right]
# input_ids_left, input_ids_right = torch.split(input_ids[:num_total_tokens], tokens_split, dim=0)
# positions_left, positions_right = torch.split(positions[:num_total_tokens], tokens_split, dim=0)
# tbo_obj_v1.set_model_input(model_runner,
# attn_metadata_left,
# attn_metadata_right,
# num_input_tokens_left,
# num_input_tokens_right,
# input_ids_left,
# input_ids_right,
# positions_left,
# positions_right,
# num_tokens_across_dp,
# intermediate_tensors,
# inputs_embeds,
# )
# model_output_left, model_output_right = tbo_obj_v1.get_model_output()
# hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right)
# tbo_obj_v1.tbo_running = False
# tbo_obj_v1.step_event.record()
# current_stream.wait_event(tbo_obj_v1.step_event)
# return hidden_or_intermediate_states
# def finalize_two_batch_overlap():
# global tbo_obj_v1
# if tbo_obj_v1 is not None:
# try:
# tbo_obj_v1.shutdown()
# finally:
# tbo_obj_v1 = None
import os import os
import queue import queue
import threading import threading
...@@ -17,7 +306,7 @@ logger = init_logger(__name__) ...@@ -17,7 +306,7 @@ logger = init_logger(__name__)
tbo_step_stream = None tbo_step_stream = None
all_reduce_stream = None all_reduce_stream = None
PERSIST_THREADS = os.getenv('VLLM_TBO_PERSIST_THREADS', '1') not in ('0','false','False','no','NO','')
STOP = object() STOP = object()
class TwoBatchOverlap: class TwoBatchOverlap:
...@@ -48,7 +337,7 @@ class TwoBatchOverlap: ...@@ -48,7 +337,7 @@ class TwoBatchOverlap:
self.event_right_t2c = torch.cuda.Event(enable_timing=False) self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self): def init_tbo_thread(self):
if self._threads_started: if self._threads_started and PERSIST_THREADS:
return return
if self.left_thread is None or not self.left_thread.is_alive(): if self.left_thread is None or not self.left_thread.is_alive():
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, self.left_thread = threading.Thread(target=self.thread_two_batch_overlap,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment