"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "182d3b5dc7b2836724a8560ed92cc88ba41fc250"
Unverified Commit 6cb76b96 authored by Qi Wang's avatar Qi Wang Committed by GitHub
Browse files

feat: introduce cuda_ipc for TRT-LLM PrefillHandler (#5773)

parent 039d35ff
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from .cuda_ipc import extract_embeddings_from_handles
from .hasher import MultimodalHasher
__all__ = ["MultimodalHasher"]
__all__ = [
"MultimodalHasher",
"extract_embeddings_from_handles",
]
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
from typing import Any, Dict, List
import torch
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
logger = logging.getLogger(__name__)
async def extract_embeddings_from_handles(
handles: List[Dict[str, Any]],
) -> List[torch.Tensor]:
"""
Extract all embedding tensors from CUDA IPC handles and move to CPU.
Runs extraction in a worker thread to avoid blocking the event loop
during GPU→CPU transfers.
WARNING: Do not reuse the given `handles` outside this function --
https://github.com/pytorch/pytorch/issues/149187
As of Jan 2026, it's safer to ensure one producer corresponds to one consumer so that
the ref counter_value return to 0, allowing Encode Process to release GPU memory
properly.
Args:
handles: List of CUDA IPC handle dictionaries from encoder response
Returns:
List of embedding tensors on CPU.
Raises:
ValueError: If a handle is missing required fields.
RuntimeError: If CUDA IPC reconstruction fails.
"""
# TODO(DIS-1398): expeiment
# - pinned memory DMA
# - parallelize GPU->CPU transfers in multiple threads
# - combination fo both (i.e. `cpu(non_blocking=True)`)
return await asyncio.to_thread(_extract_embeddings_sync, handles)
def _extract_embeddings_sync(handles: List[Dict[str, Any]]) -> List[torch.Tensor]:
"""Synchronously extract all embeddings from CUDA IPC handles."""
tensors = []
for i, handle_dict in enumerate(handles):
try:
container = SharedTensorContainer.from_dict(handle_dict)
tensor = container.get_local_view().cpu()
tensors.append(tensor)
logger.debug(
f"Extracted embedding {i}: shape={tensor.shape}, dtype={tensor.dtype}"
)
except KeyError as e:
raise ValueError(f"Invalid handle {i} - missing field: {e}")
except Exception as e:
logger.error(f"Failed to extract embedding {i}: {e}")
raise RuntimeError(f"Failed to extract embedding {i}: {e}")
return tensors
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for CUDA IPC embedding extraction utilities."""
import asyncio
import multiprocessing as mp
from multiprocessing.synchronize import Event as EventType
from typing import Any, Callable
import pytest
import torch
from tensorrt_llm._torch.shared_tensor.shared_tensor import (
SharedTensorContainer,
_SharedTensorRebuildMethodRegistry,
)
from dynamo.trtllm.multimodal.cuda_ipc import extract_embeddings_from_handles
pytestmark = [
pytest.mark.pre_merge,
pytest.mark.unit,
pytest.mark.trtllm,
pytest.mark.gpu_1,
]
def _create_tensor_on_gpu() -> torch.Tensor:
"""Create test tensor on GPU."""
return torch.arange(100 * 2048, dtype=torch.float16, device="cuda").reshape(
100, 2048
)
def producer_process(
create_tensor: Callable[[], torch.Tensor],
handle_queue: mp.Queue,
done_event: EventType,
):
"""Producer: creates GPU tensor and shares via CUDA IPC."""
try:
tensor = create_tensor()
# Share via CUDA IPC
container = SharedTensorContainer.from_tensor(tensor)
handle = container.dump_to_dict()
handle_queue.put(handle)
# Keep process alive until consumer is done
done_event.wait()
except Exception as e:
print(f"Producer error: {e}")
raise
def consumer_process(
handle_queue: mp.Queue, result_queue: mp.Queue, done_event: EventType
):
"""Consumer: receives handle and extracts embedding via CUDA IPC."""
try:
# Initialize shared tensor rebuild method registry
_SharedTensorRebuildMethodRegistry.initialize()
# Receive handle
handle = handle_queue.get(timeout=10)
# Extract embedding via CUDA IPC - pass list of handles directly (async)
result = asyncio.run(extract_embeddings_from_handles([handle]))
# Send result
result_queue.put(result[0])
except Exception as e:
print(f"Consumer error: {e}")
raise
finally:
# Always signal producer to exit
done_event.set()
class TestExtractEmbeddingsFromHandles:
"""Tests for extract_embeddings_from_handles function."""
def test_extracts_all_embeddings(self):
"""Test that embeddings are extracted successfully from GPU via CUDA IPC."""
ctx = mp.get_context("spawn")
handle_queue: mp.Queue[Any] = ctx.Queue()
result_queue: mp.Queue[Any] = ctx.Queue()
done_event = ctx.Event()
# Start processes
producer = ctx.Process(
target=producer_process,
args=(_create_tensor_on_gpu, handle_queue, done_event),
)
consumer = ctx.Process(
target=consumer_process, args=(handle_queue, result_queue, done_event)
)
producer.start()
consumer.start()
# Get result tensor
result = result_queue.get(timeout=30)
consumer.join(timeout=10)
producer.join(timeout=10)
# Verify against expected tensor
expected = _create_tensor_on_gpu().cpu()
assert result.shape == expected.shape
assert result.device.type == "cpu"
assert torch.equal(result, expected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment