Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
75e774d4
Unverified
Commit
75e774d4
authored
May 27, 2025
by
J Wyman
Committed by
GitHub
May 27, 2025
Browse files
feat: NIXL Based RDMA Support w/ Multimodal Example (#1060)
parent
9acaa8d1
Changes
11
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1784 additions
and
112 deletions
+1784
-112
.pre-commit-config.yaml
.pre-commit-config.yaml
+2
-2
examples/multimodal/components/decode_worker.py
examples/multimodal/components/decode_worker.py
+88
-29
examples/multimodal/components/encode_worker.py
examples/multimodal/components/encode_worker.py
+99
-7
examples/multimodal/components/prefill_worker.py
examples/multimodal/components/prefill_worker.py
+102
-56
examples/multimodal/components/processor.py
examples/multimodal/components/processor.py
+3
-3
examples/multimodal/configs/agg.yaml
examples/multimodal/configs/agg.yaml
+2
-2
examples/multimodal/configs/disagg.yaml
examples/multimodal/configs/disagg.yaml
+3
-3
examples/multimodal/connect/__init__.py
examples/multimodal/connect/__init__.py
+1471
-0
examples/multimodal/graphs/agg.py
examples/multimodal/graphs/agg.py
+4
-3
examples/multimodal/graphs/disagg.py
examples/multimodal/graphs/disagg.py
+6
-4
examples/multimodal/utils/protocol.py
examples/multimodal/utils/protocol.py
+4
-3
No files found.
.pre-commit-config.yaml
View file @
75e774d4
...
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
exclude
:
^(src/grpc_generated|.*\.patch$)
exclude
:
^(src/grpc_generated|.*\.patch$
|.*/connect/.*\.py
)
repos
:
-
repo
:
https://github.com/timothycrosley/isort
rev
:
5.12.0
...
...
examples/multimodal/components/worker.py
→
examples/multimodal/components/
decode_
worker.py
View file @
75e774d4
...
...
@@ -19,10 +19,11 @@ import os
import
signal
from
typing
import
Optional
import
connect
import
torch
from
components.disagg_router
import
PyDisaggregatedRouter
from
components.encode_worker
import
EncodeWorker
from
components.prefill_worker
import
PrefillWorker
from
components.encode_worker
import
Vllm
EncodeWorker
from
components.prefill_worker
import
Vllm
PrefillWorker
from
transformers
import
LlavaForConditionalGeneration
from
utils.logging
import
check_required_workers
from
utils.nixl
import
NixlMetadataStore
...
...
@@ -53,11 +54,11 @@ logger = logging.getLogger(__name__)
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
workers
=
1
,
)
class
VllmWorker
:
class
Vllm
Decode
Worker
:
# For disaggregated serving, we need to link the prefill worker to the vllm worker
prefill_worker
=
depends
(
PrefillWorker
)
prefill_worker
=
depends
(
Vllm
PrefillWorker
)
# For aggregated serving, we need to link the encode worker to the vllm worker.
encode_worker
=
depends
(
EncodeWorker
)
encode_worker
=
depends
(
Vllm
EncodeWorker
)
def
__init__
(
self
):
self
.
client
=
None
...
...
@@ -141,7 +142,11 @@ class VllmWorker:
vision_tower
.
vision_model
.
embeddings
.
position_embedding
.
num_embeddings
)
else
:
enc_comp_ns
,
enc_comp_name
=
EncodeWorker
.
dynamo_address
()
# type: ignore
EMBEDDINGS_SHAPE
=
(
1
,
577
,
4096
)
EMBEDDINGS_DTYPE
=
torch
.
float16
EMBEDDINGS_DEVICE
=
"cuda"
enc_comp_ns
,
enc_comp_name
=
VllmEncodeWorker
.
dynamo_address
()
# type: ignore
self
.
encode_worker_client
=
(
await
runtime
.
namespace
(
enc_comp_ns
)
.
component
(
enc_comp_name
)
...
...
@@ -149,9 +154,22 @@ class VllmWorker:
.
client
()
)
self
.
_connector
=
connect
.
Connector
(
runtime
=
runtime
,
namespace
=
enc_comp_ns
)
await
self
.
_connector
.
initialize
()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings
=
torch
.
empty
(
EMBEDDINGS_SHAPE
,
dtype
=
EMBEDDINGS_DTYPE
,
device
=
EMBEDDINGS_DEVICE
)
descriptor
=
connect
.
Descriptor
(
embeddings
)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor
.
register_memory
(
self
.
_connector
)
self
.
_embeddings_descriptor
=
(
embeddings
,
descriptor
)
await
check_required_workers
(
self
.
encode_worker_client
,
self
.
min_workers
)
self
.
disaggregated_router
=
None
logger
.
info
(
"VllmWorker has been initialized"
)
logger
.
info
(
"Initialization complete."
)
def
shutdown_vllm_engine
(
self
,
signum
,
frame
):
"""Shutdown the background loop"""
...
...
@@ -159,7 +177,7 @@ class VllmWorker:
loop
=
asyncio
.
get_event_loop
()
try
:
self
.
engine_client
.
close
()
logger
.
info
(
"
VllmWorker s
hutdown complete"
)
logger
.
info
(
"
S
hutdown complete
.
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error during shutdown:
{
e
}
"
)
finally
:
...
...
@@ -177,8 +195,18 @@ class VllmWorker:
@
endpoint
()
async
def
generate
(
self
,
request
:
vLLMMultimodalRequest
):
image_features
=
None
request_id
=
request
.
request_id
image_url
=
request
.
image_url
logger
.
info
(
f
"Received multimodal request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
embeddings
=
None
if
self
.
do_remote_prefill
:
logger
.
debug
(
f
"Disaggregated: request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}"
" prefill worker will populate the decode model's key-value cache ahead of time;"
" no direct encode worker interaction required."
)
if
self
.
disaggregated_router
is
not
None
:
async
with
PrefillQueue
.
get_instance
(
nats_server
=
self
.
_prefill_queue_nats_server
,
...
...
@@ -195,21 +223,21 @@ class VllmWorker:
disagg_router_decision
=
True
if
self
.
do_remote_prefill
and
disagg_router_decision
:
logger
.
debug
(
f
"Prefilling remotely for request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }} with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
remote_prefill_params
=
RemotePrefillParams
(
is_remote_prefill
=
True
,
remote_prefill_request_callback
=
self
.
get_remote_prefill_request_callback
(),
# Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
multimodal_data_source
=
{
"image_url"
:
request
.
image_url
,
"image_url"
:
image_url
,
},
)
logger
.
info
(
f
"Prefilling remotely for request
{
request
.
request_id
}
with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
else
:
remote_prefill_params
=
None
logger
.
info
(
f
"Prefilling locally for request
{
request
.
request_id
}
with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
logger
.
debug
(
f
"Prefilling locally for request {
{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }
} with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
# The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
...
...
@@ -231,33 +259,61 @@ class VllmWorker:
)
else
:
# For aggregated serving, the vllm worker will directly send the encode request to the encode worker.
logger
.
debug
(
f
"Aggregated: request {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}"
" no prefill worker available, embeddings directly from encode worker."
)
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
# Doing this avoids unnessesary memory de/registration with NIXL.
embeddings
,
descriptor
=
self
.
_embeddings_descriptor
with
self
.
_connector
.
create_writable
(
descriptor
)
as
writable
:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_request
=
EncodeRequest
(
request_id
=
request_id
,
image_url
=
image_url
,
serialized_request
=
writable
.
to_serialized
(),
)
logger
.
debug
(
f
"Encode request:
{
encode_request
.
model_dump_json
()
}
"
)
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
EncodeRequest
(
image_url
=
request
.
image_url
,
).
model_dump_json
()
encode_request
.
model_dump_json
()
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
async
for
encode_response
in
encode_generator
:
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
()
)
image_features
=
torch
.
tensor
(
encode_output
.
image_features
,
device
=
device
,
dtype
=
torch
.
float16
logger
.
info
(
f
"Received response: {{ id:
{
encode_output
.
request_id
}
}}"
)
# Wait for the write operation to complete.
# This will block until the write operation is complete.
# This await should be a no-op since we've already received a response from the encode worker.
await
writable
.
wait_for_completion
()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params
=
None
logger
.
info
(
f
"Prefilling locally for request
{
request
.
request_id
}
with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
f
"Prefilling locally for request {
{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }
} with length
{
len
(
request
.
engine_prompt
[
'prompt_token_ids'
])
}
"
)
prompt_ids
=
request
.
engine_prompt
[
"prompt_token_ids"
]
# rust HTTP requires Delta streaming
request
.
sampling_params
.
output_kind
=
RequestOutputKind
.
DELTA
if
image_features
is
not
None
:
multi_modal_data
=
{
"image"
:
image_features
}
# When using aggregated serving, the encode worker will have provided the key-value cache updates via the prefill worker.
# When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
if
embeddings
is
not
None
:
logger
.
debug
(
"Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
)
multi_modal_data
=
{
"image"
:
embeddings
}
else
:
logger
.
debug
(
"Disaggregated: no embedding data required as prefill will have provided key-value cache updates via encode worker."
)
multi_modal_data
=
None
async
for
response
in
self
.
engine_client
.
generate
(
...
...
@@ -269,6 +325,9 @@ class VllmWorker:
request_id
=
request
.
request_id
,
remote_prefill_params
=
remote_prefill_params
,
):
logger
.
debug
(
f
"Yeilding response {{ id:
{
response
.
request_id
}
, prompt: '
{
response
.
prompt
}
' }}"
)
yield
MyRequestOutput
(
request_id
=
response
.
request_id
,
prompt
=
response
.
prompt
,
...
...
examples/multimodal/components/encode_worker.py
View file @
75e774d4
...
...
@@ -15,8 +15,10 @@
import
logging
from
io
import
BytesIO
from
queue
import
Queue
from
typing
import
AsyncIterator
import
connect
import
requests
import
torch
from
PIL
import
Image
...
...
@@ -24,10 +26,25 @@ from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from
utils.protocol
import
EncodeRequest
,
EncodeResponse
from
utils.vllm
import
parse_vllm_args
from
dynamo.sdk
import
endpoint
,
service
from
dynamo.sdk
import
async_on_start
,
endpoint
,
service
logger
=
logging
.
getLogger
(
__name__
)
try
:
import
cupy
as
array_module
if
not
array_module
.
cuda
.
is_available
():
raise
ImportError
(
"CUDA is not available."
)
DEVICE
=
"cuda"
logger
.
info
(
"Using cupy for array operations (GPU mode)."
)
except
ImportError
as
e
:
logger
.
warning
(
f
"Failed to import cupy, falling back to numpy:
{
e
}
."
)
import
numpy
as
array_module
DEVICE
=
"cpu"
CACHE_SIZE_MAXIMUM
=
8
@
service
(
dynamo
=
{
...
...
@@ -36,7 +53,7 @@ logger = logging.getLogger(__name__)
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
workers
=
1
,
)
class
EncodeWorker
:
class
Vllm
EncodeWorker
:
def
__init__
(
self
)
->
None
:
class_name
=
self
.
__class__
.
__name__
self
.
engine_args
=
parse_vllm_args
(
class_name
,
""
)
...
...
@@ -50,9 +67,50 @@ class EncodeWorker:
self
.
MODEL_ID
,
device_map
=
"auto"
,
torch_dtype
=
torch
.
float16
).
eval
()
self
.
_image_cache
:
dict
[
str
,
Image
.
Image
]
=
{}
self
.
_cache_queue
:
Queue
[
str
]
=
Queue
(
maxsize
=
CACHE_SIZE_MAXIMUM
)
@
endpoint
()
async
def
encode
(
self
,
request
:
EncodeRequest
)
->
AsyncIterator
[
EncodeResponse
]:
image
=
self
.
open_image
(
request
.
image_url
)
logger
.
debug
(
f
"Received encode request: {{ id:
{
request
.
request_id
}
, image_url: '
{
request
.
image_url
}
' }}."
)
request_id
=
request
.
request_id
image_url
=
request
.
image_url
.
lower
()
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
# Either retrieve the image from the cache or download it and then cache it.
if
request
.
image_url
in
self
.
_image_cache
:
image
=
self
.
_image_cache
[
image_url
]
logger
.
debug
(
f
"Image found in cache for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
else
:
image
=
self
.
open_image
(
image_url
)
logger
.
debug
(
f
"Downloading/opening image for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
# Cache the image for future use, and evict the oldest image if the cache is full.
if
self
.
_cache_queue
.
full
():
oldest_image_url
=
self
.
_cache_queue
.
get
()
del
self
.
_image_cache
[
oldest_image_url
]
self
.
_image_cache
[
request
.
image_url
]
=
image
self
.
_cache_queue
.
put
(
request
.
image_url
)
logger
.
debug
(
f
"Processing image for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}"
)
image_embeds
=
self
.
image_processor
(
images
=
image
,
return_tensors
=
"pt"
)
with
torch
.
no_grad
():
...
...
@@ -60,22 +118,56 @@ class EncodeWorker:
vision_outputs
=
self
.
vision_model
.
vision_tower
(
image_embeds
[
"pixel_values"
].
to
(
self
.
vision_model
.
device
)
)
logger
.
debug
(
"Vision model completed."
)
embeddings
=
vision_outputs
.
last_hidden_state
embeddings
=
self
.
vision_model
.
multi_modal_projector
(
embeddings
)
logger
.
debug
(
f
"Embeddings: {{ shape:
{
embeddings
.
shape
}
, dtype:
{
embeddings
.
dtype
}
, device:
{
embeddings
.
device
}
, ptr:
{
embeddings
.
data_ptr
()
}
, elements: {{ count:
{
embeddings
.
numel
()
}
, size:
{
embeddings
.
element_size
()
}
}} }}."
)
if
request
.
serialized_request
is
None
:
logger
.
error
(
f
"Request serialized_request is None for request: {{ id:
{
request_id
}
, image_url: '
{
image_url
}
' }}."
)
# Create a descriptor for the embeddings, this will register the memory with the connector (and the NIXL runtime).
descriptor
=
connect
.
Descriptor
(
embeddings
)
# Create a write operation using the serialized request and the descriptor.
# This will begin the RDMA transfer of the embeddings to the remote worker.
write_op
=
await
self
.
_connector
.
begin_write
(
descriptor
,
request
.
serialized_request
,
)
# Await for the write operation to complete.
# This will block until the data has been written to the remote worker or an error occurs.
await
write_op
.
wait_for_completion
()
image_features
=
vision_outputs
.
last_hidden_state
image_features
=
self
.
vision_model
.
multi_modal_projector
(
image_features
)
yield
EncodeResponse
(
image_features
=
image_features
.
tolist
()
request_id
=
request
.
request_id
,
).
model_dump_json
()
@
async_on_start
()
async
def
on_start
(
self
):
logger
.
info
(
"Startup started."
)
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self
.
_connector
=
connect
.
Connector
()
await
self
.
_connector
.
initialize
()
logger
.
info
(
"Startup completed."
)
def
open_image
(
self
,
image
:
str
)
->
Image
.
Image
:
# TODO: Have a seperate field for url and non url - and avoid auto detection
try
:
# Acquire the image and convert it to the format (RGB) the image processor model expects.
if
image
.
startswith
(
"http"
)
or
image
.
startswith
(
"https"
):
response
=
requests
.
get
(
image
)
image_data
=
Image
.
open
(
BytesIO
(
response
.
content
)).
convert
(
"RGB"
)
else
:
image_data
=
Image
.
open
(
image
).
convert
(
"RGB"
)
return
image_data
except
Exception
as
e
:
logger
.
error
(
f
"Error opening image:
{
e
}
"
)
raise
e
return
image_data
examples/multimodal/components/prefill_worker.py
View file @
75e774d4
...
...
@@ -20,8 +20,9 @@ import os
import
signal
import
sys
import
connect
import
torch
from
components.encode_worker
import
EncodeWorker
from
components.encode_worker
import
Vllm
EncodeWorker
from
pydantic
import
BaseModel
from
utils.logging
import
check_required_workers
from
utils.nixl
import
NixlMetadataStore
...
...
@@ -38,6 +39,11 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic
logger
=
logging
.
getLogger
(
__name__
)
# Constants for the shape and dtype of the embeddings tensor.
EMBEDDINGS_SHAPE
=
(
1
,
577
,
4096
)
EMBEDDINGS_DTYPE
=
torch
.
float16
EMBEDDINGS_DEVICE
=
"cuda"
class
RequestType
(
BaseModel
):
text
:
str
...
...
@@ -50,8 +56,8 @@ class RequestType(BaseModel):
resources
=
{
"gpu"
:
1
,
"cpu"
:
"10"
,
"memory"
:
"20Gi"
},
workers
=
1
,
)
class
PrefillWorker
:
encode_worker
=
depends
(
EncodeWorker
)
class
Vllm
PrefillWorker
:
encode_worker
=
depends
(
Vllm
EncodeWorker
)
def
__init__
(
self
):
class_name
=
self
.
__class__
.
__name__
...
...
@@ -95,7 +101,7 @@ class PrefillWorker:
raise
RuntimeError
(
"Failed to initialize engine client"
)
runtime
=
dynamo_context
[
"runtime"
]
enc_comp_ns
,
enc_comp_name
=
EncodeWorker
.
dynamo_address
()
# type: ignore
enc_comp_ns
,
enc_comp_name
=
Vllm
EncodeWorker
.
dynamo_address
()
# type: ignore
self
.
encode_worker_client
=
(
await
runtime
.
namespace
(
enc_comp_ns
)
.
component
(
enc_comp_name
)
...
...
@@ -103,6 +109,20 @@ class PrefillWorker:
.
client
()
)
self
.
_connector
=
connect
.
Connector
(
runtime
=
runtime
,
namespace
=
enc_comp_ns
)
await
self
.
_connector
.
initialize
()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings
=
torch
.
empty
(
EMBEDDINGS_SHAPE
,
dtype
=
EMBEDDINGS_DTYPE
,
device
=
EMBEDDINGS_DEVICE
,
)
descriptor
=
connect
.
Descriptor
(
embeddings
)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor
.
register_memory
(
self
.
_connector
)
self
.
_embeddings_descriptor
=
(
embeddings
,
descriptor
)
await
check_required_workers
(
self
.
encode_worker_client
,
self
.
min_workers
)
metadata
=
self
.
engine_client
.
nixl_metadata
...
...
@@ -119,19 +139,19 @@ class PrefillWorker:
sys
.
exit
(
1
)
task
.
add_done_callback
(
prefill_queue_handler_cb
)
logger
.
info
(
"
PrefillWorker initialized
"
)
logger
.
info
(
"
Initialization complete.
"
)
def
shutdown_vllm_engine
(
self
,
signum
,
frame
):
"""Shutdown the background loop"""
logger
.
info
(
f
"
Receiv
ed signal
{
signum
}
, shutting down
"
)
logger
.
info
(
f
"
Shutdown start
ed
,
signal
{
signum
}
received.
"
)
loop
=
asyncio
.
get_event_loop
()
try
:
self
.
engine_client
.
close
()
logger
.
info
(
"PrefillWorker shutdown complete"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error during shutdown:
{
e
}
"
)
finally
:
loop
.
stop
()
logger
.
info
(
"Shutdown complete."
)
async
def
prefill_queue_handler
(
self
):
logger
.
info
(
"Prefill queue handler entered"
)
...
...
@@ -166,16 +186,42 @@ class PrefillWorker:
if
request
.
multimodal_data_source
[
"image_url"
]
is
None
:
raise
ValueError
(
"No image url provided for prefill request"
)
request_id
=
request
.
request_id
engine_id
=
request
.
engine_id
image_url
=
request
.
multimodal_data_source
[
"image_url"
]
logger
.
info
(
f
"Received prefill request {{ id:
{
request_id
}
, engine_id:
{
engine_id
}
, image_url: '
{
image_url
}
' }}."
)
# Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
# Doing this avoids unnessesary memory de/registration with NIXL.
embeddings
,
descriptor
=
self
.
_embeddings_descriptor
# Create a new writable operation from the descriptor.
with
self
.
_connector
.
create_writable
(
descriptor
)
as
writable
:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_generator
=
await
self
.
encode_worker_client
.
round_robin
(
EncodeRequest
(
image_url
=
request
.
multimodal_data_source
[
"image_url"
],
request_id
=
request_id
,
image_url
=
image_url
,
serialized_request
=
writable
.
to_serialized
(),
).
model_dump_json
()
)
async
for
encode_response
in
encode_generator
:
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
())
image_features
=
torch
.
tensor
(
encode_output
.
image_features
,
device
=
"cpu"
,
dtype
=
torch
.
float16
encode_output
=
EncodeResponse
.
model_validate_json
(
encode_response
.
data
(),
)
logger
.
debug
(
f
"Received response: {{ id:
{
encode_output
.
request_id
}
}}."
)
# Wait for the write operation to complete.
# This will block until the write operation is complete.
# This await should be a no-op since we've already received a response from the encode worker.
await
writable
.
wait_for_completion
()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
sampling_params
=
request
.
sampling_params
sampling_params
.
max_tokens
=
1
...
...
@@ -184,26 +230,26 @@ class PrefillWorker:
remote_prefill_params
=
RemotePrefillParams
(
is_remote_decode
=
True
,
decode_block_ids
=
request
.
block_ids
,
decode_engine_id
=
request
.
engine_id
,
decode_engine_id
=
engine_id
,
decode_computed_block_ids
=
request
.
computed_block_ids
,
)
# TODO check if metadata has changed
# and reload - currently only loading once
if
request
.
engine_id
not
in
self
.
_loaded_metadata
:
if
engine_id
not
in
self
.
_loaded_metadata
:
remote_metadata
=
await
self
.
_metadata_store
.
get
(
request
.
engine_id
)
await
self
.
engine_client
.
add_remote_nixl_metadata
(
remote_metadata
)
logger
.
info
(
f
"Loaded nixl metadata from engine
{
request
.
engine_id
}
into "
f
"Loaded nixl metadata from engine
{
engine_id
}
into "
f
"engine
{
self
.
engine_client
.
nixl_metadata
.
engine_id
}
"
)
self
.
_loaded_metadata
.
add
(
request
.
engine_id
)
self
.
_loaded_metadata
.
add
(
engine_id
)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID
=
32000
embedding_size
=
image_feature
s
.
shape
[
1
]
embedding_size
=
embedding
s
.
shape
[
1
]
padding_size
=
embedding_size
-
1
image_token_index
=
request
.
prompt_token_ids
.
index
(
IMAGE_TOKEN_ID
)
dummy_token_index
=
image_token_index
+
1
...
...
@@ -213,10 +259,10 @@ class PrefillWorker:
)
async
for
_
in
self
.
engine_client
.
generate
(
request_id
=
request
.
request
_id
,
request_id
=
request_id
,
prompt
=
TokensPrompt
(
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
{
"image"
:
image_feature
s
},
multi_modal_data
=
{
"image"
:
embedding
s
},
),
sampling_params
=
sampling_params
,
remote_prefill_params
=
remote_prefill_params
,
...
...
examples/multimodal/components/processor.py
View file @
75e774d4
...
...
@@ -19,7 +19,7 @@ import uuid
from
enum
import
Enum
from
typing
import
AsyncIterator
,
Tuple
,
Union
from
components.worker
import
VllmWorker
from
components.
decode_
worker
import
Vllm
Decode
Worker
from
transformers
import
AutoTokenizer
from
utils.chat_processor
import
ChatProcessor
,
CompletionsProcessor
,
ProcessMixIn
from
utils.logging
import
check_required_workers
...
...
@@ -53,7 +53,7 @@ class Processor(ProcessMixIn):
vLLM pre and post processing
"""
worker
=
depends
(
VllmWorker
)
worker
=
depends
(
Vllm
Decode
Worker
)
def
__init__
(
self
):
class_name
=
self
.
__class__
.
__name__
...
...
@@ -83,7 +83,7 @@ class Processor(ProcessMixIn):
@
async_on_start
async
def
async_init
(
self
):
runtime
=
dynamo_context
[
"runtime"
]
comp_ns
,
comp_name
=
VllmWorker
.
dynamo_address
()
# type: ignore
comp_ns
,
comp_name
=
Vllm
Decode
Worker
.
dynamo_address
()
# type: ignore
self
.
worker_client
=
(
await
runtime
.
namespace
(
comp_ns
)
.
component
(
comp_name
)
...
...
examples/multimodal/configs/agg.yaml
View file @
75e774d4
...
...
@@ -21,7 +21,7 @@ Processor:
router
:
round-robin
common-configs
:
[
model
,
block-size
,
max-model-len
]
VllmWorker
:
Vllm
Decode
Worker
:
enforce-eager
:
true
max-num-batched-tokens
:
16384
enable-prefix-caching
:
true
...
...
@@ -33,7 +33,7 @@ VllmWorker:
gpu
:
1
common-configs
:
[
model
,
block-size
,
max-model-len
]
EncodeWorker
:
Vllm
EncodeWorker
:
tensor-parallel-size
:
1
router
:
random
ServiceArgs
:
...
...
examples/multimodal/configs/disagg.yaml
View file @
75e774d4
...
...
@@ -22,7 +22,7 @@ Processor:
router
:
round-robin
common-configs
:
[
model
,
block-size
]
VllmWorker
:
Vllm
Decode
Worker
:
remote-prefill
:
true
conditional-disagg
:
true
max-local-prefill-length
:
10
...
...
@@ -33,7 +33,7 @@ VllmWorker:
gpu
:
1
common-configs
:
[
model
,
block-size
,
max-model-len
,
kv-transfer-config
]
PrefillWorker
:
Vllm
PrefillWorker
:
max-num-batched-tokens
:
16384
ServiceArgs
:
workers
:
1
...
...
@@ -41,7 +41,7 @@ PrefillWorker:
gpu
:
1
common-configs
:
[
model
,
block-size
,
max-model-len
,
kv-transfer-config
]
EncodeWorker
:
Vllm
EncodeWorker
:
tensor-parallel-size
:
1
router
:
random
ServiceArgs
:
...
...
examples/multimodal/connect/__init__.py
0 → 100644
View file @
75e774d4
This diff is collapsed.
Click to expand it.
examples/multimodal/graphs/agg.py
View file @
75e774d4
...
...
@@ -6,6 +6,7 @@
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
...
...
@@ -13,9 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
components.encode_worker
import
EncodeWorker
from
components.decode_worker
import
VllmDecodeWorker
from
components.encode_worker
import
VllmEncodeWorker
from
components.frontend
import
Frontend
from
components.processor
import
Processor
from
components.worker
import
VllmWorker
Frontend
.
link
(
Processor
).
link
(
VllmWorker
).
link
(
EncodeWorker
)
Frontend
.
link
(
Processor
).
link
(
Vllm
Decode
Worker
).
link
(
Vllm
EncodeWorker
)
examples/multimodal/graphs/disagg.py
View file @
75e774d4
...
...
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
components.encode_worker
import
EncodeWorker
from
components.decode_worker
import
VllmDecodeWorker
from
components.encode_worker
import
VllmEncodeWorker
from
components.frontend
import
Frontend
from
components.prefill_worker
import
PrefillWorker
from
components.prefill_worker
import
Vllm
PrefillWorker
from
components.processor
import
Processor
from
components.worker
import
VllmWorker
Frontend
.
link
(
Processor
).
link
(
VllmWorker
).
link
(
PrefillWorker
).
link
(
EncodeWorker
)
Frontend
.
link
(
Processor
).
link
(
VllmDecodeWorker
).
link
(
VllmPrefillWorker
).
link
(
VllmEncodeWorker
)
examples/multimodal/utils/protocol.py
View file @
75e774d4
...
...
@@ -17,6 +17,7 @@
import
json
from
typing
import
Any
,
List
,
Optional
import
connect
import
msgspec
from
pydantic
import
BaseModel
,
ConfigDict
,
field_validator
from
pydantic_core
import
core_schema
...
...
@@ -111,12 +112,13 @@ class EncodeRequest(BaseModel):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
image_url
:
str
request_id
:
str
serialized_request
:
Optional
[
connect
.
SerializedRequest
]
=
None
class
EncodeResponse
(
BaseModel
):
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
image_features
:
List
[
List
[
List
[
float
]]]
request_id
:
str
class
MyRequestOutput
(
BaseModel
):
...
...
@@ -129,7 +131,6 @@ class MyRequestOutput(BaseModel):
"""
model_config
=
ConfigDict
(
arbitrary_types_allowed
=
True
)
request_id
:
str
prompt
:
Optional
[
str
]
=
None
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment