import os import queue import threading import torch from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.forward_context import set_forward_context from vllm.multimodal.inputs import MultiModalKwargs from vllm.two_batch_overlap.forward_context import init_tbo_forward_context from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input from vllm.logger import init_logger from vllm.profiler.prof import profile from vllm import envs enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1' tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1' logger = init_logger(__name__) tbo_step_stream = None all_reduce_stream = None class TwoBatchOverlap(): def __init__(self): global tbo_step_stream global all_reduce_stream self.model_input_left_queue = queue.Queue() self.model_input_right_queue = queue.Queue() self.states_left_queue = queue.Queue() self.states_right_queue = queue.Queue() self.all_reduce_queue = queue.Queue() self.all_reduce_out = queue.Queue() self.left_thread = None self.right_thread = None self.left_tid = 0 self.right_tid = 0 self.sem_left = threading.Semaphore(0) self.sem_right = threading.Semaphore(0) self.left_first = False self.tbo_running = False if tbo_step_stream == None: tbo_step_stream = torch.cuda.Stream() all_reduce_stream = torch.cuda.Stream() self.step_event = torch.cuda.Event(enable_timing=False) self.event_left_c2t = torch.cuda.Event(enable_timing=False) self.event_right_c2t = torch.cuda.Event(enable_timing=False) self.event_left_t2c = torch.cuda.Event(enable_timing=False) self.event_right_t2c = torch.cuda.Event(enable_timing=False) def init_tbo_thread(self): self.model_input_left_queue.empty() self.model_input_right_queue.empty() self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,)) self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,)) self.left_thread.start() self.right_thread.start() logger.info('tbo:two batch overlap threads start') def finish_thread(self): if self.left_thread != None: self.model_input_left_queue.put(None) self.left_thread.join() self.left_thread = None if self.right_thread != None: self.model_input_right_queue.put(None) self.right_thread.join() self.right_thread = None logger.info('tbo:finish threads') @torch.inference_mode() def thread_two_batch_overlap(self, queue): is_left_thread = False tid = threading.get_ident() if queue == self.model_input_left_queue: self.left_tid = tid is_left_thread = True init_tbo_forward_context(True, self.left_tid) else: self.right_tid = tid init_tbo_forward_context(False, self.right_tid) with torch.cuda.stream(tbo_step_stream): while True: model_input = queue.get() if model_input == None: break profile.ProfRangePush('start') self.tbo_thread_synchronize(tid) with set_forward_context(model_input.attn_metadata, self.vllm_config, self.virtual_engine): hidden_or_intermediate_states = self.model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=self.intermediate_tensors, **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs, device=self.self_device), **self.seqlen_agnostic_kwargs, **self.model_kwargs, ) if is_left_thread: self.sem_right.release() self.states_left_queue.put(hidden_or_intermediate_states) else: self.all_reduce_queue.put(None) self.states_right_queue.put(hidden_or_intermediate_states) profile.ProfRangePop() def tbo_thread_synchronize(self, tid): if tid == self.left_tid: if not self.left_first: self.sem_right.release() profile.ProfRangePop() self.sem_left.acquire() profile.ProfRangePush('left') self.left_first = False return self.event_left_c2t, self.event_left_t2c else: self.sem_left.release() profile.ProfRangePop() self.sem_right.acquire() profile.ProfRangePush('right') return self.event_right_c2t, self.event_right_t2c def set_model_input(self, model_input_left, model_input_right, vllm_config, virtual_engine, model_executable, intermediate_tensors, multi_modal_kwargs, self_device, seqlen_agnostic_kwargs, model_kwargs): if self.left_thread == None: self.init_tbo_thread() self.vllm_config = vllm_config self.virtual_engine = virtual_engine self.model_executable = model_executable self.intermediate_tensors = intermediate_tensors self.multi_modal_kwargs = multi_modal_kwargs self.self_device = self_device self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs self.model_kwargs = model_kwargs self.model_input_left_queue.put(model_input_left) self.model_input_right_queue.put(model_input_right) def get_model_output(self): states_left = self.states_left_queue.get() states_right = self.states_right_queue.get() return states_left, states_right def all_reduce(self): while True: obj = self.all_reduce_queue.get() if obj == None: break buf, event_c2t, event_t2c = obj if tbo_one_stream: output = tensor_model_parallel_all_reduce(buf) else: event_c2t.record() with torch.cuda.stream(all_reduce_stream): all_reduce_stream.wait_event(event_c2t) output = tensor_model_parallel_all_reduce(buf) event_t2c.record() self.all_reduce_out.put(output) tbo_obj = None def init_two_batch_overlap(): if envs.VLLM_ENABLE_TBO: global tbo_obj if tbo_obj == None: tbo_obj = TwoBatchOverlap() def finish_two_batch_overlap(): global tbo_obj if tbo_obj != None: tbo_obj.finish_thread() tbo_obj = None def tbo_all_reduce(obj): if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running: tid = threading.get_ident() if not tbo_one_stream: if tid == tbo_obj.left_tid: event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c else: event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c tbo_obj.all_reduce_queue.put([obj, event_c2t, event_t2c]) output = tbo_obj.all_reduce_out.get() tbo_obj.tbo_thread_synchronize(tid) if not tbo_one_stream: tbo_step_stream.wait_event(event_t2c) return output return tensor_model_parallel_all_reduce(obj) def merge_model_output(states_left, states_right): output = torch.concat([states_left, states_right], dim=0) return output def tbo_model_executable( model_input, vllm_config, virtual_engine, model_executable, intermediate_tensors, multi_modal_kwargs, self_device, seqlen_agnostic_kwargs, model_kwargs, ): is_support = is_supported_attention_metadata(model_input.attn_metadata) if not is_support: logger.info("tbo:not surpport yet ", type(model_input.attn_metadata)) is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt batch_size = len(model_input.attn_metadata.seq_lens) if batch_size == 1 or \ (not model_input.is_prompt and not enable_tbo_decode) or \ not is_support or \ is_cuda_graph_decode: with set_forward_context(model_input.attn_metadata, vllm_config, virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self_device), **seqlen_agnostic_kwargs, **model_kwargs, ) return hidden_or_intermediate_states profile.ProfRangePush('tbo_model_executable') init_two_batch_overlap() tbo_obj.tbo_running = True tbo_obj.left_first = True batch_size_left = int(batch_size / 2) batch_size_right = batch_size_left if batch_size % 2 == 1: batch_size_right += 1 model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right) tbo_obj.step_event.record() current_stream = torch.cuda.current_stream() with torch.cuda.stream(tbo_step_stream): tbo_step_stream.wait_event(tbo_obj.step_event) tbo_obj.set_model_input(model_input_left, model_input_right, vllm_config, virtual_engine, model_executable, intermediate_tensors, multi_modal_kwargs, self_device, seqlen_agnostic_kwargs, model_kwargs) tbo_obj.all_reduce() states_left, states_right = tbo_obj.get_model_output() hidden_or_intermediate_states = merge_model_output(states_left, states_right) tbo_obj.tbo_running = False tbo_obj.step_event.record() current_stream.wait_event(tbo_obj.step_event) profile.ProfRangePop() return hidden_or_intermediate_states