Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
6cb76b96
"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "182d3b5dc7b2836724a8560ed92cc88ba41fc250"
Unverified
Commit
6cb76b96
authored
Feb 04, 2026
by
Qi Wang
Committed by
GitHub
Feb 04, 2026
Browse files
feat: introduce cuda_ipc for TRT-LLM PrefillHandler (#5773)
parent
039d35ff
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
179 additions
and
1 deletion
+179
-1
components/src/dynamo/trtllm/multimodal/__init__.py
components/src/dynamo/trtllm/multimodal/__init__.py
+5
-1
components/src/dynamo/trtllm/multimodal/cuda_ipc.py
components/src/dynamo/trtllm/multimodal/cuda_ipc.py
+62
-0
components/src/dynamo/trtllm/tests/multimodal/test_trtllm_cuda_ipc.py
...rc/dynamo/trtllm/tests/multimodal/test_trtllm_cuda_ipc.py
+112
-0
No files found.
components/src/dynamo/trtllm/multimodal/__init__.py
View file @
6cb76b96
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
.cuda_ipc
import
extract_embeddings_from_handles
from
.hasher
import
MultimodalHasher
__all__
=
[
"MultimodalHasher"
]
__all__
=
[
"MultimodalHasher"
,
"extract_embeddings_from_handles"
,
]
components/src/dynamo/trtllm/multimodal/cuda_ipc.py
0 → 100644
View file @
6cb76b96
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
logging
from
typing
import
Any
,
Dict
,
List
import
torch
from
tensorrt_llm._torch.shared_tensor
import
SharedTensorContainer
logger
=
logging
.
getLogger
(
__name__
)
async
def
extract_embeddings_from_handles
(
handles
:
List
[
Dict
[
str
,
Any
]],
)
->
List
[
torch
.
Tensor
]:
"""
Extract all embedding tensors from CUDA IPC handles and move to CPU.
Runs extraction in a worker thread to avoid blocking the event loop
during GPU→CPU transfers.
WARNING: Do not reuse the given `handles` outside this function --
https://github.com/pytorch/pytorch/issues/149187
As of Jan 2026, it's safer to ensure one producer corresponds to one consumer so that
the ref counter_value return to 0, allowing Encode Process to release GPU memory
properly.
Args:
handles: List of CUDA IPC handle dictionaries from encoder response
Returns:
List of embedding tensors on CPU.
Raises:
ValueError: If a handle is missing required fields.
RuntimeError: If CUDA IPC reconstruction fails.
"""
# TODO(DIS-1398): expeiment
# - pinned memory DMA
# - parallelize GPU->CPU transfers in multiple threads
# - combination fo both (i.e. `cpu(non_blocking=True)`)
return
await
asyncio
.
to_thread
(
_extract_embeddings_sync
,
handles
)
def
_extract_embeddings_sync
(
handles
:
List
[
Dict
[
str
,
Any
]])
->
List
[
torch
.
Tensor
]:
"""Synchronously extract all embeddings from CUDA IPC handles."""
tensors
=
[]
for
i
,
handle_dict
in
enumerate
(
handles
):
try
:
container
=
SharedTensorContainer
.
from_dict
(
handle_dict
)
tensor
=
container
.
get_local_view
().
cpu
()
tensors
.
append
(
tensor
)
logger
.
debug
(
f
"Extracted embedding
{
i
}
: shape=
{
tensor
.
shape
}
, dtype=
{
tensor
.
dtype
}
"
)
except
KeyError
as
e
:
raise
ValueError
(
f
"Invalid handle
{
i
}
- missing field:
{
e
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to extract embedding
{
i
}
:
{
e
}
"
)
raise
RuntimeError
(
f
"Failed to extract embedding
{
i
}
:
{
e
}
"
)
return
tensors
components/src/dynamo/trtllm/tests/multimodal/test_trtllm_cuda_ipc.py
0 → 100644
View file @
6cb76b96
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for CUDA IPC embedding extraction utilities."""
import
asyncio
import
multiprocessing
as
mp
from
multiprocessing.synchronize
import
Event
as
EventType
from
typing
import
Any
,
Callable
import
pytest
import
torch
from
tensorrt_llm._torch.shared_tensor.shared_tensor
import
(
SharedTensorContainer
,
_SharedTensorRebuildMethodRegistry
,
)
from
dynamo.trtllm.multimodal.cuda_ipc
import
extract_embeddings_from_handles
pytestmark
=
[
pytest
.
mark
.
pre_merge
,
pytest
.
mark
.
unit
,
pytest
.
mark
.
trtllm
,
pytest
.
mark
.
gpu_1
,
]
def
_create_tensor_on_gpu
()
->
torch
.
Tensor
:
"""Create test tensor on GPU."""
return
torch
.
arange
(
100
*
2048
,
dtype
=
torch
.
float16
,
device
=
"cuda"
).
reshape
(
100
,
2048
)
def
producer_process
(
create_tensor
:
Callable
[[],
torch
.
Tensor
],
handle_queue
:
mp
.
Queue
,
done_event
:
EventType
,
):
"""Producer: creates GPU tensor and shares via CUDA IPC."""
try
:
tensor
=
create_tensor
()
# Share via CUDA IPC
container
=
SharedTensorContainer
.
from_tensor
(
tensor
)
handle
=
container
.
dump_to_dict
()
handle_queue
.
put
(
handle
)
# Keep process alive until consumer is done
done_event
.
wait
()
except
Exception
as
e
:
print
(
f
"Producer error:
{
e
}
"
)
raise
def
consumer_process
(
handle_queue
:
mp
.
Queue
,
result_queue
:
mp
.
Queue
,
done_event
:
EventType
):
"""Consumer: receives handle and extracts embedding via CUDA IPC."""
try
:
# Initialize shared tensor rebuild method registry
_SharedTensorRebuildMethodRegistry
.
initialize
()
# Receive handle
handle
=
handle_queue
.
get
(
timeout
=
10
)
# Extract embedding via CUDA IPC - pass list of handles directly (async)
result
=
asyncio
.
run
(
extract_embeddings_from_handles
([
handle
]))
# Send result
result_queue
.
put
(
result
[
0
])
except
Exception
as
e
:
print
(
f
"Consumer error:
{
e
}
"
)
raise
finally
:
# Always signal producer to exit
done_event
.
set
()
class
TestExtractEmbeddingsFromHandles
:
"""Tests for extract_embeddings_from_handles function."""
def
test_extracts_all_embeddings
(
self
):
"""Test that embeddings are extracted successfully from GPU via CUDA IPC."""
ctx
=
mp
.
get_context
(
"spawn"
)
handle_queue
:
mp
.
Queue
[
Any
]
=
ctx
.
Queue
()
result_queue
:
mp
.
Queue
[
Any
]
=
ctx
.
Queue
()
done_event
=
ctx
.
Event
()
# Start processes
producer
=
ctx
.
Process
(
target
=
producer_process
,
args
=
(
_create_tensor_on_gpu
,
handle_queue
,
done_event
),
)
consumer
=
ctx
.
Process
(
target
=
consumer_process
,
args
=
(
handle_queue
,
result_queue
,
done_event
)
)
producer
.
start
()
consumer
.
start
()
# Get result tensor
result
=
result_queue
.
get
(
timeout
=
30
)
consumer
.
join
(
timeout
=
10
)
producer
.
join
(
timeout
=
10
)
# Verify against expected tensor
expected
=
_create_tensor_on_gpu
().
cpu
()
assert
result
.
shape
==
expected
.
shape
assert
result
.
device
.
type
==
"cpu"
assert
torch
.
equal
(
result
,
expected
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment