"vllm/vscode:/vscode.git/clone" did not exist on "aba8d6ee006b78149ac4514f460e4038b2d4f607"
Unverified Commit 49d20346 authored by jackwang2120's avatar jackwang2120 Committed by GitHub
Browse files

[Perf] Reduce H2D pageable memory copies (#38794)


Signed-off-by: jackwang2120's avatarjackcfwang <jackcfwang@tencent.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent ef076c1b
......@@ -85,40 +85,42 @@ class TritonAttentionMetadata:
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
mm_prefix_range_tensor: torch.Tensor | None = None
@property
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
@staticmethod
def compute_mm_prefix_range_tensor(
mm_prefix_range: dict[int, list[tuple[int, int]]] | None,
num_seqs: int,
device: torch.device,
) -> torch.Tensor | None:
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
Empty ranges have start==end==0, which kernel skips via is_valid check.
"""
# TODO(Isotr0py): Move to model runner's attention metadata
# preparation to avoid duplicate computation.
if self.mm_prefix_range is None:
if mm_prefix_range is None:
return None
num_seqs = self.seq_lens.shape[0]
device = self.seq_lens.device
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
range_lists = [
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
]
# Return None if all ranges are trivial (only (0,0) placeholders)
if all(r == [(0, 0)] for r in range_lists):
return None
# Create 2D tensors with shape (num_ranges, 2) for each sequence
range_tensors = [
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
for r in range_lists
]
return torch.nested.nested_tensor(
range_tensors, layout=torch.jagged
).to_padded_tensor(0)
# Build on CPU first then move to GPU in a single H2D transfer
max_ranges = max(len(r) for r in range_lists)
# Pad all sequences to the same number of ranges
padded = []
for r in range_lists:
padded_r = list(r) + [(0, 0)] * (max_ranges - len(r))
padded.append(padded_r)
# Create tensor with efficient H2D transfer
return torch.tensor(padded, dtype=torch.int32, device=device).view(
num_seqs, max_ranges, 2
)
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
......
......@@ -2332,13 +2332,8 @@ class GPUModelRunner(
req_idx = self.input_batch.req_id_to_index[req_id]
req_doc_ranges[req_idx] = image_doc_ranges
if isinstance(attn_metadata, list):
for ub_metadata in attn_metadata:
for _metadata in ub_metadata.values():
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
else:
for _metadata in attn_metadata.values():
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
# Set mm_prefix_range for all attention metadata
self._set_mm_prefix_range_for_metadata(attn_metadata, req_doc_ranges)
if spec_decode_common_attn_metadata is not None and (
num_reqs != num_reqs_padded or num_tokens != num_tokens_padded
......@@ -6487,6 +6482,46 @@ class GPUModelRunner(
return
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment]
def _set_mm_prefix_range_for_metadata(
self,
attn_metadata: Any,
req_doc_ranges: dict[int, list[tuple[int, int]]],
) -> None:
"""Set mm_prefix_range for all attention metadata objects.
This method handles both list and non-list attention metadata,
computing mm_prefix_range_tensor once and sharing it across all
metadata objects to avoid redundant host-to-device transfers.
"""
from vllm.v1.attention.backends.triton_attn import (
TritonAttentionMetadata,
)
# Get all metadata objects from either list or dict structure
metadata_list = []
if isinstance(attn_metadata, list):
for ub_metadata in attn_metadata:
metadata_list.extend(ub_metadata.values())
else:
metadata_list.extend(attn_metadata.values())
# Set mm_prefix_range for all metadata and compute tensor once
shared_tensor = None
for metadata in metadata_list:
metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
# Only compute tensor for TritonAttentionMetadata
if isinstance(metadata, TritonAttentionMetadata):
if shared_tensor is None:
shared_tensor = (
TritonAttentionMetadata.compute_mm_prefix_range_tensor(
req_doc_ranges,
metadata.seq_lens.shape[0], # type: ignore[attr-defined]
metadata.seq_lens.device, # type: ignore[attr-defined]
)
)
metadata.mm_prefix_range_tensor = shared_tensor
def may_reinitialize_input_batch(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment