Commit 12291212 authored by maxiao1's avatar maxiao1 Committed by lizhigong
Browse files

pd分离_tbo

parent 3daae57c
......@@ -414,9 +414,9 @@ def unified_attention(
output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata)
if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else:
# if envs.VLLM_ENABLE_TBO:
# tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
# else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
......@@ -462,9 +462,9 @@ def unified_attention_with_output(
attn_metadata,
output=output,
output_scale=output_scale)
if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else:
# if envs.VLLM_ENABLE_TBO:
# tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
# else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
......
......@@ -75,6 +75,9 @@ class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling(): # 非 capture 阶段
return self.forward_cuda(x) # 强制走 fused kernel
else:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
......
......@@ -165,7 +165,11 @@ class RMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling(): # 非 capture 阶段
return self.forward_cuda(x, residual) # 强制走 fused kernel
else:
# 否则fallback到原始实现
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
......@@ -184,11 +188,9 @@ class RMSNorm(CustomOp):
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
......
......@@ -17,10 +17,12 @@ logger = init_logger(__name__)
tbo_step_stream = None
all_reduce_stream = None
class TwoBatchOverlap():
PERSIST_THREADS = os.getenv('VLLM_TBO_PERSIST_THREADS', '1') not in ('0','false','False','no','NO','')
STOP = object()
class TwoBatchOverlap:
def __init__(self):
global tbo_step_stream
global all_reduce_stream
global tbo_step_stream, all_reduce_stream
self.model_input_left_queue = queue.Queue()
self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue()
......@@ -29,12 +31,14 @@ class TwoBatchOverlap():
self.right_thread = None
self.left_tid = 0
self.right_tid = 0
self._stop_evt = threading.Event()
self._threads_started = False
self.sem_left = threading.Semaphore(0)
self.sem_right = threading.Semaphore(0)
self.left_first = False
self.tbo_running = False
self.tbo_in_capture = False
if tbo_step_stream == None:
if tbo_step_stream is None:
tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False)
......@@ -44,35 +48,52 @@ class TwoBatchOverlap():
self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self):
self.model_input_left_queue.empty()
self.model_input_right_queue.empty()
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
if self._threads_started and PERSIST_THREADS:
return
if self.left_thread is None or not self.left_thread.is_alive():
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap,
args=(self.model_input_left_queue,), daemon=True)
self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
if self.right_thread is None or not self.right_thread.is_alive():
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap,
args=(self.model_input_right_queue,), daemon=True)
self.right_thread.start()
if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start')
self._threads_started = True
def finish_thread(self):
self.left_thread.join()
def shutdown(self, timeout=5.0):
self._stop_evt.set()
try:
self.model_input_left_queue.put(STOP)
self.model_input_right_queue.put(STOP)
except Exception:
pass
if self.left_thread is not None:
self.left_thread.join(timeout=timeout)
self.left_thread = None
self.right_thread.join()
if self.right_thread is not None:
self.right_thread.join(timeout=timeout)
self.right_thread = None
@torch.inference_mode()
def thread_two_batch_overlap(self, queue):
def thread_two_batch_overlap(self, q):
is_left_thread = False
tid = threading.get_ident()
if queue == self.model_input_left_queue:
if q is self.model_input_left_queue:
self.left_tid = tid
is_left_thread = True
init_tbo_forward_context(True, self.left_tid)
else:
self.right_tid = tid
init_tbo_forward_context(False, self.right_tid)
while not self._stop_evt.is_set():
item = q.get()
if item is STOP:
break
with torch.cuda.stream(tbo_step_stream):
queue.get()
self.tbo_thread_synchronize(tid)
if is_left_thread:
attn_metadata = self.attn_metadata_left
num_input_tokens = self.num_input_tokens_left
......@@ -84,20 +105,28 @@ class TwoBatchOverlap():
input_ids = self.input_ids_right
positions = self.positions_right
model_output = None
# Run the decoder.
# Use persistent buffers for CUDA graphs.
# Select per-thread tensors (left/right) with backward-compatible fallback
if is_left_thread:
intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
else:
intermediate_tensors = getattr(self, 'intermediate_tensors_right', None)
if intermediate_tensors is None:
intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
with set_forward_context(attn_metadata,
self.model_runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp,
skip_cuda_graphs=True):
skip_cuda_graphs=True,
):
model_output = self.model_runner.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=self.intermediate_tensors,
intermediate_tensors=intermediate_tensors,
inputs_embeds=self.inputs_embeds,
)
if is_left_thread:
self.sem_right.release()
self.states_left_queue.put(model_output)
......@@ -128,7 +157,8 @@ class TwoBatchOverlap():
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds):
inputs_embeds,
):
self.model_runner = model_runner
self.attn_metadata_left = attn_metadata_left
self.attn_metadata_right = attn_metadata_right
......@@ -139,9 +169,14 @@ class TwoBatchOverlap():
self.positions_left = positions_left
self.positions_right = positions_right
self.num_tokens_across_dp = num_tokens_across_dp
self.intermediate_tensors = intermediate_tensors
self.inputs_embeds = inputs_embeds
if isinstance(intermediate_tensors, tuple):
self.intermediate_tensors_left, self.intermediate_tensors_right = intermediate_tensors
else:
self.intermediate_tensors_left = intermediate_tensors
self.intermediate_tensors_right = None
self.model_input_left_queue.put(None)
self.model_input_right_queue.put(None)
......@@ -150,15 +185,18 @@ class TwoBatchOverlap():
states_right = self.states_right_queue.get()
return states_left, states_right
tbo_obj_v1 = None
def is_enable_tbo_v1():
global tbo_obj_v1
return tbo_obj_v1 != None
return tbo_obj_v1 is not None
def init_two_batch_overlap():
global tbo_obj_v1
if tbo_obj_v1 == None:
if tbo_obj_v1 is None:
tbo_obj_v1 = TwoBatchOverlap()
tbo_obj_v1.init_tbo_thread()
......@@ -171,7 +209,7 @@ def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def tbo_all_reduce_v1(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 is not None and tbo_obj_v1.tbo_running:
tid = threading.get_ident()
if tid == tbo_obj_v1.left_tid:
event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c
......@@ -207,19 +245,19 @@ def tbo_model_executable_v1(
input_ids,
positions,
intermediate_tensors,
inputs_embeds
inputs_embeds,
):
init_two_batch_overlap()
tbo_obj_v1.tbo_running = True
tbo_obj_v1.left_first = True
tbo_obj_v1.step_event.record()
current_stream = torch.cuda.current_stream()
num_total_tokens = num_input_tokens_left + num_input_tokens_right
with torch.cuda.stream(tbo_step_stream):
tbo_step_stream.wait_event(tbo_obj_v1.step_event)
tokens_split = [num_input_tokens_left, num_input_tokens_right]
input_ids_left, input_ids_right = torch.split(input_ids, tokens_split, dim=0)
positions_left, positions_right = torch.split(positions, tokens_split, dim=0)
input_ids_left, input_ids_right = torch.split(input_ids[:num_total_tokens], tokens_split, dim=0)
positions_left, positions_right = torch.split(positions[:num_total_tokens], tokens_split, dim=0)
tbo_obj_v1.set_model_input(model_runner,
attn_metadata_left,
attn_metadata_right,
......@@ -231,13 +269,21 @@ def tbo_model_executable_v1(
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds)
inputs_embeds,
)
model_output_left, model_output_right = tbo_obj_v1.get_model_output()
hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right)
tbo_obj_v1.tbo_running = False
tbo_obj_v1.step_event.record()
tbo_obj_v1.finish_thread()
current_stream.wait_event(tbo_obj_v1.step_event)
return hidden_or_intermediate_states
def finalize_two_batch_overlap():
global tbo_obj_v1
if tbo_obj_v1 is not None:
try:
tbo_obj_v1.shutdown()
finally:
tbo_obj_v1 = None
......@@ -1374,8 +1374,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ENABLE_TBO and (not self.use_cuda_graph or skip_cuda_graphs):
if envs.VLLM_ENABLE_TBO and scheduler_output.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
num_tokens_across_dp, input_ids, positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment