Commit 5aa6d7c2 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5-zero_overhead' into 'v0.8.5.post1-dev'

V0.8.5 zero overhead

See merge request dcutoolkit/deeplearing/vllm!115
parents 15587bd8 828aeaae
...@@ -62,6 +62,7 @@ from vllm.utils import (Counter, Device, deprecate_kwargs, ...@@ -62,6 +62,7 @@ from vllm.utils import (Counter, Device, deprecate_kwargs,
resolve_obj_by_qualname, weak_bind) resolve_obj_by_qualname, weak_bind)
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
from vllm.worker.model_runner_base import InputProcessingError from vllm.worker.model_runner_base import InputProcessingError
from vllm.profiler.prof import profile
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
...@@ -413,6 +414,7 @@ class LLMEngine: ...@@ -413,6 +414,7 @@ class LLMEngine:
# Flag to set when an input fails to process and the engine should run # Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling. # the next step without re-scheduling.
self._skip_scheduling_next_step = False self._skip_scheduling_next_step = False
profile.StartTracer()
def _initialize_kv_caches(self) -> None: def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s). """Initialize the KV cache in the worker(s).
......
...@@ -16,6 +16,7 @@ from vllm.lora.request import LoRARequest ...@@ -16,6 +16,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.two_batch_overlap.two_batch_overlap import finish_two_batch_overlap
from vllm.utils import make_async from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
...@@ -143,6 +144,7 @@ class ExecutorBase(ABC): ...@@ -143,6 +144,7 @@ class ExecutorBase(ABC):
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
"""Releases parallel workers from model loop.""" """Releases parallel workers from model loop."""
finish_two_batch_overlap()
return return
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
...@@ -301,6 +303,7 @@ class DistributedExecutorBase(ExecutorBase): ...@@ -301,6 +303,7 @@ class DistributedExecutorBase(ExecutorBase):
return driver_outputs return driver_outputs
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
finish_two_batch_overlap()
if self.parallel_worker_tasks is None: if self.parallel_worker_tasks is None:
return return
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
...@@ -16,6 +17,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group, ...@@ -16,6 +17,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
is_v1_kv_transfer_group) is_v1_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.two_batch_overlap.forward_context import get_tbo_forward_context, set_tbo_forward_context
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
...@@ -28,6 +30,9 @@ forward_start_time: float = 0 ...@@ -28,6 +30,9 @@ forward_start_time: float = 0
batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list) batchsize_forward_time: defaultdict = defaultdict(list)
enable_tbo = os.environ.get('VLLM_ENABLE_TBO') == '1'
def is_enable_tbo():
return enable_tbo
@dataclass @dataclass
class DPMetadata: class DPMetadata:
...@@ -50,6 +55,14 @@ _forward_context: Optional[ForwardContext] = None ...@@ -50,6 +55,14 @@ _forward_context: Optional[ForwardContext] = None
def get_forward_context() -> ForwardContext: def get_forward_context() -> ForwardContext:
if is_enable_tbo():
forward_context = get_tbo_forward_context()
"""Get the current forward context."""
assert forward_context is not None, (
"Forward context is not set. "
"Please use `set_forward_context` to set the forward context.")
return forward_context
"""Get the current forward context.""" """Get the current forward context."""
assert _forward_context is not None, ( assert _forward_context is not None, (
"Forward context is not set. " "Forward context is not set. "
...@@ -112,7 +125,8 @@ def set_forward_context(attn_metadata: Any, ...@@ -112,7 +125,8 @@ def set_forward_context(attn_metadata: Any,
kv_connector = get_kv_transfer_group() kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1) assert isinstance(kv_connector, KVConnectorBase_V1)
kv_connector.start_load_kv(_forward_context) kv_connector.start_load_kv(_forward_context)
if is_enable_tbo():
set_tbo_forward_context(_forward_context)
try: try:
yield yield
finally: finally:
...@@ -157,3 +171,5 @@ def set_forward_context(attn_metadata: Any, ...@@ -157,3 +171,5 @@ def set_forward_context(attn_metadata: Any,
kv_connector.wait_for_save() kv_connector.wait_for_save()
_forward_context = prev_context _forward_context = prev_context
if is_enable_tbo():
set_tbo_forward_context(_forward_context)
...@@ -1237,6 +1237,9 @@ class RowParallelLinear(LinearBase): ...@@ -1237,6 +1237,9 @@ class RowParallelLinear(LinearBase):
}) })
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo
self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -1307,6 +1310,9 @@ class RowParallelLinear(LinearBase): ...@@ -1307,6 +1310,9 @@ class RowParallelLinear(LinearBase):
input_parallel, input_parallel,
bias=bias_) bias=bias_)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
if self.enable_tbo:
output = self.tbo_all_reduce(output_parallel)
else:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
else: else:
output = output_parallel output = output_parallel
......
...@@ -283,6 +283,9 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -283,6 +283,9 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_padded, self.num_embeddings_padded,
params_dtype=params_dtype, params_dtype=params_dtype,
weight_loader=self.weight_loader) weight_loader=self.weight_loader)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce, is_enable_tbo
self.tbo_all_reduce = tbo_all_reduce
self.enable_tbo = is_enable_tbo()
@classmethod @classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
...@@ -434,6 +437,9 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -434,6 +437,9 @@ class VocabParallelEmbedding(torch.nn.Module):
if self.tp_size > 1: if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
if self.enable_tbo:
output = self.tbo_all_reduce(output_parallel)
else:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
return output return output
......
import threading
_forward_context_left = None
_forward_context_right = None
_left_tid = 0
_right_tid = 0
def init_tbo_forward_context(left_flag, tid):
global _left_tid
global _right_tid
if left_flag:
_left_tid = tid
else:
_right_tid = tid
def set_tbo_forward_context(_forward_context):
global _forward_context_left
global _forward_context_right
tid = threading.get_ident()
if tid == _left_tid:
_forward_context_left = _forward_context
else:
_forward_context_right = _forward_context
def get_tbo_forward_context():
tid = threading.get_ident()
if tid == _left_tid:
return _forward_context_left
else:
return _forward_context_right
import os
import queue
import threading
import torch
from vllm.attention.backends.rocm_flash_attn import ROCmFlashAttentionMetadata
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.forward_context import set_forward_context
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.utils import async_tensor_h2d
from vllm.logger import init_logger
from vllm.profiler.prof import profile
enable_tbo = os.environ.get('VLLM_ENABLE_TBO') == '1'
enable_tbo_decode = os.environ.get('VLLM_TBO_DECODE') == '1'
tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'
logger = init_logger(__name__)
def is_enable_tbo():
return enable_tbo
tbo_step_stream = None
all_reduce_stream = None
class TwoBatchOverlap():
def __init__(self):
global tbo_step_stream
global all_reduce_stream
self.model_input_left_queue = queue.Queue()
self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue()
self.states_right_queue = queue.Queue()
self.all_reduce_queue = queue.Queue()
self.all_reduce_out = queue.Queue()
self.left_thread = None
self.right_thread = None
self.left_tid = 0
self.right_tid = 0
self.sem_left = threading.Semaphore(0)
self.sem_right = threading.Semaphore(0)
self.left_first = False
self.tbo_running = False
if tbo_step_stream == None:
tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False)
self.event_left_c2t = torch.cuda.Event(enable_timing=False)
self.event_right_c2t = torch.cuda.Event(enable_timing=False)
self.event_left_t2c = torch.cuda.Event(enable_timing=False)
self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self):
self.model_input_left_queue.empty()
self.model_input_right_queue.empty()
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.left_thread.start()
self.right_thread.start()
def finish_thread(self):
if self.left_thread != None:
self.model_input_left_queue.put(None)
self.left_thread.join()
self.left_thread = None
if self.right_thread != None:
self.model_input_right_queue.put(None)
self.right_thread.join()
self.right_thread = None
logger.info('tbo:finish threads')
@torch.inference_mode()
def thread_two_batch_overlap(self, queue):
is_left_thread = False
tid = threading.get_ident()
if queue == self.model_input_left_queue:
self.left_tid = tid
is_left_thread = True
logger.info('tbo:new thread %d', self.left_tid)
init_tbo_forward_context(True, self.left_tid)
else:
self.right_tid = tid
logger.info('tbo:new thread %d', self.right_tid)
init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(tbo_step_stream):
while True:
model_input = queue.get()
if model_input == None:
break
profile.ProfRangePush('start')
self.tbo_thread_synchronize(tid)
with set_forward_context(model_input.attn_metadata,
self.vllm_config, self.virtual_engine):
hidden_or_intermediate_states = self.model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=self.intermediate_tensors,
**MultiModalKwargs.as_kwargs(self.multi_modal_kwargs,
device=self.self_device),
**self.seqlen_agnostic_kwargs,
**self.model_kwargs,
)
if is_left_thread:
self.sem_right.release()
self.states_left_queue.put(hidden_or_intermediate_states)
else:
self.all_reduce_queue.put(None)
self.states_right_queue.put(hidden_or_intermediate_states)
profile.ProfRangePop()
def tbo_thread_synchronize(self, tid):
if tid == self.left_tid:
if not self.left_first:
self.sem_right.release()
profile.ProfRangePop()
self.sem_left.acquire()
profile.ProfRangePush('left')
self.left_first = False
return self.event_left_c2t, self.event_left_t2c
else:
self.sem_left.release()
profile.ProfRangePop()
self.sem_right.acquire()
profile.ProfRangePush('right')
return self.event_right_c2t, self.event_right_t2c
def set_model_input(self,
model_input_left,
model_input_right,
vllm_config,
virtual_engine,
model_executable,
intermediate_tensors,
multi_modal_kwargs,
self_device,
seqlen_agnostic_kwargs,
model_kwargs):
if self.left_thread == None:
self.init_tbo_thread()
self.vllm_config = vllm_config
self.virtual_engine = virtual_engine
self.model_executable = model_executable
self.intermediate_tensors = intermediate_tensors
self.multi_modal_kwargs = multi_modal_kwargs
self.self_device = self_device
self.seqlen_agnostic_kwargs = seqlen_agnostic_kwargs
self.model_kwargs = model_kwargs
self.model_input_left_queue.put(model_input_left)
self.model_input_right_queue.put(model_input_right)
def get_model_output(self):
states_left = self.states_left_queue.get()
states_right = self.states_right_queue.get()
return states_left, states_right
def all_reduce(self):
while True:
obj = self.all_reduce_queue.get()
if obj == None:
break
buf, event_c2t, event_t2c = obj
if tbo_one_stream:
output = tensor_model_parallel_all_reduce(buf)
else:
event_c2t.record()
with torch.cuda.stream(all_reduce_stream):
all_reduce_stream.wait_event(event_c2t)
output = tensor_model_parallel_all_reduce(buf)
event_t2c.record()
self.all_reduce_out.put(output)
tbo_obj = None
def init_two_batch_overlap():
if enable_tbo:
global tbo_obj
if tbo_obj == None:
tbo_obj = TwoBatchOverlap()
def finish_two_batch_overlap():
global tbo_obj
if tbo_obj != None:
tbo_obj.finish_thread()
tbo_obj = None
def tbo_all_reduce(obj):
if enable_tbo and tbo_obj != None and tbo_obj.tbo_running:
tid = threading.get_ident()
if not tbo_one_stream:
if tid == tbo_obj.left_tid:
event_c2t, event_t2c = tbo_obj.event_left_c2t, tbo_obj.event_left_t2c
else:
event_c2t, event_t2c = tbo_obj.event_right_c2t, tbo_obj.event_right_t2c
tbo_obj.all_reduce_queue.put([obj, event_c2t, event_t2c])
output = tbo_obj.all_reduce_out.get()
tbo_obj.tbo_thread_synchronize(tid)
if not tbo_one_stream:
tbo_step_stream.wait_event(event_t2c)
return output
return tensor_model_parallel_all_reduce(obj)
def cumsum(lst):
cum_lst = [0]
sum = 0
for i in range(0, len(lst)):
sum = sum + lst[i]
cum_lst.append(sum)
return cum_lst
def split_model_input(model_input, self_device, batch_size_left, batch_size_right):
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
batch_size_split = [batch_size_left, batch_size_right]
split_input_tokens = torch.split(model_input.input_tokens, query_tokens_split, dim=0)
split_input_positions = torch.split(model_input.input_positions, query_tokens_split, dim=0)
seq_lens_left = model_input.attn_metadata.seq_lens[0:batch_size_left]
seq_lens_right = model_input.attn_metadata.seq_lens[batch_size_left:]
query_lens_left = model_input.query_lens[0:batch_size_left]
query_lens_right = model_input.query_lens[batch_size_left:]
split_seq_lens_tensor = torch.split(model_input.attn_metadata.seq_lens_tensor, batch_size_split, dim=0)
split_block_tables = torch.split(model_input.attn_metadata.block_tables, batch_size_split, dim=0)
num_prefills_left = 0
num_prefills_right = 0
num_prefill_tokens_left = 0
num_prefill_tokens_right = 0
num_decode_tokens_left = 0
num_decode_tokens_right = 0
max_prefill_seq_len_left = 0
max_prefill_seq_len_right = 0
max_decode_seq_len_left = 0
max_decode_seq_len_right = 0
max_decode_query_len_left = None
max_decode_query_len_right = None
encoder_seq_lens_left = None
encoder_seq_lens_right = None
encoder_seq_lens_tensor_left = None
encoder_seq_lens_tensor_right = None
max_encoder_seq_len_left = None
max_encoder_seq_len_right = None
num_encoder_tokens_left = None
num_encoder_tokens_right = None
cross_slot_mapping_left = None
cross_slot_mapping_right = None
cross_block_tables_left = None
cross_block_tables_right = None
if model_input.is_prompt:
num_prefills_left = batch_size_left
num_prefills_right = batch_size_right
num_prefill_tokens_left = sum(model_input.query_lens[0:batch_size_left])
num_prefill_tokens_right = sum(model_input.query_lens[batch_size_left:])
max_prefill_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left])
max_prefill_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:])
else:
num_decode_tokens_left = batch_size_left
num_decode_tokens_right = batch_size_right
max_decode_seq_len_left = max(model_input.attn_metadata.seq_lens[0:batch_size_left])
max_decode_seq_len_right = max(model_input.attn_metadata.seq_lens[batch_size_left:])
split_slot_mapping = torch.split(model_input.attn_metadata.slot_mapping, query_tokens_split, dim=0)
max_query_len_left = max(model_input.query_lens[0:batch_size_left])
max_query_len_right = max(model_input.query_lens[batch_size_left:])
zero_tensor = torch.tensor([0], device=self_device, dtype=torch.int32)
query_start_loc_left_list = cumsum(query_lens_left)
query_start_loc_right_list = cumsum(query_lens_right)
query_start_loc_left = async_tensor_h2d(query_start_loc_left_list, torch.int32,
self_device,
True)
query_start_loc_right = async_tensor_h2d(query_start_loc_right_list, torch.int32,
self_device,
True)
seq_start_loc_left = torch.cat((zero_tensor, split_seq_lens_tensor[0].cumsum(dim=0)), dim=0).to(torch.int32)
seq_start_loc_right = torch.cat((zero_tensor, split_seq_lens_tensor[1].cumsum(dim=0)), dim=0).to(torch.int32)
split_context_lens_tensor = torch.split(model_input.attn_metadata.context_lens_tensor, batch_size_split, dim=0)
block_tables_list_left = model_input.attn_metadata.block_tables_list[0:batch_size_left]
block_tables_list_right = model_input.attn_metadata.block_tables_list[batch_size_left:]
request_ids_to_seq_ids_left = {}
request_ids_to_seq_ids_right = {}
counter = 0
for key, value in model_input.request_ids_to_seq_ids.items():
if counter < batch_size_left:
request_ids_to_seq_ids_left[key] = value
else:
request_ids_to_seq_ids_right[key] = value
counter += 1
seq_groups_left = None
seq_groups_right = None
if model_input.sampling_metadata.seq_groups is not None:
seq_groups_left = model_input.sampling_metadata.seq_groups[0:batch_size_left]
seq_groups_right = model_input.sampling_metadata.seq_groups[batch_size_left:]
selected_token_indices_left = split_seq_lens_tensor[0].cumsum(dim=0) - 1
selected_token_indices_right = split_seq_lens_tensor[1].cumsum(dim=0) - 1
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
attn_metadata_left = ROCmFlashAttentionMetadata(
seq_lens_tensor = split_seq_lens_tensor[0],
max_decode_seq_len = max_decode_seq_len_left,
block_tables = split_block_tables[0],
num_prefills = num_prefills_left,
num_prefill_tokens = num_prefill_tokens_left,
num_decode_tokens = num_decode_tokens_left,
slot_mapping = split_slot_mapping[0],
multi_modal_placeholder_index_maps = {},
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
seq_lens = seq_lens_left,
max_prefill_seq_len = max_prefill_seq_len_left,
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
max_query_len = max_query_len_left,
query_start_loc = query_start_loc_left,
seq_start_loc = seq_start_loc_left,
context_lens_tensor = split_context_lens_tensor[0],
max_decode_query_len = max_decode_query_len_left,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
tree_attention_masks_tensor = None,
block_tables_list = block_tables_list_left,
encoder_seq_lens = encoder_seq_lens_left,
encoder_seq_lens_tensor = encoder_seq_lens_tensor_left,
max_encoder_seq_len = max_encoder_seq_len_left,
num_encoder_tokens = num_encoder_tokens_left,
cross_slot_mapping = cross_slot_mapping_left,
cross_block_tables = cross_block_tables_left,
)
model_input_left = ModelInputForGPUWithSamplingMetadata(
input_tokens=split_input_tokens[0],
input_positions=split_input_positions[0],
token_types=None,
seq_lens=seq_lens_left,
query_lens=query_lens_left,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
attn_metadata=attn_metadata_left,
prompt_adapter_mapping=model_input.prompt_adapter_mapping,
prompt_adapter_requests=model_input.prompt_adapter_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids_left,
finished_requests_ids=model_input.finished_requests_ids,
virtual_engine=model_input.virtual_engine,
async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=model_input.previous_hidden_states,
sampling_metadata=SamplingMetadata(
seq_groups=seq_groups_left,
selected_token_indices=selected_token_indices_left,
categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
num_prompts=num_prefills_left,
skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
),
is_prompt=model_input.is_prompt,
)
attn_metadata_right = ROCmFlashAttentionMetadata(
seq_lens_tensor = split_seq_lens_tensor[1],
max_decode_seq_len = max_decode_seq_len_right,
block_tables = split_block_tables[1],
num_prefills = num_prefills_right,
num_prefill_tokens = num_prefill_tokens_right,
num_decode_tokens = num_decode_tokens_right,
slot_mapping = split_slot_mapping[1],
multi_modal_placeholder_index_maps = {},
enable_kv_scales_calculation = model_input.attn_metadata.enable_kv_scales_calculation,
seq_lens = seq_lens_right,
max_prefill_seq_len = max_prefill_seq_len_right,
use_cuda_graph = model_input.attn_metadata.use_cuda_graph,
max_query_len = max_query_len_right,
query_start_loc = query_start_loc_right,
seq_start_loc = seq_start_loc_right,
context_lens_tensor = split_context_lens_tensor[1],
max_decode_query_len = max_decode_query_len_right,
_cached_prefill_metadata = None,
_cached_decode_metadata = None,
tree_attention_masks_tensor = None,
block_tables_list = block_tables_list_right,
encoder_seq_lens = encoder_seq_lens_right,
encoder_seq_lens_tensor = encoder_seq_lens_tensor_right,
max_encoder_seq_len = max_encoder_seq_len_right,
num_encoder_tokens = num_encoder_tokens_right,
cross_slot_mapping = cross_slot_mapping_right,
cross_block_tables = cross_block_tables_right,
)
model_input_right = ModelInputForGPUWithSamplingMetadata(
input_tokens=split_input_tokens[1],
input_positions=split_input_positions[1],
token_types=None,
seq_lens=seq_lens_right,
query_lens=query_lens_right,
lora_mapping=model_input.lora_mapping,
lora_requests=model_input.lora_requests,
attn_metadata=attn_metadata_right,
prompt_adapter_mapping=model_input.prompt_adapter_mapping,
prompt_adapter_requests=model_input.prompt_adapter_requests,
multi_modal_kwargs=model_input.multi_modal_kwargs,
request_ids_to_seq_ids=request_ids_to_seq_ids_right,
finished_requests_ids=model_input.finished_requests_ids,
virtual_engine=model_input.virtual_engine,
async_callback=model_input.async_callback,
scheduler_outputs=model_input.scheduler_outputs,
previous_hidden_states=model_input.previous_hidden_states,
sampling_metadata=SamplingMetadata(
seq_groups=seq_groups_right,
selected_token_indices=selected_token_indices_right,
categorized_sample_indices=model_input.sampling_metadata.categorized_sample_indices,
num_prompts=num_prefills_right,
skip_sampler_cpu_output=model_input.sampling_metadata.skip_sampler_cpu_output,
reuse_sampling_tensors=model_input.sampling_metadata.reuse_sampling_tensors,
),
is_prompt=model_input.is_prompt,
)
return model_input_left, model_input_right
def merge_model_output(states_left, states_right):
output = torch.concat([states_left, states_right], dim=0)
return output
def tbo_model_executable(
model_input,
vllm_config,
virtual_engine,
model_executable,
intermediate_tensors,
multi_modal_kwargs,
self_device,
seqlen_agnostic_kwargs,
model_kwargs,
):
init_two_batch_overlap()
is_rocm_fa = isinstance(model_input.attn_metadata, ROCmFlashAttentionMetadata)
is_cuda_graph_decode = model_input.attn_metadata.use_cuda_graph and not model_input.is_prompt
batch_size = len(model_input.attn_metadata.seq_lens)
if batch_size == 1 or \
(not model_input.is_prompt and not enable_tbo_decode) or \
not is_rocm_fa or \
is_cuda_graph_decode:
with set_forward_context(model_input.attn_metadata,
vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
device=self_device),
**seqlen_agnostic_kwargs,
**model_kwargs,
)
return hidden_or_intermediate_states
profile.ProfRangePush('tbo_model_executable')
tbo_obj.tbo_running = True
tbo_obj.left_first = True
batch_size_left = int(batch_size / 2)
batch_size_right = batch_size_left
if batch_size % 2 == 1:
batch_size_right += 1
model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
tbo_obj.step_event.record()
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(tbo_step_stream):
tbo_step_stream.wait_event(tbo_obj.step_event)
tbo_obj.set_model_input(model_input_left,
model_input_right,
vllm_config,
virtual_engine,
model_executable,
intermediate_tensors,
multi_modal_kwargs,
self_device,
seqlen_agnostic_kwargs,
model_kwargs)
tbo_obj.all_reduce()
states_left, states_right = tbo_obj.get_model_output()
hidden_or_intermediate_states = merge_model_output(states_left, states_right)
tbo_obj.tbo_running = False
tbo_obj.step_event.record()
current_stream.wait_event(tbo_obj.step_event)
profile.ProfRangePop()
return hidden_or_intermediate_states
...@@ -50,6 +50,7 @@ from vllm.prompt_adapter.worker_manager import ( ...@@ -50,6 +50,7 @@ from vllm.prompt_adapter.worker_manager import (
LRUCacheWorkerPromptAdapterManager) LRUCacheWorkerPromptAdapterManager)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.two_batch_overlap.two_batch_overlap import is_enable_tbo, tbo_model_executable
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
async_tensor_h2d, flatten_2d_lists, async_tensor_h2d, flatten_2d_lists,
is_pin_memory_available, supports_dynamo, is_pin_memory_available, supports_dynamo,
...@@ -158,6 +159,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): ...@@ -158,6 +159,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
tensor_dict = { tensor_dict = {
"input_tokens": self.input_tokens, "input_tokens": self.input_tokens,
"input_positions": self.input_positions, "input_positions": self.input_positions,
"query_lens": self.query_lens,
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
"multi_modal_kwargs": self.multi_modal_kwargs, "multi_modal_kwargs": self.multi_modal_kwargs,
...@@ -166,6 +168,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): ...@@ -166,6 +168,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"request_ids_to_seq_ids": self.request_ids_to_seq_ids, "request_ids_to_seq_ids": self.request_ids_to_seq_ids,
"finished_requests_ids": self.finished_requests_ids, "finished_requests_ids": self.finished_requests_ids,
"is_prompt": self.is_prompt,
} }
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
_add_sampling_metadata_broadcastable_dict(tensor_dict, _add_sampling_metadata_broadcastable_dict(tensor_dict,
...@@ -1776,6 +1779,18 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): ...@@ -1776,6 +1779,18 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_start.record() model_forward_start.record()
if not bypass_model_exec: if not bypass_model_exec:
if is_enable_tbo():
hidden_or_intermediate_states = tbo_model_executable(
model_input,
self.vllm_config,
virtual_engine,
model_executable,
intermediate_tensors,
multi_modal_kwargs,
self.device,
seqlen_agnostic_kwargs,
model_kwargs)
else:
with set_forward_context(model_input.attn_metadata, with set_forward_context(model_input.attn_metadata,
self.vllm_config, virtual_engine): self.vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
......
...@@ -18,6 +18,7 @@ from vllm.logger import init_logger ...@@ -18,6 +18,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.two_batch_overlap.two_batch_overlap import finish_two_batch_overlap
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, run_method, resolve_obj_by_qualname, run_method,
update_environment_variables, update_environment_variables,
...@@ -77,7 +78,6 @@ class WorkerBase: ...@@ -77,7 +78,6 @@ class WorkerBase:
from vllm.platforms import current_platform from vllm.platforms import current_platform
self.current_platform = current_platform self.current_platform = current_platform
def init_device(self) -> None: def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device """Initialize device state, such as loading the model or other on-device
memory allocations. memory allocations.
...@@ -113,6 +113,7 @@ class WorkerBase: ...@@ -113,6 +113,7 @@ class WorkerBase:
while True: while True:
output = self.execute_model(execute_model_req=None) output = self.execute_model(execute_model_req=None)
if output is None: if output is None:
finish_two_batch_overlap()
return None return None
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
......
...@@ -40,7 +40,7 @@ from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer ...@@ -40,7 +40,7 @@ from vllm.zero_overhead.tokenizer import ZeroOverheadDetokenizer
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.profiler.prof import profile from vllm.profiler.prof import profile
from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_spec_step, is_zero_no_thread, set_spec_step from vllm.zero_overhead.utils import SpecStepKind, get_accepted_token_ids, get_spec_step, is_zero_no_thread, set_spec_step, zero_overhead_stream
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -87,6 +87,7 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -87,6 +87,7 @@ class ZeroOverheadEngine(LLMEngine):
self.log_stats = log_stats self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs self.use_cached_outputs = use_cached_outputs
self.thread_running = False
if not self.model_config.skip_tokenizer_init: if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer() self.tokenizer = self._init_tokenizer()
...@@ -254,8 +255,8 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -254,8 +255,8 @@ class ZeroOverheadEngine(LLMEngine):
self.async_d2h = None self.async_d2h = None
self.last_record = None self.last_record = None
self.async_event = torch.cuda.Event(enable_timing=False) self.async_event = torch.cuda.Event(enable_timing=False)
self.thread_running = False
self.q_recorder = queue.Queue() self.q_recorder = queue.Queue()
self.use_stream = zero_overhead_stream(self.model_executor.device_config.device)
if not is_zero_no_thread(): if not is_zero_no_thread():
self.zero_thread = threading.Thread(target=self.thread_zero_overhead) self.zero_thread = threading.Thread(target=self.thread_zero_overhead)
self.thread_running = True self.thread_running = True
...@@ -272,14 +273,19 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -272,14 +273,19 @@ class ZeroOverheadEngine(LLMEngine):
self.thread_running = False self.thread_running = False
self.sem_m2s.release() self.sem_m2s.release()
def thread_zero_overhead(self): def thread_zero_overhead(self):
logger.info('zero overhead thread start!') logger.info('zero overhead thread start!')
last_sampler = get_last_sampler()
last_sampler.seq_ids.clear()
try: try:
with torch.cuda.stream(self.use_stream):
while True: while True:
self.sem_m2s.acquire() self.sem_m2s.acquire()
if not self.thread_running: if not self.thread_running:
logger.debug("Stopping remote worker execution loop.") logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
break break
virtual_engine = 0 virtual_engine = 0
...@@ -560,6 +566,7 @@ class ZeroOverheadEngine(LLMEngine): ...@@ -560,6 +566,7 @@ class ZeroOverheadEngine(LLMEngine):
return ctx.request_outputs return ctx.request_outputs
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
with torch.cuda.stream(self.use_stream):
if is_zero_no_thread(): if is_zero_no_thread():
out = self.no_thread_step() out = self.no_thread_step()
if out is None: #the first step need launch twice if out is None: #the first step need launch twice
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from enum import Enum from enum import Enum
import os import os
import torch
zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1' zero_overhead = os.environ.get('VLLM_ZERO_OVERHEAD') == '1'
zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1' zero_no_thread = os.environ.get('VLLM_ZERO_NO_THREAD') == '1'
...@@ -62,3 +63,12 @@ def record_accepted_token_ids(tensor, seq_ids): ...@@ -62,3 +63,12 @@ def record_accepted_token_ids(tensor, seq_ids):
def get_accepted_token_ids(): def get_accepted_token_ids():
return spec_context.accepted_token_ids, spec_context.accepted_seq_ids return spec_context.accepted_token_ids, spec_context.accepted_seq_ids
# 零消耗调度不在默认流上推理,用以规避runtime引入的内存申请流同步问题。
alloc_stream = {}
def zero_overhead_stream(target_device):
"""Asynchronously create a tensor and copy it from host to device."""
if target_device not in alloc_stream.keys():
alloc_stream[target_device] = torch.cuda.Stream(device=target_device)
return alloc_stream[target_device]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment