Commit 12291212 authored by maxiao1's avatar maxiao1 Committed by lizhigong
Browse files

pd分离_tbo

parent 3daae57c
...@@ -414,10 +414,10 @@ def unified_attention( ...@@ -414,10 +414,10 @@ def unified_attention(
output = self.impl.forward(self, query, key, value, kv_cache, output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata) attn_metadata)
if envs.VLLM_ENABLE_TBO: # if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache) # tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else: # else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output return output
...@@ -462,10 +462,10 @@ def unified_attention_with_output( ...@@ -462,10 +462,10 @@ def unified_attention_with_output(
attn_metadata, attn_metadata,
output=output, output=output,
output_scale=output_scale) output_scale=output_scale)
if envs.VLLM_ENABLE_TBO: # if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache) # tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else: # else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_attention_with_output_fake( def unified_attention_with_output_fake(
......
...@@ -75,8 +75,11 @@ class SiluAndMul(CustomOp): ...@@ -75,8 +75,11 @@ class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2 if not torch.compiler.is_compiling(): # 非 capture 阶段
return F.silu(x[..., :d]) * x[..., d:] return self.forward_cuda(x) # 强制走 fused kernel
else:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
......
...@@ -165,38 +165,40 @@ class RMSNorm(CustomOp): ...@@ -165,38 +165,40 @@ class RMSNorm(CustomOp):
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype if not torch.compiler.is_compiling(): # 非 capture 阶段
x = x.to(torch.float32) return self.forward_cuda(x, residual) # 强制走 fused kernel
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None:
return x
else: else:
return x, residual # 否则fallback到原始实现
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None:
return x
else:
return x, residual
def forward_cuda( def forward_cuda(
self, self,
......
from typing import Any, Optional, Union from typing import Any, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -25,99 +24,79 @@ class TBOModelInputSplit(): ...@@ -25,99 +24,79 @@ class TBOModelInputSplit():
self.req_num_right = 0 self.req_num_right = 0
self.scheduler_output_left = None self.scheduler_output_left = None
self.scheduler_output_right = None self.scheduler_output_right = None
self.query_start_loc_right = None self.split_in_req = False
input_split = TBOModelInputSplit() input_split = TBOModelInputSplit()
def split_scheduler_output(runner, scheduler_output:SchedulerOutput): def split_scheduler_output(runner, scheduler_output: SchedulerOutput):
"""Split a step's scheduled tokens evenly into left/right halves.
If a request crosses the split boundary, mark split_in_req=True and
assign left/right token counts accordingly.
"""
split_tokens = scheduler_output.total_num_scheduled_tokens // 2 split_tokens = scheduler_output.total_num_scheduled_tokens // 2
req_ids = runner.input_batch.req_ids split_counter = 0
tokens_counter = 0 num_scheduled_tokens_left: dict[int, int] = {}
min_idx = -1 num_scheduled_tokens_right: dict[int, int] = {}
min_counter = 0 input_split.req_ids_left.clear()
for i, id in enumerate(req_ids): input_split.req_ids_right.clear()
tokens_counter += scheduler_output.num_scheduled_tokens[id] total_num_scheduled_tokens_left = split_tokens
diff = abs(tokens_counter - split_tokens) total_num_scheduled_tokens_right = scheduler_output.total_num_scheduled_tokens - split_tokens
if min_idx == -1 or diff < min_counter:
min_idx = i req_splited = False
min_counter = diff input_split.split_in_req = False
if tokens_counter > split_tokens or diff == 0:
break
input_split.req_num_left = min_idx + 1
if input_split.req_num_left == len(req_ids):
input_split.req_num_left = input_split.req_num_left - 1
input_split.req_ids_left = req_ids[:input_split.req_num_left]
input_split.req_ids_right = req_ids[input_split.req_num_left:]
input_split.req_num_right = len(req_ids) - input_split.req_num_left
new_req_data_left = []
new_req_data_right = []
cached_reqs_left = []
cached_reqs_right = []
num_scheduled_tokens_left = {}
num_scheduled_tokens_right = {}
total_num_scheduled_tokens_left = 0
total_num_scheduled_tokens_right = 0
for new_req in scheduler_output.scheduled_new_reqs:
if new_req.req_id in input_split.req_ids_left:
new_req_data_left.append(new_req)
else:
new_req_data_right.append(new_req)
cached_reqs_left = CachedRequestData.make_empty()
cached_reqs_right = CachedRequestData.make_empty()
for req_idx, req_id in enumerate(scheduler_output.scheduled_cached_reqs.req_ids):
if req_id in input_split.req_ids_left:
cached_reqs_left.req_ids.append(req_id)
cached_reqs_left.resumed_from_preemption.append(scheduler_output.scheduled_cached_reqs.resumed_from_preemption[req_idx])
if len(scheduler_output.scheduled_cached_reqs.new_token_ids) > 0:
cached_reqs_left.new_token_ids.append(scheduler_output.scheduled_cached_reqs.new_token_ids[req_idx])
cached_reqs_left.new_block_ids.append(scheduler_output.scheduled_cached_reqs.new_block_ids[req_idx])
cached_reqs_left.num_computed_tokens.append(scheduler_output.scheduled_cached_reqs.num_computed_tokens[req_idx])
else:
cached_reqs_right.req_ids.append(req_id)
cached_reqs_right.resumed_from_preemption.append(scheduler_output.scheduled_cached_reqs.resumed_from_preemption[req_idx])
if len(scheduler_output.scheduled_cached_reqs.new_token_ids) > 0:
cached_reqs_right.new_token_ids.append(scheduler_output.scheduled_cached_reqs.new_token_ids[req_idx])
cached_reqs_right.new_block_ids.append(scheduler_output.scheduled_cached_reqs.new_block_ids[req_idx])
cached_reqs_right.num_computed_tokens.append(scheduler_output.scheduled_cached_reqs.num_computed_tokens[req_idx])
for key, value in scheduler_output.num_scheduled_tokens.items(): for key, value in scheduler_output.num_scheduled_tokens.items():
if key in input_split.req_ids_left: split_counter += value
if split_counter == split_tokens:
req_splited = True
num_scheduled_tokens_left[key] = value num_scheduled_tokens_left[key] = value
total_num_scheduled_tokens_left += value input_split.req_ids_left.append(key)
elif split_counter > split_tokens:
if req_splited:
# boundary already hit earlier; entire req goes to right
num_scheduled_tokens_right[key] = value
input_split.req_ids_right.append(key)
else:
# The boundary falls inside this request -> split within req
req_splited = True
input_split.split_in_req = True
right_tokens = split_counter - split_tokens
left_tokens = value - right_tokens
# right part
num_scheduled_tokens_right[key] = right_tokens
input_split.req_ids_right.append(key)
# left part
num_scheduled_tokens_left[key] = left_tokens
input_split.req_ids_left.append(key)
else: else:
num_scheduled_tokens_right[key] = value # before boundary, entire req goes to left
total_num_scheduled_tokens_right += value num_scheduled_tokens_left[key] = value
input_split.req_ids_left.append(key)
input_split.req_num_left = len(input_split.req_ids_left)
input_split.req_num_right = len(input_split.req_ids_right)
input_split.scheduler_output_left = SchedulerOutput( input_split.scheduler_output_left = SchedulerOutput(
scheduled_new_reqs=new_req_data_left, scheduled_new_reqs=None,
scheduled_cached_reqs=cached_reqs_left, scheduled_cached_reqs=None,
num_scheduled_tokens=num_scheduled_tokens_left, num_scheduled_tokens=num_scheduled_tokens_left,
total_num_scheduled_tokens=total_num_scheduled_tokens_left, total_num_scheduled_tokens=total_num_scheduled_tokens_left,
scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens, scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, ##unsupport yet scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, # unsupported yet
num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks, num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=scheduler_output.finished_req_ids, finished_req_ids=scheduler_output.finished_req_ids,
free_encoder_input_ids=scheduler_output.free_encoder_input_ids, free_encoder_input_ids=scheduler_output.free_encoder_input_ids,
structured_output_request_ids=scheduler_output.structured_output_request_ids, structured_output_request_ids=scheduler_output.structured_output_request_ids,
grammar_bitmask=scheduler_output.grammar_bitmask, grammar_bitmask=scheduler_output.grammar_bitmask,
) )
input_split.scheduler_output_right = SchedulerOutput( input_split.scheduler_output_right = SchedulerOutput(
scheduled_new_reqs=new_req_data_right, scheduled_new_reqs=None,
scheduled_cached_reqs=cached_reqs_right, scheduled_cached_reqs=None,
num_scheduled_tokens=num_scheduled_tokens_right, num_scheduled_tokens=num_scheduled_tokens_right,
total_num_scheduled_tokens=total_num_scheduled_tokens_right, total_num_scheduled_tokens=total_num_scheduled_tokens_right,
scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens, scheduled_spec_decode_tokens=scheduler_output.scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, ##unsupport yet scheduled_encoder_inputs=scheduler_output.scheduled_encoder_inputs, # unsupported yet
num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks, num_common_prefix_blocks=scheduler_output.num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=scheduler_output.finished_req_ids, finished_req_ids=scheduler_output.finished_req_ids,
free_encoder_input_ids=scheduler_output.free_encoder_input_ids, free_encoder_input_ids=scheduler_output.free_encoder_input_ids,
structured_output_request_ids=scheduler_output.structured_output_request_ids, structured_output_request_ids=scheduler_output.structured_output_request_ids,
...@@ -129,102 +108,159 @@ def prepare_tbo_atten_metadata( ...@@ -129,102 +108,159 @@ def prepare_tbo_atten_metadata(
runner, runner,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
req_ids, req_ids,
req_offset req_offset: int,
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]: ) -> dict[str, Any]: # (attn_metadata)
"""Prepare attention metadata for one half (left/right).
Key fixes for correctness when a request is split:
- Align seq_len_offset / query_start_offset with block_table slicing.
- For the right half, if a request was split, make seq_lens[0]
= (history + left-prefix + right-half tokens).
- Pass cloned slices to CommonAttentionMetadata to avoid aliasing.
"""
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0 assert total_num_scheduled_tokens > 0
num_reqs = len(req_ids) num_reqs = len(req_ids)
assert num_reqs > 0 assert num_reqs > 0
seq_len_offset = req_offset # Tokens per req in THIS half
# Get the number of scheduled tokens for each request.
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32) num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens) max_num_scheduled_tokens = max(tokens)
if req_offset > 0: #right # Request indices (relative to the WHOLE step), used by kernels
if input_split.query_start_loc_right == None: req_indices = np.repeat(runner.arange_np[:num_reqs],
# TODO: create when system init num_scheduled_tokens) + req_offset
input_split.query_start_loc_right = torch.zeros(runner.max_num_reqs + 1,
dtype=torch.int32,
device=runner.device)
cu_num_tokens, arange = runner._get_cumsum_and_arange(
num_scheduled_tokens)
# Prepare the attention metadata.
runner.query_start_loc_np[0] = 0
runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
# Cumulative tokens within this half
cu_num_tokens, arange = runner._get_cumsum_and_arange(num_scheduled_tokens)
input_split.query_start_loc_right[0: num_reqs + 1].copy_( # --- query_start_loc (within this half) ---
runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) runner.query_start_loc_np[0] = 0
# Note: pad query_start_loc to be non-decreasing, as kernels runner.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
# like FlashAttention requires that
input_split.query_start_loc_right[num_reqs + 1:].fill_(
runner.query_start_loc_cpu[num_reqs].item())
query_start_loc = input_split.query_start_loc_right[: num_reqs + 1]
# --- seq_lens (absolute context length per-req row) ---
# Default (no split across req boundary)
# Maps rows [req_offset ... req_offset+num_reqs-1]
default_seq_lens = (
runner.input_batch.num_computed_tokens_cpu[req_offset : req_offset + num_reqs]
+ num_scheduled_tokens
)
# Offsets for copying into the *global* GPU buffers
# Left-half writes at the natural position; right-half depends on split.
if req_offset == 0:
# LEFT
seq_len_offset = 0
query_start_offset = 0
seq_lens_cpu_local = torch.as_tensor(default_seq_lens, device=runner.seq_lens_cpu.device)
else: else:
query_start_loc = runner.query_start_loc[:num_reqs + 1] # RIGHT
if input_split.split_in_req:
# The block_table for RIGHT starts from (req_offset-1).
seq_lens = runner.seq_lens[seq_len_offset : seq_len_offset + num_reqs] # Align both offsets to that, and re-build the seq_lens for row-0.
seq_len_offset = req_offset - 1
query_start_offset = req_offset - 1
# row-0 is the split request (global row index = req_offset-1):
base_hist = runner.input_batch.num_computed_tokens_cpu[req_offset - 1].item()
left_prefix = input_split.scheduler_output_left.num_scheduled_tokens[req_ids[0]]
right_tokens0 = scheduler_output.num_scheduled_tokens[req_ids[0]]
first_row = base_hist + left_prefix + right_tokens0
if num_reqs > 1:
# rows 1.. map to global rows [req_offset .. req_offset+num_reqs-2]
tail_base = runner.input_batch.num_computed_tokens_cpu[req_offset : req_offset + num_reqs - 1]
tail_tokens = num_scheduled_tokens[1:]
tail = tail_base + tail_tokens
seq_lens_cpu_local = torch.empty(num_reqs, dtype=runner.seq_lens_cpu.dtype, device=runner.seq_lens_cpu.device)
seq_lens_cpu_local[0] = first_row
seq_lens_cpu_local[1:] = torch.as_tensor(tail, device=runner.seq_lens_cpu.device)
else:
seq_lens_cpu_local = torch.tensor([first_row], dtype=runner.seq_lens_cpu.dtype, device=runner.seq_lens_cpu.device)
else:
# RIGHT without split-in-req: natural positions
seq_len_offset = req_offset
query_start_offset = req_offset
seq_lens_cpu_local = torch.as_tensor(default_seq_lens, device=runner.seq_lens_cpu.device)
# Copy query_start_loc into global GPU buffer window
runner.query_start_loc[query_start_offset: query_start_offset + num_reqs + 1].copy_(
runner.query_start_loc_cpu[:num_reqs + 1], non_blocking=True
)
# Pad tail (FlashAttn requires non-decreasing)
if req_offset > 0:
runner.query_start_loc[query_start_offset + num_reqs + 1:].fill_(
runner.query_start_loc_cpu[num_reqs].item()
)
# Copy seq_lens into the aligned window; zero out the remainder on RIGHT
runner.seq_lens[seq_len_offset : seq_len_offset + num_reqs].copy_(
seq_lens_cpu_local, non_blocking=True
)
if req_offset > 0:
runner.seq_lens[seq_len_offset + num_reqs:].fill_(0)
# Build common metadata (pass CLONES to avoid aliasing between threads)
query_start_loc = runner.query_start_loc[query_start_offset: query_start_offset + num_reqs + 1].clone()
seq_lens = runner.seq_lens[seq_len_offset : seq_len_offset + num_reqs].clone()
common_attn_metadata = CommonAttentionMetadata( common_attn_metadata = CommonAttentionMetadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
seq_lens=seq_lens, seq_lens=seq_lens,
num_reqs=num_reqs, num_reqs=num_reqs,
num_actual_tokens=total_num_scheduled_tokens, num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens) max_query_len=max_num_scheduled_tokens,
)
# Prepare attention metadata for each KV cache group
attn_metadata: dict[str, Any] = {} attn_metadata: dict[str, Any] = {}
# Prepare the attention metadata for each KV cache group and make layers for kv_cache_group_id, kv_cache_group_spec in enumerate(runner.kv_cache_config.kv_cache_groups):
# in the same group share the same metadata.
for kv_cache_group_id, kv_cache_group_spec in enumerate(
runner.kv_cache_config.kv_cache_groups):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0 common_prefix_len = 0
metadata_builder = runner.attn_metadata_builders[kv_cache_group_id] metadata_builder = runner.attn_metadata_builders[kv_cache_group_id]
if runner.cascade_attn_enabled: if runner.cascade_attn_enabled:
common_prefix_len = runner._compute_cascade_attn_prefix_len( common_prefix_len = runner._compute_cascade_attn_prefix_len(
num_scheduled_tokens, num_scheduled_tokens,
scheduler_output. scheduler_output.num_common_prefix_blocks[kv_cache_group_id],
num_common_prefix_blocks[kv_cache_group_id],
kv_cache_group_spec.kv_cache_spec, kv_cache_group_spec.kv_cache_spec,
metadata_builder, metadata_builder,
) )
# Slice block_table / slot_mapping for RIGHT half
if req_offset > 0: if req_offset > 0:
origin_block_table = metadata_builder.block_table.block_table origin_block_table = metadata_builder.block_table.block_table
metadata_builder.block_table.block_table = origin_block_table[req_offset:, :] if input_split.split_in_req:
metadata_builder.block_table.block_table = origin_block_table[req_offset - 1:, :]
else:
metadata_builder.block_table.block_table = origin_block_table[req_offset:, :]
origin_slot_mapping = metadata_builder.block_table.slot_mapping origin_slot_mapping = metadata_builder.block_table.slot_mapping
metadata_builder.block_table.slot_mapping = \ origin_slot_mapping_cpu = metadata_builder.block_table.slot_mapping_cpu
origin_slot_mapping[input_split.scheduler_output_left.total_num_scheduled_tokens:] left_tokens = input_split.scheduler_output_left.total_num_scheduled_tokens
origin_slot_map_cpu = metadata_builder.block_table.slot_mapping_cpu metadata_builder.block_table.slot_mapping = origin_slot_mapping[left_tokens:]
metadata_builder.block_table.slot_mapping_cpu = \ metadata_builder.block_table.slot_mapping_cpu = origin_slot_mapping_cpu[left_tokens:]
origin_slot_map_cpu[input_split.scheduler_output_left.total_num_scheduled_tokens:]
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only # MLA-specific counters (safe to ignore for Qwen/FA paths)
if isinstance(metadata_builder, MLACommonMetadataBuilder):
_num_decodes_record = metadata_builder._num_decodes _num_decodes_record = metadata_builder._num_decodes
_num_prefills_record = metadata_builder._num_prefills _num_prefills_record = metadata_builder._num_prefills
_num_decode_tokens_record = metadata_builder._num_decode_tokens _num_decode_tokens_record = metadata_builder._num_decode_tokens
_num_prefill_tokens_record = metadata_builder._num_prefill_tokens _num_prefill_tokens_record = metadata_builder._num_prefill_tokens
metadata_builder._num_decodes = 0 metadata_builder._num_decodes = 0
metadata_builder._num_prefills = num_reqs metadata_builder._num_prefills = num_reqs
metadata_builder._num_decode_tokens = 0 metadata_builder._num_decode_tokens = 0
metadata_builder._num_prefill_tokens = total_num_scheduled_tokens metadata_builder._num_prefill_tokens = total_num_scheduled_tokens
attn_metadata_i = (
metadata_builder.build( attn_metadata_i = metadata_builder.build(
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata)) # maybe FlashAttentionMetadata common_attn_metadata=common_attn_metadata,
)
# Restore tables
if req_offset > 0: if req_offset > 0:
metadata_builder.block_table.block_table = origin_block_table metadata_builder.block_table.block_table = origin_block_table
metadata_builder.block_table.slot_mapping = origin_slot_mapping metadata_builder.block_table.slot_mapping = origin_slot_mapping
metadata_builder.block_table.slot_mapping_cpu = origin_slot_map_cpu metadata_builder.block_table.slot_mapping_cpu = origin_slot_mapping_cpu
if isinstance(metadata_builder, MLACommonMetadataBuilder): # now support prefill only if isinstance(metadata_builder, MLACommonMetadataBuilder):
metadata_builder._num_decodes = _num_decodes_record metadata_builder._num_decodes = _num_decodes_record
metadata_builder._num_prefills = _num_prefills_record metadata_builder._num_prefills = _num_prefills_record
metadata_builder._num_decode_tokens = _num_decode_tokens_record metadata_builder._num_decode_tokens = _num_decode_tokens_record
...@@ -235,31 +271,27 @@ def prepare_tbo_atten_metadata( ...@@ -235,31 +271,27 @@ def prepare_tbo_atten_metadata(
return attn_metadata return attn_metadata
def pad_num_input_tokens(self, scheduler_output): def pad_num_input_tokens(self, scheduler_output):
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): # CUDA graphs (piecewise). Add padding to batch size.
# Use piecewise CUDA graphs. num_input_tokens = self.vllm_config.pad_for_cudagraph(num_scheduled_tokens)
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_scheduled_tokens)
else: else:
# Eager mode. # Eager mode: pad to TP multiple for SP+collective fusion
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
tp_size = self.vllm_config.parallel_config.tensor_parallel_size tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.vllm_config.compilation_config.pass_config. \ if self.vllm_config.compilation_config.pass_config.enable_sequence_parallelism and tp_size > 1:
enable_sequence_parallelism and tp_size > 1:
from vllm.utils import round_up from vllm.utils import round_up
num_input_tokens = round_up(num_scheduled_tokens, tp_size) num_input_tokens = round_up(num_scheduled_tokens, tp_size)
else: else:
num_input_tokens = num_scheduled_tokens num_input_tokens = num_scheduled_tokens
# Padding for DP # DP padding
num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens) num_pad, num_tokens_across_dp = self.get_dp_padding(num_input_tokens)
num_input_tokens += num_pad num_input_tokens += num_pad
return num_input_tokens, num_tokens_across_dp return num_input_tokens, num_tokens_across_dp
def tbo_split_and_execute_model( def tbo_split_and_execute_model(
runner, runner,
attn_metadata, attn_metadata,
...@@ -269,25 +301,41 @@ def tbo_split_and_execute_model( ...@@ -269,25 +301,41 @@ def tbo_split_and_execute_model(
positions, positions,
inputs_embeds, inputs_embeds,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors],
skip_cuda_graphs: bool = True, skip_cuda_graphs: bool,
) -> Union[ModelRunnerOutput, IntermediateTensors]: ) -> Union[ModelRunnerOutput, IntermediateTensors]:
use_tbo = False # If below TBO threshold, run the normal single-batch path (supports decode/prefill as-is).
if isinstance(runner.attn_metadata_builders[0], MLACommonMetadataBuilder) and \ # Two-batch overlap path
runner.attn_metadata_builders[0]._num_decodes > 0: #is mla decode split_scheduler_output(runner, scheduler_output)
use_tbo = False num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens
else: num_input_tokens_right = input_split.scheduler_output_right.total_num_scheduled_tokens
if len(scheduler_output.num_scheduled_tokens) > 1 and num_input_tokens > envs.VLLM_TBO_MIN_TOKENS:
split_scheduler_output(runner, scheduler_output) attn_metadata_left = prepare_tbo_atten_metadata(
use_tbo = True runner, input_split.scheduler_output_left, input_split.req_ids_left, 0
if use_tbo: )
num_input_tokens_left = input_split.scheduler_output_left.total_num_scheduled_tokens attn_metadata_right = prepare_tbo_atten_metadata(
num_input_tokens_right = num_input_tokens - num_input_tokens_left runner, input_split.scheduler_output_right, input_split.req_ids_right, input_split.req_num_left
)
attn_metadata_left = prepare_tbo_atten_metadata(runner, input_split.scheduler_output_left, input_split.req_ids_left, 0)
attn_metadata_right = prepare_tbo_atten_metadata(runner, input_split.scheduler_output_right, input_split.req_ids_right, input_split.req_num_left) # === Added: split inputs_embeds & intermediate_tensors per half; setup KV connector ===
# 真实 token
with set_forward_context(attn_metadata, real_L = int(input_split.scheduler_output_left.total_num_scheduled_tokens)
real_R = int(input_split.scheduler_output_right.total_num_scheduled_tokens)
# 按左右半批切成两份
def _split_it(it, l, r):
if it is None: return None, None
lm, rm = {}, {}
for k, v in it.tensors.items():
vl, vr = torch.split(v[:l + r], [l, r], dim=0)
lm[k], rm[k] = vl, vr
return IntermediateTensors(lm), IntermediateTensors(rm)
intermediate_tensors_left, intermediate_tensors_right = _split_it(
intermediate_tensors, real_L, real_R
)
with set_forward_context(attn_metadata,
runner.vllm_config, runner.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
...@@ -303,33 +351,11 @@ def tbo_split_and_execute_model( ...@@ -303,33 +351,11 @@ def tbo_split_and_execute_model(
num_tokens_across_dp, num_tokens_across_dp,
input_ids, input_ids,
positions, positions,
intermediate_tensors, (intermediate_tensors_left, intermediate_tensors_right),
inputs_embeds) inputs_embeds)
runner.maybe_wait_for_kv_save() runner.maybe_wait_for_kv_save()
finished_sending, finished_recving = ( finished_sending, finished_recving = (
runner.get_finished_kv_transfers(scheduler_output)) runner.get_finished_kv_transfers(scheduler_output))
#finished_sending, finished_recving = None, None
else:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
envs.VLLM_ENABLE_TBO = False
with set_forward_context(attn_metadata,
runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=skip_cuda_graphs):
runner.maybe_setup_kv_connector(scheduler_output)
model_output = runner.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
runner.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
runner.get_finished_kv_transfers(scheduler_output))
envs.VLLM_ENABLE_TBO = True
return model_output, finished_sending, finished_recving return model_output, finished_sending, finished_recving
\ No newline at end of file
...@@ -17,10 +17,12 @@ logger = init_logger(__name__) ...@@ -17,10 +17,12 @@ logger = init_logger(__name__)
tbo_step_stream = None tbo_step_stream = None
all_reduce_stream = None all_reduce_stream = None
class TwoBatchOverlap(): PERSIST_THREADS = os.getenv('VLLM_TBO_PERSIST_THREADS', '1') not in ('0','false','False','no','NO','')
STOP = object()
class TwoBatchOverlap:
def __init__(self): def __init__(self):
global tbo_step_stream global tbo_step_stream, all_reduce_stream
global all_reduce_stream
self.model_input_left_queue = queue.Queue() self.model_input_left_queue = queue.Queue()
self.model_input_right_queue = queue.Queue() self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue() self.states_left_queue = queue.Queue()
...@@ -29,12 +31,14 @@ class TwoBatchOverlap(): ...@@ -29,12 +31,14 @@ class TwoBatchOverlap():
self.right_thread = None self.right_thread = None
self.left_tid = 0 self.left_tid = 0
self.right_tid = 0 self.right_tid = 0
self._stop_evt = threading.Event()
self._threads_started = False
self.sem_left = threading.Semaphore(0) self.sem_left = threading.Semaphore(0)
self.sem_right = threading.Semaphore(0) self.sem_right = threading.Semaphore(0)
self.left_first = False self.left_first = False
self.tbo_running = False self.tbo_running = False
self.tbo_in_capture = False self.tbo_in_capture = False
if tbo_step_stream == None: if tbo_step_stream is None:
tbo_step_stream = torch.cuda.Stream() tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream() all_reduce_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False) self.step_event = torch.cuda.Event(enable_timing=False)
...@@ -44,60 +48,85 @@ class TwoBatchOverlap(): ...@@ -44,60 +48,85 @@ class TwoBatchOverlap():
self.event_right_t2c = torch.cuda.Event(enable_timing=False) self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self): def init_tbo_thread(self):
self.model_input_left_queue.empty() if self._threads_started and PERSIST_THREADS:
self.model_input_right_queue.empty() return
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,)) if self.left_thread is None or not self.left_thread.is_alive():
self.left_thread.start() self.left_thread = threading.Thread(target=self.thread_two_batch_overlap,
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,)) args=(self.model_input_left_queue,), daemon=True)
self.right_thread.start() self.left_thread.start()
if get_tp_group().rank == 0: if self.right_thread is None or not self.right_thread.is_alive():
logger.info('tbo:two batch overlap start') self.right_thread = threading.Thread(target=self.thread_two_batch_overlap,
args=(self.model_input_right_queue,), daemon=True)
def finish_thread(self): self.right_thread.start()
self.left_thread.join() self._threads_started = True
self.left_thread = None
self.right_thread.join() def shutdown(self, timeout=5.0):
self.right_thread = None self._stop_evt.set()
try:
self.model_input_left_queue.put(STOP)
self.model_input_right_queue.put(STOP)
except Exception:
pass
if self.left_thread is not None:
self.left_thread.join(timeout=timeout)
self.left_thread = None
if self.right_thread is not None:
self.right_thread.join(timeout=timeout)
self.right_thread = None
@torch.inference_mode() @torch.inference_mode()
def thread_two_batch_overlap(self, queue): def thread_two_batch_overlap(self, q):
is_left_thread = False is_left_thread = False
tid = threading.get_ident() tid = threading.get_ident()
if queue == self.model_input_left_queue: if q is self.model_input_left_queue:
self.left_tid = tid self.left_tid = tid
is_left_thread = True is_left_thread = True
init_tbo_forward_context(True, self.left_tid) init_tbo_forward_context(True, self.left_tid)
else: else:
self.right_tid = tid self.right_tid = tid
init_tbo_forward_context(False, self.right_tid) init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(tbo_step_stream):
queue.get() while not self._stop_evt.is_set():
self.tbo_thread_synchronize(tid) item = q.get()
if is_left_thread: if item is STOP:
attn_metadata = self.attn_metadata_left break
num_input_tokens = self.num_input_tokens_left
input_ids = self.input_ids_left with torch.cuda.stream(tbo_step_stream):
positions = self.positions_left self.tbo_thread_synchronize(tid)
else:
attn_metadata = self.attn_metadata_right if is_left_thread:
num_input_tokens = self.num_input_tokens_right attn_metadata = self.attn_metadata_left
input_ids = self.input_ids_right num_input_tokens = self.num_input_tokens_left
positions = self.positions_right input_ids = self.input_ids_left
positions = self.positions_left
model_output = None else:
# Run the decoder. attn_metadata = self.attn_metadata_right
# Use persistent buffers for CUDA graphs. num_input_tokens = self.num_input_tokens_right
with set_forward_context(attn_metadata, input_ids = self.input_ids_right
self.model_runner.vllm_config, positions = self.positions_right
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp, # Select per-thread tensors (left/right) with backward-compatible fallback
skip_cuda_graphs=True): if is_left_thread:
model_output = self.model_runner.model( intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
input_ids=input_ids, else:
positions=positions, intermediate_tensors = getattr(self, 'intermediate_tensors_right', None)
intermediate_tensors=self.intermediate_tensors, if intermediate_tensors is None:
inputs_embeds=self.inputs_embeds, intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
)
with set_forward_context(attn_metadata,
self.model_runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp,
skip_cuda_graphs=True,
):
model_output = self.model_runner.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=self.inputs_embeds,
)
if is_left_thread: if is_left_thread:
self.sem_right.release() self.sem_right.release()
self.states_left_queue.put(model_output) self.states_left_queue.put(model_output)
...@@ -117,18 +146,19 @@ class TwoBatchOverlap(): ...@@ -117,18 +146,19 @@ class TwoBatchOverlap():
return self.event_right_c2t, self.event_right_t2c return self.event_right_c2t, self.event_right_t2c
def set_model_input(self, def set_model_input(self,
model_runner, model_runner,
attn_metadata_left, attn_metadata_left,
attn_metadata_right, attn_metadata_right,
num_input_tokens_left, num_input_tokens_left,
num_input_tokens_right, num_input_tokens_right,
input_ids_left, input_ids_left,
input_ids_right, input_ids_right,
positions_left, positions_left,
positions_right, positions_right,
num_tokens_across_dp, num_tokens_across_dp,
intermediate_tensors, intermediate_tensors,
inputs_embeds): inputs_embeds,
):
self.model_runner = model_runner self.model_runner = model_runner
self.attn_metadata_left = attn_metadata_left self.attn_metadata_left = attn_metadata_left
self.attn_metadata_right = attn_metadata_right self.attn_metadata_right = attn_metadata_right
...@@ -139,26 +169,34 @@ class TwoBatchOverlap(): ...@@ -139,26 +169,34 @@ class TwoBatchOverlap():
self.positions_left = positions_left self.positions_left = positions_left
self.positions_right = positions_right self.positions_right = positions_right
self.num_tokens_across_dp = num_tokens_across_dp self.num_tokens_across_dp = num_tokens_across_dp
self.intermediate_tensors = intermediate_tensors
self.inputs_embeds = inputs_embeds self.inputs_embeds = inputs_embeds
if isinstance(intermediate_tensors, tuple):
self.intermediate_tensors_left, self.intermediate_tensors_right = intermediate_tensors
else:
self.intermediate_tensors_left = intermediate_tensors
self.intermediate_tensors_right = None
self.model_input_left_queue.put(None) self.model_input_left_queue.put(None)
self.model_input_right_queue.put(None) self.model_input_right_queue.put(None)
def get_model_output(self): def get_model_output(self):
states_left = self.states_left_queue.get() states_left = self.states_left_queue.get()
states_right = self.states_right_queue.get() states_right = self.states_right_queue.get()
return states_left, states_right return states_left, states_right
tbo_obj_v1 = None tbo_obj_v1 = None
def is_enable_tbo_v1(): def is_enable_tbo_v1():
global tbo_obj_v1 global tbo_obj_v1
return tbo_obj_v1 != None return tbo_obj_v1 is not None
def init_two_batch_overlap(): def init_two_batch_overlap():
global tbo_obj_v1 global tbo_obj_v1
if tbo_obj_v1 == None: if tbo_obj_v1 is None:
tbo_obj_v1 = TwoBatchOverlap() tbo_obj_v1 = TwoBatchOverlap()
tbo_obj_v1.init_tbo_thread() tbo_obj_v1.init_tbo_thread()
...@@ -171,7 +209,7 @@ def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache): ...@@ -171,7 +209,7 @@ def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
maybe_save_kv_layer_to_connector(layer_name, kv_cache) maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def tbo_all_reduce_v1(obj): def tbo_all_reduce_v1(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running: if envs.VLLM_ENABLE_TBO and tbo_obj_v1 is not None and tbo_obj_v1.tbo_running:
tid = threading.get_ident() tid = threading.get_ident()
if tid == tbo_obj_v1.left_tid: if tid == tbo_obj_v1.left_tid:
event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c
...@@ -185,7 +223,7 @@ def tbo_all_reduce_v1(obj): ...@@ -185,7 +223,7 @@ def tbo_all_reduce_v1(obj):
tbo_obj_v1.tbo_thread_synchronize(tid) tbo_obj_v1.tbo_thread_synchronize(tid)
tbo_step_stream.wait_event(event_t2c) tbo_step_stream.wait_event(event_t2c)
return output return output
return tensor_model_parallel_all_reduce(obj) return tensor_model_parallel_all_reduce(obj)
def merge_model_output(states_left, states_right): def merge_model_output(states_left, states_right):
if isinstance(states_left, IntermediateTensors): if isinstance(states_left, IntermediateTensors):
...@@ -199,45 +237,53 @@ def merge_model_output(states_left, states_right): ...@@ -199,45 +237,53 @@ def merge_model_output(states_left, states_right):
def tbo_model_executable_v1( def tbo_model_executable_v1(
model_runner, model_runner,
attn_metadata_left, attn_metadata_left,
attn_metadata_right, attn_metadata_right,
num_input_tokens_left, num_input_tokens_left,
num_input_tokens_right, num_input_tokens_right,
num_tokens_across_dp, num_tokens_across_dp,
input_ids, input_ids,
positions, positions,
intermediate_tensors, intermediate_tensors,
inputs_embeds inputs_embeds,
): ):
init_two_batch_overlap() init_two_batch_overlap()
tbo_obj_v1.tbo_running = True tbo_obj_v1.tbo_running = True
tbo_obj_v1.left_first = True tbo_obj_v1.left_first = True
tbo_obj_v1.step_event.record() tbo_obj_v1.step_event.record()
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
num_total_tokens = num_input_tokens_left + num_input_tokens_right
with torch.cuda.stream(tbo_step_stream): with torch.cuda.stream(tbo_step_stream):
tbo_step_stream.wait_event(tbo_obj_v1.step_event) tbo_step_stream.wait_event(tbo_obj_v1.step_event)
tokens_split = [num_input_tokens_left, num_input_tokens_right] tokens_split = [num_input_tokens_left, num_input_tokens_right]
input_ids_left, input_ids_right = torch.split(input_ids, tokens_split, dim=0) input_ids_left, input_ids_right = torch.split(input_ids[:num_total_tokens], tokens_split, dim=0)
positions_left, positions_right = torch.split(positions, tokens_split, dim=0) positions_left, positions_right = torch.split(positions[:num_total_tokens], tokens_split, dim=0)
tbo_obj_v1.set_model_input(model_runner, tbo_obj_v1.set_model_input(model_runner,
attn_metadata_left, attn_metadata_left,
attn_metadata_right, attn_metadata_right,
num_input_tokens_left, num_input_tokens_left,
num_input_tokens_right, num_input_tokens_right,
input_ids_left, input_ids_left,
input_ids_right, input_ids_right,
positions_left, positions_left,
positions_right, positions_right,
num_tokens_across_dp, num_tokens_across_dp,
intermediate_tensors, intermediate_tensors,
inputs_embeds) inputs_embeds,
)
model_output_left, model_output_right = tbo_obj_v1.get_model_output() model_output_left, model_output_right = tbo_obj_v1.get_model_output()
hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right) hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right)
tbo_obj_v1.tbo_running = False tbo_obj_v1.tbo_running = False
tbo_obj_v1.step_event.record() tbo_obj_v1.step_event.record()
tbo_obj_v1.finish_thread()
current_stream.wait_event(tbo_obj_v1.step_event) current_stream.wait_event(tbo_obj_v1.step_event)
return hidden_or_intermediate_states return hidden_or_intermediate_states
\ No newline at end of file
def finalize_two_batch_overlap():
global tbo_obj_v1
if tbo_obj_v1 is not None:
try:
tbo_obj_v1.shutdown()
finally:
tbo_obj_v1 = None
...@@ -1374,8 +1374,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1374,8 +1374,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# If attention doesn't support CUDA Graphs for this batch, but we # If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely. # compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ENABLE_TBO and scheduler_output.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
if envs.VLLM_ENABLE_TBO and (not self.use_cuda_graph or skip_cuda_graphs):
model_output, finished_sending, finished_recving = \ model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens, tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
num_tokens_across_dp, input_ids, positions, num_tokens_across_dp, input_ids, positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment