Commit fe1c4016 authored by zhuwenwen's avatar zhuwenwen
Browse files

add two batch overlap decude support muti-stream cuda-graph

parent d805c59c
......@@ -328,3 +328,72 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
is_prompt=model_input.is_prompt,
)
return model_input_left, model_input_right
def split_capture_attention_metadata(attn_metadata, batch_size_left, batch_size_right):
batch_size_split = [batch_size_left, batch_size_right]
split_seq_lens_tensor = torch.split(attn_metadata.seq_lens_tensor, batch_size_split, dim=0)
split_block_tables = torch.split(attn_metadata.block_tables, batch_size_split, dim=0)
split_slot_mapping = torch.split(attn_metadata.slot_mapping, batch_size_split, dim=0)
if isinstance(attn_metadata, ROCmFlashAttentionMetadata):
attn_metadata_left = ROCmFlashAttentionMetadata(
seq_lens_tensor = split_seq_lens_tensor[0],
max_decode_seq_len = attn_metadata.max_decode_seq_len,
block_tables = split_block_tables[0],
num_prefills = 0,
num_prefill_tokens = 0,
num_decode_tokens = batch_size_left,
slot_mapping = split_slot_mapping[0],
multi_modal_placeholder_index_maps = attn_metadata.multi_modal_placeholder_index_maps,
enable_kv_scales_calculation = attn_metadata.enable_kv_scales_calculation,
seq_lens = None,
max_prefill_seq_len = 0,
use_cuda_graph = attn_metadata.use_cuda_graph,
max_query_len = 1,
query_start_loc = None,
seq_start_loc = None,
context_lens_tensor = None,
max_decode_query_len = 1,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
tree_attention_masks_tensor = None,
block_tables_list = None,
encoder_seq_lens = None,
encoder_seq_lens_tensor = None,
max_encoder_seq_len = None,
num_encoder_tokens = None,
cross_slot_mapping = None,
cross_block_tables = None,
)
attn_metadata_right = ROCmFlashAttentionMetadata(
seq_lens_tensor = split_seq_lens_tensor[1],
max_decode_seq_len = attn_metadata.max_decode_seq_len,
block_tables = split_block_tables[1],
num_prefills = 0,
num_prefill_tokens = 0,
num_decode_tokens = batch_size_right,
slot_mapping = split_slot_mapping[1],
multi_modal_placeholder_index_maps = attn_metadata.multi_modal_placeholder_index_maps,
enable_kv_scales_calculation = attn_metadata.enable_kv_scales_calculation,
seq_lens = None,
max_prefill_seq_len = 0,
use_cuda_graph = attn_metadata.use_cuda_graph,
max_query_len = 1,
query_start_loc = None,
seq_start_loc = None,
context_lens_tensor = None,
max_decode_query_len = 1,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
tree_attention_masks_tensor = None,
block_tables_list = None,
encoder_seq_lens = None,
encoder_seq_lens_tensor = None,
max_encoder_seq_len = None,
num_encoder_tokens = None,
cross_slot_mapping = None,
cross_block_tables = None,
)
else:
print("tbo:not surpport in cuda-graph ", type(attn_metadata))
return attn_metadata_left, attn_metadata_right
import gc
import os
import queue
import threading
from typing import List, Optional, Tuple
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_capture_attention_metadata, split_model_input
from vllm.logger import init_logger
from vllm.profiler.prof import profile
from vllm import envs
from vllm.utils import weak_ref_tensor
tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'
......@@ -37,6 +40,7 @@ class TwoBatchOverlap():
self.sem_right = threading.Semaphore(0)
self.left_first = False
self.tbo_running = False
self.tbo_in_capture = False
if tbo_step_stream == None:
tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream()
......@@ -84,6 +88,27 @@ class TwoBatchOverlap():
else:
model_kwargs = self.model_kwargs_right
intermediate_tensors = self.intermediate_tensors_right
hidden_or_intermediate_states = None
if self.tbo_in_capture:
if is_left_thread:
attn_metadata = self.attn_metadata_left
input_tokens = self.input_tokens_left
input_positions = self.split_input_positions[0]
else:
attn_metadata = self.attn_metadata_right
input_tokens = self.input_tokens_right
input_positions = self.split_input_positions[1]
with set_forward_context(attn_metadata,
self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable(
input_ids=input_tokens,
positions=input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
device=self.self_device),
**model_kwargs,
)
elif model_input != None:
with set_forward_context(model_input.attn_metadata,
self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable(
......@@ -144,6 +169,37 @@ class TwoBatchOverlap():
self.model_input_left_queue.put(model_input_left)
self.model_input_right_queue.put(model_input_right)
def set_capture_model_input(self,
input_tokens_left,
input_tokens_right,
split_input_positions,
vllm_config,
virtual_engine,
runner_model,
runner_device,
intermediate_tensors_left,
intermediate_tensors_right,
model_kwargs_left,
model_kwargs_right,
attn_metadata_left,
attn_metadata_right):
self.input_tokens_left = input_tokens_left
self.input_tokens_right = input_tokens_right
self.split_input_positions = split_input_positions
self.vllm_config = vllm_config
self.virtual_engine = virtual_engine
self.model_executable = runner_model
self.self_device = runner_device
self.intermediate_tensors_left = intermediate_tensors_left
self.intermediate_tensors_right = intermediate_tensors_right
self.model_kwargs_left = model_kwargs_left
self.model_kwargs_right = model_kwargs_right
self.attn_metadata_left = attn_metadata_left
self.attn_metadata_right = attn_metadata_right
self.model_input_left_queue.put(None)
self.model_input_right_queue.put(None)
def get_model_output(self):
states_left = self.states_left_queue.get()
states_right = self.states_right_queue.get()
......@@ -280,3 +336,141 @@ def tbo_model_executable(
current_stream.wait_event(tbo_obj.step_event)
profile.ProfRangePop()
return hidden_or_intermediate_states
def _run_once(vllm_config, virtual_engine,
runner,
self_device,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_inputs: Optional[IntermediateTensors],
attn_metadata: AttentionMetadata,
stream: torch.cuda.Stream,
**kwargs):
global tbo_step_stream
stream_back = tbo_step_stream
tbo_step_stream = stream
init_two_batch_overlap()
tbo_obj.left_first = True
decode_batch_size = input_ids.shape[0]
batch_size_left = int(decode_batch_size / 2)
batch_size_right = decode_batch_size - batch_size_left
query_tokens_split = [batch_size_left, batch_size_right]
input_tokens_left, input_tokens_right = torch.split(input_ids, query_tokens_split, dim=0)
split_input_positions = torch.split(positions, query_tokens_split, dim=0)
model_kwargs_left = kwargs.copy()
model_kwargs_right = kwargs.copy()
intermediate_tensors_left = None
intermediate_tensors_right = None
if "previous_hidden_states" in kwargs:
previous_hidden_states = kwargs["previous_hidden_states"]
split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
if intermediate_inputs != None:
query_tokens_split = [batch_size_left, batch_size_right]
intermediate_tensors_left = {}
intermediate_tensors_right = {}
for key in intermediate_inputs.tensors:
split_intermediate_tensors = torch.split(intermediate_inputs.tensors[key], query_tokens_split, dim=0)
intermediate_tensors_left[key] = split_intermediate_tensors[0]
intermediate_tensors_right[key] = split_intermediate_tensors[1]
intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
attn_metadata_left, attn_metadata_right = split_capture_attention_metadata(attn_metadata, batch_size_left, batch_size_right)
tbo_obj.tbo_running = True
tbo_obj.tbo_in_capture = True
tbo_obj.set_capture_model_input(input_tokens_left,
input_tokens_right,
split_input_positions,
vllm_config,
virtual_engine,
runner.model,
self_device,
intermediate_tensors_left,
intermediate_tensors_right,
model_kwargs_left,
model_kwargs_right,
attn_metadata_left,
attn_metadata_right)
states_left, states_right = tbo_obj.get_model_output()
output_hidden_or_intermediate_states = merge_model_output(states_left, states_right)
tbo_obj.tbo_in_capture = False
tbo_obj.tbo_running = False
tbo_obj.finish_thread()
tbo_step_stream = stream_back
return output_hidden_or_intermediate_states
def tbo_capture(vllm_config, virtual_engine, _NUM_WARMUP_ITERS,
runner,
self_device,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream,
**kwargs):
for i in range(_NUM_WARMUP_ITERS):
_run_once(vllm_config,
virtual_engine,
runner,
self_device,
input_ids,
positions,
intermediate_inputs,
attn_metadata,
torch.cuda.current_stream(),
**kwargs)
torch.cuda.synchronize()
runner._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(runner._graph, pool=memory_pool, stream=stream):
output_hidden_or_intermediate_states = _run_once(vllm_config,
virtual_engine,
runner,
self_device,
input_ids,
positions,
intermediate_inputs,
attn_metadata,
torch.cuda.current_stream(),
**kwargs)
if isinstance(output_hidden_or_intermediate_states, torch.Tensor):
hidden_or_intermediate_states = weak_ref_tensor(
output_hidden_or_intermediate_states)
elif isinstance(output_hidden_or_intermediate_states,
IntermediateTensors):
hidden_or_intermediate_states = IntermediateTensors(
tensors={
key: weak_ref_tensor(value)
for key, value in
output_hidden_or_intermediate_states.tensors.items()
})
del output_hidden_or_intermediate_states
# make sure `output_hidden_or_intermediate_states` is deleted
# in the graph's memory pool
gc.collect()
torch.cuda.synchronize()
# Save the input and output buffers.
runner.input_buffers = {
"input_ids":
input_ids,
"positions":
positions,
"kv_caches":
kv_caches,
**runner.attn_state.get_graph_input_buffers(
attn_metadata, runner._is_encoder_decoder_model),
**kwargs,
}
if intermediate_inputs is not None:
runner.input_buffers.update(intermediate_inputs.tensors)
if get_pp_group().is_last_rank:
runner.output_buffers = {
"hidden_states": hidden_or_intermediate_states
}
else:
runner.output_buffers = hidden_or_intermediate_states
......@@ -52,7 +52,7 @@ from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.two_batch_overlap.two_batch_overlap import tbo_model_executable
from vllm.two_batch_overlap.two_batch_overlap import tbo_capture, tbo_model_executable
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
async_tensor_h2d, flatten_2d_lists,
is_pin_memory_available, supports_dynamo,
......@@ -1668,6 +1668,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self._update_inputs_to_capture_for_enc_dec_model(
capture_inputs)
if envs.VLLM_ENABLE_TBO and envs.VLLM_TBO_DECODE_BS > 1 and batch_size >= envs.VLLM_TBO_DECODE_BS:
tbo_capture(self.vllm_config, virtual_engine, _NUM_WARMUP_ITERS,
graph_runner,
self.device,
**capture_inputs)
else:
with set_forward_context(attn_metadata, self.vllm_config,
virtual_engine):
graph_runner.capture(**capture_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment