Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
75e774d4
Unverified
Commit
75e774d4
authored
May 27, 2025
by
J Wyman
Committed by
GitHub
May 27, 2025
Browse files
feat: NIXL Based RDMA Support w/ Multimodal Example (#1060)
parent
9acaa8d1
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1784 additions
and
112 deletions
+1784
-112
.pre-commit-config.yaml
.pre-commit-config.yaml
+2
-2
examples/multimodal/components/decode_worker.py
examples/multimodal/components/decode_worker.py
+88
-29
examples/multimodal/components/encode_worker.py
examples/multimodal/components/encode_worker.py
+99
-7
examples/multimodal/components/prefill_worker.py
examples/multimodal/components/prefill_worker.py
+102
-56
examples/multimodal/components/processor.py
examples/multimodal/components/processor.py
+3
-3
examples/multimodal/configs/agg.yaml
examples/multimodal/configs/agg.yaml
+2
-2
examples/multimodal/configs/disagg.yaml
examples/multimodal/configs/disagg.yaml
+3
-3
examples/multimodal/connect/__init__.py
examples/multimodal/connect/__init__.py
+1471
-0
examples/multimodal/graphs/agg.py
examples/multimodal/graphs/agg.py
+4
-3
examples/multimodal/graphs/disagg.py
examples/multimodal/graphs/disagg.py
+6
-4
examples/multimodal/utils/protocol.py
examples/multimodal/utils/protocol.py
+4
-3
No files found.
.pre-commit-config.yaml
View file @
75e774d4
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
exclude
:
^(src/grpc_generated|.*\.patch$)
exclude
:
^(src/grpc_generated|.*\.patch$
|.*/connect/.*\.py
)
repos
:
repos
:
-
repo
:
https://github.com/timothycrosley/isort
-
repo
:
https://github.com/timothycrosley/isort
rev
:
5.12.0
rev
:
5.12.0
...
@@ -82,4 +82,4 @@ repos:
...
@@ -82,4 +82,4 @@ repos:
# NOTE: pyright may be able to find other classes of errors not covered above,
# NOTE: pyright may be able to find other classes of errors not covered above,
# but would require some configuring and venv setup to properly eliminate noise
# but would require some configuring and venv setup to properly eliminate noise
# and give it visiblity into all the local and third_party packages expected.
# and give it visiblity into all the local and third_party packages expected.
\ No newline at end of file
examples/multimodal/components/worker.py
→
examples/multimodal/components/
decode_
worker.py
View file @
75e774d4
...
@@ -19,10 +19,11 @@ import os
...
@@ -19,10 +19,11 @@ import os
import
signal
import
signal
from
typing
import
Optional
from
typing
import
Optional
import
connect
import
torch
import
torch
from
components.disagg_router
import
PyDisaggregatedRouter
from
components.disagg_router
import
PyDisaggregatedRouter
from
components.encode_worker
import
EncodeWorker
from
components.encode_worker
import
Vllm
EncodeWorker
from
components.prefill_worker
import
PrefillWorker
from
components.prefill_worker
import
Vllm
PrefillWorker
from
transformers
import
LlavaForConditionalGeneration
from
transformers
import
LlavaForConditionalGeneration
from
utils.logging
import
check_required_workers
from
utils.logging
import
check_required_workers
from
utils.nixl
import
NixlMetadataStore
from
utils.nixl
import
NixlMetadataStore
...
@@ -53,11 +54,11 @@ logger = logging.getLogger(__name__)
...
@@ -53,11 +54,11 @@ logger = logging.getLogger(__name__)
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
workers
=
1
,
workers
=
1
,
)
)
class
VllmWorker
:
class
Vllm
Decode
Worker
:
# For disaggregated serving, we need to link the prefill worker to the vllm worker
# For disaggregated serving, we need to link the prefill worker to the vllm worker
prefill_worker
=
depends
(
PrefillWorker
)
prefill_worker
=
depends
(
Vllm
PrefillWorker
)
# For aggregated serving, we need to link the encode worker to the vllm worker.
# For aggregated serving, we need to link the encode worker to the vllm worker.
encode_worker
=
depends
(
EncodeWorker
)
encode_worker
=
depends
(
Vllm
EncodeWorker
)
def
__init__
(
self
):
def
__init__
(
self
):
self
.
client
=
None
self
.
client
=
None
...
@@ -141,7 +142,11 @@ class VllmWorker:
...
@@ -141,7 +142,11 @@ class VllmWorker:
vision_tower
.
vision_model
.
embeddings
.
position_embedding
.
num_embeddings
vision_tower
.
vision_model
.
embeddings
.
position_embedding
.
num_embeddings
)
)
else
:
else
:
enc_comp_ns
,
enc_comp_name
=
EncodeWorker
.
dynamo_address
()
# type: ignore
EMBEDDINGS_SHAPE
=
(
1
,
577
,
4096
)
EMBEDDINGS_DTYPE
=
torch
.
float16
EMBEDDINGS_DEVICE
=
"cuda"
enc_comp_ns
,
enc_comp_name
=
VllmEncodeWorker
.
dynamo_address
()
# type: ignore
self
.
encode_worker_client
=
(
self
.
encode_worker_client
=
(
await
runtime
.
namespace
(
enc_comp_ns
)
await
runtime
.
namespace
(
enc_comp_ns
)
.
component
(
enc_comp_name
)
.
component
(
enc_comp_name
)
...
@@ -149,9 +154,22 @@ class VllmWorker:
...
@@ -149,9 +154,22 @@ class VllmWorker:
.
client
()
.
client
()
)
)
self
.
_connector
=
connect
.
Connector
(
runtime
=
runtime
,
namespace
=
enc_comp_ns
)
await
self
.
_connector
.
initialize
()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings
=
torch
.
empty
(
EMBEDDINGS_SHAPE
,
dtype
=
EMBEDDINGS_DTYPE
,
device
=
EMBEDDINGS_DEVICE
)
descriptor
=
connect
.
Descriptor
(
embeddings
)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor
.
register_memory
(
self
.
_connector
)
self
.
_embeddings_descriptor
=
(
embeddings
,
descriptor
)
await
check_required_workers
(
self
.
encode_worker_client
,
self
.
min_workers
)
await
check_required_workers
(
self
.
encode_worker_client
,
self
.
min_workers
)
self
.
disaggregated_router
=
None
self
.
disaggregated_router
=
None
logger
.
info
(
"VllmWorker has been initialized"
)
logger
.
info
(
"Initialization complete."
)
def
shutdown_vllm_engine
(
self
,
signum
,
frame
):
def
shutdown_vllm_engine
(
self
,
signum
,
frame
):
"""Shutdown the background loop"""
"""Shutdown the background loop"""
...
@@ -159,7 +177,7 @@ class VllmWorker:
...
@@ -159,7 +177,7 @@ class VllmWorker:
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
try
:
try
:
self
.
engine_client
.
close
()
self
.
engine_client
.
close
()
logger
.
info
(
"
VllmWorker s
hutdown complete"
)
logger
.
info
(
"
S
hutdown complete
.
"
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error during shutdown:
{
e
}
"
)
logger
.
error
(
f
"Error during shutdown:
{
e
}
"
)
finally
:
finally
:
...
@@ -177,8 +195,18 @@ class VllmWorker:
...
@@ -177,8 +195,18 @@ class VllmWorker:
@
endpoint
()
@
endpoint
()
async
def
generate
(
self
,
request
:
vLLMMultimodalRequest
):
async
def
generate
(
self
,
request
:
vLLMMultimodalRequest
):
image_features
=
None
request_id
=
request
.
request_id
image_url
=
request
.
image_url
logger
.
info
(
f
"Received multimodal request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
embeddings
=
None
if
self
.
do_remote_prefill
:
if
self
.
do_remote_prefill
:
logger
.
debug
(
f
"Disaggregated: request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}"
" prefill worker will populate the decode model's key-value cache ahead of time;"
" no direct encode worker interaction required."
)
if
self
.
disaggregated_router
is
not
None
:
if
self
.
disaggregated_router
is
not
None
:
async
with
PrefillQueue
.
get_instance
(
async
with
PrefillQueue
.
get_instance
(
nats_server
=
self
.
_prefill_queue_nats_server
,
nats_server
=
self
.
_prefill_queue_nats_server
,
...
@@ -195,21 +223,21 @@ class VllmWorker:
...
@@ -195,21 +223,21 @@ class VllmWorker:
disagg_router_decision
=
True
disagg_router_decision
=
True
if
self
.
do_remote_prefill
and
disagg_router_decision
:
if
self
.
do_remote_prefill
and
disagg_router_decision
:
logger
.
debug
(
f
"Prefilling remotely for request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }} with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
remote_prefill_params
=
RemotePrefillParams
(
remote_prefill_params
=
RemotePrefillParams
(
is_remote_prefill
=
True
,
is_remote_prefill
=
True
,
remote_prefill_request_callback
=
self
.
get_remote_prefill_request_callback
(),
remote_prefill_request_callback
=
self
.
get_remote_prefill_request_callback
(),
# Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
# Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
multimodal_data_source
=
{
multimodal_data_source
=
{
"image_url"
:
request
.
image_url
,
"image_url"
:
image_url
,
},
},
)
)
logger
.
info
(
f
"Prefilling remotely for request
{
request
.
request_id
}
with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
else
:
else
:
remote_prefill_params
=
None
remote_prefill_params
=
None
logger
.
info
(
logger
.
debug
(
f
"Prefilling locally for request
{
request
.
request_id
}
with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
f
"Prefilling locally for request {
{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }
} with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
)
# The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
# The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
...
@@ -231,33 +259,61 @@ class VllmWorker:
...
@@ -231,33 +259,61 @@ class VllmWorker:
)
)
else
:
else
:
# For aggregated serving, the vllm worker will directly send the encode request to the encode worker.
logger
.
debug
(
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
f
"Aggregated: request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}"
EncodeRequest
(
" no prefill worker available, embeddings directly from encode worker."
image_url
=
request
.
image_url
,
).
model_dump_json
()
)
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
async
for
encode_response
in
encode_generator
:
# Doing this avoids unnessesary memory de/registration with NIXL.
encode_output
=
EncodeResponse
.
model_validate_json
(
embeddings
,
descriptor
=
self
.
_embeddings_descriptor
encode_response
.
data
()
with
self
.
_connector
.
create_writable
(
descriptor
)
as
writable
:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_request
=
EncodeRequest
(
request_id
=
request_id
,
image_url
=
image_url
,
serialized_request
=
writable
.
to_serialized
(),
)
)
image_features
=
torch
.
tensor
(
logger
.
debug
(
f
"Encode request:
{
encode_request
.
model_dump_json
()
}
"
)
encode_output
.
image_features
,
device
=
device
,
dtype
=
torch
.
float16
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
encode_request
.
model_dump_json
()
)
)
async
for
encode_response
in
encode_generator
:
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
()
)
logger
.
info
(
f
"Received response: {{ id:
{
encode_output
.
request_id
}
}}"
)
# Wait for the write operation to complete.
# This will block until the write operation is complete.
# This await should be a no-op since we've already received a response from the encode worker.
await
writable
.
wait_for_completion
()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params
=
None
remote_prefill_params
=
None
logger
.
info
(
logger
.
info
(
f
"Prefilling locally for request
{
request
.
request_id
}
with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
f
"Prefilling locally for request {
{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }
} with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
)
prompt_ids
=
request
.
engine_prompt
[
"prompt_token_ids"
]
prompt_ids
=
request
.
engine_prompt
[
"prompt_token_ids"
]
# rust HTTP requires Delta streaming
# rust HTTP requires Delta streaming
request
.
sampling_params
.
output_kind
=
RequestOutputKind
.
DELTA
request
.
sampling_params
.
output_kind
=
RequestOutputKind
.
DELTA
if
image_features
is
not
None
:
# When using aggregated serving, the encode worker will have provided the key-value cache updates via the prefill worker.
multi_modal_data
=
{
"image"
:
image_features
}
# When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
if
embeddings
is
not
None
:
logger
.
debug
(
"Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
)
multi_modal_data
=
{
"image"
:
embeddings
}
else
:
else
:
logger
.
debug
(
"Disaggregated: no embedding data required as prefill will have provided key-value cache updates via encode worker."
)
multi_modal_data
=
None
multi_modal_data
=
None
async
for
response
in
self
.
engine_client
.
generate
(
async
for
response
in
self
.
engine_client
.
generate
(
...
@@ -269,6 +325,9 @@ class VllmWorker:
...
@@ -269,6 +325,9 @@ class VllmWorker:
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
remote_prefill_params
=
remote_prefill_params
,
remote_prefill_params
=
remote_prefill_params
,
):
):
logger
.
debug
(
f
"Yeilding response {{ id:
{
response
.
request_id
}
, prompt: '
{
response
.
prompt
}
' }}"
)
yield
MyRequestOutput
(
yield
MyRequestOutput
(
request_id
=
response
.
request_id
,
request_id
=
response
.
request_id
,
prompt
=
response
.
prompt
,
prompt
=
response
.
prompt
,
...
...
examples/multimodal/components/encode_worker.py
View file @
75e774d4
...
@@ -15,8 +15,10 @@
...
@@ -15,8 +15,10 @@
import
logging
import
logging
from
io
import
BytesIO
from
io
import
BytesIO
from
queue
import
Queue
from
typing
import
AsyncIterator
from
typing
import
AsyncIterator
import
connect
import
requests
import
requests
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
...
@@ -24,10 +26,25 @@ from transformers import AutoImageProcessor, LlavaForConditionalGeneration
...
@@ -24,10 +26,25 @@ from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from
utils.protocol
import
EncodeRequest
,
EncodeResponse
from
utils.protocol
import
EncodeRequest
,
EncodeResponse
from
utils.vllm
import
parse_vllm_args
from
utils.vllm
import
parse_vllm_args
from
dynamo.sdk
import
endpoint
,
service
from
dynamo.sdk
import
async_on_start
,
endpoint
,
service
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
try
:
import
cupy
as
array_module
if
not
array_module
.
cuda
.
is_available
():
raise
ImportError
(
"CUDA is not available."
)
DEVICE
=
"cuda"
logger
.
info
(
"Using cupy for array operations (GPU mode)."
)
except
ImportError
as
e
:
logger
.
warning
(
f
"Failed to import cupy, falling back to numpy:
{
e
}
."
)
import
numpy
as
array_module
DEVICE
=
"cpu"
CACHE_SIZE_MAXIMUM
=
8
@
service
(
@
service
(
dynamo
=
{
dynamo
=
{
...
@@ -36,7 +53,7 @@ logger = logging.getLogger(__name__)
...
@@ -36,7 +53,7 @@ logger = logging.getLogger(__name__)
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
workers
=
1
,
workers
=
1
,
)
)
class
EncodeWorker
:
class
Vllm
EncodeWorker
:
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
class_name
=
self
.
__class__
.
__name__
class_name
=
self
.
__class__
.
__name__
self
.
engine_args
=
parse_vllm_args
(
class_name
,
""
)
self
.
engine_args
=
parse_vllm_args
(
class_name
,
""
)
...
@@ -50,9 +67,50 @@ class EncodeWorker:
...
@@ -50,9 +67,50 @@ class EncodeWorker:
self
.
MODEL_ID
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
float16
self
.
MODEL_ID
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
float16
).
eval
()
).
eval
()
self
.
_image_cache
:
dict
[
str
,
Image
.
Image
]
=
{}
self
.
_cache_queue
:
Queue
[
str
]
=
Queue
(
maxsize
=
CACHE_SIZE_MAXIMUM
)
@
endpoint
()
@
endpoint
()
async
def
encode
(
self
,
request
:
EncodeRequest
)
->
AsyncIterator
[
EncodeResponse
]:
async
def
encode
(
self
,
request
:
EncodeRequest
)
->
AsyncIterator
[
EncodeResponse
]:
image
=
self
.
open_image
(
request
.
image_url
)
logger
.
debug
(
f
"Received encode request: {{ id:
{
request
.
request_id
}
, image_url: '
{
request
.
image_url
}
' }}."
)
request_id
=
request
.
request_id
image_url
=
request
.
image_url
.
lower
()
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
# Either retrieve the image from the cache or download it and then cache it.
if
request
.
image_url
in
self
.
_image_cache
:
image
=
self
.
_image_cache
[
image_url
]
logger
.
debug
(
f
"Image found in cache for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
else
:
image
=
self
.
open_image
(
image_url
)
logger
.
debug
(
f
"Downloading/opening image for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
# Cache the image for future use, and evict the oldest image if the cache is full.
if
self
.
_cache_queue
.
full
():
oldest_image_url
=
self
.
_cache_queue
.
get
()
del
self
.
_image_cache
[
oldest_image_url
]
self
.
_image_cache
[
request
.
image_url
]
=
image
self
.
_cache_queue
.
put
(
request
.
image_url
)
logger
.
debug
(
f
"Processing image for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}"
)
image_embeds
=
self
.
image_processor
(
images
=
image
,
return_tensors
=
"pt"
)
image_embeds
=
self
.
image_processor
(
images
=
image
,
return_tensors
=
"pt"
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -60,22 +118,56 @@ class EncodeWorker:
...
@@ -60,22 +118,56 @@ class EncodeWorker:
vision_outputs
=
self
.
vision_model
.
vision_tower
(
vision_outputs
=
self
.
vision_model
.
vision_tower
(
image_embeds
[
"pixel_values"
].
to
(
self
.
vision_model
.
device
)
image_embeds
[
"pixel_values"
].
to
(
self
.
vision_model
.
device
)
)
)
logger
.
debug
(
"Vision model completed."
)
embeddings
=
vision_outputs
.
last_hidden_state
embeddings
=
self
.
vision_model
.
multi_modal_projector
(
embeddings
)
logger
.
debug
(
f
"Embeddings: {{ shape:
{
embeddings
.
shape
}
, dtype:
{
embeddings
.
dtype
}
, device:
{
embeddings
.
device
}
, ptr:
{
embeddings
.
data_ptr
()
}
, elements: {{ count:
{
embeddings
.
numel
()
}
, size:
{
embeddings
.
element_size
()
}
}} }}."
)
if
request
.
serialized_request
is
None
:
logger
.
error
(
f
"Request serialized_request is None for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
# Create a descriptor for the embeddings, this will register the memory with the connector (and the NIXL runtime).
descriptor
=
connect
.
Descriptor
(
embeddings
)
# Create a write operation using the serialized request and the descriptor.
# This will begin the RDMA transfer of the embeddings to the remote worker.
write_op
=
await
self
.
_connector
.
begin_write
(
descriptor
,
request
.
serialized_request
,
)
# Await for the write operation to complete.
# This will block until the data has been written to the remote worker or an error occurs.
await
write_op
.
wait_for_completion
()
image_features
=
vision_outputs
.
last_hidden_state
image_features
=
self
.
vision_model
.
multi_modal_projector
(
image_features
)
yield
EncodeResponse
(
yield
EncodeResponse
(
image_features
=
image_features
.
tolist
()
request_id
=
request
.
request_id
,
).
model_dump_json
()
).
model_dump_json
()
@
async_on_start
()
async
def
on_start
(
self
):
logger
.
info
(
"Startup started."
)
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self
.
_connector
=
connect
.
Connector
()
await
self
.
_connector
.
initialize
()
logger
.
info
(
"Startup completed."
)
def
open_image
(
self
,
image
:
str
)
->
Image
.
Image
:
def
open_image
(
self
,
image
:
str
)
->
Image
.
Image
:
# TODO: Have a seperate field for url and non url - and avoid auto detection
# TODO: Have a seperate field for url and non url - and avoid auto detection
try
:
try
:
# Acquire the image and convert it to the format (RGB) the image processor model expects.
if
image
.
startswith
(
"http"
)
or
image
.
startswith
(
"https"
):
if
image
.
startswith
(
"http"
)
or
image
.
startswith
(
"https"
):
response
=
requests
.
get
(
image
)
response
=
requests
.
get
(
image
)
image_data
=
Image
.
open
(
BytesIO
(
response
.
content
)).
convert
(
"RGB"
)
image_data
=
Image
.
open
(
BytesIO
(
response
.
content
)).
convert
(
"RGB"
)
else
:
else
:
image_data
=
Image
.
open
(
image
).
convert
(
"RGB"
)
image_data
=
Image
.
open
(
image
).
convert
(
"RGB"
)
return
image_data
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error opening image:
{
e
}
"
)
logger
.
error
(
f
"Error opening image:
{
e
}
"
)
raise
e
raise
e
return
image_data
examples/multimodal/components/prefill_worker.py
View file @
75e774d4
...
@@ -20,8 +20,9 @@ import os
...
@@ -20,8 +20,9 @@ import os
import
signal
import
signal
import
sys
import
sys
import
connect
import
torch
import
torch
from
components.encode_worker
import
EncodeWorker
from
components.encode_worker
import
Vllm
EncodeWorker
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
from
utils.logging
import
check_required_workers
from
utils.logging
import
check_required_workers
from
utils.nixl
import
NixlMetadataStore
from
utils.nixl
import
NixlMetadataStore
...
@@ -38,6 +39,11 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic
...
@@ -38,6 +39,11 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Constants for the shape and dtype of the embeddings tensor.
EMBEDDINGS_SHAPE
=
(
1
,
577
,
4096
)
EMBEDDINGS_DTYPE
=
torch
.
float16
EMBEDDINGS_DEVICE
=
"cuda"
class
RequestType
(
BaseModel
):
class
RequestType
(
BaseModel
):
text
:
str
text
:
str
...
@@ -50,8 +56,8 @@ class RequestType(BaseModel):
...
@@ -50,8 +56,8 @@ class RequestType(BaseModel):
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
workers
=
1
,
workers
=
1
,
)
)
class
PrefillWorker
:
class
Vllm
PrefillWorker
:
encode_worker
=
depends
(
EncodeWorker
)
encode_worker
=
depends
(
Vllm
EncodeWorker
)
def
__init__
(
self
):
def
__init__
(
self
):
class_name
=
self
.
__class__
.
__name__
class_name
=
self
.
__class__
.
__name__
...
@@ -95,7 +101,7 @@ class PrefillWorker:
...
@@ -95,7 +101,7 @@ class PrefillWorker:
raise
RuntimeError
(
"Failed to initialize engine client"
)
raise
RuntimeError
(
"Failed to initialize engine client"
)
runtime
=
dynamo_context
[
"runtime"
]
runtime
=
dynamo_context
[
"runtime"
]
enc_comp_ns
,
enc_comp_name
=
EncodeWorker
.
dynamo_address
()
# type: ignore
enc_comp_ns
,
enc_comp_name
=
Vllm
EncodeWorker
.
dynamo_address
()
# type: ignore
self
.
encode_worker_client
=
(
self
.
encode_worker_client
=
(
await
runtime
.
namespace
(
enc_comp_ns
)
await
runtime
.
namespace
(
enc_comp_ns
)
.
component
(
enc_comp_name
)
.
component
(
enc_comp_name
)
...
@@ -103,6 +109,20 @@ class PrefillWorker:
...
@@ -103,6 +109,20 @@ class PrefillWorker:
.
client
()
.
client
()
)
)
self
.
_connector
=
connect
.
Connector
(
runtime
=
runtime
,
namespace
=
enc_comp_ns
)
await
self
.
_connector
.
initialize
()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings
=
torch
.
empty
(
EMBEDDINGS_SHAPE
,
dtype
=
EMBEDDINGS_DTYPE
,
device
=
EMBEDDINGS_DEVICE
,
)
descriptor
=
connect
.
Descriptor
(
embeddings
)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor
.
register_memory
(
self
.
_connector
)
self
.
_embeddings_descriptor
=
(
embeddings
,
descriptor
)
await
check_required_workers
(
self
.
encode_worker_client
,
self
.
min_workers
)
await
check_required_workers
(
self
.
encode_worker_client
,
self
.
min_workers
)
metadata
=
self
.
engine_client
.
nixl_metadata
metadata
=
self
.
engine_client
.
nixl_metadata
...
@@ -119,19 +139,19 @@ class PrefillWorker:
...
@@ -119,19 +139,19 @@ class PrefillWorker:
sys
.
exit
(
1
)
sys
.
exit
(
1
)
task
.
add_done_callback
(
prefill_queue_handler_cb
)
task
.
add_done_callback
(
prefill_queue_handler_cb
)
logger
.
info
(
"
PrefillWorker initialized
"
)
logger
.
info
(
"
Initialization complete.
"
)
def
shutdown_vllm_engine
(
self
,
signum
,
frame
):
def
shutdown_vllm_engine
(
self
,
signum
,
frame
):
"""Shutdown the background loop"""
"""Shutdown the background loop"""
logger
.
info
(
f
"
Receiv
ed signal
{
signum
}
, shutting down
"
)
logger
.
info
(
f
"
Shutdown start
ed
,
signal
{
signum
}
received.
"
)
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
try
:
try
:
self
.
engine_client
.
close
()
self
.
engine_client
.
close
()
logger
.
info
(
"PrefillWorker shutdown complete"
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error during shutdown:
{
e
}
"
)
logger
.
error
(
f
"Error during shutdown:
{
e
}
"
)
finally
:
finally
:
loop
.
stop
()
loop
.
stop
()
logger
.
info
(
"Shutdown complete."
)
async
def
prefill_queue_handler
(
self
):
async
def
prefill_queue_handler
(
self
):
logger
.
info
(
"Prefill queue handler entered"
)
logger
.
info
(
"Prefill queue handler entered"
)
...
@@ -166,62 +186,88 @@ class PrefillWorker:
...
@@ -166,62 +186,88 @@ class PrefillWorker:
if
request
.
multimodal_data_source
[
"image_url"
]
is
None
:
if
request
.
multimodal_data_source
[
"image_url"
]
is
None
:
raise
ValueError
(
"No image url provided for prefill request"
)
raise
ValueError
(
"No image url provided for prefill request"
)
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
request_id
=
request
.
request_id
EncodeRequest
(
engine_id
=
request
.
engine_id
image_url
=
request
.
multimodal_data_source
[
"image_url"
],
image_url
=
request
.
multimodal_data_source
[
"image_url"
]
).
model_dump_json
()
logger
.
info
(
f
"Received prefill request {{ id:
{
request_id
}
, engine_id:
{
engine_id
}
, image_url: '
{
image_url
}
' }}."
)
)
async
for
encode_response
in
encode_generator
:
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
())
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
image_features
=
torch
.
tensor
(
# Doing this avoids unnessesary memory de/registration with NIXL.
encode_output
.
image_features
,
device
=
"cpu"
,
dtype
=
torch
.
float16
embeddings
,
descriptor
=
self
.
_embeddings_descriptor
# Create a new writable operation from the descriptor.
with
self
.
_connector
.
create_writable
(
descriptor
)
as
writable
:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
EncodeRequest
(
request_id
=
request_id
,
image_url
=
image_url
,
serialized_request
=
writable
.
to_serialized
(),
).
model_dump_json
()
)
)
async
for
encode_response
in
encode_generator
:
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
(),
)
logger
.
debug
(
f
"Received response: {{ id:
{
encode_output
.
request_id
}
}}."
)
sampling_params
=
request
.
sampling_params
# Wait for the write operation to complete.
sampling_params
.
max_tokens
=
1
# This will block until the write operation is complete.
sampling_params
.
min_tokens
=
1
# This await should be a no-op since we've already received a response from the encode worker.
await
writable
.
wait_for_completion
()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params
=
RemotePrefillParams
(
sampling_params
=
request
.
sampling_params
is_remote_decode
=
True
,
sampling_params
.
max_tokens
=
1
decode_block_ids
=
request
.
block_ids
,
sampling_params
.
min_tokens
=
1
decode_engine_id
=
request
.
engine_id
,
decode_computed_block_ids
=
request
.
computed_block_ids
,
)
# TODO check if metadata has changed
remote_prefill_params
=
RemotePrefillParams
(
# and reload - currently only loading once
is_remote_decode
=
True
,
if
request
.
engine_id
not
in
self
.
_loaded_metadata
:
decode_block_ids
=
request
.
block_ids
,
remote_metadata
=
await
self
.
_metadata_store
.
get
(
request
.
engine_id
)
decode_engine_id
=
engine_id
,
await
self
.
engine_client
.
add_remote_nixl_metadata
(
remote_metadata
)
decode_computed_block_ids
=
request
.
computed_block_ids
,
logger
.
info
(
)
f
"Loaded nixl metadata from engine
{
request
.
engine_id
}
into "
f
"engine
{
self
.
engine_client
.
nixl_metadata
.
engine_id
}
"
# TODO check if metadata has changed
# and reload - currently only loading once
if
engine_id
not
in
self
.
_loaded_metadata
:
remote_metadata
=
await
self
.
_metadata_store
.
get
(
request
.
engine_id
)
await
self
.
engine_client
.
add_remote_nixl_metadata
(
remote_metadata
)
logger
.
info
(
f
"Loaded nixl metadata from engine
{
engine_id
}
into "
f
"engine
{
self
.
engine_client
.
nixl_metadata
.
engine_id
}
"
)
self
.
_loaded_metadata
.
add
(
engine_id
)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID
=
32000
embedding_size
=
embeddings
.
shape
[
1
]
padding_size
=
embedding_size
-
1
image_token_index
=
request
.
prompt_token_ids
.
index
(
IMAGE_TOKEN_ID
)
dummy_token_index
=
image_token_index
+
1
prompt_token_ids
=
(
request
.
prompt_token_ids
[:
dummy_token_index
]
+
request
.
prompt_token_ids
[
dummy_token_index
+
padding_size
:]
)
)
self
.
_loaded_metadata
.
add
(
request
.
engine_id
)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID
=
32000
embedding_size
=
image_features
.
shape
[
1
]
padding_size
=
embedding_size
-
1
image_token_index
=
request
.
prompt_token_ids
.
index
(
IMAGE_TOKEN_ID
)
dummy_token_index
=
image_token_index
+
1
prompt_token_ids
=
(
request
.
prompt_token_ids
[:
dummy_token_index
]
+
request
.
prompt_token_ids
[
dummy_token_index
+
padding_size
:]
)
async
for
_
in
self
.
engine_client
.
generate
(
async
for
_
in
self
.
engine_client
.
generate
(
request_id
=
request
.
request
_id
,
request_id
=
request_id
,
prompt
=
TokensPrompt
(
prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
{
"image"
:
image_feature
s
},
multi_modal_data
=
{
"image"
:
embedding
s
},
),
),
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
remote_prefill_params
=
remote_prefill_params
,
remote_prefill_params
=
remote_prefill_params
,
):
):
yield
yield
@
endpoint
()
@
endpoint
()
async
def
mock
(
self
,
req
:
RequestType
):
async
def
mock
(
self
,
req
:
RequestType
):
...
...
examples/multimodal/components/processor.py
View file @
75e774d4
...
@@ -19,7 +19,7 @@ import uuid
...
@@ -19,7 +19,7 @@ import uuid
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
AsyncIterator
,
Tuple
,
Union
from
typing
import
AsyncIterator
,
Tuple
,
Union
from
components.worker
import
VllmWorker
from
components.
decode_
worker
import
Vllm
Decode
Worker
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
utils.chat_processor
import
ChatProcessor
,
CompletionsProcessor
,
ProcessMixIn
from
utils.chat_processor
import
ChatProcessor
,
CompletionsProcessor
,
ProcessMixIn
from
utils.logging
import
check_required_workers
from
utils.logging
import
check_required_workers
...
@@ -53,7 +53,7 @@ class Processor(ProcessMixIn):
...
@@ -53,7 +53,7 @@ class Processor(ProcessMixIn):
vLLM pre and post processing
vLLM pre and post processing
"""
"""
worker
=
depends
(
VllmWorker
)
worker
=
depends
(
Vllm
Decode
Worker
)
def
__init__
(
self
):
def
__init__
(
self
):
class_name
=
self
.
__class__
.
__name__
class_name
=
self
.
__class__
.
__name__
...
@@ -83,7 +83,7 @@ class Processor(ProcessMixIn):
...
@@ -83,7 +83,7 @@ class Processor(ProcessMixIn):
@
async_on_start
@
async_on_start
async
def
async_init
(
self
):
async
def
async_init
(
self
):
runtime
=
dynamo_context
[
"runtime"
]
runtime
=
dynamo_context
[
"runtime"
]
comp_ns
,
comp_name
=
VllmWorker
.
dynamo_address
()
# type: ignore
comp_ns
,
comp_name
=
Vllm
Decode
Worker
.
dynamo_address
()
# type: ignore
self
.
worker_client
=
(
self
.
worker_client
=
(
await
runtime
.
namespace
(
comp_ns
)
await
runtime
.
namespace
(
comp_ns
)
.
component
(
comp_name
)
.
component
(
comp_name
)
...
...
examples/multimodal/configs/agg.yaml
View file @
75e774d4
...
@@ -21,7 +21,7 @@ Processor:
...
@@ -21,7 +21,7 @@ Processor:
router
:
round-robin
router
:
round-robin
common-configs
:
[
model
,
block-size
,
max-model-len
]
common-configs
:
[
model
,
block-size
,
max-model-len
]
VllmWorker
:
Vllm
Decode
Worker
:
enforce-eager
:
true
enforce-eager
:
true
max-num-batched-tokens
:
16384
max-num-batched-tokens
:
16384
enable-prefix-caching
:
true
enable-prefix-caching
:
true
...
@@ -33,7 +33,7 @@ VllmWorker:
...
@@ -33,7 +33,7 @@ VllmWorker:
gpu
:
1
gpu
:
1
common-configs
:
[
model
,
block-size
,
max-model-len
]
common-configs
:
[
model
,
block-size
,
max-model-len
]
EncodeWorker
:
Vllm
EncodeWorker
:
tensor-parallel-size
:
1
tensor-parallel-size
:
1
router
:
random
router
:
random
ServiceArgs
:
ServiceArgs
:
...
...
examples/multimodal/configs/disagg.yaml
View file @
75e774d4
...
@@ -22,7 +22,7 @@ Processor:
...
@@ -22,7 +22,7 @@ Processor:
router
:
round-robin
router
:
round-robin
common-configs
:
[
model
,
block-size
]
common-configs
:
[
model
,
block-size
]
VllmWorker
:
Vllm
Decode
Worker
:
remote-prefill
:
true
remote-prefill
:
true
conditional-disagg
:
true
conditional-disagg
:
true
max-local-prefill-length
:
10
max-local-prefill-length
:
10
...
@@ -33,7 +33,7 @@ VllmWorker:
...
@@ -33,7 +33,7 @@ VllmWorker:
gpu
:
1
gpu
:
1
common-configs
:
[
model
,
block-size
,
max-model-len
,
kv-transfer-config
]
common-configs
:
[
model
,
block-size
,
max-model-len
,
kv-transfer-config
]
PrefillWorker
:
Vllm
PrefillWorker
:
max-num-batched-tokens
:
16384
max-num-batched-tokens
:
16384
ServiceArgs
:
ServiceArgs
:
workers
:
1
workers
:
1
...
@@ -41,7 +41,7 @@ PrefillWorker:
...
@@ -41,7 +41,7 @@ PrefillWorker:
gpu
:
1
gpu
:
1
common-configs
:
[
model
,
block-size
,
max-model-len
,
kv-transfer-config
]
common-configs
:
[
model
,
block-size
,
max-model-len
,
kv-transfer-config
]
EncodeWorker
:
Vllm
EncodeWorker
:
tensor-parallel-size
:
1
tensor-parallel-size
:
1
router
:
random
router
:
random
ServiceArgs
:
ServiceArgs
:
...
...
examples/multimodal/connect/__init__.py
0 → 100644
View file @
75e774d4
This diff is collapsed.
Click to expand it.
examples/multimodal/graphs/agg.py
View file @
75e774d4
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
...
@@ -13,9 +14,9 @@
...
@@ -13,9 +14,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
components.encode_worker
import
EncodeWorker
from
components.decode_worker
import
VllmDecodeWorker
from
components.encode_worker
import
VllmEncodeWorker
from
components.frontend
import
Frontend
from
components.frontend
import
Frontend
from
components.processor
import
Processor
from
components.processor
import
Processor
from
components.worker
import
VllmWorker
Frontend
.
link
(
Processor
).
link
(
VllmWorker
).
link
(
EncodeWorker
)
Frontend
.
link
(
Processor
).
link
(
Vllm
Decode
Worker
).
link
(
Vllm
EncodeWorker
)
examples/multimodal/graphs/disagg.py
View file @
75e774d4
...
@@ -13,10 +13,12 @@
...
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
components.encode_worker
import
EncodeWorker
from
components.decode_worker
import
VllmDecodeWorker
from
components.encode_worker
import
VllmEncodeWorker
from
components.frontend
import
Frontend
from
components.frontend
import
Frontend
from
components.prefill_worker
import
PrefillWorker
from
components.prefill_worker
import
Vllm
PrefillWorker
from
components.processor
import
Processor
from
components.processor
import
Processor
from
components.worker
import
VllmWorker
Frontend
.
link
(
Processor
).
link
(
VllmWorker
).
link
(
PrefillWorker
).
link
(
EncodeWorker
)
Frontend
.
link
(
Processor
).
link
(
VllmDecodeWorker
).
link
(
VllmPrefillWorker
).
link
(
VllmEncodeWorker
)
examples/multimodal/utils/protocol.py
View file @
75e774d4
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
import
json
import
json
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
List
,
Optional
import
connect
import
msgspec
import
msgspec
from
pydantic
import
BaseModel
,
ConfigDict
,
field_validator
from
pydantic
import
BaseModel
,
ConfigDict
,
field_validator
from
pydantic_core
import
core_schema
from
pydantic_core
import
core_schema
...
@@ -111,12 +112,13 @@ class EncodeRequest(BaseModel):
...
@@ -111,12 +112,13 @@ class EncodeRequest(BaseModel):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
image_url
:
str
image_url
:
str
request_id
:
str
serialized_request
:
Optional
[
connect
.
SerializedRequest
]
=
None
class
EncodeResponse
(
BaseModel
):
class
EncodeResponse
(
BaseModel
):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
request_id
:
str
image_features
:
List
[
List
[
List
[
float
]]]
class
MyRequestOutput
(
BaseModel
):
class
MyRequestOutput
(
BaseModel
):
...
@@ -129,7 +131,6 @@ class MyRequestOutput(BaseModel):
...
@@ -129,7 +131,6 @@ class MyRequestOutput(BaseModel):
"""
"""
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
request_id
:
str
request_id
:
str
prompt
:
Optional
[
str
]
=
None
prompt
:
Optional
[
str
]
=
None
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment