Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
75e774d4
Unverified
Commit
75e774d4
authored
May 27, 2025
by
J Wyman
Committed by
GitHub
May 27, 2025
Browse files
feat: NIXL Based RDMA Support w/ Multimodal Example (#1060)
parent
9acaa8d1
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1784 additions
and
112 deletions
+1784
-112
.pre-commit-config.yaml
.pre-commit-config.yaml
+2
-2
examples/multimodal/components/decode_worker.py
examples/multimodal/components/decode_worker.py
+88
-29
examples/multimodal/components/encode_worker.py
examples/multimodal/components/encode_worker.py
+99
-7
examples/multimodal/components/prefill_worker.py
examples/multimodal/components/prefill_worker.py
+102
-56
examples/multimodal/components/processor.py
examples/multimodal/components/processor.py
+3
-3
examples/multimodal/configs/agg.yaml
examples/multimodal/configs/agg.yaml
+2
-2
examples/multimodal/configs/disagg.yaml
examples/multimodal/configs/disagg.yaml
+3
-3
examples/multimodal/connect/__init__.py
examples/multimodal/connect/__init__.py
+1471
-0
examples/multimodal/graphs/agg.py
examples/multimodal/graphs/agg.py
+4
-3
examples/multimodal/graphs/disagg.py
examples/multimodal/graphs/disagg.py
+6
-4
examples/multimodal/utils/protocol.py
examples/multimodal/utils/protocol.py
+4
-3
No files found.
.pre-commit-config.yaml
View file @
75e774d4
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
exclude
:
^(src/grpc_generated|.*\.patch$)
exclude
:
^(src/grpc_generated|.*\.patch$
|.*/connect/.*\.py
)
repos
:
repos
:
-
repo
:
https://github.com/timothycrosley/isort
-
repo
:
https://github.com/timothycrosley/isort
rev
:
5.12.0
rev
:
5.12.0
...
@@ -82,4 +82,4 @@ repos:
...
@@ -82,4 +82,4 @@ repos:
# NOTE: pyright may be able to find other classes of errors not covered above,
# NOTE: pyright may be able to find other classes of errors not covered above,
# but would require some configuring and venv setup to properly eliminate noise
# but would require some configuring and venv setup to properly eliminate noise
# and give it visiblity into all the local and third_party packages expected.
# and give it visiblity into all the local and third_party packages expected.
\ No newline at end of file
examples/multimodal/components/worker.py
→
examples/multimodal/components/
decode_
worker.py
View file @
75e774d4
...
@@ -19,10 +19,11 @@ import os
...
@@ -19,10 +19,11 @@ import os
import
signal
import
signal
from
typing
import
Optional
from
typing
import
Optional
import
connect
import
torch
import
torch
from
components.disagg_router
import
PyDisaggregatedRouter
from
components.disagg_router
import
PyDisaggregatedRouter
from
components.encode_worker
import
EncodeWorker
from
components.encode_worker
import
Vllm
EncodeWorker
from
components.prefill_worker
import
PrefillWorker
from
components.prefill_worker
import
Vllm
PrefillWorker
from
transformers
import
LlavaForConditionalGeneration
from
transformers
import
LlavaForConditionalGeneration
from
utils.logging
import
check_required_workers
from
utils.logging
import
check_required_workers
from
utils.nixl
import
NixlMetadataStore
from
utils.nixl
import
NixlMetadataStore
...
@@ -53,11 +54,11 @@ logger = logging.getLogger(__name__)
...
@@ -53,11 +54,11 @@ logger = logging.getLogger(__name__)
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
workers
=
1
,
workers
=
1
,
)
)
class
VllmWorker
:
class
Vllm
Decode
Worker
:
# For disaggregated serving, we need to link the prefill worker to the vllm worker
# For disaggregated serving, we need to link the prefill worker to the vllm worker
prefill_worker
=
depends
(
PrefillWorker
)
prefill_worker
=
depends
(
Vllm
PrefillWorker
)
# For aggregated serving, we need to link the encode worker to the vllm worker.
# For aggregated serving, we need to link the encode worker to the vllm worker.
encode_worker
=
depends
(
EncodeWorker
)
encode_worker
=
depends
(
Vllm
EncodeWorker
)
def
__init__
(
self
):
def
__init__
(
self
):
self
.
client
=
None
self
.
client
=
None
...
@@ -141,7 +142,11 @@ class VllmWorker:
...
@@ -141,7 +142,11 @@ class VllmWorker:
vision_tower
.
vision_model
.
embeddings
.
position_embedding
.
num_embeddings
vision_tower
.
vision_model
.
embeddings
.
position_embedding
.
num_embeddings
)
)
else
:
else
:
enc_comp_ns
,
enc_comp_name
=
EncodeWorker
.
dynamo_address
()
# type: ignore
EMBEDDINGS_SHAPE
=
(
1
,
577
,
4096
)
EMBEDDINGS_DTYPE
=
torch
.
float16
EMBEDDINGS_DEVICE
=
"cuda"
enc_comp_ns
,
enc_comp_name
=
VllmEncodeWorker
.
dynamo_address
()
# type: ignore
self
.
encode_worker_client
=
(
self
.
encode_worker_client
=
(
await
runtime
.
namespace
(
enc_comp_ns
)
await
runtime
.
namespace
(
enc_comp_ns
)
.
component
(
enc_comp_name
)
.
component
(
enc_comp_name
)
...
@@ -149,9 +154,22 @@ class VllmWorker:
...
@@ -149,9 +154,22 @@ class VllmWorker:
.
client
()
.
client
()
)
)
self
.
_connector
=
connect
.
Connector
(
runtime
=
runtime
,
namespace
=
enc_comp_ns
)
await
self
.
_connector
.
initialize
()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings
=
torch
.
empty
(
EMBEDDINGS_SHAPE
,
dtype
=
EMBEDDINGS_DTYPE
,
device
=
EMBEDDINGS_DEVICE
)
descriptor
=
connect
.
Descriptor
(
embeddings
)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor
.
register_memory
(
self
.
_connector
)
self
.
_embeddings_descriptor
=
(
embeddings
,
descriptor
)
await
check_required_workers
(
self
.
encode_worker_client
,
self
.
min_workers
)
await
check_required_workers
(
self
.
encode_worker_client
,
self
.
min_workers
)
self
.
disaggregated_router
=
None
self
.
disaggregated_router
=
None
logger
.
info
(
"VllmWorker has been initialized"
)
logger
.
info
(
"Initialization complete."
)
def
shutdown_vllm_engine
(
self
,
signum
,
frame
):
def
shutdown_vllm_engine
(
self
,
signum
,
frame
):
"""Shutdown the background loop"""
"""Shutdown the background loop"""
...
@@ -159,7 +177,7 @@ class VllmWorker:
...
@@ -159,7 +177,7 @@ class VllmWorker:
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
try
:
try
:
self
.
engine_client
.
close
()
self
.
engine_client
.
close
()
logger
.
info
(
"
VllmWorker s
hutdown complete"
)
logger
.
info
(
"
S
hutdown complete
.
"
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error during shutdown:
{
e
}
"
)
logger
.
error
(
f
"Error during shutdown:
{
e
}
"
)
finally
:
finally
:
...
@@ -177,8 +195,18 @@ class VllmWorker:
...
@@ -177,8 +195,18 @@ class VllmWorker:
@
endpoint
()
@
endpoint
()
async
def
generate
(
self
,
request
:
vLLMMultimodalRequest
):
async
def
generate
(
self
,
request
:
vLLMMultimodalRequest
):
image_features
=
None
request_id
=
request
.
request_id
image_url
=
request
.
image_url
logger
.
info
(
f
"Received multimodal request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
embeddings
=
None
if
self
.
do_remote_prefill
:
if
self
.
do_remote_prefill
:
logger
.
debug
(
f
"Disaggregated: request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}"
" prefill worker will populate the decode model's key-value cache ahead of time;"
" no direct encode worker interaction required."
)
if
self
.
disaggregated_router
is
not
None
:
if
self
.
disaggregated_router
is
not
None
:
async
with
PrefillQueue
.
get_instance
(
async
with
PrefillQueue
.
get_instance
(
nats_server
=
self
.
_prefill_queue_nats_server
,
nats_server
=
self
.
_prefill_queue_nats_server
,
...
@@ -195,21 +223,21 @@ class VllmWorker:
...
@@ -195,21 +223,21 @@ class VllmWorker:
disagg_router_decision
=
True
disagg_router_decision
=
True
if
self
.
do_remote_prefill
and
disagg_router_decision
:
if
self
.
do_remote_prefill
and
disagg_router_decision
:
logger
.
debug
(
f
"Prefilling remotely for request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }} with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
remote_prefill_params
=
RemotePrefillParams
(
remote_prefill_params
=
RemotePrefillParams
(
is_remote_prefill
=
True
,
is_remote_prefill
=
True
,
remote_prefill_request_callback
=
self
.
get_remote_prefill_request_callback
(),
remote_prefill_request_callback
=
self
.
get_remote_prefill_request_callback
(),
# Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
# Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
multimodal_data_source
=
{
multimodal_data_source
=
{
"image_url"
:
request
.
image_url
,
"image_url"
:
image_url
,
},
},
)
)
logger
.
info
(
f
"Prefilling remotely for request
{
request
.
request_id
}
with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
else
:
else
:
remote_prefill_params
=
None
remote_prefill_params
=
None
logger
.
info
(
logger
.
debug
(
f
"Prefilling locally for request
{
request
.
request_id
}
with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
f
"Prefilling locally for request {
{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }
} with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
)
# The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
# The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
...
@@ -231,33 +259,61 @@ class VllmWorker:
...
@@ -231,33 +259,61 @@ class VllmWorker:
)
)
else
:
else
:
# For aggregated serving, the vllm worker will directly send the encode request to the encode worker.
logger
.
debug
(
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
f
"Aggregated: request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}"
EncodeRequest
(
" no prefill worker available, embeddings directly from encode worker."
image_url
=
request
.
image_url
,
).
model_dump_json
()
)
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
async
for
encode_response
in
encode_generator
:
# Doing this avoids unnessesary memory de/registration with NIXL.
encode_output
=
EncodeResponse
.
model_validate_json
(
embeddings
,
descriptor
=
self
.
_embeddings_descriptor
encode_response
.
data
()
with
self
.
_connector
.
create_writable
(
descriptor
)
as
writable
:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_request
=
EncodeRequest
(
request_id
=
request_id
,
image_url
=
image_url
,
serialized_request
=
writable
.
to_serialized
(),
)
)
image_features
=
torch
.
tensor
(
logger
.
debug
(
f
"Encode request:
{
encode_request
.
model_dump_json
()
}
"
)
encode_output
.
image_features
,
device
=
device
,
dtype
=
torch
.
float16
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
encode_request
.
model_dump_json
()
)
)
async
for
encode_response
in
encode_generator
:
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
()
)
logger
.
info
(
f
"Received response: {{ id:
{
encode_output
.
request_id
}
}}"
)
# Wait for the write operation to complete.
# This will block until the write operation is complete.
# This await should be a no-op since we've already received a response from the encode worker.
await
writable
.
wait_for_completion
()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params
=
None
remote_prefill_params
=
None
logger
.
info
(
logger
.
info
(
f
"Prefilling locally for request
{
request
.
request_id
}
with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
f
"Prefilling locally for request {
{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }
} with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
)
prompt_ids
=
request
.
engine_prompt
[
"prompt_token_ids"
]
prompt_ids
=
request
.
engine_prompt
[
"prompt_token_ids"
]
# rust HTTP requires Delta streaming
# rust HTTP requires Delta streaming
request
.
sampling_params
.
output_kind
=
RequestOutputKind
.
DELTA
request
.
sampling_params
.
output_kind
=
RequestOutputKind
.
DELTA
if
image_features
is
not
None
:
# When using aggregated serving, the encode worker will have provided the key-value cache updates via the prefill worker.
multi_modal_data
=
{
"image"
:
image_features
}
# When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
if
embeddings
is
not
None
:
logger
.
debug
(
"Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
)
multi_modal_data
=
{
"image"
:
embeddings
}
else
:
else
:
logger
.
debug
(
"Disaggregated: no embedding data required as prefill will have provided key-value cache updates via encode worker."
)
multi_modal_data
=
None
multi_modal_data
=
None
async
for
response
in
self
.
engine_client
.
generate
(
async
for
response
in
self
.
engine_client
.
generate
(
...
@@ -269,6 +325,9 @@ class VllmWorker:
...
@@ -269,6 +325,9 @@ class VllmWorker:
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
remote_prefill_params
=
remote_prefill_params
,
remote_prefill_params
=
remote_prefill_params
,
):
):
logger
.
debug
(
f
"Yeilding response {{ id:
{
response
.
request_id
}
, prompt: '
{
response
.
prompt
}
' }}"
)
yield
MyRequestOutput
(
yield
MyRequestOutput
(
request_id
=
response
.
request_id
,
request_id
=
response
.
request_id
,
prompt
=
response
.
prompt
,
prompt
=
response
.
prompt
,
...
...
examples/multimodal/components/encode_worker.py
View file @
75e774d4
...
@@ -15,8 +15,10 @@
...
@@ -15,8 +15,10 @@
import
logging
import
logging
from
io
import
BytesIO
from
io
import
BytesIO
from
queue
import
Queue
from
typing
import
AsyncIterator
from
typing
import
AsyncIterator
import
connect
import
requests
import
requests
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
...
@@ -24,10 +26,25 @@ from transformers import AutoImageProcessor, LlavaForConditionalGeneration
...
@@ -24,10 +26,25 @@ from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from
utils.protocol
import
EncodeRequest
,
EncodeResponse
from
utils.protocol
import
EncodeRequest
,
EncodeResponse
from
utils.vllm
import
parse_vllm_args
from
utils.vllm
import
parse_vllm_args
from
dynamo.sdk
import
endpoint
,
service
from
dynamo.sdk
import
async_on_start
,
endpoint
,
service
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
try
:
import
cupy
as
array_module
if
not
array_module
.
cuda
.
is_available
():
raise
ImportError
(
"CUDA is not available."
)
DEVICE
=
"cuda"
logger
.
info
(
"Using cupy for array operations (GPU mode)."
)
except
ImportError
as
e
:
logger
.
warning
(
f
"Failed to import cupy, falling back to numpy:
{
e
}
."
)
import
numpy
as
array_module
DEVICE
=
"cpu"
CACHE_SIZE_MAXIMUM
=
8
@
service
(
@
service
(
dynamo
=
{
dynamo
=
{
...
@@ -36,7 +53,7 @@ logger = logging.getLogger(__name__)
...
@@ -36,7 +53,7 @@ logger = logging.getLogger(__name__)
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
workers
=
1
,
workers
=
1
,
)
)
class
EncodeWorker
:
class
Vllm
EncodeWorker
:
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
class_name
=
self
.
__class__
.
__name__
class_name
=
self
.
__class__
.
__name__
self
.
engine_args
=
parse_vllm_args
(
class_name
,
""
)
self
.
engine_args
=
parse_vllm_args
(
class_name
,
""
)
...
@@ -50,9 +67,50 @@ class EncodeWorker:
...
@@ -50,9 +67,50 @@ class EncodeWorker:
self
.
MODEL_ID
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
float16
self
.
MODEL_ID
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
float16
).
eval
()
).
eval
()
self
.
_image_cache
:
dict
[
str
,
Image
.
Image
]
=
{}
self
.
_cache_queue
:
Queue
[
str
]
=
Queue
(
maxsize
=
CACHE_SIZE_MAXIMUM
)
@
endpoint
()
@
endpoint
()
async
def
encode
(
self
,
request
:
EncodeRequest
)
->
AsyncIterator
[
EncodeResponse
]:
async
def
encode
(
self
,
request
:
EncodeRequest
)
->
AsyncIterator
[
EncodeResponse
]:
image
=
self
.
open_image
(
request
.
image_url
)
logger
.
debug
(
f
"Received encode request: {{ id:
{
request
.
request_id
}
, image_url: '
{
request
.
image_url
}
' }}."
)
request_id
=
request
.
request_id
image_url
=
request
.
image_url
.
lower
()
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
# Either retrieve the image from the cache or download it and then cache it.
if
request
.
image_url
in
self
.
_image_cache
:
image
=
self
.
_image_cache
[
image_url
]
logger
.
debug
(
f
"Image found in cache for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
else
:
image
=
self
.
open_image
(
image_url
)
logger
.
debug
(
f
"Downloading/opening image for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
# Cache the image for future use, and evict the oldest image if the cache is full.
if
self
.
_cache_queue
.
full
():
oldest_image_url
=
self
.
_cache_queue
.
get
()
del
self
.
_image_cache
[
oldest_image_url
]
self
.
_image_cache
[
request
.
image_url
]
=
image
self
.
_cache_queue
.
put
(
request
.
image_url
)
logger
.
debug
(
f
"Processing image for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}"
)
image_embeds
=
self
.
image_processor
(
images
=
image
,
return_tensors
=
"pt"
)
image_embeds
=
self
.
image_processor
(
images
=
image
,
return_tensors
=
"pt"
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -60,22 +118,56 @@ class EncodeWorker:
...
@@ -60,22 +118,56 @@ class EncodeWorker:
vision_outputs
=
self
.
vision_model
.
vision_tower
(
vision_outputs
=
self
.
vision_model
.
vision_tower
(
image_embeds
[
"pixel_values"
].
to
(
self
.
vision_model
.
device
)
image_embeds
[
"pixel_values"
].
to
(
self
.
vision_model
.
device
)
)
)
logger
.
debug
(
"Vision model completed."
)
embeddings
=
vision_outputs
.
last_hidden_state
embeddings
=
self
.
vision_model
.
multi_modal_projector
(
embeddings
)
logger
.
debug
(
f
"Embeddings: {{ shape:
{
embeddings
.
shape
}
, dtype:
{
embeddings
.
dtype
}
, device:
{
embeddings
.
device
}
, ptr:
{
embeddings
.
data_ptr
()
}
, elements: {{ count:
{
embeddings
.
numel
()
}
, size:
{
embeddings
.
element_size
()
}
}} }}."
)
if
request
.
serialized_request
is
None
:
logger
.
error
(
f
"Request serialized_request is None for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
# Create a descriptor for the embeddings, this will register the memory with the connector (and the NIXL runtime).
descriptor
=
connect
.
Descriptor
(
embeddings
)
# Create a write operation using the serialized request and the descriptor.
# This will begin the RDMA transfer of the embeddings to the remote worker.
write_op
=
await
self
.
_connector
.
begin_write
(
descriptor
,
request
.
serialized_request
,
)
# Await for the write operation to complete.
# This will block until the data has been written to the remote worker or an error occurs.
await
write_op
.
wait_for_completion
()
image_features
=
vision_outputs
.
last_hidden_state
image_features
=
self
.
vision_model
.
multi_modal_projector
(
image_features
)
yield
EncodeResponse
(
yield
EncodeResponse
(
image_features
=
image_features
.
tolist
()
request_id
=
request
.
request_id
,
).
model_dump_json
()
).
model_dump_json
()
@
async_on_start
()
async
def
on_start
(
self
):
logger
.
info
(
"Startup started."
)
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self
.
_connector
=
connect
.
Connector
()
await
self
.
_connector
.
initialize
()
logger
.
info
(
"Startup completed."
)
def
open_image
(
self
,
image
:
str
)
->
Image
.
Image
:
def
open_image
(
self
,
image
:
str
)
->
Image
.
Image
:
# TODO: Have a seperate field for url and non url - and avoid auto detection
# TODO: Have a seperate field for url and non url - and avoid auto detection
try
:
try
:
# Acquire the image and convert it to the format (RGB) the image processor model expects.
if
image
.
startswith
(
"http"
)
or
image
.
startswith
(
"https"
):
if
image
.
startswith
(
"http"
)
or
image
.
startswith
(
"https"
):
response
=
requests
.
get
(
image
)
response
=
requests
.
get
(
image
)
image_data
=
Image
.
open
(
BytesIO
(
response
.
content
)).
convert
(
"RGB"
)
image_data
=
Image
.
open
(
BytesIO
(
response
.
content
)).
convert
(
"RGB"
)
else
:
else
:
image_data
=
Image
.
open
(
image
).
convert
(
"RGB"
)
image_data
=
Image
.
open
(
image
).
convert
(
"RGB"
)
return
image_data
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error opening image:
{
e
}
"
)
logger
.
error
(
f
"Error opening image:
{
e
}
"
)
raise
e
raise
e
return
image_data
examples/multimodal/components/prefill_worker.py
View file @
75e774d4
...
@@ -20,8 +20,9 @@ import os
...
@@ -20,8 +20,9 @@ import os
import
signal
import
signal
import
sys
import
sys
import
connect
import
torch
import
torch
from
components.encode_worker
import
EncodeWorker
from
components.encode_worker
import
Vllm
EncodeWorker
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
utils.logging
import
check_required_workers
from
utils.logging
import
check_required_workers
from
utils.nixl
import
NixlMetadataStore
from
utils.nixl
import
NixlMetadataStore
...
@@ -38,6 +39,11 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic
...
@@ -38,6 +39,11 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Constants for the shape and dtype of the embeddings tensor.
EMBEDDINGS_SHAPE
=
(
1
,
577
,
4096
)
EMBEDDINGS_DTYPE
=
torch
.
float16
EMBEDDINGS_DEVICE
=
"cuda"
class
RequestType
(
BaseModel
):
class
RequestType
(
BaseModel
):
text
:
str
text
:
str
...
@@ -50,8 +56,8 @@ class RequestType(BaseModel):
...
@@ -50,8 +56,8 @@ class RequestType(BaseModel):
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
workers
=
1
,
workers
=
1
,
)
)
class
PrefillWorker
:
class
Vllm
PrefillWorker
:
encode_worker
=
depends
(
EncodeWorker
)
encode_worker
=
depends
(
Vllm
EncodeWorker
)
def
__init__
(
self
):
def
__init__
(
self
):
class_name
=
self
.
__class__
.
__name__
class_name
=
self
.
__class__
.
__name__
...
@@ -95,7 +101,7 @@ class PrefillWorker:
...
@@ -95,7 +101,7 @@ class PrefillWorker:
raise
RuntimeError
(
"Failed to initialize engine client"
)
raise
RuntimeError
(
"Failed to initialize engine client"
)
runtime
=
dynamo_context
[
"runtime"
]
runtime
=
dynamo_context
[
"runtime"
]
enc_comp_ns
,
enc_comp_name
=
EncodeWorker
.
dynamo_address
()
# type: ignore
enc_comp_ns
,
enc_comp_name
=
Vllm
EncodeWorker
.
dynamo_address
()
# type: ignore
self
.
encode_worker_client
=
(
self
.
encode_worker_client
=
(
await
runtime
.
namespace
(
enc_comp_ns
)
await
runtime
.
namespace
(
enc_comp_ns
)
.
component
(
enc_comp_name
)
.
component
(
enc_comp_name
)
...
@@ -103,6 +109,20 @@ class PrefillWorker:
...
@@ -103,6 +109,20 @@ class PrefillWorker:
.
client
()
.
client
()
)
)
self
.
_connector
=
connect
.
Connector
(
runtime
=
runtime
,
namespace
=
enc_comp_ns
)
await
self
.
_connector
.
initialize
()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings
=
torch
.
empty
(
EMBEDDINGS_SHAPE
,
dtype
=
EMBEDDINGS_DTYPE
,
device
=
EMBEDDINGS_DEVICE
,
)
descriptor
=
connect
.
Descriptor
(
embeddings
)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor
.
register_memory
(
self
.
_connector
)
self
.
_embeddings_descriptor
=
(
embeddings
,
descriptor
)
await
check_required_workers
(
self
.
encode_worker_client
,
self
.
min_workers
)
await
check_required_workers
(
self
.
encode_worker_client
,
self
.
min_workers
)
metadata
=
self
.
engine_client
.
nixl_metadata
metadata
=
self
.
engine_client
.
nixl_metadata
...
@@ -119,19 +139,19 @@ class PrefillWorker:
...
@@ -119,19 +139,19 @@ class PrefillWorker:
sys
.
exit
(
1
)
sys
.
exit
(
1
)
task
.
add_done_callback
(
prefill_queue_handler_cb
)
task
.
add_done_callback
(
prefill_queue_handler_cb
)
logger
.
info
(
"
PrefillWorker initialized
"
)
logger
.
info
(
"
Initialization complete.
"
)
def
shutdown_vllm_engine
(
self
,
signum
,
frame
):
def
shutdown_vllm_engine
(
self
,
signum
,
frame
):
"""Shutdown the background loop"""
"""Shutdown the background loop"""
logger
.
info
(
f
"
Receiv
ed signal
{
signum
}
, shutting down
"
)
logger
.
info
(
f
"
Shutdown start
ed
,
signal
{
signum
}
received.
"
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
try
:
try
:
self
.
engine_client
.
close
()
self
.
engine_client
.
close
()
logger
.
info
(
"PrefillWorker shutdown complete"
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error during shutdown:
{
e
}
"
)
logger
.
error
(
f
"Error during shutdown:
{
e
}
"
)
finally
:
finally
:
loop
.
stop
()
loop
.
stop
()
logger
.
info
(
"Shutdown complete."
)
async
def
prefill_queue_handler
(
self
):
async
def
prefill_queue_handler
(
self
):
logger
.
info
(
"Prefill queue handler entered"
)
logger
.
info
(
"Prefill queue handler entered"
)
...
@@ -166,62 +186,88 @@ class PrefillWorker:
...
@@ -166,62 +186,88 @@ class PrefillWorker:
if
request
.
multimodal_data_source
[
"image_url"
]
is
None
:
if
request
.
multimodal_data_source
[
"image_url"
]
is
None
:
raise
ValueError
(
"No image url provided for prefill request"
)
raise
ValueError
(
"No image url provided for prefill request"
)
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
request_id
=
request
.
request_id
EncodeRequest
(
engine_id
=
request
.
engine_id
image_url
=
request
.
multimodal_data_source
[
"image_url"
],
image_url
=
request
.
multimodal_data_source
[
"image_url"
]
).
model_dump_json
()
logger
.
info
(
f
"Received prefill request {{ id:
{
request_id
}
, engine_id:
{
engine_id
}
, image_url: '
{
image_url
}
' }}."
)
)
async
for
encode_response
in
encode_generator
:
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
())
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
image_features
=
torch
.
tensor
(
# Doing this avoids unnessesary memory de/registration with NIXL.
encode_output
.
image_features
,
device
=
"cpu"
,
dtype
=
torch
.
float16
embeddings
,
descriptor
=
self
.
_embeddings_descriptor
# Create a new writable operation from the descriptor.
with
self
.
_connector
.
create_writable
(
descriptor
)
as
writable
:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
EncodeRequest
(
request_id
=
request_id
,
image_url
=
image_url
,
serialized_request
=
writable
.
to_serialized
(),
).
model_dump_json
()
)
)
async
for
encode_response
in
encode_generator
:
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
(),
)
logger
.
debug
(
f
"Received response: {{ id:
{
encode_output
.
request_id
}
}}."
)
sampling_params
=
request
.
sampling_params
# Wait for the write operation to complete.
sampling_params
.
max_tokens
=
1
# This will block until the write operation is complete.
sampling_params
.
min_tokens
=
1
# This await should be a no-op since we've already received a response from the encode worker.
await
writable
.
wait_for_completion
()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params
=
RemotePrefillParams
(
sampling_params
=
request
.
sampling_params
is_remote_decode
=
True
,
sampling_params
.
max_tokens
=
1
decode_block_ids
=
request
.
block_ids
,
sampling_params
.
min_tokens
=
1
decode_engine_id
=
request
.
engine_id
,
decode_computed_block_ids
=
request
.
computed_block_ids
,
)
# TODO check if metadata has changed
remote_prefill_params
=
RemotePrefillParams
(
# and reload - currently only loading once
is_remote_decode
=
True
,
if
request
.
engine_id
not
in
self
.
_loaded_metadata
:
decode_block_ids
=
request
.
block_ids
,
remote_metadata
=
await
self
.
_metadata_store
.
get
(
request
.
engine_id
)
decode_engine_id
=
engine_id
,
await
self
.
engine_client
.
add_remote_nixl_metadata
(
remote_metadata
)
decode_computed_block_ids
=
request
.
computed_block_ids
,
logger
.
info
(
)
f
"Loaded nixl metadata from engine
{
request
.
engine_id
}
into "
f
"engine
{
self
.
engine_client
.
nixl_metadata
.
engine_id
}
"
# TODO check if metadata has changed
# and reload - currently only loading once
if
engine_id
not
in
self
.
_loaded_metadata
:
remote_metadata
=
await
self
.
_metadata_store
.
get
(
request
.
engine_id
)
await
self
.
engine_client
.
add_remote_nixl_metadata
(
remote_metadata
)
logger
.
info
(
f
"Loaded nixl metadata from engine
{
engine_id
}
into "
f
"engine
{
self
.
engine_client
.
nixl_metadata
.
engine_id
}
"
)
self
.
_loaded_metadata
.
add
(
engine_id
)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID
=
32000
embedding_size
=
embeddings
.
shape
[
1
]
padding_size
=
embedding_size
-
1
image_token_index
=
request
.
prompt_token_ids
.
index
(
IMAGE_TOKEN_ID
)
dummy_token_index
=
image_token_index
+
1
prompt_token_ids
=
(
request
.
prompt_token_ids
[:
dummy_token_index
]
+
request
.
prompt_token_ids
[
dummy_token_index
+
padding_size
:]
)
)
self
.
_loaded_metadata
.
add
(
request
.
engine_id
)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID
=
32000
embedding_size
=
image_features
.
shape
[
1
]
padding_size
=
embedding_size
-
1
image_token_index
=
request
.
prompt_token_ids
.
index
(
IMAGE_TOKEN_ID
)
dummy_token_index
=
image_token_index
+
1
prompt_token_ids
=
(
request
.
prompt_token_ids
[:
dummy_token_index
]
+
request
.
prompt_token_ids
[
dummy_token_index
+
padding_size
:]
)
async
for
_
in
self
.
engine_client
.
generate
(
async
for
_
in
self
.
engine_client
.
generate
(
request_id
=
request
.
request
_id
,
request_id
=
request_id
,
prompt
=
TokensPrompt
(
prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
{
"image"
:
image_feature
s
},
multi_modal_data
=
{
"image"
:
embedding
s
},
),
),
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
remote_prefill_params
=
remote_prefill_params
,
remote_prefill_params
=
remote_prefill_params
,
):
):
yield
yield
@
endpoint
()
@
endpoint
()
async
def
mock
(
self
,
req
:
RequestType
):
async
def
mock
(
self
,
req
:
RequestType
):
...
...
examples/multimodal/components/processor.py
View file @
75e774d4
...
@@ -19,7 +19,7 @@ import uuid
...
@@ -19,7 +19,7 @@ import uuid
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
AsyncIterator
,
Tuple
,
Union
from
typing
import
AsyncIterator
,
Tuple
,
Union
from
components.worker
import
VllmWorker
from
components.
decode_
worker
import
Vllm
Decode
Worker
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
utils.chat_processor
import
ChatProcessor
,
CompletionsProcessor
,
ProcessMixIn
from
utils.chat_processor
import
ChatProcessor
,
CompletionsProcessor
,
ProcessMixIn
from
utils.logging
import
check_required_workers
from
utils.logging
import
check_required_workers
...
@@ -53,7 +53,7 @@ class Processor(ProcessMixIn):
...
@@ -53,7 +53,7 @@ class Processor(ProcessMixIn):
vLLM pre and post processing
vLLM pre and post processing
"""
"""
worker
=
depends
(
VllmWorker
)
worker
=
depends
(
Vllm
Decode
Worker
)
def
__init__
(
self
):
def
__init__
(
self
):
class_name
=
self
.
__class__
.
__name__
class_name
=
self
.
__class__
.
__name__
...
@@ -83,7 +83,7 @@ class Processor(ProcessMixIn):
...
@@ -83,7 +83,7 @@ class Processor(ProcessMixIn):
@
async_on_start
@
async_on_start
async
def
async_init
(
self
):
async
def
async_init
(
self
):
runtime
=
dynamo_context
[
"runtime"
]
runtime
=
dynamo_context
[
"runtime"
]
comp_ns
,
comp_name
=
VllmWorker
.
dynamo_address
()
# type: ignore
comp_ns
,
comp_name
=
Vllm
Decode
Worker
.
dynamo_address
()
# type: ignore
self
.
worker_client
=
(
self
.
worker_client
=
(
await
runtime
.
namespace
(
comp_ns
)
await
runtime
.
namespace
(
comp_ns
)
.
component
(
comp_name
)
.
component
(
comp_name
)
...
...
examples/multimodal/configs/agg.yaml
View file @
75e774d4
...
@@ -21,7 +21,7 @@ Processor:
...
@@ -21,7 +21,7 @@ Processor:
router
:
round-robin
router
:
round-robin
common-configs
:
[
model
,
block-size
,
max-model-len
]
common-configs
:
[
model
,
block-size
,
max-model-len
]
VllmWorker
:
Vllm
Decode
Worker
:
enforce-eager
:
true
enforce-eager
:
true
max-num-batched-tokens
:
16384
max-num-batched-tokens
:
16384
enable-prefix-caching
:
true
enable-prefix-caching
:
true
...
@@ -33,7 +33,7 @@ VllmWorker:
...
@@ -33,7 +33,7 @@ VllmWorker:
gpu
:
1
gpu
:
1
common-configs
:
[
model
,
block-size
,
max-model-len
]
common-configs
:
[
model
,
block-size
,
max-model-len
]
EncodeWorker
:
Vllm
EncodeWorker
:
tensor-parallel-size
:
1
tensor-parallel-size
:
1
router
:
random
router
:
random
ServiceArgs
:
ServiceArgs
:
...
...
examples/multimodal/configs/disagg.yaml
View file @
75e774d4
...
@@ -22,7 +22,7 @@ Processor:
...
@@ -22,7 +22,7 @@ Processor:
router
:
round-robin
router
:
round-robin
common-configs
:
[
model
,
block-size
]
common-configs
:
[
model
,
block-size
]
VllmWorker
:
Vllm
Decode
Worker
:
remote-prefill
:
true
remote-prefill
:
true
conditional-disagg
:
true
conditional-disagg
:
true
max-local-prefill-length
:
10
max-local-prefill-length
:
10
...
@@ -33,7 +33,7 @@ VllmWorker:
...
@@ -33,7 +33,7 @@ VllmWorker:
gpu
:
1
gpu
:
1
common-configs
:
[
model
,
block-size
,
max-model-len
,
kv-transfer-config
]
common-configs
:
[
model
,
block-size
,
max-model-len
,
kv-transfer-config
]
PrefillWorker
:
Vllm
PrefillWorker
:
max-num-batched-tokens
:
16384
max-num-batched-tokens
:
16384
ServiceArgs
:
ServiceArgs
:
workers
:
1
workers
:
1
...
@@ -41,7 +41,7 @@ PrefillWorker:
...
@@ -41,7 +41,7 @@ PrefillWorker:
gpu
:
1
gpu
:
1
common-configs
:
[
model
,
block-size
,
max-model-len
,
kv-transfer-config
]
common-configs
:
[
model
,
block-size
,
max-model-len
,
kv-transfer-config
]
EncodeWorker
:
Vllm
EncodeWorker
:
tensor-parallel-size
:
1
tensor-parallel-size
:
1
router
:
random
router
:
random
ServiceArgs
:
ServiceArgs
:
...
...
examples/multimodal/connect/__init__.py
0 → 100644
View file @
75e774d4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
asyncio
import
logging
import
socket
import
uuid
import
zlib
from
abc
import
ABC
,
abstractmethod
from
enum
import
IntEnum
from
functools
import
cached_property
from
typing
import
Any
,
List
,
Optional
import
nixl._api
as
nixl_api
import
nixl._bindings
as
nixl_bindings
import
torch
from
pydantic
import
BaseModel
,
ConfigDict
,
field_validator
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.sdk
import
dynamo_context
logger
=
logging
.
getLogger
(
__name__
)
try
:
import
cupy
as
array_module
from
cupy_backends.cuda.api.runtime
import
CUDARuntimeError
logger
.
info
(
"Utilizing cupy to enable GPU acceleration."
)
except
ImportError
:
try
:
import
numpy
as
array_module
logger
.
warning
(
"Failed to load cupy for GPU acceleration, utilizing numpy to provide CPU based operations."
)
except
ImportError
as
e
:
raise
ImportError
(
"Numpy or cupy must be installed to use this module."
)
from
e
class
AbstractOperation
(
ABC
):
"""
Abstract base class for awaitable NIXL based RDMA operations.
"""
def
__init__
(
self
,
connector
:
Connector
,
operation_kind
:
OperationKind
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
remote_descriptors
:
Optional
[
Descriptor
|
list
[
Descriptor
]],
notification_key
:
Optional
[
str
],
)
->
None
:
if
not
isinstance
(
connector
,
Connector
):
raise
TypeError
(
"Argument `connector` must be `dynamo.connect.Connector`."
)
if
operation_kind
is
not
OperationKind
.
READ
and
operation_kind
is
not
OperationKind
.
WRITE
:
raise
ValueError
(
"Argument `operation_kind` must be either `READ` or `WRITE`."
)
if
not
(
isinstance
(
local_descriptors
,
(
Descriptor
,
list
))
or
(
isinstance
(
local_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
local_descriptors
))
):
raise
TypeError
(
"Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`."
)
if
(
remote_descriptors
is
not
None
and
not
(
isinstance
(
remote_descriptors
,
Descriptor
)
or
(
isinstance
(
remote_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
remote_descriptors
))
)
):
raise
TypeError
(
"Argument `remote_descriptors` must be dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`, or `None`."
)
if
isinstance
(
local_descriptors
,
list
)
and
len
(
local_descriptors
)
==
0
:
raise
ValueError
(
"Argument `local_descriptors` must not be an empty list."
)
if
(
remote_descriptors
is
not
None
and
isinstance
(
remote_descriptors
,
list
)
and
len
(
remote_descriptors
)
==
0
):
raise
ValueError
(
"Argument `remote_descriptors` must not be an empty list."
)
notification_key
=
str
(
uuid
.
uuid4
())
if
notification_key
is
None
else
notification_key
if
not
isinstance
(
notification_key
,
str
):
raise
TypeError
(
"Argument `notification_key` must be `str` or `None`."
)
if
len
(
notification_key
)
==
0
:
raise
ValueError
(
"Argument `notification_key` must not be an empty string."
)
self
.
_notification_key
:
str
=
""
if
notification_key
is
None
else
notification_key
self
.
_connector
:
Connector
=
connector
self
.
_operation_kind
:
OperationKind
=
operation_kind
self
.
_local_descriptors
:
Descriptor
|
list
[
Descriptor
]
=
local_descriptors
self
.
_local_dlist
:
Optional
[
list
[
tuple
[
int
,
int
,
int
]]]
=
None
self
.
_local_memtype
:
DeviceKind
=
DeviceKind
.
UNSPECIFIED
self
.
_remote_descriptors
:
Optional
[
Descriptor
|
list
[
Descriptor
]]
=
None
if
remote_descriptors
is
None
else
remote_descriptors
self
.
_remote_dlist
:
Optional
[
list
[
tuple
[
int
,
int
,
int
]]]
=
None
self
.
_remote_memtype
:
DeviceKind
=
DeviceKind
.
UNSPECIFIED
# Register local descriptors with NIXL.
# Note: Only local descriptors should be registered with NIXL,
if
isinstance
(
local_descriptors
,
list
):
for
d
in
local_descriptors
:
d
.
register_memory
(
self
.
_connector
)
else
:
local_descriptors
.
register_memory
(
self
.
_connector
)
# Record local descriptors.
memtype
,
dtlist
=
self
.
_create_dlist
(
local_descriptors
)
self
.
_local_dlist
=
dtlist
self
.
_local_memtype
=
memtype
# Record remote descriptors when provided.
if
remote_descriptors
is
not
None
:
memtype
,
dtlist
=
self
.
_create_dlist
(
remote_descriptors
)
self
.
_remote_dlist
=
dtlist
self
.
_remote_memtype
=
memtype
def
__del__
(
self
)
->
None
:
self
.
_release
()
def
__enter__
(
self
)
->
AbstractOperation
:
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
self
.
_release
()
def
_release
(
self
)
->
None
:
"""
Private method to release resources. Only to be called by `self`.
"""
pass
@
property
def
connector
(
self
)
->
Connector
:
"""
Gets the local associated with this operation.
"""
return
self
.
_connector
@
property
def
operation_kind
(
self
)
->
OperationKind
:
"""
Gets the kind of operation.
"""
return
self
.
_operation_kind
@
abstractmethod
async
def
wait_for_completion
(
self
)
->
None
:
"""
Blocks the caller asynchronously until the operation has completed.
"""
raise
NotImplementedError
(
"Abstract method not implemented by derived class."
)
# Private Methods
def
_create_dlist
(
self
,
descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
tuple
[
DeviceKind
,
list
[
tuple
[
int
,
int
,
int
]]]:
"""
Helper function to create a list of tuples (ptr, size, device) from descriptors.
"""
dlist
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
memtype
:
DeviceKind
=
DeviceKind
.
UNSPECIFIED
if
isinstance
(
descriptors
,
list
):
memtype
=
descriptors
[
0
].
device
.
kind
for
desc
in
descriptors
:
if
memtype
!=
desc
.
device
.
kind
:
raise
ValueError
(
"All local descriptors must have the same memory type."
)
dlist
.
append
((
desc
.
ptr
,
desc
.
size
,
desc
.
device
.
id
))
else
:
memtype
=
descriptors
.
device
.
kind
dlist
.
append
((
descriptors
.
ptr
,
descriptors
.
size
,
descriptors
.
device
.
id
))
return
(
memtype
,
dlist
)
class
ActiveOperation
(
AbstractOperation
):
"""
Abstract class for active operations that initiates a NIXL based RDMA transfer based `SerializedRequest`
provided by the remote worker's corresponding `PassiveOperation`.
"""
def
__init__
(
self
,
remote
:
Remote
,
operation_kind
:
OperationKind
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
remote_descriptors
:
Descriptor
|
list
[
Descriptor
],
notification_key
:
str
,
)
->
None
:
if
not
isinstance
(
remote
,
Remote
)
or
remote
.
_connector
is
None
:
raise
TypeError
(
"Argument `remote` must be valid `dynamo.connect.RemoteAgent`."
)
if
not
isinstance
(
operation_kind
,
OperationKind
):
raise
TypeError
(
"Argument `operation_kind` must `dynamo.connect.OperationKind`."
)
if
operation_kind
is
not
OperationKind
.
READ
and
operation_kind
is
not
OperationKind
.
WRITE
:
raise
ValueError
(
"Argument `operation_kind` must be either `READ` or `WRITE`."
)
if
not
(
isinstance
(
local_descriptors
,
Descriptor
)
or
(
isinstance
(
local_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
local_descriptors
))
):
raise
TypeError
(
"Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`."
)
if
not
(
isinstance
(
remote_descriptors
,
Descriptor
)
or
(
isinstance
(
remote_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
remote_descriptors
))
):
raise
TypeError
(
"Argument `remote_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`."
)
# Unpack single descriptors from lists if they are provided as single descriptors.
if
isinstance
(
local_descriptors
,
list
)
and
len
(
local_descriptors
)
==
1
:
local_descriptors
=
local_descriptors
[
0
]
if
isinstance
(
remote_descriptors
,
list
)
and
len
(
remote_descriptors
)
==
1
:
remote_descriptors
=
remote_descriptors
[
0
]
if
(
isinstance
(
local_descriptors
,
list
)
and
isinstance
(
remote_descriptors
,
list
)
and
len
(
local_descriptors
)
!=
len
(
remote_descriptors
)):
raise
ValueError
(
"When `local_descriptors` and `remote_descriptors` are lists, they must have the same length."
)
elif
isinstance
(
local_descriptors
,
list
)
!=
isinstance
(
remote_descriptors
,
list
):
raise
ValueError
(
"Both `local_descriptors` and `remote_descriptors` must be either lists or single descriptors."
)
if
not
isinstance
(
notification_key
,
str
):
raise
TypeError
(
"Argument `notification_key` must be `str`."
)
if
len
(
notification_key
)
==
0
:
raise
ValueError
(
"Argument `notification_key` must not be an empty string."
)
self
.
_remote
=
remote
self
.
_status
=
OperationStatus
.
UNINTIALIZED
super
().
__init__
(
remote
.
connector
,
operation_kind
,
local_descriptors
,
remote_descriptors
,
notification_key
)
# Quick check to ensure remote descriptors are not None to make static analysis happy.
if
self
.
_local_dlist
is
None
or
self
.
_remote_dlist
is
None
:
raise
RuntimeError
(
"NIXL descriptor list(s) not bound to operation."
)
self
.
_local_xfer_descs
:
Optional
[
nixl_bindings
.
nixlXferDList
]
=
None
self
.
_remote_xfer_descs
:
Optional
[
nixl_bindings
.
nixlXferDList
]
=
None
self
.
_xfer_hndl
:
Optional
[
nixl_api
.
nixl_xfer_handle
]
=
None
self
.
_local_xfer_descs
=
self
.
_connector
.
_nixl
.
get_xfer_descs
(
descs
=
self
.
_local_dlist
,
mem_type
=
str
(
self
.
_local_memtype
),
)
logger
.
debug
(
f
"Created local NIXL xfer descs:
{
self
.
_local_xfer_descs
}
"
)
self
.
_remote_xfer_descs
=
self
.
_connector
.
_nixl
.
get_xfer_descs
(
descs
=
self
.
_remote_dlist
,
mem_type
=
str
(
self
.
_remote_memtype
),
)
logger
.
debug
(
f
"Created remote NIXL xfer descs:
{
self
.
_remote_xfer_descs
}
"
)
self
.
_xfer_hndl
=
self
.
_connector
.
_nixl
.
initialize_xfer
(
operation
=
str
(
operation_kind
),
local_descs
=
self
.
_local_xfer_descs
,
remote_descs
=
self
.
_remote_xfer_descs
,
remote_agent
=
self
.
_remote
.
name
,
notif_msg
=
self
.
_notification_key
.
encode
(
"utf-8"
),
)
logger
.
debug
(
f
"Created NIXL transfer handle:
{
self
.
_xfer_hndl
}
"
)
def
__del__
(
self
)
->
None
:
super
().
__del__
()
self
.
_release
()
def
__enter__
(
self
)
->
ActiveOperation
:
super
().
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
match
self
.
status
:
case
OperationStatus
.
IN_PROGRESS
|
OperationStatus
.
INITIALIZED
:
self
.
_status
=
OperationStatus
.
CANCELLED
self
.
_release
()
def
__repr__
(
self
)
->
str
:
return
str
(
f
"
{
self
.
__class__
.
__name__
}
("
f
"operation_kind=
{
self
.
_operation_kind
}
, "
f
"local_descriptors=
{
self
.
_local_descriptors
}
, "
f
"remote_descriptors=
{
self
.
_remote_descriptors
}
, "
f
"notification_key='
{
self
.
_notification_key
}
', "
f
"remote='
{
self
.
_remote
.
name
}
', "
f
"status='
{
self
.
_status
}
'"
f
")"
)
def
_release
(
self
)
->
None
:
"""
Private method to release resources.
"""
error
:
Optional
[
Exception
]
=
None
if
self
.
_xfer_hndl
is
not
None
:
try
:
logger
.
debug
(
f
"NIXL transfer handle
{
self
.
_xfer_hndl
}
released."
)
self
.
_connector
.
_nixl
.
release_xfer_handle
(
self
.
_xfer_hndl
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to release resources:
{
e
}
"
)
error
=
e
finally
:
self
.
_xfer_hndl
=
None
try
:
super
().
_release
()
except
Exception
as
e
:
logger
.
error
(
f
"Failed to release WaitableOperation resources:
{
e
}
"
)
if
error
is
not
None
:
e
.
__cause__
=
error
error
=
e
if
error
is
not
None
:
raise
error
def
_cancel_
(
self
)
->
None
:
if
self
.
_xfer_hndl
is
None
:
return
if
self
.
status
==
OperationStatus
.
ERRORED
:
raise
RuntimeError
(
"Operation is errored, unable to cancel the operation."
)
logger
.
info
(
f
"Cancellation requested for operation {{ kind=
{
self
.
_operation_kind
}
, remote='
{
self
.
_remote
.
name
}
', status=
{
self
.
_status
}
}}."
)
# NIXL will cancel the transfer if it is in progress when the handle is released.
self
.
_connector
.
_nixl
.
release_xfer_handle
(
self
.
_xfer_hndl
)
self
.
_status
=
OperationStatus
.
CANCELLED
self
.
_xfer_hndl
=
None
async
def
_wait_for_completion_
(
self
)
->
None
:
# Loop until the operation is no longer in progress (or "initalized"),
# yielding control to the event loop to allow other operations to run.
iteration_count
=
0
while
True
:
if
iteration_count
&
10
==
0
:
logger
.
debug
(
f
"Waiting for operation {{ kind=
{
self
.
_operation_kind
}
, remote='
{
self
.
_remote
.
name
}
', duration=
{
iteration_count
/
10
}
s }}."
)
match
self
.
status
:
# "in progress" or "initialized" means the operation is ongoing.
case
OperationStatus
.
INITIALIZED
:
await
asyncio
.
sleep
(
0.1
)
case
OperationStatus
.
IN_PROGRESS
:
await
asyncio
.
sleep
(
0.1
)
# Any other state indicates completion or error.
case
_
:
return
@
abstractmethod
def
cancel
(
self
)
->
None
:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or has been cancelled.
"""
raise
NotImplementedError
(
"Abstract method not implemented by derived class."
)
@
property
def
remote
(
self
)
->
Remote
:
"""
Gets the remote agent associated with this operation.
"""
return
self
.
_remote
@
property
def
status
(
self
)
->
OperationStatus
:
"""
Gets the status of the operation.
"""
# Early return if the operation is already complete, errored, or cancelled.
match
self
.
_status
:
case
OperationStatus
.
COMPLETE
|
OperationStatus
.
ERRORED
|
OperationStatus
.
CANCELLED
:
return
self
.
_status
if
self
.
_xfer_hndl
is
None
:
raise
RuntimeError
(
"NIXL transfer handle is invalid."
)
old_status
=
self
.
_status
if
self
.
_status
==
OperationStatus
.
UNINTIALIZED
:
state
=
self
.
_connector
.
_nixl
.
transfer
(
self
.
_xfer_hndl
,
self
.
_notification_key
.
encode
(
"utf-8"
))
logger
.
debug
(
f
"NIXL reported transfer state:
{
state
}
"
)
if
state
==
"ERR"
:
self
.
_status
=
OperationStatus
.
ERRORED
elif
state
==
"DONE"
:
self
.
_status
=
OperationStatus
.
COMPLETE
else
:
self
.
_status
=
OperationStatus
.
INITIALIZED
else
:
state
=
self
.
_connector
.
_nixl
.
check_xfer_state
(
self
.
_xfer_hndl
)
logger
.
debug
(
f
"NIXL reported transfer state:
{
state
}
"
)
if
state
==
"ERR"
:
self
.
_status
=
OperationStatus
.
ERRORED
elif
state
==
"DONE"
:
self
.
_status
=
OperationStatus
.
COMPLETE
else
:
self
.
_status
=
OperationStatus
.
IN_PROGRESS
if
self
.
_status
!=
old_status
:
logger
.
debug
(
f
"
{
self
.
__class__
.
__name__
}
{{ remote: '
{
self
.
_remote
.
name
}
' status: '
{
old_status
}
' => '
{
self
.
_status
}
' }}."
)
return
self
.
_status
class
Connector
:
"""
Core class for managing the connection between agents in a distributed environment.
Use this class to create readable and writable operations, or read and write data to remote agents.
"""
def
__init__
(
self
,
namespace
:
Optional
[
str
]
=
None
,
runtime
:
Optional
[
DistributedRuntime
]
=
None
,
worker_id
:
Optional
[
str
]
=
None
,
)
->
None
:
"""
Creates a new Connector instance.
Parameters
----------
namespace : Optional[str], optional
Dynamo namespace of the component, defaults to "dynamo" when `None`.
runtime : Optional[DistributedRuntime], optional
Reference the dynamo runtime used by the compenent, attempts to use the current runtime when `None`.
worker_id : Optional[str], optional
Unique identifier of the worker, defaults to a new UUID when `None`.
Raises
------
TypeError
When `namespace` is provied and not of type 'str'.
TypeError
When `runtime` iis provied and not of type `dynamo.runtime.DistributedRuntime`.
TypeError
When `worker_id` is provied and not of type `uuid.UUID`.
"""
namespace
=
"dynamo"
if
namespace
is
None
else
namespace
if
not
isinstance
(
namespace
,
str
):
raise
TypeError
(
"Argument `namespace` must be `str` or `None`."
)
if
dynamo_context
is
not
None
and
"runtime"
in
dynamo_context
:
runtime
=
dynamo_context
[
"runtime"
]
if
runtime
is
None
else
runtime
if
not
isinstance
(
runtime
,
DistributedRuntime
)
or
runtime
is
None
:
raise
TypeError
(
"Argument `runtime` must be `dynamo.runtime.DistributedRuntime` or `None`."
)
worker_id
=
worker_id
if
worker_id
is
not
None
else
str
(
uuid
.
uuid4
())
if
not
isinstance
(
worker_id
,
str
)
or
len
(
worker_id
)
==
0
:
raise
TypeError
(
"Argument `worker_id` must be a non-empty `str` or `None`."
)
self
.
_worker_id
=
worker_id
self
.
_is_initialized
=
False
self
.
_runtime
=
runtime
self
.
_namespace
=
namespace
self
.
_nixl
=
nixl_api
.
nixl_agent
(
self
.
_worker_id
)
self
.
_hostname
=
socket
.
gethostname
()
self
.
_agent_metadata
:
Optional
[
bytes
]
=
None
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
."
)
def
__repr__
(
self
)
->
str
:
return
str
(
f
"
{
self
.
__class__
.
__name__
}
("
f
"worker_id='
{
self
.
_worker_id
}
', "
f
"namespace=
{
self
.
_namespace
}
, "
f
"hostname=
{
self
.
_hostname
}
, "
f
"metadata=<
{
0
if
self
.
_agent_metadata
is
None
else
len
(
self
.
_agent_metadata
)
}
bytes>"
")"
)
def
__str__
(
self
)
->
str
:
return
self
.
_worker_id
@
cached_property
def
is_cuda_available
(
self
)
->
bool
:
# Note: cuda.is_avalailable initializes cuda
# and can't be called when forking subprocesses
# care should be taken to only call it within
# subprocesses or use 'spawn'
try
:
return
array_module
.
cuda
is
not
None
and
array_module
.
cuda
.
is_available
()
except
CUDARuntimeError
:
return
False
@
property
def
metadata
(
self
)
->
bytes
:
"""
Get the metadata of the agent.
"""
return
self
.
_nixl
.
get_agent_metadata
()
@
property
def
name
(
self
)
->
str
|
None
:
"""
Get the name of the agent.
"""
return
self
.
_worker_id
@
property
def
namespace
(
self
)
->
str
:
"""
Get the namespace of the local.
"""
return
self
.
_namespace
@
property
def
runtime
(
self
)
->
DistributedRuntime
:
"""
Get the runtime of the local.
"""
if
self
.
_runtime
is
None
:
raise
RuntimeError
(
"Runtime is not set. This Connector was not initialized with a runtime."
)
return
self
.
_runtime
async
def
begin_read
(
self
,
remote_request
:
SerializedRequest
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
ReadOperation
:
"""
Creates a read operation for fulfilling a remote readable operation.
Parameters
----------
remote_request : SerializedRequest
Serialized request from a remote worker that has created a readable operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to receive data from the remote worker described by `remote_request`.
Returns
-------
ReadOperation
Awaitable read operation that can be used to transfer data from a remote agent.
Raises
------
TypeError
When `remote_request` is not of type `SerializedRequest`.
TypeError
When `local_descriptors` is not of type `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
if
remote_request
is
None
or
not
isinstance
(
remote_request
,
SerializedRequest
):
raise
TypeError
(
"Argument `remote_request` must be `SerializedRequest`."
)
if
not
(
isinstance
(
local_descriptors
,
Descriptor
)
or
(
isinstance
(
local_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
local_descriptors
))
):
raise
TypeError
(
"Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`."
)
if
remote_request
.
operation_kind
!=
OperationKind
.
READ
.
value
:
raise
RuntimeError
(
"Cannot create a `dynamo.connect.ReadOperation` to read from a remote `dynamo.connect.WritableOperation`."
)
if
not
self
.
_is_initialized
:
raise
RuntimeError
(
"Connector not initialized. Call `initialize()` before calling this method."
)
op
=
ReadOperation
(
self
,
remote_request
,
local_descriptors
)
return
op
async
def
begin_write
(
self
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
remote_request
:
SerializedRequest
,
)
->
WriteOperation
:
"""
Creates a write operation for transferring data to a remote agent.
Parameters
----------
remote_request : SerializedRequest
Serialized request from a remote worker that has created a readable operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptors of one or more data objects to be transferred to the remote agent.
"""
if
remote_request
is
None
or
not
isinstance
(
remote_request
,
SerializedRequest
):
raise
TypeError
(
"Argument `remote_request` must be `SerializedRequest`."
)
if
not
(
isinstance
(
local_descriptors
,
Descriptor
)
or
(
isinstance
(
local_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
local_descriptors
))
):
raise
TypeError
(
"Argument `local_descriptors` must be `Descriptor` or `list[Descriptor]`."
)
if
remote_request
.
operation_kind
!=
OperationKind
.
WRITE
:
raise
RuntimeError
(
"Cannot create a `WriteOperation` to write to a remote `ReadableOperation`."
)
if
not
isinstance
(
remote_request
.
nixl_metadata
,
str
):
raise
TypeError
(
"Argument `remote_request.nixl_metadata` must be `str`."
)
if
not
self
.
_is_initialized
:
raise
RuntimeError
(
"Connector not initialized. Call `initialize()` before calling this method."
)
op
=
WriteOperation
(
self
,
local_descriptors
,
remote_request
)
return
op
def
create_readable
(
self
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
ReadableOperation
:
"""
Creates a readable operation for transferring data from a remote agent.
Returns
-------
ReadableOperation
A readable operation that can be used to transfer data from a remote agent.
"""
if
not
self
.
_is_initialized
:
raise
RuntimeError
(
"Connector not initialized. Call `initialize()` before calling this method."
)
op
=
ReadableOperation
(
self
,
local_descriptors
)
return
op
def
create_writable
(
self
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
WritableOperation
:
"""
Creates a writable operation for transferring data to a remote agent.
Returns
-------
WritableOperation
A writable operation that can be used to transfer data to a remote agent.
"""
if
not
self
.
_is_initialized
:
raise
RuntimeError
(
"Connector not initialized. Call `initialize()` before calling this method."
)
op
=
WritableOperation
(
self
,
local_descriptors
)
return
op
async
def
initialize
(
self
)
->
None
:
# Only initialize the connector once.
if
self
.
_is_initialized
:
return
self
.
_is_initialized
=
True
# This method is a no-op for now, in the future it may be used to initialize the connector.
logger
.
debug
(
f
"Initialized Connector {{ name: '
{
self
.
_worker_id
}
', namespace '
{
self
.
_namespace
}
' }} completed."
)
class
Descriptor
:
"""
Memory descriptor that ensures memory is registered w/ NIXL, used for transferring data between workers.
"""
def
__init__
(
self
,
data
:
torch
.
Tensor
|
tuple
[
array_module
.
ndarray
,
Device
|
str
]
|
bytes
|
tuple
[
int
,
int
,
Device
|
str
,
Any
],
)
->
None
:
"""
Memory descriptor for transferring data between agents.
Parameters
----------
data : torch.Tensor | tuple[ndarray, Device|str] | bytes | tuple[int, int, Device|str, Any]
The data to be transferred.
When `torch.Tensor` is provided, the attributes of the tensor will be used to create the descriptor.
When `tuple[ndarray, Device]` is provided, the tuple must contain:
- `ndarray`: The CuPy or NumPy array to be transferred.
- `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu").
When `bytes` is provided, the pointer and size derived from the bytes object and memory type will be assumed to be CPU.
When `tuple[int, int, Device|str, Any]` is provided, the tuple must contain the following elements:
- `int`: Pointer to the data in memory.
- `int`: Size of the data in bytes.
- `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu").
- `Any`: Optional reference to the data (e.g., the original tensor or bytes object).
This is useful for keeping a reference to the data in memory, but it is not required.
Raises
------
ValueError
When `data` is `None`.
TypeError
When `data` is not a valid type (i.e., not `torch.Tensor`, `bytes`, or a valid tuple).
TypeError
When `data` is a tuple but the elements are not of the expected types (i.e., [`ndarray`, `Device|str`] OR [`int`, `int`, `Device|str`, `Any`]).
"""
TYPE_ERROR_MESSAGE
=
"Argument `data` must be `torch.Tensor`, `tuple[ndarray, Device|str]`, `bytes`, or `tuple[int, int, Device|str, Any]`."
if
data
is
None
:
raise
ValueError
(
"Argument `data` cannot be `None`."
)
if
not
(
isinstance
(
data
,
torch
.
Tensor
)
or
isinstance
(
data
,
bytes
)
or
isinstance
(
data
,
tuple
)):
raise
TypeError
(
TYPE_ERROR_MESSAGE
)
self
.
_data_device
:
Device
=
Device
(
"cpu"
)
self
.
_data_ptr
:
int
=
0
self
.
_data_ref
:
Optional
[
Any
]
=
None
self
.
_data_size
:
int
=
0
# Member fields for managing NIXL memory registration.
# Note: ONLY local descriptors should be registered with NIXL,
# remote descriptors do not have a valid memory address and registration will fault.
self
.
_connector
:
Optional
[
Connector
]
=
None
self
.
_nixl_hndl
:
Optional
[
nixl_bindings
.
nixlRegDList
]
=
None
# Initially `None` cached serialized descriptor reference, populated when `to_serialized()` is called.
self
.
_serialized
:
Optional
[
SerializedDescriptor
]
=
None
# Data is `torch.Tensor`.
if
isinstance
(
data
,
torch
.
Tensor
):
self
.
_data_ptr
=
data
.
data_ptr
()
self
.
_data_size
=
data
.
numel
()
*
data
.
element_size
()
if
data
.
is_cuda
:
self
.
_data_device
=
Device
((
DeviceKind
.
CUDA
,
data
.
get_device
()))
self
.
_data_ref
=
data
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
from `torch.Tensor`."
)
# Data is `tuple[ndarray, Device]`.
elif
(
isinstance
(
data
,
tuple
)
and
len
(
data
)
==
2
and
isinstance
(
data
[
0
],
array_module
.
ndarray
)
and
(
isinstance
(
data
[
1
],
Device
)
or
isinstance
(
data
[
1
],
str
))
):
if
hasattr
(
data
[
0
],
"__array_interface__"
):
self
.
_data_ptr
=
data
[
0
].
__array_interface__
[
"data"
][
0
]
elif
hasattr
(
data
[
0
],
"__cuda_array_interface__"
):
self
.
_data_ptr
=
data
[
0
].
__cuda_array_interface__
[
"data"
][
0
]
else
:
raise
TypeError
(
"Argument `data[0]` must be a `ndarray` with a valid array interface."
)
self
.
_data_size
=
data
[
0
].
nbytes
self
.
_data_device
=
data
[
1
]
if
isinstance
(
data
[
1
],
Device
)
else
Device
(
data
[
1
])
self
.
_data_ref
=
data
[
0
]
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
from `tuple[ndarray, Device|str]`."
)
# Data is `bytes`.
elif
isinstance
(
data
,
bytes
):
self
.
_data_ptr
=
id
(
data
)
self
.
_data_size
=
len
(
data
)
self
.
_data_ref
=
data
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
from `bytes`."
)
# Data is `tuple[int, int, Device, dtype, tuple, Any]`.
elif
isinstance
(
data
,
tuple
)
and
len
(
data
)
>=
2
and
isinstance
(
data
[
0
],
int
)
and
isinstance
(
data
[
1
],
int
):
if
len
(
data
)
>=
3
and
not
(
isinstance
(
data
[
2
],
Device
)
or
isinstance
(
data
[
2
],
str
)):
raise
TypeError
(
"Argument `data` must be a `tuple[int, int, Device|str, Any]`."
)
self
.
_data_ptr
=
data
[
0
]
self
.
_data_size
=
data
[
1
]
if
len
(
data
)
>=
3
:
self
.
_data_device
=
data
[
2
]
if
isinstance
(
data
[
2
],
Device
)
else
Device
(
data
[
2
])
self
.
_data_ref
=
data
[
3
]
if
len
(
data
)
>=
4
else
None
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
from `tuple[int, int, Device|str, Any]`."
)
else
:
raise
TypeError
(
TYPE_ERROR_MESSAGE
)
def
__del__
(
self
)
->
None
:
if
self
.
_nixl_hndl
is
not
None
and
self
.
_connector
is
not
None
:
# Unregister the memory with NIXL.
self
.
_connector
.
_nixl
.
deregister_memory
(
self
.
_nixl_hndl
)
self
.
_nixl_hndl
=
None
if
self
.
_data_ref
is
not
None
:
# Release the reference to the data.
del
self
.
_data_ref
logger
.
debug
(
f
"Deleted
{
self
.
__repr__
()
}
."
)
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(
{
self
}
)"
def
__str__
(
self
)
->
str
:
return
f
"ptr=
{
hex
(
self
.
_data_ptr
)
}
, size=
{
self
.
_data_size
}
, device=
{
self
.
_data_device
}
"
@
property
def
device
(
self
)
->
Device
:
"""
Gets the device the of the descriptor.
"""
return
self
.
_data_device
@
property
def
ptr
(
self
)
->
int
:
"""
Gets the pointer of the descriptor.
"""
return
self
.
_data_ptr
@
property
def
size
(
self
)
->
int
:
"""
Gets the size of the descriptor.
"""
return
self
.
_data_size
@
staticmethod
def
from_serialized
(
serialized
:
SerializedDescriptor
,
)
->
Descriptor
:
"""
Deserializes a `SerializedDescriptor` into a `Descriptor` object.
Parameters
----------
serialized : SerializedDescriptor
The serialized descriptor to deserialize.
Returns
-------
Descriptor
The deserialized descriptor.
"""
if
not
isinstance
(
serialized
,
SerializedDescriptor
):
raise
TypeError
(
"Argument `serialized` must be `SerializedDescriptor`."
)
return
serialized
.
to_descriptor
()
def
register_memory
(
self
,
connector
:
Connector
,
)
->
None
:
"""
Registers the memory of the descriptor with NIXL.
"""
if
not
isinstance
(
connector
,
Connector
):
raise
TypeError
(
"Argument `connector` must be `dynamo.connect.Connector`."
)
if
self
.
_data_ptr
==
0
:
raise
ValueError
(
"Cannot register memory with a null pointer."
)
if
not
(
self
.
_nixl_hndl
is
None
and
self
.
_connector
is
None
):
return
# Register the memory with NIXL.
self
.
_connector
=
connector
if
isinstance
(
self
.
_data_ref
,
torch
.
Tensor
):
self
.
_nixl_hndl
=
connector
.
_nixl
.
register_memory
(
self
.
_data_ref
)
else
:
mem_type
=
str
(
self
.
_data_device
.
kind
)
reg_list
=
[(
self
.
_data_ptr
,
self
.
_data_size
,
self
.
_data_device
.
id
,
mem_type
)]
self
.
_nixl_hndl
=
connector
.
_nixl
.
register_memory
(
reg_list
,
mem_type
)
logger
.
debug
(
f
"Registered
{
self
.
__repr__
()
}
with NIXL."
)
def
to_serialized
(
self
)
->
SerializedDescriptor
:
"""
Serializes the descriptor into a `SerializedDescriptor` object.
"""
if
self
.
_serialized
is
None
:
self
.
_serialized
=
SerializedDescriptor
(
device
=
f
"
{
self
.
_data_device
}
"
,
ptr
=
self
.
_data_ptr
,
size
=
self
.
_data_size
,
)
return
self
.
_serialized
class
Device
:
"""
Represents a device in the system.
"""
def
__init__
(
self
,
metadata
:
str
|
tuple
[
DeviceKind
,
int
],
)
->
None
:
if
metadata
is
None
:
raise
ValueError
(
"Argument `metadata` cannot be `None`."
)
if
isinstance
(
metadata
,
tuple
)
and
len
(
metadata
)
==
2
and
isinstance
(
metadata
[
0
],
DeviceKind
)
and
isinstance
(
metadata
[
1
],
int
):
kind
,
device_id
=
metadata
elif
isinstance
(
metadata
,
str
):
metadata
=
metadata
.
strip
().
lower
()
if
metadata
.
startswith
(
"cuda"
)
or
metadata
.
startswith
(
"gpu"
):
kind
=
DeviceKind
.
CUDA
device_id
=
0
if
metadata
.
find
(
":"
)
==
-
1
else
int
(
metadata
.
split
(
":"
)[
1
])
elif
metadata
.
startswith
(
"cpu"
)
or
metadata
.
startswith
(
"host"
):
kind
=
DeviceKind
.
HOST
device_id
=
0
else
:
raise
ValueError
(
"Argument `metadata` must be in the format 'cuda:<device_id>' or 'cpu'."
)
else
:
raise
TypeError
(
"Argument `metadata` must be a `tuple[MemoryKind, int]` or a `str`."
)
self
.
_device_id
=
device_id
self
.
_kind
=
kind
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(kind=
{
self
.
_kind
}
, id=
{
self
.
_device_id
}
)"
def
__str__
(
self
)
->
str
:
return
f
"
{
self
.
_kind
}
:
{
self
.
_device_id
}
"
if
self
.
_kind
is
DeviceKind
.
CUDA
else
f
"
{
self
.
_kind
}
"
@
property
def
id
(
self
)
->
int
:
"""
Gets the device ID of the device.
"""
return
self
.
_device_id
@
property
def
kind
(
self
)
->
DeviceKind
:
"""
Gets the memory kind of the device.
"""
return
self
.
_kind
class
DeviceKind
(
IntEnum
):
"""
Type of memory a descriptor has been allocated to.
"""
UNSPECIFIED
=
0
HOST
=
1
CUDA
=
2
def
__str__
(
self
)
->
str
:
if
self
==
DeviceKind
.
HOST
:
return
"cpu"
elif
self
==
DeviceKind
.
CUDA
:
return
"cuda"
else
:
return
"<invalid>"
class
OperationKind
(
IntEnum
):
"""
Kind of an operation.
"""
UNSPECIFIED
=
0
READ
=
1
WRITE
=
2
def
__str__
(
self
)
->
str
:
if
self
==
OperationKind
.
READ
:
return
"READ"
elif
self
==
OperationKind
.
WRITE
:
return
"WRITE"
else
:
return
"<invalid>"
class
OperationStatus
(
IntEnum
):
"""
Status of an operation.
"""
UNINTIALIZED
=
0
INITIALIZED
=
1
IN_PROGRESS
=
2
COMPLETE
=
3
CANCELLED
=
4
ERRORED
=
5
def
__str__
(
self
)
->
str
:
if
self
==
OperationStatus
.
INITIALIZED
:
return
"INIT"
elif
self
==
OperationStatus
.
IN_PROGRESS
:
return
"PROC"
elif
self
==
OperationStatus
.
COMPLETE
:
return
"DONE"
elif
self
==
OperationStatus
.
ERRORED
:
return
"ERR"
elif
self
==
OperationStatus
.
CANCELLED
:
return
"STOP"
else
:
return
"<invalid>"
class
PassiveOperation
(
AbstractOperation
):
"""
Abstract class for common functionality of passive operations.
"""
def
__init__
(
self
,
connector
:
Connector
,
operation_kind
:
OperationKind
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
None
:
if
operation_kind
is
not
OperationKind
.
READ
and
operation_kind
is
not
OperationKind
.
WRITE
:
raise
ValueError
(
"Argument `operation_kind` must be either `READ` or `WRITE`."
)
self
.
_status
=
OperationStatus
.
UNINTIALIZED
super
().
__init__
(
connector
,
operation_kind
,
local_descriptors
,
None
,
None
)
self
.
_serialized_request
:
Optional
[
SerializedRequest
]
=
None
self
.
_status
=
OperationStatus
.
INITIALIZED
def
__del__
(
self
)
->
None
:
super
().
__del__
()
def
__enter__
(
self
)
->
AbstractOperation
:
super
().
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
super
().
__exit__
(
exc_type
,
exc_value
,
traceback
)
def
__repr__
(
self
)
->
str
:
return
str
(
f
"
{
self
.
__class__
.
__name__
}
("
f
"operation_kind=
{
self
.
_operation_kind
}
, "
f
"local_descriptors=
{
self
.
_local_descriptors
}
, "
f
"notification_key='
{
self
.
_notification_key
}
', "
f
"status='
{
self
.
_status
}
'"
f
")"
)
async
def
_wait_for_completion_
(
self
)
->
None
:
# Loop until the operation is no longer in progress (or "initalized"),
# yielding control to the event loop to allow other operations to run.
while
True
:
match
self
.
status
:
# "in progress" or "initialized" means the operation is ongoing.
case
OperationStatus
.
INITIALIZED
:
await
asyncio
.
sleep
(
0.1
)
case
OperationStatus
.
IN_PROGRESS
:
await
asyncio
.
sleep
(
0.1
)
# Any other state indicates completion or error.
case
_
:
return
@
property
def
status
(
self
)
->
OperationStatus
:
"""
Gets the status of the operation.
"""
# Early return if the operation is already complete, errored, or cancelled.
match
self
.
_status
:
case
OperationStatus
.
COMPLETE
|
OperationStatus
.
ERRORED
|
OperationStatus
.
CANCELLED
:
return
self
.
_status
old_status
=
self
.
_status
# Query NIXL for any notifications.
notifications
=
self
.
_connector
.
_nixl
.
update_notifs
()
if
isinstance
(
notifications
,
dict
):
remote_state
=
OperationStatus
.
IN_PROGRESS
logger
.
debug
(
f
"NIXL reported notifications:
{
len
(
notifications
)
}
."
)
for
key
,
values
in
notifications
.
items
():
if
not
isinstance
(
values
,
list
):
raise
TypeError
(
f
"Expected `dict[str, list[bytes]]` from NIXL notification query; got
{
type
(
notifications
)
}
."
)
for
value
in
values
:
if
not
isinstance
(
value
,
bytes
):
continue
notification_key
=
value
.
decode
(
"utf-8"
)
# Once we've found the notification key, we know the operation is complete.
if
notification_key
==
self
.
_notification_key
:
remote_state
=
OperationStatus
.
COMPLETE
break
if
remote_state
==
OperationStatus
.
COMPLETE
:
self
.
_status
=
remote_state
logger
.
debug
(
f
"
{
self
.
__class__
.
__name__
}
{{ remote: '
{
self
.
_connector
.
name
}
' status: '
{
old_status
}
' => '
{
self
.
_status
}
' }}."
)
return
self
.
_status
def
to_serialized
(
self
)
->
SerializedRequest
:
"""
Gets the request descriptor for the operation.
"""
if
self
.
_serialized_request
is
None
:
# When we've not yet cached the serialized request, we need to generate one before returning it.
# Handle both cases: multiple and single descriptors.
if
isinstance
(
self
.
_local_descriptors
,
list
):
descriptors
=
[
desc
.
to_serialized
()
for
desc
in
self
.
_local_descriptors
]
else
:
descriptors
=
[
self
.
_local_descriptors
.
to_serialized
()]
original_len
=
len
(
self
.
_connector
.
metadata
)
nixl_metadata
=
self
.
_connector
.
metadata
nixl_metadata
=
zlib
.
compress
(
nixl_metadata
,
level
=
6
)
compressed_len
=
len
(
nixl_metadata
)
logger
.
debug
(
f
"Compressed NIXL metadata from
{
original_len
}
bytes to
{
compressed_len
}
bytes."
)
if
compressed_len
>
original_len
:
logger
.
warning
(
f
"Compressed NIXL metadata is larger than original (
{
compressed_len
}
>
{
original_len
}
)."
)
self
.
_serialized_request
=
SerializedRequest
(
descriptors
=
descriptors
,
nixl_metadata
=
nixl_metadata
.
hex
(),
notification_key
=
self
.
_notification_key
,
operation_kind
=
int
(
self
.
_operation_kind
),
)
return
self
.
_serialized_request
@
abstractmethod
async
def
wait_for_completion
(
self
)
->
None
:
"""
Blocks the caller asynchronously until the operation has completed.
"""
raise
NotImplementedError
(
"Abstract method not implemented by derived class."
)
class
ReadOperation
(
ActiveOperation
):
"""
Operation that initiates an RDMA read operation to transfer data from a remote worker's `ReadableOperation`,
as described by `remote_request`, to local buffers.
"""
def
__init__
(
self
,
connector
:
Connector
,
remote_request
:
SerializedRequest
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
None
:
"""
Creates a new instance of `ReadOperation`, registers `local_descriptors` with NIXL,
and begins an RDMA read operation which will transfer data described by `remote_request`
to `local_descriptors`.
Parameters
----------
connector : Connector
Connector instance to use for the operation.
remote_request : SerializedRequest
Serialized request from the remote worker.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to to receive the data from the remote agent.
"""
if
not
isinstance
(
connector
,
Connector
):
raise
TypeError
(
"Argument `connector` must be `dynamo.connect.Connector`."
)
if
not
isinstance
(
remote_request
,
SerializedRequest
):
raise
TypeError
(
"Argument `remote_request` must be `dynamo.connect.RequestDescriptor`."
)
if
remote_request
.
operation_kind
!=
OperationKind
.
READ
.
value
:
raise
ValueError
(
"Argument `remote_request` must be of kind `READ`."
)
remote
=
Remote
(
connector
,
remote_request
.
nixl_metadata
)
remote_descriptors
=
remote_request
.
to_descriptors
()
if
not
(
isinstance
(
local_descriptors
,
Descriptor
)
or
(
isinstance
(
local_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
local_descriptors
))
):
raise
TypeError
(
"Argument `local_descriptors` must be `dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`."
)
super
().
__init__
(
remote
,
OperationKind
.
READ
,
local_descriptors
,
remote_descriptors
,
remote_request
.
notification_key
)
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
"
)
def
__del__
(
self
)
->
None
:
super
().
__del__
()
logger
.
debug
(
f
"Deleted
{
self
.
__repr__
()
}
"
)
def
__enter__
(
self
)
->
ReadOperation
:
super
().
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
super
().
__exit__
(
exc_type
,
exc_value
,
traceback
)
def
__repr__
(
self
)
->
str
:
return
super
().
__repr__
()
def
cancel
(
self
)
->
None
:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or been cancelled.
"""
super
().
_cancel_
()
def
results
(
self
)
->
list
[
Descriptor
]:
"""
Gets the results of the operation.
Returns a single descriptor if only one was requested, or a list of descriptors if multiple were requested.
"""
if
self
.
_status
!=
OperationStatus
.
COMPLETE
:
raise
RuntimeError
(
"Operation has not completed yet, cannot get results."
)
return
self
.
_local_descriptors
if
isinstance
(
self
.
_local_descriptors
,
list
)
else
[
self
.
_local_descriptors
]
async
def
wait_for_completion
(
self
)
->
None
:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await
super
().
_wait_for_completion_
()
class
ReadableOperation
(
PassiveOperation
):
"""
Operation that can be awaited until a remote worker has completed a `ReadOperation`.
"""
def
__init__
(
self
,
connector
:
Connector
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
None
:
super
().
__init__
(
connector
,
OperationKind
.
READ
,
local_descriptors
)
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
"
)
def
__del__
(
self
)
->
None
:
super
().
__del__
()
logger
.
debug
(
f
"Deleted
{
self
.
__repr__
()
}
"
)
def
__enter__
(
self
)
->
ReadableOperation
:
super
().
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
super
().
__exit__
(
exc_type
,
exc_value
,
traceback
)
def
__repr__
(
self
)
->
str
:
return
super
().
__repr__
()
async
def
wait_for_completion
(
self
)
->
None
:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await
super
().
_wait_for_completion_
()
class
Remote
:
"""
Identifies a remote NIXL agent relative to a local NIXL agent.
"""
def
__init__
(
self
,
connector
:
Connector
,
nixl_metadata
:
bytes
|
str
,
)
->
None
:
if
not
isinstance
(
connector
,
Connector
):
raise
TypeError
(
"Argument `local` must be `dynamo.connect.Connector`."
)
if
not
(
isinstance
(
nixl_metadata
,
bytes
)
or
isinstance
(
nixl_metadata
,
str
)):
raise
TypeError
(
"Argument `nixl_metadata` must be `bytes` or `str`."
)
if
len
(
nixl_metadata
)
==
0
:
raise
ValueError
(
"Argument `nixl_metadata` cannot be empty."
)
self
.
_connector
=
connector
# When `nixl_metadata` is a string, it is assumed to have come from a remote worker
# via a `SerializedRequest` object and therefore can assumed be a hex-encoded, compressed
# representation of the NIXL metadata.
if
isinstance
(
nixl_metadata
,
str
):
# Decode the hex-encoded string into bytes.
nixl_metadata
=
bytes
.
fromhex
(
nixl_metadata
)
# Decompress the NIXL metadata.
nixl_metadata
=
zlib
.
decompress
(
nixl_metadata
)
self
.
_name
=
connector
.
_nixl
.
add_remote_agent
(
nixl_metadata
)
if
isinstance
(
self
.
_name
,
bytes
):
self
.
_name
=
self
.
_name
.
decode
(
"utf-8"
)
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
."
)
def
__del__
(
self
)
->
None
:
self
.
_release
()
def
__enter__
(
self
)
->
Remote
:
"""
Context manager entry method. Returns the current instance.
"""
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
"""
Context manager exit method. Cleans up the instance.
"""
self
.
_release
()
def
__repr__
(
self
)
->
str
:
return
f
"RemoteAgent(name=
{
self
.
_name
}
, connector=
{
self
.
_connector
.
name
}
)"
def
__str__
(
self
)
->
str
:
return
self
.
_name
def
_release
(
self
)
->
None
:
"""
Private method for releasing NIXL resources. Not intended for public use.
"""
pass
@
property
def
connector
(
self
)
->
Connector
:
"""
Gets the local connector associated with this remote agent.
"""
return
self
.
_connector
@
property
def
name
(
self
)
->
str
:
"""
Gets the name of the remote agent.
"""
return
self
.
_name
class
SerializedDescriptor
(
BaseModel
):
"""
Pydantic serialization type for memory descriptors.
"""
model_config
=
ConfigDict
(
extra
=
"forbid"
,
frozen
=
True
,
arbitrary_types_allowed
=
True
,
)
device
:
str
=
"cpu"
ptr
:
int
=
0
size
:
int
=
0
def
to_descriptor
(
self
)
->
Descriptor
:
"""
Deserialize the serialized descriptor into a `Descriptor` object.
"""
return
Descriptor
(
data
=
(
self
.
ptr
,
self
.
size
,
self
.
device
,
None
))
@
field_validator
(
"device"
)
def
validate_memtype
(
cls
,
v
:
str
)
->
str
:
if
not
isinstance
(
v
,
str
):
raise
TypeError
(
"Argument `device` must be `str`."
)
v
=
v
.
strip
().
lower
()
if
not
(
v
.
startswith
(
"cuda"
)
or
v
==
"cpu"
):
raise
ValueError
(
"Argument `device` must be one of 'cpu' or 'cuda:<device_id>'."
)
return
v
@
field_validator
(
"ptr"
)
def
validate_ptr
(
cls
,
v
:
int
)
->
int
:
if
v
==
0
:
raise
ValueError
(
"Argument `ptr` cannot be zero (aka `null` or `None`)."
)
return
v
@
field_validator
(
"size"
)
def
validate_size
(
cls
,
v
:
int
)
->
int
:
if
v
<
0
:
raise
ValueError
(
"Argument `size` must be an integer greater than or equal to zero."
)
return
v
class
SerializedRequest
(
BaseModel
):
"""
Pydantic serialization type for describing the passive side of a transfer.
"""
model_config
=
ConfigDict
(
extra
=
"forbid"
,
frozen
=
True
,
arbitrary_types_allowed
=
True
,
)
descriptors
:
List
[
SerializedDescriptor
]
=
[]
nixl_metadata
:
str
=
""
notification_key
:
str
=
""
operation_kind
:
int
=
0
def
to_descriptors
(
self
)
->
Descriptor
|
list
[
Descriptor
]:
"""
Deserializes the request descriptor into a `dynamo.connect.Descriptor` or list of `dynamo.connect.Descriptor` objects.
"""
if
len
(
self
.
descriptors
)
==
0
:
raise
ValueError
(
"Request descriptor must contain at least one serialized descriptor."
)
if
len
(
self
.
descriptors
)
==
1
:
return
self
.
descriptors
[
0
].
to_descriptor
()
return
[
item
.
to_descriptor
()
for
item
in
self
.
descriptors
]
@
field_validator
(
"operation_kind"
)
def
validate_operation_kind
(
cls
,
v
:
int
)
->
int
:
if
v
<
1
or
v
>
3
:
raise
TypeError
(
"Argument `operation_kind` must be an integer value of `dynamo.connect.OperationKind`."
)
return
v
class
WritableOperation
(
PassiveOperation
):
"""
Operation which can be awaited until written to by a `WriteOperation` from a remote worker.
"""
def
__init__
(
self
,
connector
:
Connector
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
None
:
"""
Creates a new instance of `WritableOperation`, registers the operation and descriptors w/ NIXL,
and enables an RDMA write operation to occur.
Parameters
----------
connector : Connector
Connector instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Descriptors to receive data from a remote worker.
Raises
TypeError
When `local` is not a `dynamo.connect.Connector`.
TypeError
When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
super
().
__init__
(
connector
,
OperationKind
.
WRITE
,
local_descriptors
)
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
"
)
def
__del__
(
self
)
->
None
:
super
().
__del__
()
logger
.
debug
(
f
"Deleted
{
self
.
__repr__
()
}
"
)
def
__enter__
(
self
)
->
WritableOperation
:
super
().
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
super
().
__exit__
(
exc_type
,
exc_value
,
traceback
)
def
__repr__
(
self
)
->
str
:
return
super
().
__repr__
()
async
def
wait_for_completion
(
self
)
->
None
:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await
super
().
_wait_for_completion_
()
class
WriteOperation
(
ActiveOperation
):
"""
Awaitable write operation which initiates an RDMA write operation to a remote worker
which provided a `SerializedRequest` object from a `WritableOperation`.
"""
def
__init__
(
self
,
connector
:
Connector
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
remote_request
:
SerializedRequest
,
)
->
None
:
"""
Creates a new instance of `WriteOperation`, registers `local_descriptors` with NIXL,
and begins an RDMA write operation which will transfer from `local_descriptors` to
remote target(s) described by `remote_request`
Parameters
----------
connector : Connector
Connector instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to send from, to the remote agent.
remote_request : SerializedRequest
Serialized request from the remote worker that describes the target(s) to send to.
Raises
TypeError
When `connector` is not a `dynamo.connect.Connector`.
TypeError
When `remote_request` is not a `dynamo.connect.RequestDescriptor`.
ValueError
When `remote_request` is not of kind `WRITE`.
ValueError
When `remote_request.nixl_metadata` is not a non-empty `str`.
TypeError
When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
if
not
isinstance
(
connector
,
Connector
):
raise
TypeError
(
"Argument `connector` must be `dynamo.connect.Connector`."
)
if
not
isinstance
(
remote_request
,
SerializedRequest
):
raise
TypeError
(
"Argument `remote_request` must be `dynamo.connect.RequestDescriptor`."
)
if
remote_request
.
operation_kind
!=
OperationKind
.
WRITE
.
value
:
raise
ValueError
(
"Argument `remote_request` must be of kind `WRITE`."
)
remote
=
Remote
(
connector
,
remote_request
.
nixl_metadata
)
remote_descriptors
=
remote_request
.
to_descriptors
()
super
().
__init__
(
remote
,
OperationKind
.
WRITE
,
local_descriptors
,
remote_descriptors
,
remote_request
.
notification_key
)
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
"
)
def
__del__
(
self
)
->
None
:
super
().
__del__
()
logger
.
debug
(
f
"Deleted
{
self
.
__repr__
()
}
"
)
def
__enter__
(
self
)
->
WriteOperation
:
super
().
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
super
().
__exit__
(
exc_type
,
exc_value
,
traceback
)
def
__repr__
(
self
)
->
str
:
return
super
().
__repr__
()
def
cancel
(
self
)
->
None
:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or has been cancelled.
"""
super
().
_cancel_
()
async
def
wait_for_completion
(
self
)
->
None
:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await
super
().
_wait_for_completion_
()
examples/multimodal/graphs/agg.py
View file @
75e774d4
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -13,9 +14,9 @@
...
@@ -13,9 +14,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
components.encode_worker
import
EncodeWorker
from
components.decode_worker
import
VllmDecodeWorker
from
components.encode_worker
import
VllmEncodeWorker
from
components.frontend
import
Frontend
from
components.frontend
import
Frontend
from
components.processor
import
Processor
from
components.processor
import
Processor
from
components.worker
import
VllmWorker
Frontend
.
link
(
Processor
).
link
(
VllmWorker
).
link
(
EncodeWorker
)
Frontend
.
link
(
Processor
).
link
(
Vllm
Decode
Worker
).
link
(
Vllm
EncodeWorker
)
examples/multimodal/graphs/disagg.py
View file @
75e774d4
...
@@ -13,10 +13,12 @@
...
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
components.encode_worker
import
EncodeWorker
from
components.decode_worker
import
VllmDecodeWorker
from
components.encode_worker
import
VllmEncodeWorker
from
components.frontend
import
Frontend
from
components.frontend
import
Frontend
from
components.prefill_worker
import
PrefillWorker
from
components.prefill_worker
import
Vllm
PrefillWorker
from
components.processor
import
Processor
from
components.processor
import
Processor
from
components.worker
import
VllmWorker
Frontend
.
link
(
Processor
).
link
(
VllmWorker
).
link
(
PrefillWorker
).
link
(
EncodeWorker
)
Frontend
.
link
(
Processor
).
link
(
VllmDecodeWorker
).
link
(
VllmPrefillWorker
).
link
(
VllmEncodeWorker
)
examples/multimodal/utils/protocol.py
View file @
75e774d4
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
json
import
json
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
List
,
Optional
import
connect
import
msgspec
import
msgspec
from
pydantic
import
BaseModel
,
ConfigDict
,
field_validator
from
pydantic
import
BaseModel
,
ConfigDict
,
field_validator
from
pydantic_core
import
core_schema
from
pydantic_core
import
core_schema
...
@@ -111,12 +112,13 @@ class EncodeRequest(BaseModel):
...
@@ -111,12 +112,13 @@ class EncodeRequest(BaseModel):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
image_url
:
str
image_url
:
str
request_id
:
str
serialized_request
:
Optional
[
connect
.
SerializedRequest
]
=
None
class
EncodeResponse
(
BaseModel
):
class
EncodeResponse
(
BaseModel
):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
request_id
:
str
image_features
:
List
[
List
[
List
[
float
]]]
class
MyRequestOutput
(
BaseModel
):
class
MyRequestOutput
(
BaseModel
):
...
@@ -129,7 +131,6 @@ class MyRequestOutput(BaseModel):
...
@@ -129,7 +131,6 @@ class MyRequestOutput(BaseModel):
"""
"""
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
request_id
:
str
request_id
:
str
prompt
:
Optional
[
str
]
=
None
prompt
:
Optional
[
str
]
=
None
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment