import os import queue import threading import torch from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.forward_context import set_forward_context from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal.inputs import MultiModalKwargs from vllm.two_batch_overlap.forward_context import init_tbo_forward_context from vllm.utils import async_tensor_h2d from vllm.logger import init_logger from vllm.profiler.prof import profile enable_tbo = os.environ.get('VLLM_ENABLE_TBO') == '1' enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1' tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1' logger = init_logger(__name__) def is_enable_tbo(): return enable_tbo class TwoBatchOverlap(): def __init__(self): self.model_input_left_queue = queue.Queue() self.model_input_right_queue = queue.Queue() self.states_left_queue = queue.Queue() self.states_right_queue = queue.Queue() self.all_reduce_queue = queue.Queue() self.all_reduce_out = queue.Queue() self.left_thread = None self.right_thread = None self.left_tid = 0 self.right_tid = 0 self.sem_left = threading.Semaphore(0) self.sem_right = threading.Semaphore(0) self.left_first = False self.tbo_running = False self.stream = torch.cuda.Stream() self.step_stream = torch.cuda.Stream() self.step_event = torch.cuda.Event(enable_timing=False) self.event_left_c2t = torch.cuda.Event(enable_timing=False) self.event_right_c2t = torch.cuda.Event(enable_timing=False) self.event_left_t2c = torch.cuda.Event(enable_timing=False) self.event_right_t2c = torch.cuda.Event(enable_timing=False) def init_tbo_thread(self): self.model_input_left_queue.empty() self.model_input_right_queue.empty() self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,)) self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,)) self.left_thread.start() self.right_thread.start() def finish_thread(self): if self.left_thread != None: self.model_input_left_queue.put(None) self.left_thread.join() self.left_thread = None if self.right_thread != None: self.model_input_right_queue.put(None) self.right_thread.join() self.right_thread = None logger.info('tbo:finish threads') @torch.inference_mode() def thread_two_batch_overlap(self, queue): is_left_thread = False tid = threading.get_ident() if queue == self.model_input_left_queue: self.left_tid = tid is_left_thread = True logger.info('tbo:new thread %d', self.left_tid) init_tbo_forward_context(True, self.left_tid) else: self.right_tid = tid logger.info('tbo:new thread %d', self.right_tid) init_tbo_forward_context(False, self.right_tid) with torch.cuda.stream(self.step_stream): while True: model_input = queue.get() if model_input == None: break profile.ProfRangePush('start') self.tbo_thread_synchronize(tid) with set_forward_context(model_input.attn_metadata, self.vllm_config, self.virtual_engine): hidden_or_intermediate_states = self.model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=self.intermediate_tensors, **MultiModalKwargs.as_kwargs(self.multi_modal_kwargs, device=self.self_device), **self.seqlen_agnostic_kwargs, **self.model_kwargs, ) if is_left_thread: self.sem_right.release() self.states_left_queue.put(hidden_or_intermediate_states) else: self.all_reduce_queue.put(None) self.states_right_queue.put(hidden_or_intermediate_states) profile.ProfRangePop() def tbo_thread_synchronize(self, tid): if tid == self.left_tid: if not self.left_first: self.sem_right.release() profile.ProfRangePop() self.sem_left.acquire() profile.ProfRangePush('left') self.left_first = False return self.event_left_c2t, self.event_left_t2c else: self.sem_left.release() profile.ProfRangePop() self.sem_right.acquire() profile.ProfRangePush('right') return self.event_right_c2t, self.event_right_t2c def set_model_input(self, model_input_left, model_input_right, vllm_config, virtual_engine, model_executable, intermediate_tensors, multi_modal_kwargs, self_device, seqlen_agnostic_kwargs, model_kwargs): if self.left_thread == None: self.init_tbo_thread() self.vllm_config = vllm_config self.virtual_engine = virtual_engine self.model_executable = model_executable self.intermediate_tensors = intermediate_tensors self.multi_modal_kwargs = multi_modal_kwargs self.self_device = self_device self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs self.model_kwargs = model_kwargs self.model_input_left_queue.put(model_input_left) self.model_input_right_queue.put(model_input_right) def get_model_output(self): states_left = self.states_left_queue.get() states_right = self.states_right_queue.get() return states_left, states_right def all_reduce(self): while True: obj = self.all_reduce_queue.get() if obj == None: break buf, event_c2t, event_t2c = obj if tbo_one_stream: output = tensor_model_parallel_all_reduce(buf) else: event_c2t.record() with torch.cuda.stream(self.stream): self.stream.wait_event(event_c2t) output = tensor_model_parallel_all_reduce(buf) event_t2c.record() self.all_reduce_out.put(output) tbo_obj = None def init_two_batch_overlap(): if enable_tbo: global tbo_obj if tbo_obj == None: tbo_obj = TwoBatchOverlap() def finish_two_batch_overlap(): global tbo_obj if tbo_obj != None: tbo_obj.finish_thread() tbo_obj = None def tbo_all_reduce(obj): if enable_tbo and tbo_obj != None and tbo_obj.tbo_running: tid = threading.get_ident() if not tbo_one_stream: if tid == tbo_obj.left_tid: event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c else: event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c tbo_obj.all_reduce_queue.put([obj, event_c2t, event_t2c]) output = tbo_obj.all_reduce_out.get() tbo_obj.tbo_thread_synchronize(tid) if not tbo_one_stream: tbo_obj.step_stream.wait_event(event_t2c) return output return tensor_model_parallel_all_reduce(obj) def cumsum(lst): cum_lst = [0] sum = 0 for i in range(0, len(lst)): sum = sum + lst[i] cum_lst.append(sum) return cum_lst def split_model_input(model_input, self_device, batch_size_left, batch_size_right): query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])] batch_size_split = [batch_size_left, batch_size_right] split_input_tokens = torch.split(model_input.input_tokens, query_tokens_split, dim=0) split_input_positions = torch.split(model_input.input_positions, query_tokens_split, dim=0) seq_lens_left = model_input.attn_metadata.seq_lens[0:batch_size_left] seq_lens_right = model_input.attn_metadata.seq_lens[batch_size_left:] query_lens_left = model_input.query_lens[0:batch_size_left] query_lens_right = model_input.query_lens[batch_size_left:] split_seq_lens_tensor = torch.split(model_input.attn_metadata.seq_lens_tensor, batch_size_split, dim=0) split_block_tables = torch.split(model_input.attn_metadata.block_tables, batch_size_split, dim=0) num_prefills_left = 0 num_prefills_right = 0 num_prefill_tokens_left = 0 num_prefill_tokens_right = 0 num_decode_tokens_left = 0 num_decode_tokens_right = 0 max_prefill_seq_len_left = 0 max_prefill_seq_len_right = 0 max_decode_seq_len_left = 0 max_decode_seq_len_right = 0 max_decode_query_len_left = None max_decode_query_len_right = None encoder_seq_lens_left = None encoder_seq_lens_right = None encoder_seq_lens_tensor_left = None encoder_seq_lens_tensor_right = None max_encoder_seq_len_left = None max_encoder_seq_len_right = None num_encoder_tokens_left = None num_encoder_tokens_right = None cross_slot_mapping_left = None cross_slot_mapping_right = None cross_block_tables_left = None cross_block_tables_right = None if model_input.is_prompt: num_prefills_left = batch_size_left num_prefills_right = batch_size_right num_prefill_tokens_left = sum(model_input.query_lens[0:batch_size_left]) num_prefill_tokens_right = sum(model_input.query_lens[batch_size_left:]) max_prefill_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left]) max_prefill_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:]) else: num_decode_tokens_left = batch_size_left num_decode_tokens_right = batch_size_right max_decode_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left]) max_decode_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:]) split_slot_mapping = torch.split(model_input.attn_metadata.slot_mapping, query_tokens_split, dim=0) max_query_len_left = max(model_input.query_lens[0:batch_size_left]) max_query_len_right = max(model_input.query_lens[batch_size_left:]) zero_tensor = torch.tensor([0], device=self_device, dtype=torch.int32) query_start_loc_left_list = cumsum(query_lens_left) query_start_loc_right_list = cumsum(query_lens_right) query_start_loc_left = async_tensor_h2d(query_start_loc_left_list, torch.int32, self_device, True) query_start_loc_right = async_tensor_h2d(query_start_loc_right_list, torch.int32, self_device, True) seq_start_loc_left = torch.cat((zero_tensor, split_seq_lens_tensor[0].cumsum(dim=0)), dim=0).to(torch.int32) seq_start_loc_right = torch.cat((zero_tensor, split_seq_lens_tensor[1].cumsum(dim=0)), dim=0).to(torch.int32) split_context_lens_tensor = torch.split(model_input.attn_metadata.context_lens_tensor, batch_size_split, dim=0) block_tables_list_left = model_input.attn_metadata.block_tables_list[0:batch_size_left] block_tables_list_right = model_input.attn_metadata.block_tables_list[batch_size_left:] request_ids_to_seq_ids_left = {} request_ids_to_seq_ids_right = {} counter = 0 for key, value in model_input.request_ids_to_seq_ids.items(): if counter < batch_size_left: request_ids_to_seq_ids_left[key] = value else: request_ids_to_seq_ids_right[key] = value counter += 1 seq_groups_left = None seq_groups_right = None if model_input.sampling_metadata.seq_groups is not None: seq_groups_left = model_input.sampling_metadata.seq_groups[0:batch_size_left] seq_groups_right = model_input.sampling_metadata.seq_groups[batch_size_left:] selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1 selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1 from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata attn_metadata_left = ROCmFlashAttentionMetadata( seq_lens_tensor = split_seq_lens_tensor[0], max_decode_seq_len = max_decode_seq_len_left, block_tables = split_block_tables[0], num_prefills = num_prefills_left, num_prefill_tokens = num_prefill_tokens_left, num_decode_tokens = num_decode_tokens_left, slot_mapping = split_slot_mapping[0], multi_modal_placeholder_index_maps = {}, enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation, seq_lens = seq_lens_left, max_prefill_seq_len = max_prefill_seq_len_left, use_cuda_graph = model_input.attn_metadata.use_cuda_graph, max_query_len = max_query_len_left, query_start_loc = query_start_loc_left, seq_start_loc = seq_start_loc_left, context_lens_tensor = split_context_lens_tensor[0], max_decode_query_len = max_decode_query_len_left, _cached_prefill_metadata = None, _cached_decode_metadata = None, tree_attention_masks_tensor = None, block_tables_list = block_tables_list_left, encoder_seq_lens = encoder_seq_lens_left, encoder_seq_lens_tensor = encoder_seq_lens_tensor_left, max_encoder_seq_len = max_encoder_seq_len_left, num_encoder_tokens = num_encoder_tokens_left, cross_slot_mapping = cross_slot_mapping_left, cross_block_tables = cross_block_tables_left, ) model_input_left = ModelInputForGPUWithSamplingMetadata( input_tokens=split_input_tokens[0], input_positions=split_input_positions[0], token_types=None, seq_lens=seq_lens_left, query_lens=query_lens_left, lora_mapping=model_input.lora_mapping, lora_requests=model_input.lora_requests, attn_metadata=attn_metadata_left, prompt_adapter_mapping=model_input.prompt_adapter_mapping, prompt_adapter_requests=model_input.prompt_adapter_requests, multi_modal_kwargs=model_input.multi_modal_kwargs, request_ids_to_seq_ids=request_ids_to_seq_ids_left, finished_requests_ids=model_input.finished_requests_ids, virtual_engine=model_input.virtual_engine, async_callback=model_input.async_callback, scheduler_outputs=model_input.scheduler_outputs, previous_hidden_states=model_input.previous_hidden_states, sampling_metadata=SamplingMetadata( seq_groups=seq_groups_left, selected_token_indices=selected_token_indices_left, categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices, num_prompts=num_prefills_left, skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output, reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors, ), is_prompt=model_input.is_prompt, ) attn_metadata_right = ROCmFlashAttentionMetadata( seq_lens_tensor = split_seq_lens_tensor[1], max_decode_seq_len = max_decode_seq_len_right, block_tables = split_block_tables[1], num_prefills = num_prefills_right, num_prefill_tokens = num_prefill_tokens_right, num_decode_tokens = num_decode_tokens_right, slot_mapping = split_slot_mapping[1], multi_modal_placeholder_index_maps = {}, enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation, seq_lens = seq_lens_right, max_prefill_seq_len = max_prefill_seq_len_right, use_cuda_graph = model_input.attn_metadata.use_cuda_graph, max_query_len = max_query_len_right, query_start_loc = query_start_loc_right, seq_start_loc = seq_start_loc_right, context_lens_tensor = split_context_lens_tensor[1], max_decode_query_len = max_decode_query_len_right, _cached_prefill_metadata = None, _cached_decode_metadata = None, tree_attention_masks_tensor = None, block_tables_list = block_tables_list_right, encoder_seq_lens = encoder_seq_lens_right, encoder_seq_lens_tensor = encoder_seq_lens_tensor_right, max_encoder_seq_len = max_encoder_seq_len_right, num_encoder_tokens = num_encoder_tokens_right, cross_slot_mapping = cross_slot_mapping_right, cross_block_tables = cross_block_tables_right, ) model_input_right = ModelInputForGPUWithSamplingMetadata( input_tokens=split_input_tokens[1], input_positions=split_input_positions[1], token_types=None, seq_lens=seq_lens_right, query_lens=query_lens_right, lora_mapping=model_input.lora_mapping, lora_requests=model_input.lora_requests, attn_metadata=attn_metadata_right, prompt_adapter_mapping=model_input.prompt_adapter_mapping, prompt_adapter_requests=model_input.prompt_adapter_requests, multi_modal_kwargs=model_input.multi_modal_kwargs, request_ids_to_seq_ids=request_ids_to_seq_ids_right, finished_requests_ids=model_input.finished_requests_ids, virtual_engine=model_input.virtual_engine, async_callback=model_input.async_callback, scheduler_outputs=model_input.scheduler_outputs, previous_hidden_states=model_input.previous_hidden_states, sampling_metadata=SamplingMetadata( seq_groups=seq_groups_right, selected_token_indices=selected_token_indices_right, categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices, num_prompts=num_prefills_right, skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output, reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors, ), is_prompt=model_input.is_prompt, ) return model_input_left, model_input_right def merge_model_output(states_left, states_right): output = torch.concat([states_left, states_right], dim=0) return output def tbo_model_executable( model_input, vllm_config, virtual_engine, model_executable, intermediate_tensors, multi_modal_kwargs, self_device, seqlen_agnostic_kwargs, model_kwargs, ): profile.ProfRangePush('tbo_model_executable') init_two_batch_overlap() is_rocm_fa = isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata) is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt batch_size = len(model_input.attn_metadata.seq_lens) if batch_size == 1 or \ (not model_input.is_prompt and not enable_tbo_decode) or \ not is_rocm_fa or \ is_cuda_graph_decode: with set_forward_context(model_input.attn_metadata, vllm_config, virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, device=self_device), **seqlen_agnostic_kwargs, **model_kwargs, ) return hidden_or_intermediate_states tbo_obj.tbo_running = True tbo_obj.left_first = True batch_size_left = int(batch_size / 2) batch_size_right = batch_size_left if batch_size % 2 == 1: batch_size_right += 1 tbo_obj.step_event.record() current_stream = torch.cuda.current_stream() with torch.cuda.stream(tbo_obj.step_stream): tbo_obj.step_stream.wait_event(tbo_obj.step_event) model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right) tbo_obj.set_model_input(model_input_left, model_input_right, vllm_config, virtual_engine, model_executable, intermediate_tensors, multi_modal_kwargs, self_device, seqlen_agnostic_kwargs, model_kwargs) tbo_obj.all_reduce() states_left, states_right = tbo_obj.get_model_output() hidden_or_intermediate_states = merge_model_output(states_left, states_right) tbo_obj.tbo_running = False tbo_obj.step_event.record() current_stream.wait_event(tbo_obj.step_event) profile.ProfRangePop() return hidden_or_intermediate_states