Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
75e774d4
Unverified
Commit
75e774d4
authored
May 27, 2025
by
J Wyman
Committed by
GitHub
May 27, 2025
Browse files
feat: NIXL Based RDMA Support w/ Multimodal Example (#1060)
parent
9acaa8d1
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1784 additions
and
112 deletions
+1784
-112
.pre-commit-config.yaml
.pre-commit-config.yaml
+2
-2
examples/multimodal/components/decode_worker.py
examples/multimodal/components/decode_worker.py
+88
-29
examples/multimodal/components/encode_worker.py
examples/multimodal/components/encode_worker.py
+99
-7
examples/multimodal/components/prefill_worker.py
examples/multimodal/components/prefill_worker.py
+102
-56
examples/multimodal/components/processor.py
examples/multimodal/components/processor.py
+3
-3
examples/multimodal/configs/agg.yaml
examples/multimodal/configs/agg.yaml
+2
-2
examples/multimodal/configs/disagg.yaml
examples/multimodal/configs/disagg.yaml
+3
-3
examples/multimodal/connect/__init__.py
examples/multimodal/connect/__init__.py
+1471
-0
examples/multimodal/graphs/agg.py
examples/multimodal/graphs/agg.py
+4
-3
examples/multimodal/graphs/disagg.py
examples/multimodal/graphs/disagg.py
+6
-4
examples/multimodal/utils/protocol.py
examples/multimodal/utils/protocol.py
+4
-3
No files found.
.pre-commit-config.yaml
View file @
75e774d4
...
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
exclude
:
^(src/grpc_generated|.*\.patch$)
exclude
:
^(src/grpc_generated|.*\.patch$
|.*/connect/.*\.py
)
repos
:
-
repo
:
https://github.com/timothycrosley/isort
rev
:
5.12.0
...
...
@@ -82,4 +82,4 @@ repos:
# NOTE: pyright may be able to find other classes of errors not covered above,
# but would require some configuring and venv setup to properly eliminate noise
# and give it visiblity into all the local and third_party packages expected.
\ No newline at end of file
# and give it visiblity into all the local and third_party packages expected.
examples/multimodal/components/worker.py
→
examples/multimodal/components/
decode_
worker.py
View file @
75e774d4
...
...
@@ -19,10 +19,11 @@ import os
import
signal
from
typing
import
Optional
import
connect
import
torch
from
components.disagg_router
import
PyDisaggregatedRouter
from
components.encode_worker
import
EncodeWorker
from
components.prefill_worker
import
PrefillWorker
from
components.encode_worker
import
Vllm
EncodeWorker
from
components.prefill_worker
import
Vllm
PrefillWorker
from
transformers
import
LlavaForConditionalGeneration
from
utils.logging
import
check_required_workers
from
utils.nixl
import
NixlMetadataStore
...
...
@@ -53,11 +54,11 @@ logger = logging.getLogger(__name__)
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
workers
=
1
,
)
class
VllmWorker
:
class
Vllm
Decode
Worker
:
# For disaggregated serving, we need to link the prefill worker to the vllm worker
prefill_worker
=
depends
(
PrefillWorker
)
prefill_worker
=
depends
(
Vllm
PrefillWorker
)
# For aggregated serving, we need to link the encode worker to the vllm worker.
encode_worker
=
depends
(
EncodeWorker
)
encode_worker
=
depends
(
Vllm
EncodeWorker
)
def
__init__
(
self
):
self
.
client
=
None
...
...
@@ -141,7 +142,11 @@ class VllmWorker:
vision_tower
.
vision_model
.
embeddings
.
position_embedding
.
num_embeddings
)
else
:
enc_comp_ns
,
enc_comp_name
=
EncodeWorker
.
dynamo_address
()
# type: ignore
EMBEDDINGS_SHAPE
=
(
1
,
577
,
4096
)
EMBEDDINGS_DTYPE
=
torch
.
float16
EMBEDDINGS_DEVICE
=
"cuda"
enc_comp_ns
,
enc_comp_name
=
VllmEncodeWorker
.
dynamo_address
()
# type: ignore
self
.
encode_worker_client
=
(
await
runtime
.
namespace
(
enc_comp_ns
)
.
component
(
enc_comp_name
)
...
...
@@ -149,9 +154,22 @@ class VllmWorker:
.
client
()
)
self
.
_connector
=
connect
.
Connector
(
runtime
=
runtime
,
namespace
=
enc_comp_ns
)
await
self
.
_connector
.
initialize
()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings
=
torch
.
empty
(
EMBEDDINGS_SHAPE
,
dtype
=
EMBEDDINGS_DTYPE
,
device
=
EMBEDDINGS_DEVICE
)
descriptor
=
connect
.
Descriptor
(
embeddings
)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor
.
register_memory
(
self
.
_connector
)
self
.
_embeddings_descriptor
=
(
embeddings
,
descriptor
)
await
check_required_workers
(
self
.
encode_worker_client
,
self
.
min_workers
)
self
.
disaggregated_router
=
None
logger
.
info
(
"VllmWorker has been initialized"
)
logger
.
info
(
"Initialization complete."
)
def
shutdown_vllm_engine
(
self
,
signum
,
frame
):
"""Shutdown the background loop"""
...
...
@@ -159,7 +177,7 @@ class VllmWorker:
loop
=
asyncio
.
get_event_loop
()
try
:
self
.
engine_client
.
close
()
logger
.
info
(
"
VllmWorker s
hutdown complete"
)
logger
.
info
(
"
S
hutdown complete
.
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error during shutdown:
{
e
}
"
)
finally
:
...
...
@@ -177,8 +195,18 @@ class VllmWorker:
@
endpoint
()
async
def
generate
(
self
,
request
:
vLLMMultimodalRequest
):
image_features
=
None
request_id
=
request
.
request_id
image_url
=
request
.
image_url
logger
.
info
(
f
"Received multimodal request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
embeddings
=
None
if
self
.
do_remote_prefill
:
logger
.
debug
(
f
"Disaggregated: request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}"
" prefill worker will populate the decode model's key-value cache ahead of time;"
" no direct encode worker interaction required."
)
if
self
.
disaggregated_router
is
not
None
:
async
with
PrefillQueue
.
get_instance
(
nats_server
=
self
.
_prefill_queue_nats_server
,
...
...
@@ -195,21 +223,21 @@ class VllmWorker:
disagg_router_decision
=
True
if
self
.
do_remote_prefill
and
disagg_router_decision
:
logger
.
debug
(
f
"Prefilling remotely for request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }} with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
remote_prefill_params
=
RemotePrefillParams
(
is_remote_prefill
=
True
,
remote_prefill_request_callback
=
self
.
get_remote_prefill_request_callback
(),
# Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
multimodal_data_source
=
{
"image_url"
:
request
.
image_url
,
"image_url"
:
image_url
,
},
)
logger
.
info
(
f
"Prefilling remotely for request
{
request
.
request_id
}
with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
else
:
remote_prefill_params
=
None
logger
.
info
(
f
"Prefilling locally for request
{
request
.
request_id
}
with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
logger
.
debug
(
f
"Prefilling locally for request {
{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }
} with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
# The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
...
...
@@ -231,33 +259,61 @@ class VllmWorker:
)
else
:
# For aggregated serving, the vllm worker will directly send the encode request to the encode worker.
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
EncodeRequest
(
image_url
=
request
.
image_url
,
).
model_dump_json
()
logger
.
debug
(
f
"Aggregated: request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}"
" no prefill worker available, embeddings directly from encode worker."
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
async
for
encode_response
in
encode_generator
:
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
()
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
# Doing this avoids unnessesary memory de/registration with NIXL.
embeddings
,
descriptor
=
self
.
_embeddings_descriptor
with
self
.
_connector
.
create_writable
(
descriptor
)
as
writable
:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_request
=
EncodeRequest
(
request_id
=
request_id
,
image_url
=
image_url
,
serialized_request
=
writable
.
to_serialized
(),
)
image_features
=
torch
.
tensor
(
encode_output
.
image_features
,
device
=
device
,
dtype
=
torch
.
float16
logger
.
debug
(
f
"Encode request:
{
encode_request
.
model_dump_json
()
}
"
)
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
encode_request
.
model_dump_json
()
)
async
for
encode_response
in
encode_generator
:
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
()
)
logger
.
info
(
f
"Received response: {{ id:
{
encode_output
.
request_id
}
}}"
)
# Wait for the write operation to complete.
# This will block until the write operation is complete.
# This await should be a no-op since we've already received a response from the encode worker.
await
writable
.
wait_for_completion
()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params
=
None
logger
.
info
(
f
"Prefilling locally for request
{
request
.
request_id
}
with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
f
"Prefilling locally for request {
{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }
} with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
prompt_ids
=
request
.
engine_prompt
[
"prompt_token_ids"
]
# rust HTTP requires Delta streaming
request
.
sampling_params
.
output_kind
=
RequestOutputKind
.
DELTA
if
image_features
is
not
None
:
multi_modal_data
=
{
"image"
:
image_features
}
# When using aggregated serving, the encode worker will have provided the key-value cache updates via the prefill worker.
# When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
if
embeddings
is
not
None
:
logger
.
debug
(
"Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
)
multi_modal_data
=
{
"image"
:
embeddings
}
else
:
logger
.
debug
(
"Disaggregated: no embedding data required as prefill will have provided key-value cache updates via encode worker."
)
multi_modal_data
=
None
async
for
response
in
self
.
engine_client
.
generate
(
...
...
@@ -269,6 +325,9 @@ class VllmWorker:
request_id
=
request
.
request_id
,
remote_prefill_params
=
remote_prefill_params
,
):
logger
.
debug
(
f
"Yeilding response {{ id:
{
response
.
request_id
}
, prompt: '
{
response
.
prompt
}
' }}"
)
yield
MyRequestOutput
(
request_id
=
response
.
request_id
,
prompt
=
response
.
prompt
,
...
...
examples/multimodal/components/encode_worker.py
View file @
75e774d4
...
...
@@ -15,8 +15,10 @@
import
logging
from
io
import
BytesIO
from
queue
import
Queue
from
typing
import
AsyncIterator
import
connect
import
requests
import
torch
from
PIL
import
Image
...
...
@@ -24,10 +26,25 @@ from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from
utils.protocol
import
EncodeRequest
,
EncodeResponse
from
utils.vllm
import
parse_vllm_args
from
dynamo.sdk
import
endpoint
,
service
from
dynamo.sdk
import
async_on_start
,
endpoint
,
service
logger
=
logging
.
getLogger
(
__name__
)
try
:
import
cupy
as
array_module
if
not
array_module
.
cuda
.
is_available
():
raise
ImportError
(
"CUDA is not available."
)
DEVICE
=
"cuda"
logger
.
info
(
"Using cupy for array operations (GPU mode)."
)
except
ImportError
as
e
:
logger
.
warning
(
f
"Failed to import cupy, falling back to numpy:
{
e
}
."
)
import
numpy
as
array_module
DEVICE
=
"cpu"
CACHE_SIZE_MAXIMUM
=
8
@
service
(
dynamo
=
{
...
...
@@ -36,7 +53,7 @@ logger = logging.getLogger(__name__)
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
workers
=
1
,
)
class
EncodeWorker
:
class
Vllm
EncodeWorker
:
def
__init__
(
self
)
->
None
:
class_name
=
self
.
__class__
.
__name__
self
.
engine_args
=
parse_vllm_args
(
class_name
,
""
)
...
...
@@ -50,9 +67,50 @@ class EncodeWorker:
self
.
MODEL_ID
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
float16
).
eval
()
self
.
_image_cache
:
dict
[
str
,
Image
.
Image
]
=
{}
self
.
_cache_queue
:
Queue
[
str
]
=
Queue
(
maxsize
=
CACHE_SIZE_MAXIMUM
)
@
endpoint
()
async
def
encode
(
self
,
request
:
EncodeRequest
)
->
AsyncIterator
[
EncodeResponse
]:
image
=
self
.
open_image
(
request
.
image_url
)
logger
.
debug
(
f
"Received encode request: {{ id:
{
request
.
request_id
}
, image_url: '
{
request
.
image_url
}
' }}."
)
request_id
=
request
.
request_id
image_url
=
request
.
image_url
.
lower
()
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
# Either retrieve the image from the cache or download it and then cache it.
if
request
.
image_url
in
self
.
_image_cache
:
image
=
self
.
_image_cache
[
image_url
]
logger
.
debug
(
f
"Image found in cache for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
else
:
image
=
self
.
open_image
(
image_url
)
logger
.
debug
(
f
"Downloading/opening image for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
# Cache the image for future use, and evict the oldest image if the cache is full.
if
self
.
_cache_queue
.
full
():
oldest_image_url
=
self
.
_cache_queue
.
get
()
del
self
.
_image_cache
[
oldest_image_url
]
self
.
_image_cache
[
request
.
image_url
]
=
image
self
.
_cache_queue
.
put
(
request
.
image_url
)
logger
.
debug
(
f
"Processing image for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}"
)
image_embeds
=
self
.
image_processor
(
images
=
image
,
return_tensors
=
"pt"
)
with
torch
.
no_grad
():
...
...
@@ -60,22 +118,56 @@ class EncodeWorker:
vision_outputs
=
self
.
vision_model
.
vision_tower
(
image_embeds
[
"pixel_values"
].
to
(
self
.
vision_model
.
device
)
)
logger
.
debug
(
"Vision model completed."
)
embeddings
=
vision_outputs
.
last_hidden_state
embeddings
=
self
.
vision_model
.
multi_modal_projector
(
embeddings
)
logger
.
debug
(
f
"Embeddings: {{ shape:
{
embeddings
.
shape
}
, dtype:
{
embeddings
.
dtype
}
, device:
{
embeddings
.
device
}
, ptr:
{
embeddings
.
data_ptr
()
}
, elements: {{ count:
{
embeddings
.
numel
()
}
, size:
{
embeddings
.
element_size
()
}
}} }}."
)
if
request
.
serialized_request
is
None
:
logger
.
error
(
f
"Request serialized_request is None for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
# Create a descriptor for the embeddings, this will register the memory with the connector (and the NIXL runtime).
descriptor
=
connect
.
Descriptor
(
embeddings
)
# Create a write operation using the serialized request and the descriptor.
# This will begin the RDMA transfer of the embeddings to the remote worker.
write_op
=
await
self
.
_connector
.
begin_write
(
descriptor
,
request
.
serialized_request
,
)
# Await for the write operation to complete.
# This will block until the data has been written to the remote worker or an error occurs.
await
write_op
.
wait_for_completion
()
image_features
=
vision_outputs
.
last_hidden_state
image_features
=
self
.
vision_model
.
multi_modal_projector
(
image_features
)
yield
EncodeResponse
(
image_features
=
image_features
.
tolist
()
request_id
=
request
.
request_id
,
).
model_dump_json
()
@
async_on_start
()
async
def
on_start
(
self
):
logger
.
info
(
"Startup started."
)
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self
.
_connector
=
connect
.
Connector
()
await
self
.
_connector
.
initialize
()
logger
.
info
(
"Startup completed."
)
def
open_image
(
self
,
image
:
str
)
->
Image
.
Image
:
# TODO: Have a seperate field for url and non url - and avoid auto detection
try
:
# Acquire the image and convert it to the format (RGB) the image processor model expects.
if
image
.
startswith
(
"http"
)
or
image
.
startswith
(
"https"
):
response
=
requests
.
get
(
image
)
image_data
=
Image
.
open
(
BytesIO
(
response
.
content
)).
convert
(
"RGB"
)
else
:
image_data
=
Image
.
open
(
image
).
convert
(
"RGB"
)
return
image_data
except
Exception
as
e
:
logger
.
error
(
f
"Error opening image:
{
e
}
"
)
raise
e
return
image_data
examples/multimodal/components/prefill_worker.py
View file @
75e774d4
...
...
@@ -20,8 +20,9 @@ import os
import
signal
import
sys
import
connect
import
torch
from
components.encode_worker
import
EncodeWorker
from
components.encode_worker
import
Vllm
EncodeWorker
from
pydantic
import
BaseModel
from
utils.logging
import
check_required_workers
from
utils.nixl
import
NixlMetadataStore
...
...
@@ -38,6 +39,11 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic
logger
=
logging
.
getLogger
(
__name__
)
# Constants for the shape and dtype of the embeddings tensor.
EMBEDDINGS_SHAPE
=
(
1
,
577
,
4096
)
EMBEDDINGS_DTYPE
=
torch
.
float16
EMBEDDINGS_DEVICE
=
"cuda"
class
RequestType
(
BaseModel
):
text
:
str
...
...
@@ -50,8 +56,8 @@ class RequestType(BaseModel):
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
workers
=
1
,
)
class
PrefillWorker
:
encode_worker
=
depends
(
EncodeWorker
)
class
Vllm
PrefillWorker
:
encode_worker
=
depends
(
Vllm
EncodeWorker
)
def
__init__
(
self
):
class_name
=
self
.
__class__
.
__name__
...
...
@@ -95,7 +101,7 @@ class PrefillWorker:
raise
RuntimeError
(
"Failed to initialize engine client"
)
runtime
=
dynamo_context
[
"runtime"
]
enc_comp_ns
,
enc_comp_name
=
EncodeWorker
.
dynamo_address
()
# type: ignore
enc_comp_ns
,
enc_comp_name
=
Vllm
EncodeWorker
.
dynamo_address
()
# type: ignore
self
.
encode_worker_client
=
(
await
runtime
.
namespace
(
enc_comp_ns
)
.
component
(
enc_comp_name
)
...
...
@@ -103,6 +109,20 @@ class PrefillWorker:
.
client
()
)
self
.
_connector
=
connect
.
Connector
(
runtime
=
runtime
,
namespace
=
enc_comp_ns
)
await
self
.
_connector
.
initialize
()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings
=
torch
.
empty
(
EMBEDDINGS_SHAPE
,
dtype
=
EMBEDDINGS_DTYPE
,
device
=
EMBEDDINGS_DEVICE
,
)
descriptor
=
connect
.
Descriptor
(
embeddings
)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor
.
register_memory
(
self
.
_connector
)
self
.
_embeddings_descriptor
=
(
embeddings
,
descriptor
)
await
check_required_workers
(
self
.
encode_worker_client
,
self
.
min_workers
)
metadata
=
self
.
engine_client
.
nixl_metadata
...
...
@@ -119,19 +139,19 @@ class PrefillWorker:
sys
.
exit
(
1
)
task
.
add_done_callback
(
prefill_queue_handler_cb
)
logger
.
info
(
"
PrefillWorker initialized
"
)
logger
.
info
(
"
Initialization complete.
"
)
def
shutdown_vllm_engine
(
self
,
signum
,
frame
):
"""Shutdown the background loop"""
logger
.
info
(
f
"
Receiv
ed signal
{
signum
}
, shutting down
"
)
logger
.
info
(
f
"
Shutdown start
ed
,
signal
{
signum
}
received.
"
)
loop
=
asyncio
.
get_event_loop
()
try
:
self
.
engine_client
.
close
()
logger
.
info
(
"PrefillWorker shutdown complete"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error during shutdown:
{
e
}
"
)
finally
:
loop
.
stop
()
logger
.
info
(
"Shutdown complete."
)
async
def
prefill_queue_handler
(
self
):
logger
.
info
(
"Prefill queue handler entered"
)
...
...
@@ -166,62 +186,88 @@ class PrefillWorker:
if
request
.
multimodal_data_source
[
"image_url"
]
is
None
:
raise
ValueError
(
"No image url provided for prefill request"
)
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
EncodeRequest
(
image_url
=
request
.
multimodal_data_source
[
"image_url"
],
).
model_dump_json
()
request_id
=
request
.
request_id
engine_id
=
request
.
engine_id
image_url
=
request
.
multimodal_data_source
[
"image_url"
]
logger
.
info
(
f
"Received prefill request {{ id:
{
request_id
}
, engine_id:
{
engine_id
}
, image_url: '
{
image_url
}
' }}."
)
async
for
encode_response
in
encode_generator
:
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
())
image_features
=
torch
.
tensor
(
encode_output
.
image_features
,
device
=
"cpu"
,
dtype
=
torch
.
float16
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
# Doing this avoids unnessesary memory de/registration with NIXL.
embeddings
,
descriptor
=
self
.
_embeddings_descriptor
# Create a new writable operation from the descriptor.
with
self
.
_connector
.
create_writable
(
descriptor
)
as
writable
:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
EncodeRequest
(
request_id
=
request_id
,
image_url
=
image_url
,
serialized_request
=
writable
.
to_serialized
(),
).
model_dump_json
()
)
async
for
encode_response
in
encode_generator
:
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
(),
)
logger
.
debug
(
f
"Received response: {{ id:
{
encode_output
.
request_id
}
}}."
)
sampling_params
=
request
.
sampling_params
sampling_params
.
max_tokens
=
1
sampling_params
.
min_tokens
=
1
# Wait for the write operation to complete.
# This will block until the write operation is complete.
# This await should be a no-op since we've already received a response from the encode worker.
await
writable
.
wait_for_completion
()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params
=
RemotePrefillParams
(
is_remote_decode
=
True
,
decode_block_ids
=
request
.
block_ids
,
decode_engine_id
=
request
.
engine_id
,
decode_computed_block_ids
=
request
.
computed_block_ids
,
)
sampling_params
=
request
.
sampling_params
sampling_params
.
max_tokens
=
1
sampling_params
.
min_tokens
=
1
# TODO check if metadata has changed
# and reload - currently only loading once
if
request
.
engine_id
not
in
self
.
_loaded_metadata
:
remote_metadata
=
await
self
.
_metadata_store
.
get
(
request
.
engine_id
)
await
self
.
engine_client
.
add_remote_nixl_metadata
(
remote_metadata
)
logger
.
info
(
f
"Loaded nixl metadata from engine
{
request
.
engine_id
}
into "
f
"engine
{
self
.
engine_client
.
nixl_metadata
.
engine_id
}
"
remote_prefill_params
=
RemotePrefillParams
(
is_remote_decode
=
True
,
decode_block_ids
=
request
.
block_ids
,
decode_engine_id
=
engine_id
,
decode_computed_block_ids
=
request
.
computed_block_ids
,
)
# TODO check if metadata has changed
# and reload - currently only loading once
if
engine_id
not
in
self
.
_loaded_metadata
:
remote_metadata
=
await
self
.
_metadata_store
.
get
(
request
.
engine_id
)
await
self
.
engine_client
.
add_remote_nixl_metadata
(
remote_metadata
)
logger
.
info
(
f
"Loaded nixl metadata from engine
{
engine_id
}
into "
f
"engine
{
self
.
engine_client
.
nixl_metadata
.
engine_id
}
"
)
self
.
_loaded_metadata
.
add
(
engine_id
)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID
=
32000
embedding_size
=
embeddings
.
shape
[
1
]
padding_size
=
embedding_size
-
1
image_token_index
=
request
.
prompt_token_ids
.
index
(
IMAGE_TOKEN_ID
)
dummy_token_index
=
image_token_index
+
1
prompt_token_ids
=
(
request
.
prompt_token_ids
[:
dummy_token_index
]
+
request
.
prompt_token_ids
[
dummy_token_index
+
padding_size
:]
)
self
.
_loaded_metadata
.
add
(
request
.
engine_id
)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID
=
32000
embedding_size
=
image_features
.
shape
[
1
]
padding_size
=
embedding_size
-
1
image_token_index
=
request
.
prompt_token_ids
.
index
(
IMAGE_TOKEN_ID
)
dummy_token_index
=
image_token_index
+
1
prompt_token_ids
=
(
request
.
prompt_token_ids
[:
dummy_token_index
]
+
request
.
prompt_token_ids
[
dummy_token_index
+
padding_size
:]
)
async
for
_
in
self
.
engine_client
.
generate
(
request_id
=
request
.
request
_id
,
prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
{
"image"
:
image_feature
s
},
),
sampling_params
=
sampling_params
,
remote_prefill_params
=
remote_prefill_params
,
):
yield
async
for
_
in
self
.
engine_client
.
generate
(
request_id
=
request_id
,
prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
{
"image"
:
embedding
s
},
),
sampling_params
=
sampling_params
,
remote_prefill_params
=
remote_prefill_params
,
):
yield
@
endpoint
()
async
def
mock
(
self
,
req
:
RequestType
):
...
...
examples/multimodal/components/processor.py
View file @
75e774d4
...
...
@@ -19,7 +19,7 @@ import uuid
from
enum
import
Enum
from
typing
import
AsyncIterator
,
Tuple
,
Union
from
components.worker
import
VllmWorker
from
components.
decode_
worker
import
Vllm
Decode
Worker
from
transformers
import
AutoTokenizer
from
utils.chat_processor
import
ChatProcessor
,
CompletionsProcessor
,
ProcessMixIn
from
utils.logging
import
check_required_workers
...
...
@@ -53,7 +53,7 @@ class Processor(ProcessMixIn):
vLLM pre and post processing
"""
worker
=
depends
(
VllmWorker
)
worker
=
depends
(
Vllm
Decode
Worker
)
def
__init__
(
self
):
class_name
=
self
.
__class__
.
__name__
...
...
@@ -83,7 +83,7 @@ class Processor(ProcessMixIn):
@
async_on_start
async
def
async_init
(
self
):
runtime
=
dynamo_context
[
"runtime"
]
comp_ns
,
comp_name
=
VllmWorker
.
dynamo_address
()
# type: ignore
comp_ns
,
comp_name
=
Vllm
Decode
Worker
.
dynamo_address
()
# type: ignore
self
.
worker_client
=
(
await
runtime
.
namespace
(
comp_ns
)
.
component
(
comp_name
)
...
...
examples/multimodal/configs/agg.yaml
View file @
75e774d4
...
...
@@ -21,7 +21,7 @@ Processor:
router
:
round-robin
common-configs
:
[
model
,
block-size
,
max-model-len
]
VllmWorker
:
Vllm
Decode
Worker
:
enforce-eager
:
true
max-num-batched-tokens
:
16384
enable-prefix-caching
:
true
...
...
@@ -33,7 +33,7 @@ VllmWorker:
gpu
:
1
common-configs
:
[
model
,
block-size
,
max-model-len
]
EncodeWorker
:
Vllm
EncodeWorker
:
tensor-parallel-size
:
1
router
:
random
ServiceArgs
:
...
...
examples/multimodal/configs/disagg.yaml
View file @
75e774d4
...
...
@@ -22,7 +22,7 @@ Processor:
router
:
round-robin
common-configs
:
[
model
,
block-size
]
VllmWorker
:
Vllm
Decode
Worker
:
remote-prefill
:
true
conditional-disagg
:
true
max-local-prefill-length
:
10
...
...
@@ -33,7 +33,7 @@ VllmWorker:
gpu
:
1
common-configs
:
[
model
,
block-size
,
max-model-len
,
kv-transfer-config
]
PrefillWorker
:
Vllm
PrefillWorker
:
max-num-batched-tokens
:
16384
ServiceArgs
:
workers
:
1
...
...
@@ -41,7 +41,7 @@ PrefillWorker:
gpu
:
1
common-configs
:
[
model
,
block-size
,
max-model-len
,
kv-transfer-config
]
EncodeWorker
:
Vllm
EncodeWorker
:
tensor-parallel-size
:
1
router
:
random
ServiceArgs
:
...
...
examples/multimodal/connect/__init__.py
0 → 100644
View file @
75e774d4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
asyncio
import
logging
import
socket
import
uuid
import
zlib
from
abc
import
ABC
,
abstractmethod
from
enum
import
IntEnum
from
functools
import
cached_property
from
typing
import
Any
,
List
,
Optional
import
nixl._api
as
nixl_api
import
nixl._bindings
as
nixl_bindings
import
torch
from
pydantic
import
BaseModel
,
ConfigDict
,
field_validator
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.sdk
import
dynamo_context
logger
=
logging
.
getLogger
(
__name__
)
try
:
import
cupy
as
array_module
from
cupy_backends.cuda.api.runtime
import
CUDARuntimeError
logger
.
info
(
"Utilizing cupy to enable GPU acceleration."
)
except
ImportError
:
try
:
import
numpy
as
array_module
logger
.
warning
(
"Failed to load cupy for GPU acceleration, utilizing numpy to provide CPU based operations."
)
except
ImportError
as
e
:
raise
ImportError
(
"Numpy or cupy must be installed to use this module."
)
from
e
class
AbstractOperation
(
ABC
):
"""
Abstract base class for awaitable NIXL based RDMA operations.
"""
def
__init__
(
self
,
connector
:
Connector
,
operation_kind
:
OperationKind
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
remote_descriptors
:
Optional
[
Descriptor
|
list
[
Descriptor
]],
notification_key
:
Optional
[
str
],
)
->
None
:
if
not
isinstance
(
connector
,
Connector
):
raise
TypeError
(
"Argument `connector` must be `dynamo.connect.Connector`."
)
if
operation_kind
is
not
OperationKind
.
READ
and
operation_kind
is
not
OperationKind
.
WRITE
:
raise
ValueError
(
"Argument `operation_kind` must be either `READ` or `WRITE`."
)
if
not
(
isinstance
(
local_descriptors
,
(
Descriptor
,
list
))
or
(
isinstance
(
local_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
local_descriptors
))
):
raise
TypeError
(
"Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`."
)
if
(
remote_descriptors
is
not
None
and
not
(
isinstance
(
remote_descriptors
,
Descriptor
)
or
(
isinstance
(
remote_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
remote_descriptors
))
)
):
raise
TypeError
(
"Argument `remote_descriptors` must be dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`, or `None`."
)
if
isinstance
(
local_descriptors
,
list
)
and
len
(
local_descriptors
)
==
0
:
raise
ValueError
(
"Argument `local_descriptors` must not be an empty list."
)
if
(
remote_descriptors
is
not
None
and
isinstance
(
remote_descriptors
,
list
)
and
len
(
remote_descriptors
)
==
0
):
raise
ValueError
(
"Argument `remote_descriptors` must not be an empty list."
)
notification_key
=
str
(
uuid
.
uuid4
())
if
notification_key
is
None
else
notification_key
if
not
isinstance
(
notification_key
,
str
):
raise
TypeError
(
"Argument `notification_key` must be `str` or `None`."
)
if
len
(
notification_key
)
==
0
:
raise
ValueError
(
"Argument `notification_key` must not be an empty string."
)
self
.
_notification_key
:
str
=
""
if
notification_key
is
None
else
notification_key
self
.
_connector
:
Connector
=
connector
self
.
_operation_kind
:
OperationKind
=
operation_kind
self
.
_local_descriptors
:
Descriptor
|
list
[
Descriptor
]
=
local_descriptors
self
.
_local_dlist
:
Optional
[
list
[
tuple
[
int
,
int
,
int
]]]
=
None
self
.
_local_memtype
:
DeviceKind
=
DeviceKind
.
UNSPECIFIED
self
.
_remote_descriptors
:
Optional
[
Descriptor
|
list
[
Descriptor
]]
=
None
if
remote_descriptors
is
None
else
remote_descriptors
self
.
_remote_dlist
:
Optional
[
list
[
tuple
[
int
,
int
,
int
]]]
=
None
self
.
_remote_memtype
:
DeviceKind
=
DeviceKind
.
UNSPECIFIED
# Register local descriptors with NIXL.
# Note: Only local descriptors should be registered with NIXL,
if
isinstance
(
local_descriptors
,
list
):
for
d
in
local_descriptors
:
d
.
register_memory
(
self
.
_connector
)
else
:
local_descriptors
.
register_memory
(
self
.
_connector
)
# Record local descriptors.
memtype
,
dtlist
=
self
.
_create_dlist
(
local_descriptors
)
self
.
_local_dlist
=
dtlist
self
.
_local_memtype
=
memtype
# Record remote descriptors when provided.
if
remote_descriptors
is
not
None
:
memtype
,
dtlist
=
self
.
_create_dlist
(
remote_descriptors
)
self
.
_remote_dlist
=
dtlist
self
.
_remote_memtype
=
memtype
def
__del__
(
self
)
->
None
:
self
.
_release
()
def
__enter__
(
self
)
->
AbstractOperation
:
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
self
.
_release
()
def
_release
(
self
)
->
None
:
"""
Private method to release resources. Only to be called by `self`.
"""
pass
@
property
def
connector
(
self
)
->
Connector
:
"""
Gets the local associated with this operation.
"""
return
self
.
_connector
@
property
def
operation_kind
(
self
)
->
OperationKind
:
"""
Gets the kind of operation.
"""
return
self
.
_operation_kind
@
abstractmethod
async
def
wait_for_completion
(
self
)
->
None
:
"""
Blocks the caller asynchronously until the operation has completed.
"""
raise
NotImplementedError
(
"Abstract method not implemented by derived class."
)
# Private Methods
def
_create_dlist
(
self
,
descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
tuple
[
DeviceKind
,
list
[
tuple
[
int
,
int
,
int
]]]:
"""
Helper function to create a list of tuples (ptr, size, device) from descriptors.
"""
dlist
:
list
[
tuple
[
int
,
int
,
int
]]
=
[]
memtype
:
DeviceKind
=
DeviceKind
.
UNSPECIFIED
if
isinstance
(
descriptors
,
list
):
memtype
=
descriptors
[
0
].
device
.
kind
for
desc
in
descriptors
:
if
memtype
!=
desc
.
device
.
kind
:
raise
ValueError
(
"All local descriptors must have the same memory type."
)
dlist
.
append
((
desc
.
ptr
,
desc
.
size
,
desc
.
device
.
id
))
else
:
memtype
=
descriptors
.
device
.
kind
dlist
.
append
((
descriptors
.
ptr
,
descriptors
.
size
,
descriptors
.
device
.
id
))
return
(
memtype
,
dlist
)
class
ActiveOperation
(
AbstractOperation
):
"""
Abstract class for active operations that initiates a NIXL based RDMA transfer based `SerializedRequest`
provided by the remote worker's corresponding `PassiveOperation`.
"""
def
__init__
(
self
,
remote
:
Remote
,
operation_kind
:
OperationKind
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
remote_descriptors
:
Descriptor
|
list
[
Descriptor
],
notification_key
:
str
,
)
->
None
:
if
not
isinstance
(
remote
,
Remote
)
or
remote
.
_connector
is
None
:
raise
TypeError
(
"Argument `remote` must be valid `dynamo.connect.RemoteAgent`."
)
if
not
isinstance
(
operation_kind
,
OperationKind
):
raise
TypeError
(
"Argument `operation_kind` must `dynamo.connect.OperationKind`."
)
if
operation_kind
is
not
OperationKind
.
READ
and
operation_kind
is
not
OperationKind
.
WRITE
:
raise
ValueError
(
"Argument `operation_kind` must be either `READ` or `WRITE`."
)
if
not
(
isinstance
(
local_descriptors
,
Descriptor
)
or
(
isinstance
(
local_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
local_descriptors
))
):
raise
TypeError
(
"Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`."
)
if
not
(
isinstance
(
remote_descriptors
,
Descriptor
)
or
(
isinstance
(
remote_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
remote_descriptors
))
):
raise
TypeError
(
"Argument `remote_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`."
)
# Unpack single descriptors from lists if they are provided as single descriptors.
if
isinstance
(
local_descriptors
,
list
)
and
len
(
local_descriptors
)
==
1
:
local_descriptors
=
local_descriptors
[
0
]
if
isinstance
(
remote_descriptors
,
list
)
and
len
(
remote_descriptors
)
==
1
:
remote_descriptors
=
remote_descriptors
[
0
]
if
(
isinstance
(
local_descriptors
,
list
)
and
isinstance
(
remote_descriptors
,
list
)
and
len
(
local_descriptors
)
!=
len
(
remote_descriptors
)):
raise
ValueError
(
"When `local_descriptors` and `remote_descriptors` are lists, they must have the same length."
)
elif
isinstance
(
local_descriptors
,
list
)
!=
isinstance
(
remote_descriptors
,
list
):
raise
ValueError
(
"Both `local_descriptors` and `remote_descriptors` must be either lists or single descriptors."
)
if
not
isinstance
(
notification_key
,
str
):
raise
TypeError
(
"Argument `notification_key` must be `str`."
)
if
len
(
notification_key
)
==
0
:
raise
ValueError
(
"Argument `notification_key` must not be an empty string."
)
self
.
_remote
=
remote
self
.
_status
=
OperationStatus
.
UNINTIALIZED
super
().
__init__
(
remote
.
connector
,
operation_kind
,
local_descriptors
,
remote_descriptors
,
notification_key
)
# Quick check to ensure remote descriptors are not None to make static analysis happy.
if
self
.
_local_dlist
is
None
or
self
.
_remote_dlist
is
None
:
raise
RuntimeError
(
"NIXL descriptor list(s) not bound to operation."
)
self
.
_local_xfer_descs
:
Optional
[
nixl_bindings
.
nixlXferDList
]
=
None
self
.
_remote_xfer_descs
:
Optional
[
nixl_bindings
.
nixlXferDList
]
=
None
self
.
_xfer_hndl
:
Optional
[
nixl_api
.
nixl_xfer_handle
]
=
None
self
.
_local_xfer_descs
=
self
.
_connector
.
_nixl
.
get_xfer_descs
(
descs
=
self
.
_local_dlist
,
mem_type
=
str
(
self
.
_local_memtype
),
)
logger
.
debug
(
f
"Created local NIXL xfer descs:
{
self
.
_local_xfer_descs
}
"
)
self
.
_remote_xfer_descs
=
self
.
_connector
.
_nixl
.
get_xfer_descs
(
descs
=
self
.
_remote_dlist
,
mem_type
=
str
(
self
.
_remote_memtype
),
)
logger
.
debug
(
f
"Created remote NIXL xfer descs:
{
self
.
_remote_xfer_descs
}
"
)
self
.
_xfer_hndl
=
self
.
_connector
.
_nixl
.
initialize_xfer
(
operation
=
str
(
operation_kind
),
local_descs
=
self
.
_local_xfer_descs
,
remote_descs
=
self
.
_remote_xfer_descs
,
remote_agent
=
self
.
_remote
.
name
,
notif_msg
=
self
.
_notification_key
.
encode
(
"utf-8"
),
)
logger
.
debug
(
f
"Created NIXL transfer handle:
{
self
.
_xfer_hndl
}
"
)
def
__del__
(
self
)
->
None
:
super
().
__del__
()
self
.
_release
()
def
__enter__
(
self
)
->
ActiveOperation
:
super
().
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
match
self
.
status
:
case
OperationStatus
.
IN_PROGRESS
|
OperationStatus
.
INITIALIZED
:
self
.
_status
=
OperationStatus
.
CANCELLED
self
.
_release
()
def
__repr__
(
self
)
->
str
:
return
str
(
f
"
{
self
.
__class__
.
__name__
}
("
f
"operation_kind=
{
self
.
_operation_kind
}
, "
f
"local_descriptors=
{
self
.
_local_descriptors
}
, "
f
"remote_descriptors=
{
self
.
_remote_descriptors
}
, "
f
"notification_key='
{
self
.
_notification_key
}
', "
f
"remote='
{
self
.
_remote
.
name
}
', "
f
"status='
{
self
.
_status
}
'"
f
")"
)
def
_release
(
self
)
->
None
:
"""
Private method to release resources.
"""
error
:
Optional
[
Exception
]
=
None
if
self
.
_xfer_hndl
is
not
None
:
try
:
logger
.
debug
(
f
"NIXL transfer handle
{
self
.
_xfer_hndl
}
released."
)
self
.
_connector
.
_nixl
.
release_xfer_handle
(
self
.
_xfer_hndl
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to release resources:
{
e
}
"
)
error
=
e
finally
:
self
.
_xfer_hndl
=
None
try
:
super
().
_release
()
except
Exception
as
e
:
logger
.
error
(
f
"Failed to release WaitableOperation resources:
{
e
}
"
)
if
error
is
not
None
:
e
.
__cause__
=
error
error
=
e
if
error
is
not
None
:
raise
error
def
_cancel_
(
self
)
->
None
:
if
self
.
_xfer_hndl
is
None
:
return
if
self
.
status
==
OperationStatus
.
ERRORED
:
raise
RuntimeError
(
"Operation is errored, unable to cancel the operation."
)
logger
.
info
(
f
"Cancellation requested for operation {{ kind=
{
self
.
_operation_kind
}
, remote='
{
self
.
_remote
.
name
}
', status=
{
self
.
_status
}
}}."
)
# NIXL will cancel the transfer if it is in progress when the handle is released.
self
.
_connector
.
_nixl
.
release_xfer_handle
(
self
.
_xfer_hndl
)
self
.
_status
=
OperationStatus
.
CANCELLED
self
.
_xfer_hndl
=
None
async
def
_wait_for_completion_
(
self
)
->
None
:
# Loop until the operation is no longer in progress (or "initalized"),
# yielding control to the event loop to allow other operations to run.
iteration_count
=
0
while
True
:
if
iteration_count
&
10
==
0
:
logger
.
debug
(
f
"Waiting for operation {{ kind=
{
self
.
_operation_kind
}
, remote='
{
self
.
_remote
.
name
}
', duration=
{
iteration_count
/
10
}
s }}."
)
match
self
.
status
:
# "in progress" or "initialized" means the operation is ongoing.
case
OperationStatus
.
INITIALIZED
:
await
asyncio
.
sleep
(
0.1
)
case
OperationStatus
.
IN_PROGRESS
:
await
asyncio
.
sleep
(
0.1
)
# Any other state indicates completion or error.
case
_
:
return
@
abstractmethod
def
cancel
(
self
)
->
None
:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or has been cancelled.
"""
raise
NotImplementedError
(
"Abstract method not implemented by derived class."
)
@
property
def
remote
(
self
)
->
Remote
:
"""
Gets the remote agent associated with this operation.
"""
return
self
.
_remote
@
property
def
status
(
self
)
->
OperationStatus
:
"""
Gets the status of the operation.
"""
# Early return if the operation is already complete, errored, or cancelled.
match
self
.
_status
:
case
OperationStatus
.
COMPLETE
|
OperationStatus
.
ERRORED
|
OperationStatus
.
CANCELLED
:
return
self
.
_status
if
self
.
_xfer_hndl
is
None
:
raise
RuntimeError
(
"NIXL transfer handle is invalid."
)
old_status
=
self
.
_status
if
self
.
_status
==
OperationStatus
.
UNINTIALIZED
:
state
=
self
.
_connector
.
_nixl
.
transfer
(
self
.
_xfer_hndl
,
self
.
_notification_key
.
encode
(
"utf-8"
))
logger
.
debug
(
f
"NIXL reported transfer state:
{
state
}
"
)
if
state
==
"ERR"
:
self
.
_status
=
OperationStatus
.
ERRORED
elif
state
==
"DONE"
:
self
.
_status
=
OperationStatus
.
COMPLETE
else
:
self
.
_status
=
OperationStatus
.
INITIALIZED
else
:
state
=
self
.
_connector
.
_nixl
.
check_xfer_state
(
self
.
_xfer_hndl
)
logger
.
debug
(
f
"NIXL reported transfer state:
{
state
}
"
)
if
state
==
"ERR"
:
self
.
_status
=
OperationStatus
.
ERRORED
elif
state
==
"DONE"
:
self
.
_status
=
OperationStatus
.
COMPLETE
else
:
self
.
_status
=
OperationStatus
.
IN_PROGRESS
if
self
.
_status
!=
old_status
:
logger
.
debug
(
f
"
{
self
.
__class__
.
__name__
}
{{ remote: '
{
self
.
_remote
.
name
}
' status: '
{
old_status
}
' => '
{
self
.
_status
}
' }}."
)
return
self
.
_status
class
Connector
:
"""
Core class for managing the connection between agents in a distributed environment.
Use this class to create readable and writable operations, or read and write data to remote agents.
"""
def
__init__
(
self
,
namespace
:
Optional
[
str
]
=
None
,
runtime
:
Optional
[
DistributedRuntime
]
=
None
,
worker_id
:
Optional
[
str
]
=
None
,
)
->
None
:
"""
Creates a new Connector instance.
Parameters
----------
namespace : Optional[str], optional
Dynamo namespace of the component, defaults to "dynamo" when `None`.
runtime : Optional[DistributedRuntime], optional
Reference the dynamo runtime used by the compenent, attempts to use the current runtime when `None`.
worker_id : Optional[str], optional
Unique identifier of the worker, defaults to a new UUID when `None`.
Raises
------
TypeError
When `namespace` is provied and not of type 'str'.
TypeError
When `runtime` iis provied and not of type `dynamo.runtime.DistributedRuntime`.
TypeError
When `worker_id` is provied and not of type `uuid.UUID`.
"""
namespace
=
"dynamo"
if
namespace
is
None
else
namespace
if
not
isinstance
(
namespace
,
str
):
raise
TypeError
(
"Argument `namespace` must be `str` or `None`."
)
if
dynamo_context
is
not
None
and
"runtime"
in
dynamo_context
:
runtime
=
dynamo_context
[
"runtime"
]
if
runtime
is
None
else
runtime
if
not
isinstance
(
runtime
,
DistributedRuntime
)
or
runtime
is
None
:
raise
TypeError
(
"Argument `runtime` must be `dynamo.runtime.DistributedRuntime` or `None`."
)
worker_id
=
worker_id
if
worker_id
is
not
None
else
str
(
uuid
.
uuid4
())
if
not
isinstance
(
worker_id
,
str
)
or
len
(
worker_id
)
==
0
:
raise
TypeError
(
"Argument `worker_id` must be a non-empty `str` or `None`."
)
self
.
_worker_id
=
worker_id
self
.
_is_initialized
=
False
self
.
_runtime
=
runtime
self
.
_namespace
=
namespace
self
.
_nixl
=
nixl_api
.
nixl_agent
(
self
.
_worker_id
)
self
.
_hostname
=
socket
.
gethostname
()
self
.
_agent_metadata
:
Optional
[
bytes
]
=
None
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
."
)
def
__repr__
(
self
)
->
str
:
return
str
(
f
"
{
self
.
__class__
.
__name__
}
("
f
"worker_id='
{
self
.
_worker_id
}
', "
f
"namespace=
{
self
.
_namespace
}
, "
f
"hostname=
{
self
.
_hostname
}
, "
f
"metadata=<
{
0
if
self
.
_agent_metadata
is
None
else
len
(
self
.
_agent_metadata
)
}
bytes>"
")"
)
def
__str__
(
self
)
->
str
:
return
self
.
_worker_id
@
cached_property
def
is_cuda_available
(
self
)
->
bool
:
# Note: cuda.is_avalailable initializes cuda
# and can't be called when forking subprocesses
# care should be taken to only call it within
# subprocesses or use 'spawn'
try
:
return
array_module
.
cuda
is
not
None
and
array_module
.
cuda
.
is_available
()
except
CUDARuntimeError
:
return
False
@
property
def
metadata
(
self
)
->
bytes
:
"""
Get the metadata of the agent.
"""
return
self
.
_nixl
.
get_agent_metadata
()
@
property
def
name
(
self
)
->
str
|
None
:
"""
Get the name of the agent.
"""
return
self
.
_worker_id
@
property
def
namespace
(
self
)
->
str
:
"""
Get the namespace of the local.
"""
return
self
.
_namespace
@
property
def
runtime
(
self
)
->
DistributedRuntime
:
"""
Get the runtime of the local.
"""
if
self
.
_runtime
is
None
:
raise
RuntimeError
(
"Runtime is not set. This Connector was not initialized with a runtime."
)
return
self
.
_runtime
async
def
begin_read
(
self
,
remote_request
:
SerializedRequest
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
ReadOperation
:
"""
Creates a read operation for fulfilling a remote readable operation.
Parameters
----------
remote_request : SerializedRequest
Serialized request from a remote worker that has created a readable operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to receive data from the remote worker described by `remote_request`.
Returns
-------
ReadOperation
Awaitable read operation that can be used to transfer data from a remote agent.
Raises
------
TypeError
When `remote_request` is not of type `SerializedRequest`.
TypeError
When `local_descriptors` is not of type `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
if
remote_request
is
None
or
not
isinstance
(
remote_request
,
SerializedRequest
):
raise
TypeError
(
"Argument `remote_request` must be `SerializedRequest`."
)
if
not
(
isinstance
(
local_descriptors
,
Descriptor
)
or
(
isinstance
(
local_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
local_descriptors
))
):
raise
TypeError
(
"Argument `local_descriptors` must be `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`."
)
if
remote_request
.
operation_kind
!=
OperationKind
.
READ
.
value
:
raise
RuntimeError
(
"Cannot create a `dynamo.connect.ReadOperation` to read from a remote `dynamo.connect.WritableOperation`."
)
if
not
self
.
_is_initialized
:
raise
RuntimeError
(
"Connector not initialized. Call `initialize()` before calling this method."
)
op
=
ReadOperation
(
self
,
remote_request
,
local_descriptors
)
return
op
async
def
begin_write
(
self
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
remote_request
:
SerializedRequest
,
)
->
WriteOperation
:
"""
Creates a write operation for transferring data to a remote agent.
Parameters
----------
remote_request : SerializedRequest
Serialized request from a remote worker that has created a readable operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptors of one or more data objects to be transferred to the remote agent.
"""
if
remote_request
is
None
or
not
isinstance
(
remote_request
,
SerializedRequest
):
raise
TypeError
(
"Argument `remote_request` must be `SerializedRequest`."
)
if
not
(
isinstance
(
local_descriptors
,
Descriptor
)
or
(
isinstance
(
local_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
local_descriptors
))
):
raise
TypeError
(
"Argument `local_descriptors` must be `Descriptor` or `list[Descriptor]`."
)
if
remote_request
.
operation_kind
!=
OperationKind
.
WRITE
:
raise
RuntimeError
(
"Cannot create a `WriteOperation` to write to a remote `ReadableOperation`."
)
if
not
isinstance
(
remote_request
.
nixl_metadata
,
str
):
raise
TypeError
(
"Argument `remote_request.nixl_metadata` must be `str`."
)
if
not
self
.
_is_initialized
:
raise
RuntimeError
(
"Connector not initialized. Call `initialize()` before calling this method."
)
op
=
WriteOperation
(
self
,
local_descriptors
,
remote_request
)
return
op
def
create_readable
(
self
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
ReadableOperation
:
"""
Creates a readable operation for transferring data from a remote agent.
Returns
-------
ReadableOperation
A readable operation that can be used to transfer data from a remote agent.
"""
if
not
self
.
_is_initialized
:
raise
RuntimeError
(
"Connector not initialized. Call `initialize()` before calling this method."
)
op
=
ReadableOperation
(
self
,
local_descriptors
)
return
op
def
create_writable
(
self
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
WritableOperation
:
"""
Creates a writable operation for transferring data to a remote agent.
Returns
-------
WritableOperation
A writable operation that can be used to transfer data to a remote agent.
"""
if
not
self
.
_is_initialized
:
raise
RuntimeError
(
"Connector not initialized. Call `initialize()` before calling this method."
)
op
=
WritableOperation
(
self
,
local_descriptors
)
return
op
async
def
initialize
(
self
)
->
None
:
# Only initialize the connector once.
if
self
.
_is_initialized
:
return
self
.
_is_initialized
=
True
# This method is a no-op for now, in the future it may be used to initialize the connector.
logger
.
debug
(
f
"Initialized Connector {{ name: '
{
self
.
_worker_id
}
', namespace '
{
self
.
_namespace
}
' }} completed."
)
class
Descriptor
:
"""
Memory descriptor that ensures memory is registered w/ NIXL, used for transferring data between workers.
"""
def
__init__
(
self
,
data
:
torch
.
Tensor
|
tuple
[
array_module
.
ndarray
,
Device
|
str
]
|
bytes
|
tuple
[
int
,
int
,
Device
|
str
,
Any
],
)
->
None
:
"""
Memory descriptor for transferring data between agents.
Parameters
----------
data : torch.Tensor | tuple[ndarray, Device|str] | bytes | tuple[int, int, Device|str, Any]
The data to be transferred.
When `torch.Tensor` is provided, the attributes of the tensor will be used to create the descriptor.
When `tuple[ndarray, Device]` is provided, the tuple must contain:
- `ndarray`: The CuPy or NumPy array to be transferred.
- `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu").
When `bytes` is provided, the pointer and size derived from the bytes object and memory type will be assumed to be CPU.
When `tuple[int, int, Device|str, Any]` is provided, the tuple must contain the following elements:
- `int`: Pointer to the data in memory.
- `int`: Size of the data in bytes.
- `Device`: Either a `dynamo.connect.Device` or a string representing the device type (e.g., "cuda" or "cpu").
- `Any`: Optional reference to the data (e.g., the original tensor or bytes object).
This is useful for keeping a reference to the data in memory, but it is not required.
Raises
------
ValueError
When `data` is `None`.
TypeError
When `data` is not a valid type (i.e., not `torch.Tensor`, `bytes`, or a valid tuple).
TypeError
When `data` is a tuple but the elements are not of the expected types (i.e., [`ndarray`, `Device|str`] OR [`int`, `int`, `Device|str`, `Any`]).
"""
TYPE_ERROR_MESSAGE
=
"Argument `data` must be `torch.Tensor`, `tuple[ndarray, Device|str]`, `bytes`, or `tuple[int, int, Device|str, Any]`."
if
data
is
None
:
raise
ValueError
(
"Argument `data` cannot be `None`."
)
if
not
(
isinstance
(
data
,
torch
.
Tensor
)
or
isinstance
(
data
,
bytes
)
or
isinstance
(
data
,
tuple
)):
raise
TypeError
(
TYPE_ERROR_MESSAGE
)
self
.
_data_device
:
Device
=
Device
(
"cpu"
)
self
.
_data_ptr
:
int
=
0
self
.
_data_ref
:
Optional
[
Any
]
=
None
self
.
_data_size
:
int
=
0
# Member fields for managing NIXL memory registration.
# Note: ONLY local descriptors should be registered with NIXL,
# remote descriptors do not have a valid memory address and registration will fault.
self
.
_connector
:
Optional
[
Connector
]
=
None
self
.
_nixl_hndl
:
Optional
[
nixl_bindings
.
nixlRegDList
]
=
None
# Initially `None` cached serialized descriptor reference, populated when `to_serialized()` is called.
self
.
_serialized
:
Optional
[
SerializedDescriptor
]
=
None
# Data is `torch.Tensor`.
if
isinstance
(
data
,
torch
.
Tensor
):
self
.
_data_ptr
=
data
.
data_ptr
()
self
.
_data_size
=
data
.
numel
()
*
data
.
element_size
()
if
data
.
is_cuda
:
self
.
_data_device
=
Device
((
DeviceKind
.
CUDA
,
data
.
get_device
()))
self
.
_data_ref
=
data
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
from `torch.Tensor`."
)
# Data is `tuple[ndarray, Device]`.
elif
(
isinstance
(
data
,
tuple
)
and
len
(
data
)
==
2
and
isinstance
(
data
[
0
],
array_module
.
ndarray
)
and
(
isinstance
(
data
[
1
],
Device
)
or
isinstance
(
data
[
1
],
str
))
):
if
hasattr
(
data
[
0
],
"__array_interface__"
):
self
.
_data_ptr
=
data
[
0
].
__array_interface__
[
"data"
][
0
]
elif
hasattr
(
data
[
0
],
"__cuda_array_interface__"
):
self
.
_data_ptr
=
data
[
0
].
__cuda_array_interface__
[
"data"
][
0
]
else
:
raise
TypeError
(
"Argument `data[0]` must be a `ndarray` with a valid array interface."
)
self
.
_data_size
=
data
[
0
].
nbytes
self
.
_data_device
=
data
[
1
]
if
isinstance
(
data
[
1
],
Device
)
else
Device
(
data
[
1
])
self
.
_data_ref
=
data
[
0
]
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
from `tuple[ndarray, Device|str]`."
)
# Data is `bytes`.
elif
isinstance
(
data
,
bytes
):
self
.
_data_ptr
=
id
(
data
)
self
.
_data_size
=
len
(
data
)
self
.
_data_ref
=
data
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
from `bytes`."
)
# Data is `tuple[int, int, Device, dtype, tuple, Any]`.
elif
isinstance
(
data
,
tuple
)
and
len
(
data
)
>=
2
and
isinstance
(
data
[
0
],
int
)
and
isinstance
(
data
[
1
],
int
):
if
len
(
data
)
>=
3
and
not
(
isinstance
(
data
[
2
],
Device
)
or
isinstance
(
data
[
2
],
str
)):
raise
TypeError
(
"Argument `data` must be a `tuple[int, int, Device|str, Any]`."
)
self
.
_data_ptr
=
data
[
0
]
self
.
_data_size
=
data
[
1
]
if
len
(
data
)
>=
3
:
self
.
_data_device
=
data
[
2
]
if
isinstance
(
data
[
2
],
Device
)
else
Device
(
data
[
2
])
self
.
_data_ref
=
data
[
3
]
if
len
(
data
)
>=
4
else
None
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
from `tuple[int, int, Device|str, Any]`."
)
else
:
raise
TypeError
(
TYPE_ERROR_MESSAGE
)
def
__del__
(
self
)
->
None
:
if
self
.
_nixl_hndl
is
not
None
and
self
.
_connector
is
not
None
:
# Unregister the memory with NIXL.
self
.
_connector
.
_nixl
.
deregister_memory
(
self
.
_nixl_hndl
)
self
.
_nixl_hndl
=
None
if
self
.
_data_ref
is
not
None
:
# Release the reference to the data.
del
self
.
_data_ref
logger
.
debug
(
f
"Deleted
{
self
.
__repr__
()
}
."
)
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(
{
self
}
)"
def
__str__
(
self
)
->
str
:
return
f
"ptr=
{
hex
(
self
.
_data_ptr
)
}
, size=
{
self
.
_data_size
}
, device=
{
self
.
_data_device
}
"
@
property
def
device
(
self
)
->
Device
:
"""
Gets the device the of the descriptor.
"""
return
self
.
_data_device
@
property
def
ptr
(
self
)
->
int
:
"""
Gets the pointer of the descriptor.
"""
return
self
.
_data_ptr
@
property
def
size
(
self
)
->
int
:
"""
Gets the size of the descriptor.
"""
return
self
.
_data_size
@
staticmethod
def
from_serialized
(
serialized
:
SerializedDescriptor
,
)
->
Descriptor
:
"""
Deserializes a `SerializedDescriptor` into a `Descriptor` object.
Parameters
----------
serialized : SerializedDescriptor
The serialized descriptor to deserialize.
Returns
-------
Descriptor
The deserialized descriptor.
"""
if
not
isinstance
(
serialized
,
SerializedDescriptor
):
raise
TypeError
(
"Argument `serialized` must be `SerializedDescriptor`."
)
return
serialized
.
to_descriptor
()
def
register_memory
(
self
,
connector
:
Connector
,
)
->
None
:
"""
Registers the memory of the descriptor with NIXL.
"""
if
not
isinstance
(
connector
,
Connector
):
raise
TypeError
(
"Argument `connector` must be `dynamo.connect.Connector`."
)
if
self
.
_data_ptr
==
0
:
raise
ValueError
(
"Cannot register memory with a null pointer."
)
if
not
(
self
.
_nixl_hndl
is
None
and
self
.
_connector
is
None
):
return
# Register the memory with NIXL.
self
.
_connector
=
connector
if
isinstance
(
self
.
_data_ref
,
torch
.
Tensor
):
self
.
_nixl_hndl
=
connector
.
_nixl
.
register_memory
(
self
.
_data_ref
)
else
:
mem_type
=
str
(
self
.
_data_device
.
kind
)
reg_list
=
[(
self
.
_data_ptr
,
self
.
_data_size
,
self
.
_data_device
.
id
,
mem_type
)]
self
.
_nixl_hndl
=
connector
.
_nixl
.
register_memory
(
reg_list
,
mem_type
)
logger
.
debug
(
f
"Registered
{
self
.
__repr__
()
}
with NIXL."
)
def
to_serialized
(
self
)
->
SerializedDescriptor
:
"""
Serializes the descriptor into a `SerializedDescriptor` object.
"""
if
self
.
_serialized
is
None
:
self
.
_serialized
=
SerializedDescriptor
(
device
=
f
"
{
self
.
_data_device
}
"
,
ptr
=
self
.
_data_ptr
,
size
=
self
.
_data_size
,
)
return
self
.
_serialized
class
Device
:
"""
Represents a device in the system.
"""
def
__init__
(
self
,
metadata
:
str
|
tuple
[
DeviceKind
,
int
],
)
->
None
:
if
metadata
is
None
:
raise
ValueError
(
"Argument `metadata` cannot be `None`."
)
if
isinstance
(
metadata
,
tuple
)
and
len
(
metadata
)
==
2
and
isinstance
(
metadata
[
0
],
DeviceKind
)
and
isinstance
(
metadata
[
1
],
int
):
kind
,
device_id
=
metadata
elif
isinstance
(
metadata
,
str
):
metadata
=
metadata
.
strip
().
lower
()
if
metadata
.
startswith
(
"cuda"
)
or
metadata
.
startswith
(
"gpu"
):
kind
=
DeviceKind
.
CUDA
device_id
=
0
if
metadata
.
find
(
":"
)
==
-
1
else
int
(
metadata
.
split
(
":"
)[
1
])
elif
metadata
.
startswith
(
"cpu"
)
or
metadata
.
startswith
(
"host"
):
kind
=
DeviceKind
.
HOST
device_id
=
0
else
:
raise
ValueError
(
"Argument `metadata` must be in the format 'cuda:<device_id>' or 'cpu'."
)
else
:
raise
TypeError
(
"Argument `metadata` must be a `tuple[MemoryKind, int]` or a `str`."
)
self
.
_device_id
=
device_id
self
.
_kind
=
kind
def
__repr__
(
self
)
->
str
:
return
f
"
{
self
.
__class__
.
__name__
}
(kind=
{
self
.
_kind
}
, id=
{
self
.
_device_id
}
)"
def
__str__
(
self
)
->
str
:
return
f
"
{
self
.
_kind
}
:
{
self
.
_device_id
}
"
if
self
.
_kind
is
DeviceKind
.
CUDA
else
f
"
{
self
.
_kind
}
"
@
property
def
id
(
self
)
->
int
:
"""
Gets the device ID of the device.
"""
return
self
.
_device_id
@
property
def
kind
(
self
)
->
DeviceKind
:
"""
Gets the memory kind of the device.
"""
return
self
.
_kind
class
DeviceKind
(
IntEnum
):
"""
Type of memory a descriptor has been allocated to.
"""
UNSPECIFIED
=
0
HOST
=
1
CUDA
=
2
def
__str__
(
self
)
->
str
:
if
self
==
DeviceKind
.
HOST
:
return
"cpu"
elif
self
==
DeviceKind
.
CUDA
:
return
"cuda"
else
:
return
"<invalid>"
class
OperationKind
(
IntEnum
):
"""
Kind of an operation.
"""
UNSPECIFIED
=
0
READ
=
1
WRITE
=
2
def
__str__
(
self
)
->
str
:
if
self
==
OperationKind
.
READ
:
return
"READ"
elif
self
==
OperationKind
.
WRITE
:
return
"WRITE"
else
:
return
"<invalid>"
class
OperationStatus
(
IntEnum
):
"""
Status of an operation.
"""
UNINTIALIZED
=
0
INITIALIZED
=
1
IN_PROGRESS
=
2
COMPLETE
=
3
CANCELLED
=
4
ERRORED
=
5
def
__str__
(
self
)
->
str
:
if
self
==
OperationStatus
.
INITIALIZED
:
return
"INIT"
elif
self
==
OperationStatus
.
IN_PROGRESS
:
return
"PROC"
elif
self
==
OperationStatus
.
COMPLETE
:
return
"DONE"
elif
self
==
OperationStatus
.
ERRORED
:
return
"ERR"
elif
self
==
OperationStatus
.
CANCELLED
:
return
"STOP"
else
:
return
"<invalid>"
class
PassiveOperation
(
AbstractOperation
):
"""
Abstract class for common functionality of passive operations.
"""
def
__init__
(
self
,
connector
:
Connector
,
operation_kind
:
OperationKind
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
None
:
if
operation_kind
is
not
OperationKind
.
READ
and
operation_kind
is
not
OperationKind
.
WRITE
:
raise
ValueError
(
"Argument `operation_kind` must be either `READ` or `WRITE`."
)
self
.
_status
=
OperationStatus
.
UNINTIALIZED
super
().
__init__
(
connector
,
operation_kind
,
local_descriptors
,
None
,
None
)
self
.
_serialized_request
:
Optional
[
SerializedRequest
]
=
None
self
.
_status
=
OperationStatus
.
INITIALIZED
def
__del__
(
self
)
->
None
:
super
().
__del__
()
def
__enter__
(
self
)
->
AbstractOperation
:
super
().
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
super
().
__exit__
(
exc_type
,
exc_value
,
traceback
)
def
__repr__
(
self
)
->
str
:
return
str
(
f
"
{
self
.
__class__
.
__name__
}
("
f
"operation_kind=
{
self
.
_operation_kind
}
, "
f
"local_descriptors=
{
self
.
_local_descriptors
}
, "
f
"notification_key='
{
self
.
_notification_key
}
', "
f
"status='
{
self
.
_status
}
'"
f
")"
)
async
def
_wait_for_completion_
(
self
)
->
None
:
# Loop until the operation is no longer in progress (or "initalized"),
# yielding control to the event loop to allow other operations to run.
while
True
:
match
self
.
status
:
# "in progress" or "initialized" means the operation is ongoing.
case
OperationStatus
.
INITIALIZED
:
await
asyncio
.
sleep
(
0.1
)
case
OperationStatus
.
IN_PROGRESS
:
await
asyncio
.
sleep
(
0.1
)
# Any other state indicates completion or error.
case
_
:
return
@
property
def
status
(
self
)
->
OperationStatus
:
"""
Gets the status of the operation.
"""
# Early return if the operation is already complete, errored, or cancelled.
match
self
.
_status
:
case
OperationStatus
.
COMPLETE
|
OperationStatus
.
ERRORED
|
OperationStatus
.
CANCELLED
:
return
self
.
_status
old_status
=
self
.
_status
# Query NIXL for any notifications.
notifications
=
self
.
_connector
.
_nixl
.
update_notifs
()
if
isinstance
(
notifications
,
dict
):
remote_state
=
OperationStatus
.
IN_PROGRESS
logger
.
debug
(
f
"NIXL reported notifications:
{
len
(
notifications
)
}
."
)
for
key
,
values
in
notifications
.
items
():
if
not
isinstance
(
values
,
list
):
raise
TypeError
(
f
"Expected `dict[str, list[bytes]]` from NIXL notification query; got
{
type
(
notifications
)
}
."
)
for
value
in
values
:
if
not
isinstance
(
value
,
bytes
):
continue
notification_key
=
value
.
decode
(
"utf-8"
)
# Once we've found the notification key, we know the operation is complete.
if
notification_key
==
self
.
_notification_key
:
remote_state
=
OperationStatus
.
COMPLETE
break
if
remote_state
==
OperationStatus
.
COMPLETE
:
self
.
_status
=
remote_state
logger
.
debug
(
f
"
{
self
.
__class__
.
__name__
}
{{ remote: '
{
self
.
_connector
.
name
}
' status: '
{
old_status
}
' => '
{
self
.
_status
}
' }}."
)
return
self
.
_status
def
to_serialized
(
self
)
->
SerializedRequest
:
"""
Gets the request descriptor for the operation.
"""
if
self
.
_serialized_request
is
None
:
# When we've not yet cached the serialized request, we need to generate one before returning it.
# Handle both cases: multiple and single descriptors.
if
isinstance
(
self
.
_local_descriptors
,
list
):
descriptors
=
[
desc
.
to_serialized
()
for
desc
in
self
.
_local_descriptors
]
else
:
descriptors
=
[
self
.
_local_descriptors
.
to_serialized
()]
original_len
=
len
(
self
.
_connector
.
metadata
)
nixl_metadata
=
self
.
_connector
.
metadata
nixl_metadata
=
zlib
.
compress
(
nixl_metadata
,
level
=
6
)
compressed_len
=
len
(
nixl_metadata
)
logger
.
debug
(
f
"Compressed NIXL metadata from
{
original_len
}
bytes to
{
compressed_len
}
bytes."
)
if
compressed_len
>
original_len
:
logger
.
warning
(
f
"Compressed NIXL metadata is larger than original (
{
compressed_len
}
>
{
original_len
}
)."
)
self
.
_serialized_request
=
SerializedRequest
(
descriptors
=
descriptors
,
nixl_metadata
=
nixl_metadata
.
hex
(),
notification_key
=
self
.
_notification_key
,
operation_kind
=
int
(
self
.
_operation_kind
),
)
return
self
.
_serialized_request
@
abstractmethod
async
def
wait_for_completion
(
self
)
->
None
:
"""
Blocks the caller asynchronously until the operation has completed.
"""
raise
NotImplementedError
(
"Abstract method not implemented by derived class."
)
class
ReadOperation
(
ActiveOperation
):
"""
Operation that initiates an RDMA read operation to transfer data from a remote worker's `ReadableOperation`,
as described by `remote_request`, to local buffers.
"""
def
__init__
(
self
,
connector
:
Connector
,
remote_request
:
SerializedRequest
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
None
:
"""
Creates a new instance of `ReadOperation`, registers `local_descriptors` with NIXL,
and begins an RDMA read operation which will transfer data described by `remote_request`
to `local_descriptors`.
Parameters
----------
connector : Connector
Connector instance to use for the operation.
remote_request : SerializedRequest
Serialized request from the remote worker.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to to receive the data from the remote agent.
"""
if
not
isinstance
(
connector
,
Connector
):
raise
TypeError
(
"Argument `connector` must be `dynamo.connect.Connector`."
)
if
not
isinstance
(
remote_request
,
SerializedRequest
):
raise
TypeError
(
"Argument `remote_request` must be `dynamo.connect.RequestDescriptor`."
)
if
remote_request
.
operation_kind
!=
OperationKind
.
READ
.
value
:
raise
ValueError
(
"Argument `remote_request` must be of kind `READ`."
)
remote
=
Remote
(
connector
,
remote_request
.
nixl_metadata
)
remote_descriptors
=
remote_request
.
to_descriptors
()
if
not
(
isinstance
(
local_descriptors
,
Descriptor
)
or
(
isinstance
(
local_descriptors
,
list
)
and
all
(
isinstance
(
d
,
Descriptor
)
for
d
in
local_descriptors
))
):
raise
TypeError
(
"Argument `local_descriptors` must be `dynamo.connect.Descriptor`, `list[dynamo.connect.Descriptor]`."
)
super
().
__init__
(
remote
,
OperationKind
.
READ
,
local_descriptors
,
remote_descriptors
,
remote_request
.
notification_key
)
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
"
)
def
__del__
(
self
)
->
None
:
super
().
__del__
()
logger
.
debug
(
f
"Deleted
{
self
.
__repr__
()
}
"
)
def
__enter__
(
self
)
->
ReadOperation
:
super
().
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
super
().
__exit__
(
exc_type
,
exc_value
,
traceback
)
def
__repr__
(
self
)
->
str
:
return
super
().
__repr__
()
def
cancel
(
self
)
->
None
:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or been cancelled.
"""
super
().
_cancel_
()
def
results
(
self
)
->
list
[
Descriptor
]:
"""
Gets the results of the operation.
Returns a single descriptor if only one was requested, or a list of descriptors if multiple were requested.
"""
if
self
.
_status
!=
OperationStatus
.
COMPLETE
:
raise
RuntimeError
(
"Operation has not completed yet, cannot get results."
)
return
self
.
_local_descriptors
if
isinstance
(
self
.
_local_descriptors
,
list
)
else
[
self
.
_local_descriptors
]
async
def
wait_for_completion
(
self
)
->
None
:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await
super
().
_wait_for_completion_
()
class
ReadableOperation
(
PassiveOperation
):
"""
Operation that can be awaited until a remote worker has completed a `ReadOperation`.
"""
def
__init__
(
self
,
connector
:
Connector
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
None
:
super
().
__init__
(
connector
,
OperationKind
.
READ
,
local_descriptors
)
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
"
)
def
__del__
(
self
)
->
None
:
super
().
__del__
()
logger
.
debug
(
f
"Deleted
{
self
.
__repr__
()
}
"
)
def
__enter__
(
self
)
->
ReadableOperation
:
super
().
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
super
().
__exit__
(
exc_type
,
exc_value
,
traceback
)
def
__repr__
(
self
)
->
str
:
return
super
().
__repr__
()
async
def
wait_for_completion
(
self
)
->
None
:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await
super
().
_wait_for_completion_
()
class
Remote
:
"""
Identifies a remote NIXL agent relative to a local NIXL agent.
"""
def
__init__
(
self
,
connector
:
Connector
,
nixl_metadata
:
bytes
|
str
,
)
->
None
:
if
not
isinstance
(
connector
,
Connector
):
raise
TypeError
(
"Argument `local` must be `dynamo.connect.Connector`."
)
if
not
(
isinstance
(
nixl_metadata
,
bytes
)
or
isinstance
(
nixl_metadata
,
str
)):
raise
TypeError
(
"Argument `nixl_metadata` must be `bytes` or `str`."
)
if
len
(
nixl_metadata
)
==
0
:
raise
ValueError
(
"Argument `nixl_metadata` cannot be empty."
)
self
.
_connector
=
connector
# When `nixl_metadata` is a string, it is assumed to have come from a remote worker
# via a `SerializedRequest` object and therefore can assumed be a hex-encoded, compressed
# representation of the NIXL metadata.
if
isinstance
(
nixl_metadata
,
str
):
# Decode the hex-encoded string into bytes.
nixl_metadata
=
bytes
.
fromhex
(
nixl_metadata
)
# Decompress the NIXL metadata.
nixl_metadata
=
zlib
.
decompress
(
nixl_metadata
)
self
.
_name
=
connector
.
_nixl
.
add_remote_agent
(
nixl_metadata
)
if
isinstance
(
self
.
_name
,
bytes
):
self
.
_name
=
self
.
_name
.
decode
(
"utf-8"
)
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
."
)
def
__del__
(
self
)
->
None
:
self
.
_release
()
def
__enter__
(
self
)
->
Remote
:
"""
Context manager entry method. Returns the current instance.
"""
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
"""
Context manager exit method. Cleans up the instance.
"""
self
.
_release
()
def
__repr__
(
self
)
->
str
:
return
f
"RemoteAgent(name=
{
self
.
_name
}
, connector=
{
self
.
_connector
.
name
}
)"
def
__str__
(
self
)
->
str
:
return
self
.
_name
def
_release
(
self
)
->
None
:
"""
Private method for releasing NIXL resources. Not intended for public use.
"""
pass
@
property
def
connector
(
self
)
->
Connector
:
"""
Gets the local connector associated with this remote agent.
"""
return
self
.
_connector
@
property
def
name
(
self
)
->
str
:
"""
Gets the name of the remote agent.
"""
return
self
.
_name
class
SerializedDescriptor
(
BaseModel
):
"""
Pydantic serialization type for memory descriptors.
"""
model_config
=
ConfigDict
(
extra
=
"forbid"
,
frozen
=
True
,
arbitrary_types_allowed
=
True
,
)
device
:
str
=
"cpu"
ptr
:
int
=
0
size
:
int
=
0
def
to_descriptor
(
self
)
->
Descriptor
:
"""
Deserialize the serialized descriptor into a `Descriptor` object.
"""
return
Descriptor
(
data
=
(
self
.
ptr
,
self
.
size
,
self
.
device
,
None
))
@
field_validator
(
"device"
)
def
validate_memtype
(
cls
,
v
:
str
)
->
str
:
if
not
isinstance
(
v
,
str
):
raise
TypeError
(
"Argument `device` must be `str`."
)
v
=
v
.
strip
().
lower
()
if
not
(
v
.
startswith
(
"cuda"
)
or
v
==
"cpu"
):
raise
ValueError
(
"Argument `device` must be one of 'cpu' or 'cuda:<device_id>'."
)
return
v
@
field_validator
(
"ptr"
)
def
validate_ptr
(
cls
,
v
:
int
)
->
int
:
if
v
==
0
:
raise
ValueError
(
"Argument `ptr` cannot be zero (aka `null` or `None`)."
)
return
v
@
field_validator
(
"size"
)
def
validate_size
(
cls
,
v
:
int
)
->
int
:
if
v
<
0
:
raise
ValueError
(
"Argument `size` must be an integer greater than or equal to zero."
)
return
v
class
SerializedRequest
(
BaseModel
):
"""
Pydantic serialization type for describing the passive side of a transfer.
"""
model_config
=
ConfigDict
(
extra
=
"forbid"
,
frozen
=
True
,
arbitrary_types_allowed
=
True
,
)
descriptors
:
List
[
SerializedDescriptor
]
=
[]
nixl_metadata
:
str
=
""
notification_key
:
str
=
""
operation_kind
:
int
=
0
def
to_descriptors
(
self
)
->
Descriptor
|
list
[
Descriptor
]:
"""
Deserializes the request descriptor into a `dynamo.connect.Descriptor` or list of `dynamo.connect.Descriptor` objects.
"""
if
len
(
self
.
descriptors
)
==
0
:
raise
ValueError
(
"Request descriptor must contain at least one serialized descriptor."
)
if
len
(
self
.
descriptors
)
==
1
:
return
self
.
descriptors
[
0
].
to_descriptor
()
return
[
item
.
to_descriptor
()
for
item
in
self
.
descriptors
]
@
field_validator
(
"operation_kind"
)
def
validate_operation_kind
(
cls
,
v
:
int
)
->
int
:
if
v
<
1
or
v
>
3
:
raise
TypeError
(
"Argument `operation_kind` must be an integer value of `dynamo.connect.OperationKind`."
)
return
v
class
WritableOperation
(
PassiveOperation
):
"""
Operation which can be awaited until written to by a `WriteOperation` from a remote worker.
"""
def
__init__
(
self
,
connector
:
Connector
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
)
->
None
:
"""
Creates a new instance of `WritableOperation`, registers the operation and descriptors w/ NIXL,
and enables an RDMA write operation to occur.
Parameters
----------
connector : Connector
Connector instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Descriptors to receive data from a remote worker.
Raises
TypeError
When `local` is not a `dynamo.connect.Connector`.
TypeError
When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
super
().
__init__
(
connector
,
OperationKind
.
WRITE
,
local_descriptors
)
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
"
)
def
__del__
(
self
)
->
None
:
super
().
__del__
()
logger
.
debug
(
f
"Deleted
{
self
.
__repr__
()
}
"
)
def
__enter__
(
self
)
->
WritableOperation
:
super
().
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
super
().
__exit__
(
exc_type
,
exc_value
,
traceback
)
def
__repr__
(
self
)
->
str
:
return
super
().
__repr__
()
async
def
wait_for_completion
(
self
)
->
None
:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await
super
().
_wait_for_completion_
()
class
WriteOperation
(
ActiveOperation
):
"""
Awaitable write operation which initiates an RDMA write operation to a remote worker
which provided a `SerializedRequest` object from a `WritableOperation`.
"""
def
__init__
(
self
,
connector
:
Connector
,
local_descriptors
:
Descriptor
|
list
[
Descriptor
],
remote_request
:
SerializedRequest
,
)
->
None
:
"""
Creates a new instance of `WriteOperation`, registers `local_descriptors` with NIXL,
and begins an RDMA write operation which will transfer from `local_descriptors` to
remote target(s) described by `remote_request`
Parameters
----------
connector : Connector
Connector instance to use for the operation.
local_descriptors : Descriptor | list[Descriptor]
Local descriptor(s) to send from, to the remote agent.
remote_request : SerializedRequest
Serialized request from the remote worker that describes the target(s) to send to.
Raises
TypeError
When `connector` is not a `dynamo.connect.Connector`.
TypeError
When `remote_request` is not a `dynamo.connect.RequestDescriptor`.
ValueError
When `remote_request` is not of kind `WRITE`.
ValueError
When `remote_request.nixl_metadata` is not a non-empty `str`.
TypeError
When `local_descriptors` is not a `dynamo.connect.Descriptor` or `list[dynamo.connect.Descriptor]`.
"""
if
not
isinstance
(
connector
,
Connector
):
raise
TypeError
(
"Argument `connector` must be `dynamo.connect.Connector`."
)
if
not
isinstance
(
remote_request
,
SerializedRequest
):
raise
TypeError
(
"Argument `remote_request` must be `dynamo.connect.RequestDescriptor`."
)
if
remote_request
.
operation_kind
!=
OperationKind
.
WRITE
.
value
:
raise
ValueError
(
"Argument `remote_request` must be of kind `WRITE`."
)
remote
=
Remote
(
connector
,
remote_request
.
nixl_metadata
)
remote_descriptors
=
remote_request
.
to_descriptors
()
super
().
__init__
(
remote
,
OperationKind
.
WRITE
,
local_descriptors
,
remote_descriptors
,
remote_request
.
notification_key
)
logger
.
debug
(
f
"Created
{
self
.
__repr__
()
}
"
)
def
__del__
(
self
)
->
None
:
super
().
__del__
()
logger
.
debug
(
f
"Deleted
{
self
.
__repr__
()
}
"
)
def
__enter__
(
self
)
->
WriteOperation
:
super
().
__enter__
()
return
self
def
__exit__
(
self
,
exc_type
:
Any
,
exc_value
:
Any
,
traceback
:
Any
)
->
None
:
super
().
__exit__
(
exc_type
,
exc_value
,
traceback
)
def
__repr__
(
self
)
->
str
:
return
super
().
__repr__
()
def
cancel
(
self
)
->
None
:
"""
Cancels the operation.
No affect if the operation has already completed or errored, or has been cancelled.
"""
super
().
_cancel_
()
async
def
wait_for_completion
(
self
)
->
None
:
"""
Blocks the caller asynchronously until the operation has completed.
"""
await
super
().
_wait_for_completion_
()
examples/multimodal/graphs/agg.py
View file @
75e774d4
...
...
@@ -6,6 +6,7 @@
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -13,9 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
components.encode_worker
import
EncodeWorker
from
components.decode_worker
import
VllmDecodeWorker
from
components.encode_worker
import
VllmEncodeWorker
from
components.frontend
import
Frontend
from
components.processor
import
Processor
from
components.worker
import
VllmWorker
Frontend
.
link
(
Processor
).
link
(
VllmWorker
).
link
(
EncodeWorker
)
Frontend
.
link
(
Processor
).
link
(
Vllm
Decode
Worker
).
link
(
Vllm
EncodeWorker
)
examples/multimodal/graphs/disagg.py
View file @
75e774d4
...
...
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
components.encode_worker
import
EncodeWorker
from
components.decode_worker
import
VllmDecodeWorker
from
components.encode_worker
import
VllmEncodeWorker
from
components.frontend
import
Frontend
from
components.prefill_worker
import
PrefillWorker
from
components.prefill_worker
import
Vllm
PrefillWorker
from
components.processor
import
Processor
from
components.worker
import
VllmWorker
Frontend
.
link
(
Processor
).
link
(
VllmWorker
).
link
(
PrefillWorker
).
link
(
EncodeWorker
)
Frontend
.
link
(
Processor
).
link
(
VllmDecodeWorker
).
link
(
VllmPrefillWorker
).
link
(
VllmEncodeWorker
)
examples/multimodal/utils/protocol.py
View file @
75e774d4
...
...
@@ -17,6 +17,7 @@
import
json
from
typing
import
Any
,
List
,
Optional
import
connect
import
msgspec
from
pydantic
import
BaseModel
,
ConfigDict
,
field_validator
from
pydantic_core
import
core_schema
...
...
@@ -111,12 +112,13 @@ class EncodeRequest(BaseModel):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
image_url
:
str
request_id
:
str
serialized_request
:
Optional
[
connect
.
SerializedRequest
]
=
None
class
EncodeResponse
(
BaseModel
):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
image_features
:
List
[
List
[
List
[
float
]]]
request_id
:
str
class
MyRequestOutput
(
BaseModel
):
...
...
@@ -129,7 +131,6 @@ class MyRequestOutput(BaseModel):
"""
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
request_id
:
str
prompt
:
Optional
[
str
]
=
None
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment