Commit fe1c4016 authored by zhuwenwen's avatar zhuwenwen
Browse files

add two batch overlap decude support muti-stream cuda-graph

parent d805c59c
...@@ -328,3 +328,72 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ ...@@ -328,3 +328,72 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
is_prompt=model_input.is_prompt, is_prompt=model_input.is_prompt,
) )
return model_input_left, model_input_right return model_input_left, model_input_right
def split_capture_attention_metadata(attn_metadata, batch_size_left, batch_size_right):
batch_size_split = [batch_size_left, batch_size_right]
split_seq_lens_tensor = torch.split(attn_metadata.seq_lens_tensor, batch_size_split, dim=0)
split_block_tables = torch.split(attn_metadata.block_tables, batch_size_split, dim=0)
split_slot_mapping = torch.split(attn_metadata.slot_mapping, batch_size_split, dim=0)
if isinstance(attn_metadata, ROCmFlashAttentionMetadata):
attn_metadata_left = ROCmFlashAttentionMetadata(
seq_lens_tensor = split_seq_lens_tensor[0],
max_decode_seq_len = attn_metadata.max_decode_seq_len,
block_tables = split_block_tables[0],
num_prefills = 0,
num_prefill_tokens = 0,
num_decode_tokens = batch_size_left,
slot_mapping = split_slot_mapping[0],
multi_modal_placeholder_index_maps = attn_metadata.multi_modal_placeholder_index_maps,
enable_kv_scales_calculation = attn_metadata.enable_kv_scales_calculation,
seq_lens = None,
max_prefill_seq_len = 0,
use_cuda_graph = attn_metadata.use_cuda_graph,
max_query_len = 1,
query_start_loc = None,
seq_start_loc = None,
context_lens_tensor = None,
max_decode_query_len = 1,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
tree_attention_masks_tensor = None,
block_tables_list = None,
encoder_seq_lens = None,
encoder_seq_lens_tensor = None,
max_encoder_seq_len = None,
num_encoder_tokens = None,
cross_slot_mapping = None,
cross_block_tables = None,
)
attn_metadata_right = ROCmFlashAttentionMetadata(
seq_lens_tensor = split_seq_lens_tensor[1],
max_decode_seq_len = attn_metadata.max_decode_seq_len,
block_tables = split_block_tables[1],
num_prefills = 0,
num_prefill_tokens = 0,
num_decode_tokens = batch_size_right,
slot_mapping = split_slot_mapping[1],
multi_modal_placeholder_index_maps = attn_metadata.multi_modal_placeholder_index_maps,
enable_kv_scales_calculation = attn_metadata.enable_kv_scales_calculation,
seq_lens = None,
max_prefill_seq_len = 0,
use_cuda_graph = attn_metadata.use_cuda_graph,
max_query_len = 1,
query_start_loc = None,
seq_start_loc = None,
context_lens_tensor = None,
max_decode_query_len = 1,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
tree_attention_masks_tensor = None,
block_tables_list = None,
encoder_seq_lens = None,
encoder_seq_lens_tensor = None,
max_encoder_seq_len = None,
num_encoder_tokens = None,
cross_slot_mapping = None,
cross_block_tables = None,
)
else:
print("tbo:not surpport in cuda-graph ", type(attn_metadata))
return attn_metadata_left, attn_metadata_right
import gc
import os import os
import queue import queue
import threading import threading
from typing import List, Optional, Tuple
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_model_input from vllm.two_batch_overlap.model_input_split import is_supported_attention_metadata, split_capture_attention_metadata, split_model_input
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm import envs from vllm import envs
from vllm.utils import weak_ref_tensor
tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1' tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'
...@@ -37,6 +40,7 @@ class TwoBatchOverlap(): ...@@ -37,6 +40,7 @@ class TwoBatchOverlap():
self.sem_right = threading.Semaphore(0) self.sem_right = threading.Semaphore(0)
self.left_first = False self.left_first = False
self.tbo_running = False self.tbo_running = False
self.tbo_in_capture = False
if tbo_step_stream == None: if tbo_step_stream == None:
tbo_step_stream = torch.cuda.Stream() tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream() all_reduce_stream = torch.cuda.Stream()
...@@ -84,17 +88,38 @@ class TwoBatchOverlap(): ...@@ -84,17 +88,38 @@ class TwoBatchOverlap():
else: else:
model_kwargs = self.model_kwargs_right model_kwargs = self.model_kwargs_right
intermediate_tensors = self.intermediate_tensors_right intermediate_tensors = self.intermediate_tensors_right
with set_forward_context(model_input.attn_metadata, hidden_or_intermediate_states = None
self.vllm_config, self.virtual_engine): if self.tbo_in_capture:
hidden_or_intermediate_states = self.model_executable( if is_left_thread:
input_ids=model_input.input_tokens, attn_metadata = self.attn_metadata_left
positions=model_input.input_positions, input_tokens = self.input_tokens_left
intermediate_tensors=intermediate_tensors, input_positions = self.split_input_positions[0]
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs, else:
device=self.self_device), attn_metadata = self.attn_metadata_right
**self.seqlen_agnostic_kwargs, input_tokens = self.input_tokens_right
**model_kwargs, input_positions = self.split_input_positions[1]
) with set_forward_context(attn_metadata,
self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable(
input_ids=input_tokens,
positions=input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
device=self.self_device),
**model_kwargs,
)
elif model_input != None:
with set_forward_context(model_input.attn_metadata,
self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
device=self.self_device),
**self.seqlen_agnostic_kwargs,
**model_kwargs,
)
if is_left_thread: if is_left_thread:
self.sem_right.release() self.sem_right.release()
self.states_left_queue.put(hidden_or_intermediate_states) self.states_left_queue.put(hidden_or_intermediate_states)
...@@ -143,6 +168,37 @@ class TwoBatchOverlap(): ...@@ -143,6 +168,37 @@ class TwoBatchOverlap():
self.model_kwargs_right = model_kwargs_right self.model_kwargs_right = model_kwargs_right
self.model_input_left_queue.put(model_input_left) self.model_input_left_queue.put(model_input_left)
self.model_input_right_queue.put(model_input_right) self.model_input_right_queue.put(model_input_right)
def set_capture_model_input(self,
input_tokens_left,
input_tokens_right,
split_input_positions,
vllm_config,
virtual_engine,
runner_model,
runner_device,
intermediate_tensors_left,
intermediate_tensors_right,
model_kwargs_left,
model_kwargs_right,
attn_metadata_left,
attn_metadata_right):
self.input_tokens_left = input_tokens_left
self.input_tokens_right = input_tokens_right
self.split_input_positions = split_input_positions
self.vllm_config = vllm_config
self.virtual_engine = virtual_engine
self.model_executable = runner_model
self.self_device = runner_device
self.intermediate_tensors_left = intermediate_tensors_left
self.intermediate_tensors_right = intermediate_tensors_right
self.model_kwargs_left = model_kwargs_left
self.model_kwargs_right = model_kwargs_right
self.attn_metadata_left = attn_metadata_left
self.attn_metadata_right = attn_metadata_right
self.model_input_left_queue.put(None)
self.model_input_right_queue.put(None)
def get_model_output(self): def get_model_output(self):
states_left = self.states_left_queue.get() states_left = self.states_left_queue.get()
...@@ -280,3 +336,141 @@ def tbo_model_executable( ...@@ -280,3 +336,141 @@ def tbo_model_executable(
current_stream.wait_event(tbo_obj.step_event) current_stream.wait_event(tbo_obj.step_event)
profile.ProfRangePop() profile.ProfRangePop()
return hidden_or_intermediate_states return hidden_or_intermediate_states
def _run_once(vllm_config, virtual_engine,
runner,
self_device,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_inputs: Optional[IntermediateTensors],
attn_metadata: AttentionMetadata,
stream: torch.cuda.Stream,
**kwargs):
global tbo_step_stream
stream_back = tbo_step_stream
tbo_step_stream = stream
init_two_batch_overlap()
tbo_obj.left_first = True
decode_batch_size = input_ids.shape[0]
batch_size_left = int(decode_batch_size / 2)
batch_size_right = decode_batch_size - batch_size_left
query_tokens_split = [batch_size_left, batch_size_right]
input_tokens_left, input_tokens_right = torch.split(input_ids, query_tokens_split, dim=0)
split_input_positions = torch.split(positions, query_tokens_split, dim=0)
model_kwargs_left = kwargs.copy()
model_kwargs_right = kwargs.copy()
intermediate_tensors_left = None
intermediate_tensors_right = None
if "previous_hidden_states" in kwargs:
previous_hidden_states = kwargs["previous_hidden_states"]
split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
if intermediate_inputs != None:
query_tokens_split = [batch_size_left, batch_size_right]
intermediate_tensors_left = {}
intermediate_tensors_right = {}
for key in intermediate_inputs.tensors:
split_intermediate_tensors = torch.split(intermediate_inputs.tensors[key], query_tokens_split, dim=0)
intermediate_tensors_left[key] = split_intermediate_tensors[0]
intermediate_tensors_right[key] = split_intermediate_tensors[1]
intermediate_tensors_left = IntermediateTensors(intermediate_tensors_left)
intermediate_tensors_right = IntermediateTensors(intermediate_tensors_right)
attn_metadata_left, attn_metadata_right = split_capture_attention_metadata(attn_metadata, batch_size_left, batch_size_right)
tbo_obj.tbo_running = True
tbo_obj.tbo_in_capture = True
tbo_obj.set_capture_model_input(input_tokens_left,
input_tokens_right,
split_input_positions,
vllm_config,
virtual_engine,
runner.model,
self_device,
intermediate_tensors_left,
intermediate_tensors_right,
model_kwargs_left,
model_kwargs_right,
attn_metadata_left,
attn_metadata_right)
states_left, states_right = tbo_obj.get_model_output()
output_hidden_or_intermediate_states = merge_model_output(states_left, states_right)
tbo_obj.tbo_in_capture = False
tbo_obj.tbo_running = False
tbo_obj.finish_thread()
tbo_step_stream = stream_back
return output_hidden_or_intermediate_states
def tbo_capture(vllm_config, virtual_engine, _NUM_WARMUP_ITERS,
runner,
self_device,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool: Optional[Tuple[int, int]],
stream: torch.cuda.Stream,
**kwargs):
for i in range(_NUM_WARMUP_ITERS):
_run_once(vllm_config,
virtual_engine,
runner,
self_device,
input_ids,
positions,
intermediate_inputs,
attn_metadata,
torch.cuda.current_stream(),
**kwargs)
torch.cuda.synchronize()
runner._graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(runner._graph, pool=memory_pool, stream=stream):
output_hidden_or_intermediate_states = _run_once(vllm_config,
virtual_engine,
runner,
self_device,
input_ids,
positions,
intermediate_inputs,
attn_metadata,
torch.cuda.current_stream(),
**kwargs)
if isinstance(output_hidden_or_intermediate_states, torch.Tensor):
hidden_or_intermediate_states = weak_ref_tensor(
output_hidden_or_intermediate_states)
elif isinstance(output_hidden_or_intermediate_states,
IntermediateTensors):
hidden_or_intermediate_states = IntermediateTensors(
tensors={
key: weak_ref_tensor(value)
for key, value in
output_hidden_or_intermediate_states.tensors.items()
})
del output_hidden_or_intermediate_states
# make sure `output_hidden_or_intermediate_states` is deleted
# in the graph's memory pool
gc.collect()
torch.cuda.synchronize()
# Save the input and output buffers.
runner.input_buffers = {
"input_ids":
input_ids,
"positions":
positions,
"kv_caches":
kv_caches,
**runner.attn_state.get_graph_input_buffers(
attn_metadata, runner._is_encoder_decoder_model),
**kwargs,
}
if intermediate_inputs is not None:
runner.input_buffers.update(intermediate_inputs.tensors)
if get_pp_group().is_last_rank:
runner.output_buffers = {
"hidden_states": hidden_or_intermediate_states
}
else:
runner.output_buffers = hidden_or_intermediate_states
...@@ -52,7 +52,7 @@ from vllm.prompt_adapter.worker_manager import ( ...@@ -52,7 +52,7 @@ from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager) LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.two_batch_overlap.two_batch_overlap import tbo_model_executable from vllm.two_batch_overlap.two_batch_overlap import tbo_capture, tbo_model_executable
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
async_tensor_h2d, flatten_2d_lists, async_tensor_h2d, flatten_2d_lists,
is_pin_memory_available, supports_dynamo, is_pin_memory_available, supports_dynamo,
...@@ -1668,9 +1668,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1668,9 +1668,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self._update_inputs_to_capture_for_enc_dec_model( self._update_inputs_to_capture_for_enc_dec_model(
capture_inputs) capture_inputs)
with set_forward_context(attn_metadata, self.vllm_config, if envs.VLLM_ENABLE_TBO and envs.VLLM_TBO_DECODE_BS > 1 and batch_size >= envs.VLLM_TBO_DECODE_BS:
virtual_engine): tbo_capture(self.vllm_config, virtual_engine, _NUM_WARMUP_ITERS,
graph_runner.capture(**capture_inputs) graph_runner,
self.device,
**capture_inputs)
else:
with set_forward_context(attn_metadata, self.vllm_config,
virtual_engine):
graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool() self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][( self.graph_runners[virtual_engine][(
batch_size, use_inputs_embeds)] = graph_runner batch_size, use_inputs_embeds)] = graph_runner
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment