Unverified Commit 49d20346 authored by jackwang2120's avatar jackwang2120 Committed by GitHub
Browse files

[Perf] Reduce H2D pageable memory copies (#38794)


Signed-off-by: jackwang2120's avatarjackcfwang <jackcfwang@tencent.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent ef076c1b
...@@ -85,40 +85,42 @@ class TritonAttentionMetadata: ...@@ -85,40 +85,42 @@ class TritonAttentionMetadata:
scheduler_metadata: torch.Tensor | None = None scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None prefix_scheduler_metadata: torch.Tensor | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
mm_prefix_range_tensor: torch.Tensor | None = None
@property @staticmethod
def mm_prefix_range_tensor(self) -> torch.Tensor | None: def compute_mm_prefix_range_tensor(
mm_prefix_range: dict[int, list[tuple[int, int]]] | None,
num_seqs: int,
device: torch.device,
) -> torch.Tensor | None:
"""Convert mm_prefix_range dict to padded tensor for Triton kernel. """Convert mm_prefix_range dict to padded tensor for Triton kernel.
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges. Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
Empty ranges have start==end==0, which kernel skips via is_valid check. Empty ranges have start==end==0, which kernel skips via is_valid check.
""" """
# TODO(Isotr0py): Move to model runner's attention metadata if mm_prefix_range is None:
# preparation to avoid duplicate computation.
if self.mm_prefix_range is None:
return None return None
num_seqs = self.seq_lens.shape[0]
device = self.seq_lens.device
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims # Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
range_lists = [ range_lists = [
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs) mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
] ]
# Return None if all ranges are trivial (only (0,0) placeholders) # Return None if all ranges are trivial (only (0,0) placeholders)
if all(r == [(0, 0)] for r in range_lists): if all(r == [(0, 0)] for r in range_lists):
return None return None
# Create 2D tensors with shape (num_ranges, 2) for each sequence # Build on CPU first then move to GPU in a single H2D transfer
range_tensors = [ max_ranges = max(len(r) for r in range_lists)
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2) # Pad all sequences to the same number of ranges
for r in range_lists padded = []
] for r in range_lists:
padded_r = list(r) + [(0, 0)] * (max_ranges - len(r))
return torch.nested.nested_tensor( padded.append(padded_r)
range_tensors, layout=torch.jagged # Create tensor with efficient H2D transfer
).to_padded_tensor(0) return torch.tensor(padded, dtype=torch.int32, device=device).view(
num_seqs, max_ranges, 2
)
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]): class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
......
...@@ -2332,13 +2332,8 @@ class GPUModelRunner( ...@@ -2332,13 +2332,8 @@ class GPUModelRunner(
req_idx = self.input_batch.req_id_to_index[req_id] req_idx = self.input_batch.req_id_to_index[req_id]
req_doc_ranges[req_idx] = image_doc_ranges req_doc_ranges[req_idx] = image_doc_ranges
if isinstance(attn_metadata, list): # Set mm_prefix_range for all attention metadata
for ub_metadata in attn_metadata: self._set_mm_prefix_range_for_metadata(attn_metadata, req_doc_ranges)
for _metadata in ub_metadata.values():
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
else:
for _metadata in attn_metadata.values():
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
if spec_decode_common_attn_metadata is not None and ( if spec_decode_common_attn_metadata is not None and (
num_reqs != num_reqs_padded or num_tokens != num_tokens_padded num_reqs != num_reqs_padded or num_tokens != num_tokens_padded
...@@ -6487,6 +6482,46 @@ class GPUModelRunner( ...@@ -6487,6 +6482,46 @@ class GPUModelRunner(
return return
self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment] self.reorder_batch_threshold = reduce(min_none_high, reorder_batch_thresholds) # type: ignore[assignment]
def _set_mm_prefix_range_for_metadata(
self,
attn_metadata: Any,
req_doc_ranges: dict[int, list[tuple[int, int]]],
) -> None:
"""Set mm_prefix_range for all attention metadata objects.
This method handles both list and non-list attention metadata,
computing mm_prefix_range_tensor once and sharing it across all
metadata objects to avoid redundant host-to-device transfers.
"""
from vllm.v1.attention.backends.triton_attn import (
TritonAttentionMetadata,
)
# Get all metadata objects from either list or dict structure
metadata_list = []
if isinstance(attn_metadata, list):
for ub_metadata in attn_metadata:
metadata_list.extend(ub_metadata.values())
else:
metadata_list.extend(attn_metadata.values())
# Set mm_prefix_range for all metadata and compute tensor once
shared_tensor = None
for metadata in metadata_list:
metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
# Only compute tensor for TritonAttentionMetadata
if isinstance(metadata, TritonAttentionMetadata):
if shared_tensor is None:
shared_tensor = (
TritonAttentionMetadata.compute_mm_prefix_range_tensor(
req_doc_ranges,
metadata.seq_lens.shape[0], # type: ignore[attr-defined]
metadata.seq_lens.device, # type: ignore[attr-defined]
)
)
metadata.mm_prefix_range_tensor = shared_tensor
def may_reinitialize_input_batch( def may_reinitialize_input_batch(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None: ) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment