Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
6cb76b96
Unverified
Commit
6cb76b96
authored
Feb 04, 2026
by
Qi Wang
Committed by
GitHub
Feb 04, 2026
Browse files
feat: introduce cuda_ipc for TRT-LLM PrefillHandler (#5773)
parent
039d35ff
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
179 additions
and
1 deletion
+179
-1
components/src/dynamo/trtllm/multimodal/__init__.py
components/src/dynamo/trtllm/multimodal/__init__.py
+5
-1
components/src/dynamo/trtllm/multimodal/cuda_ipc.py
components/src/dynamo/trtllm/multimodal/cuda_ipc.py
+62
-0
components/src/dynamo/trtllm/tests/multimodal/test_trtllm_cuda_ipc.py
...rc/dynamo/trtllm/tests/multimodal/test_trtllm_cuda_ipc.py
+112
-0
No files found.
components/src/dynamo/trtllm/multimodal/__init__.py
View file @
6cb76b96
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
.cuda_ipc
import
extract_embeddings_from_handles
from
.hasher
import
MultimodalHasher
__all__
=
[
"MultimodalHasher"
]
__all__
=
[
"MultimodalHasher"
,
"extract_embeddings_from_handles"
,
]
components/src/dynamo/trtllm/multimodal/cuda_ipc.py
0 → 100644
View file @
6cb76b96
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
logging
from
typing
import
Any
,
Dict
,
List
import
torch
from
tensorrt_llm._torch.shared_tensor
import
SharedTensorContainer
logger
=
logging
.
getLogger
(
__name__
)
async
def
extract_embeddings_from_handles
(
handles
:
List
[
Dict
[
str
,
Any
]],
)
->
List
[
torch
.
Tensor
]:
"""
Extract all embedding tensors from CUDA IPC handles and move to CPU.
Runs extraction in a worker thread to avoid blocking the event loop
during GPU→CPU transfers.
WARNING: Do not reuse the given `handles` outside this function --
https://github.com/pytorch/pytorch/issues/149187
As of Jan 2026, it's safer to ensure one producer corresponds to one consumer so that
the ref counter_value return to 0, allowing Encode Process to release GPU memory
properly.
Args:
handles: List of CUDA IPC handle dictionaries from encoder response
Returns:
List of embedding tensors on CPU.
Raises:
ValueError: If a handle is missing required fields.
RuntimeError: If CUDA IPC reconstruction fails.
"""
# TODO(DIS-1398): expeiment
# - pinned memory DMA
# - parallelize GPU->CPU transfers in multiple threads
# - combination fo both (i.e. `cpu(non_blocking=True)`)
return
await
asyncio
.
to_thread
(
_extract_embeddings_sync
,
handles
)
def
_extract_embeddings_sync
(
handles
:
List
[
Dict
[
str
,
Any
]])
->
List
[
torch
.
Tensor
]:
"""Synchronously extract all embeddings from CUDA IPC handles."""
tensors
=
[]
for
i
,
handle_dict
in
enumerate
(
handles
):
try
:
container
=
SharedTensorContainer
.
from_dict
(
handle_dict
)
tensor
=
container
.
get_local_view
().
cpu
()
tensors
.
append
(
tensor
)
logger
.
debug
(
f
"Extracted embedding
{
i
}
: shape=
{
tensor
.
shape
}
, dtype=
{
tensor
.
dtype
}
"
)
except
KeyError
as
e
:
raise
ValueError
(
f
"Invalid handle
{
i
}
- missing field:
{
e
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to extract embedding
{
i
}
:
{
e
}
"
)
raise
RuntimeError
(
f
"Failed to extract embedding
{
i
}
:
{
e
}
"
)
return
tensors
components/src/dynamo/trtllm/tests/multimodal/test_trtllm_cuda_ipc.py
0 → 100644
View file @
6cb76b96
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for CUDA IPC embedding extraction utilities."""
import
asyncio
import
multiprocessing
as
mp
from
multiprocessing.synchronize
import
Event
as
EventType
from
typing
import
Any
,
Callable
import
pytest
import
torch
from
tensorrt_llm._torch.shared_tensor.shared_tensor
import
(
SharedTensorContainer
,
_SharedTensorRebuildMethodRegistry
,
)
from
dynamo.trtllm.multimodal.cuda_ipc
import
extract_embeddings_from_handles
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
trtllm
,
pytest
.
mark
.
gpu_1
,
]
def
_create_tensor_on_gpu
()
->
torch
.
Tensor
:
"""Create test tensor on GPU."""
return
torch
.
arange
(
100
*
2048
,
dtype
=
torch
.
float16
,
device
=
"cuda"
).
reshape
(
100
,
2048
)
def
producer_process
(
create_tensor
:
Callable
[[],
torch
.
Tensor
],
handle_queue
:
mp
.
Queue
,
done_event
:
EventType
,
):
"""Producer: creates GPU tensor and shares via CUDA IPC."""
try
:
tensor
=
create_tensor
()
# Share via CUDA IPC
container
=
SharedTensorContainer
.
from_tensor
(
tensor
)
handle
=
container
.
dump_to_dict
()
handle_queue
.
put
(
handle
)
# Keep process alive until consumer is done
done_event
.
wait
()
except
Exception
as
e
:
print
(
f
"Producer error:
{
e
}
"
)
raise
def
consumer_process
(
handle_queue
:
mp
.
Queue
,
result_queue
:
mp
.
Queue
,
done_event
:
EventType
):
"""Consumer: receives handle and extracts embedding via CUDA IPC."""
try
:
# Initialize shared tensor rebuild method registry
_SharedTensorRebuildMethodRegistry
.
initialize
()
# Receive handle
handle
=
handle_queue
.
get
(
timeout
=
10
)
# Extract embedding via CUDA IPC - pass list of handles directly (async)
result
=
asyncio
.
run
(
extract_embeddings_from_handles
([
handle
]))
# Send result
result_queue
.
put
(
result
[
0
])
except
Exception
as
e
:
print
(
f
"Consumer error:
{
e
}
"
)
raise
finally
:
# Always signal producer to exit
done_event
.
set
()
class
TestExtractEmbeddingsFromHandles
:
"""Tests for extract_embeddings_from_handles function."""
def
test_extracts_all_embeddings
(
self
):
"""Test that embeddings are extracted successfully from GPU via CUDA IPC."""
ctx
=
mp
.
get_context
(
"spawn"
)
handle_queue
:
mp
.
Queue
[
Any
]
=
ctx
.
Queue
()
result_queue
:
mp
.
Queue
[
Any
]
=
ctx
.
Queue
()
done_event
=
ctx
.
Event
()
# Start processes
producer
=
ctx
.
Process
(
target
=
producer_process
,
args
=
(
_create_tensor_on_gpu
,
handle_queue
,
done_event
),
)
consumer
=
ctx
.
Process
(
target
=
consumer_process
,
args
=
(
handle_queue
,
result_queue
,
done_event
)
)
producer
.
start
()
consumer
.
start
()
# Get result tensor
result
=
result_queue
.
get
(
timeout
=
30
)
consumer
.
join
(
timeout
=
10
)
producer
.
join
(
timeout
=
10
)
# Verify against expected tensor
expected
=
_create_tensor_on_gpu
().
cpu
()
assert
result
.
shape
==
expected
.
shape
assert
result
.
device
.
type
==
"cpu"
assert
torch
.
equal
(
result
,
expected
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment