Commit 20e75ed6 authored by lizhigong's avatar lizhigong Committed by maxiao1@sugon.com
Browse files

add tbo on v1 engine

parent eba84521
......@@ -159,6 +159,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_TBO: bool = False
VLLM_TBO_REQ_DELAY_MS: int = 0
VLLM_TBO_DECODE_BS: int = 0
VLLM_TBO_MIN_TOKENS: int = 200
VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_PA: bool = False
......@@ -1069,6 +1070,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_TBO_DECODE_BS":
lambda: int(os.getenv("VLLM_TBO_DECODE_BS", "0")),
# set the minimum tokens size for each mini-batch to enable TBO on v1, default is 200.
"VLLM_TBO_MIN_TOKENS":
lambda: int(os.getenv("VLLM_TBO_MIN_TOKENS", "200")),
# Enable zero overhead scheduler.
"VLLM_ZERO_OVERHEAD":
lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))),
......
......@@ -16,6 +16,7 @@ from vllm.logger import init_logger
from vllm.profiler.prof import profile
from vllm import envs
from vllm.utils import weak_ref_tensor
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import is_enable_tbo_v1, tbo_all_reduce_v1
tbo_one_stream = os.environ.get('VLLM_TBO_ONE_STREAM') == '1'
......@@ -214,6 +215,8 @@ def init_two_batch_overlap():
tbo_obj.init_tbo_thread()
def tbo_all_reduce(obj):
if is_enable_tbo_v1():
return tbo_all_reduce_v1(obj)
if envs.VLLM_ENABLE_TBO and tbo_obj != None and tbo_obj.tbo_running:
tid = threading.get_ident()
if not tbo_one_stream:
......
This diff is collapsed.
import os
import queue
import threading
import torch
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tp_group
from vllm.forward_context import set_forward_context
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.forward_context import init_tbo_forward_context
from vllm.logger import init_logger
from vllm.profiler.prof import profile
from vllm import envs
logger = init_logger(__name__)
tbo_step_stream = None
all_reduce_stream = None
class TwoBatchOverlap():
def __init__(self):
global tbo_step_stream
global all_reduce_stream
self.model_input_left_queue = queue.Queue()
self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue()
self.states_right_queue = queue.Queue()
self.left_thread = None
self.right_thread = None
self.left_tid = 0
self.right_tid = 0
self.sem_left = threading.Semaphore(0)
self.sem_right = threading.Semaphore(0)
self.left_first = False
self.tbo_running = False
self.tbo_in_capture = False
if tbo_step_stream == None:
tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False)
self.event_left_c2t = torch.cuda.Event(enable_timing=False)
self.event_right_c2t = torch.cuda.Event(enable_timing=False)
self.event_left_t2c = torch.cuda.Event(enable_timing=False)
self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self):
self.model_input_left_queue.empty()
self.model_input_right_queue.empty()
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start()
logger.info('tbo:two batch overlap start')
def finish_thread(self):
self.left_thread.join()
self.left_thread = None
self.right_thread.join()
self.right_thread = None
@torch.inference_mode()
def thread_two_batch_overlap(self, queue):
is_left_thread = False
tid = threading.get_ident()
if queue == self.model_input_left_queue:
self.left_tid = tid
is_left_thread = True
init_tbo_forward_context(True, self.left_tid)
else:
self.right_tid = tid
init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(tbo_step_stream):
queue.get()
profile.ProfRangePush('start')
self.tbo_thread_synchronize(tid)
if is_left_thread:
attn_metadata = self.attn_metadata_left
num_input_tokens = self.num_input_tokens_left
input_ids = self.input_ids_left
positions = self.positions_left
else:
attn_metadata = self.attn_metadata_right
num_input_tokens = self.num_input_tokens_right
input_ids = self.input_ids_right
positions = self.positions_right
model_output = None
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata,
self.model_runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp):
model_output = self.model_runner.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=self.intermediate_tensors,
inputs_embeds=self.inputs_embeds,
)
if is_left_thread:
self.sem_right.release()
self.states_left_queue.put(model_output)
else:
self.states_right_queue.put(model_output)
profile.ProfRangePop()
def tbo_thread_synchronize(self, tid):
if tid == self.left_tid:
if not self.left_first:
self.sem_right.release()
self.left_first = False
profile.ProfRangePop()
self.sem_left.acquire()
profile.ProfRangePush('left')
return self.event_left_c2t, self.event_left_t2c
else:
self.sem_left.release()
profile.ProfRangePop()
self.sem_right.acquire()
profile.ProfRangePush('right')
return self.event_right_c2t, self.event_right_t2c
def set_model_input(self,
model_runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
input_ids_left,
input_ids_right,
positions_left,
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds):
self.model_runner = model_runner
self.attn_metadata_left = attn_metadata_left
self.attn_metadata_right = attn_metadata_right
self.num_input_tokens_left = num_input_tokens_left
self.num_input_tokens_right = num_input_tokens_right
self.input_ids_left = input_ids_left
self.input_ids_right = input_ids_right
self.positions_left = positions_left
self.positions_right = positions_right
self.num_tokens_across_dp = num_tokens_across_dp
self.intermediate_tensors = intermediate_tensors
self.inputs_embeds = inputs_embeds
self.model_input_left_queue.put(None)
self.model_input_right_queue.put(None)
def get_model_output(self):
states_left = self.states_left_queue.get()
states_right = self.states_right_queue.get()
return states_left, states_right
tbo_obj_v1 = None
def is_enable_tbo_v1():
global tbo_obj_v1
return tbo_obj_v1 != None
def init_two_batch_overlap():
global tbo_obj_v1
if tbo_obj_v1 == None:
tbo_obj_v1 = TwoBatchOverlap()
tbo_obj_v1.init_tbo_thread()
def tbo_all_reduce_v1(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
tid = threading.get_ident()
if tid == tbo_obj_v1.left_tid:
event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c
else:
event_c2t, event_t2c = tbo_obj_v1.event_right_c2t, tbo_obj_v1.event_right_t2c
event_c2t.record()
with torch.cuda.stream(all_reduce_stream):
all_reduce_stream.wait_event(event_c2t)
output = tensor_model_parallel_all_reduce(obj)
event_t2c.record()
tbo_obj_v1.tbo_thread_synchronize(tid)
tbo_step_stream.wait_event(event_t2c)
return output
return tensor_model_parallel_all_reduce(obj)
def merge_model_output(states_left, states_right):
if isinstance(states_left, IntermediateTensors):
output_map = {}
for key in states_left.tensors:
output_map[key] = torch.concat([states_left.tensors[key], states_right.tensors[key]], dim=0)
output = IntermediateTensors(output_map)
else:
output = torch.concat([states_left, states_right], dim=0)
return output
def tbo_model_executable_v1(
model_runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
num_tokens_across_dp,
input_ids,
positions,
intermediate_tensors,
inputs_embeds
):
init_two_batch_overlap()
tbo_obj_v1.tbo_running = True
tbo_obj_v1.left_first = True
tbo_obj_v1.step_event.record()
current_stream = torch.cuda.current_stream()
with torch.cuda.stream(tbo_step_stream):
tbo_step_stream.wait_event(tbo_obj_v1.step_event)
tokens_split = [num_input_tokens_left, num_input_tokens_right]
input_ids_left, input_ids_right = torch.split(input_ids, tokens_split, dim=0)
positions_left, positions_right = torch.split(positions, tokens_split, dim=0)
tbo_obj_v1.set_model_input(model_runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
input_ids_left,
input_ids_right,
positions_left,
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds)
model_output_left, model_output_right = tbo_obj_v1.get_model_output()
hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right)
tbo_obj_v1.tbo_running = False
tbo_obj_v1.step_event.record()
tbo_obj_v1.finish_thread()
current_stream.wait_event(tbo_obj_v1.step_event)
return hidden_or_intermediate_states
\ No newline at end of file
......@@ -22,6 +22,7 @@ from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.two_batch_overlap.v1.gpu_model_runner import TBO_GPUModelRunner
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
......@@ -162,6 +163,10 @@ class Worker(WorkerBase):
set_random_seed(self.model_config.seed)
# Construct the model runner
if envs.VLLM_ENABLE_TBO:
self.model_runner: TBO_GPUModelRunner = TBO_GPUModelRunner(
self.vllm_config, self.device)
else:
self.model_runner: GPUModelRunner = GPUModelRunner(
self.vllm_config, self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment