Unverified Commit 75e774d4 authored by J Wyman's avatar J Wyman Committed by GitHub
Browse files

feat: NIXL Based RDMA Support w/ Multimodal Example (#1060)

parent 9acaa8d1
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
exclude: ^(src/grpc_generated|.*\.patch$) exclude: ^(src/grpc_generated|.*\.patch$|.*/connect/.*\.py)
repos: repos:
- repo: https://github.com/timothycrosley/isort - repo: https://github.com/timothycrosley/isort
rev: 5.12.0 rev: 5.12.0
...@@ -82,4 +82,4 @@ repos: ...@@ -82,4 +82,4 @@ repos:
# NOTE: pyright may be able to find other classes of errors not covered above, # NOTE: pyright may be able to find other classes of errors not covered above,
# but would require some configuring and venv setup to properly eliminate noise # but would require some configuring and venv setup to properly eliminate noise
# and give it visiblity into all the local and third_party packages expected. # and give it visiblity into all the local and third_party packages expected.
\ No newline at end of file
...@@ -19,10 +19,11 @@ import os ...@@ -19,10 +19,11 @@ import os
import signal import signal
from typing import Optional from typing import Optional
import connect
import torch import torch
from components.disagg_router import PyDisaggregatedRouter from components.disagg_router import PyDisaggregatedRouter
from components.encode_worker import EncodeWorker from components.encode_worker import VllmEncodeWorker
from components.prefill_worker import PrefillWorker from components.prefill_worker import VllmPrefillWorker
from transformers import LlavaForConditionalGeneration from transformers import LlavaForConditionalGeneration
from utils.logging import check_required_workers from utils.logging import check_required_workers
from utils.nixl import NixlMetadataStore from utils.nixl import NixlMetadataStore
...@@ -53,11 +54,11 @@ logger = logging.getLogger(__name__) ...@@ -53,11 +54,11 @@ logger = logging.getLogger(__name__)
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
) )
class VllmWorker: class VllmDecodeWorker:
# For disaggregated serving, we need to link the prefill worker to the vllm worker # For disaggregated serving, we need to link the prefill worker to the vllm worker
prefill_worker = depends(PrefillWorker) prefill_worker = depends(VllmPrefillWorker)
# For aggregated serving, we need to link the encode worker to the vllm worker. # For aggregated serving, we need to link the encode worker to the vllm worker.
encode_worker = depends(EncodeWorker) encode_worker = depends(VllmEncodeWorker)
def __init__(self): def __init__(self):
self.client = None self.client = None
...@@ -141,7 +142,11 @@ class VllmWorker: ...@@ -141,7 +142,11 @@ class VllmWorker:
vision_tower.vision_model.embeddings.position_embedding.num_embeddings vision_tower.vision_model.embeddings.position_embedding.num_embeddings
) )
else: else:
enc_comp_ns, enc_comp_name = EncodeWorker.dynamo_address() # type: ignore EMBEDDINGS_SHAPE = (1, 577, 4096)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda"
enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = ( self.encode_worker_client = (
await runtime.namespace(enc_comp_ns) await runtime.namespace(enc_comp_ns)
.component(enc_comp_name) .component(enc_comp_name)
...@@ -149,9 +154,22 @@ class VllmWorker: ...@@ -149,9 +154,22 @@ class VllmWorker:
.client() .client()
) )
self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
await self._connector.initialize()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings = torch.empty(
EMBEDDINGS_SHAPE, dtype=EMBEDDINGS_DTYPE, device=EMBEDDINGS_DEVICE
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
await check_required_workers(self.encode_worker_client, self.min_workers) await check_required_workers(self.encode_worker_client, self.min_workers)
self.disaggregated_router = None self.disaggregated_router = None
logger.info("VllmWorker has been initialized")
logger.info("Initialization complete.")
def shutdown_vllm_engine(self, signum, frame): def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop""" """Shutdown the background loop"""
...@@ -159,7 +177,7 @@ class VllmWorker: ...@@ -159,7 +177,7 @@ class VllmWorker:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
self.engine_client.close() self.engine_client.close()
logger.info("VllmWorker shutdown complete") logger.info("Shutdown complete.")
except Exception as e: except Exception as e:
logger.error(f"Error during shutdown: {e}") logger.error(f"Error during shutdown: {e}")
finally: finally:
...@@ -177,8 +195,18 @@ class VllmWorker: ...@@ -177,8 +195,18 @@ class VllmWorker:
@endpoint() @endpoint()
async def generate(self, request: vLLMMultimodalRequest): async def generate(self, request: vLLMMultimodalRequest):
image_features = None request_id = request.request_id
image_url = request.image_url
logger.info(
f"Received multimodal request {{ id: {request_id}, image_url: '{image_url}' }}."
)
embeddings = None
if self.do_remote_prefill: if self.do_remote_prefill:
logger.debug(
f"Disaggregated: request {{ id: {request_id}, image_url: '{image_url}' }}"
" prefill worker will populate the decode model's key-value cache ahead of time;"
" no direct encode worker interaction required."
)
if self.disaggregated_router is not None: if self.disaggregated_router is not None:
async with PrefillQueue.get_instance( async with PrefillQueue.get_instance(
nats_server=self._prefill_queue_nats_server, nats_server=self._prefill_queue_nats_server,
...@@ -195,21 +223,21 @@ class VllmWorker: ...@@ -195,21 +223,21 @@ class VllmWorker:
disagg_router_decision = True disagg_router_decision = True
if self.do_remote_prefill and disagg_router_decision: if self.do_remote_prefill and disagg_router_decision:
logger.debug(
f"Prefilling remotely for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
remote_prefill_params = RemotePrefillParams( remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True, is_remote_prefill=True,
remote_prefill_request_callback=self.get_remote_prefill_request_callback(), remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
# Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest # Pass the image url as part of the RemotePrefillParams, which will be passed to the prefill worker via RemotePrefillRequest
multimodal_data_source={ multimodal_data_source={
"image_url": request.image_url, "image_url": image_url,
}, },
) )
logger.info(
f"Prefilling remotely for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}"
)
else: else:
remote_prefill_params = None remote_prefill_params = None
logger.info( logger.debug(
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}" f"Prefilling locally for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}"
) )
# The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache. # The decode worker will pre-allocate the memory based on the prompt token length for the prefill worker to transfer the kv cache.
...@@ -231,33 +259,61 @@ class VllmWorker: ...@@ -231,33 +259,61 @@ class VllmWorker:
) )
else: else:
# For aggregated serving, the vllm worker will directly send the encode request to the encode worker. logger.debug(
encode_generator = await self.encode_worker_client.round_robin( f"Aggregated: request {{ id: {request_id}, image_url: '{image_url}' }}"
EncodeRequest( " no prefill worker available, embeddings directly from encode worker."
image_url=request.image_url,
).model_dump_json()
) )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
async for encode_response in encode_generator: # Doing this avoids unnessesary memory de/registration with NIXL.
encode_output = EncodeResponse.model_validate_json( embeddings, descriptor = self._embeddings_descriptor
encode_response.data()
with self._connector.create_writable(descriptor) as writable:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_request = EncodeRequest(
request_id=request_id,
image_url=image_url,
serialized_request=writable.to_serialized(),
) )
image_features = torch.tensor( logger.debug(f"Encode request: {encode_request.model_dump_json()}")
encode_output.image_features, device=device, dtype=torch.float16 encode_generator = await self.encode_worker_client.round_robin(
encode_request.model_dump_json()
) )
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(
encode_response.data()
)
logger.info(
f"Received response: {{ id: {encode_output.request_id} }}"
)
# Wait for the write operation to complete.
# This will block until the write operation is complete.
# This await should be a no-op since we've already received a response from the encode worker.
await writable.wait_for_completion()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params = None remote_prefill_params = None
logger.info( logger.info(
f"Prefilling locally for request {request.request_id} with length {len(request.engine_prompt['prompt_token_ids'])}" f"Prefilling locally for request {{ id: {request_id}, image_url: '{image_url}' }} with length {len(request.engine_prompt['prompt_token_ids'])}"
) )
prompt_ids = request.engine_prompt["prompt_token_ids"] prompt_ids = request.engine_prompt["prompt_token_ids"]
# rust HTTP requires Delta streaming # rust HTTP requires Delta streaming
request.sampling_params.output_kind = RequestOutputKind.DELTA request.sampling_params.output_kind = RequestOutputKind.DELTA
if image_features is not None: # When using aggregated serving, the encode worker will have provided the key-value cache updates via the prefill worker.
multi_modal_data = {"image": image_features} # When using disaggregated serving, the encode worker will have provided the key-value cache updates via the encode worker.
if embeddings is not None:
logger.debug(
"Aggregated: embedding data from encode worker provided via multi-modal data to decode model."
)
multi_modal_data = {"image": embeddings}
else: else:
logger.debug(
"Disaggregated: no embedding data required as prefill will have provided key-value cache updates via encode worker."
)
multi_modal_data = None multi_modal_data = None
async for response in self.engine_client.generate( async for response in self.engine_client.generate(
...@@ -269,6 +325,9 @@ class VllmWorker: ...@@ -269,6 +325,9 @@ class VllmWorker:
request_id=request.request_id, request_id=request.request_id,
remote_prefill_params=remote_prefill_params, remote_prefill_params=remote_prefill_params,
): ):
logger.debug(
f"Yeilding response {{ id: {response.request_id}, prompt: '{response.prompt}' }}"
)
yield MyRequestOutput( yield MyRequestOutput(
request_id=response.request_id, request_id=response.request_id,
prompt=response.prompt, prompt=response.prompt,
......
...@@ -15,8 +15,10 @@ ...@@ -15,8 +15,10 @@
import logging import logging
from io import BytesIO from io import BytesIO
from queue import Queue
from typing import AsyncIterator from typing import AsyncIterator
import connect
import requests import requests
import torch import torch
from PIL import Image from PIL import Image
...@@ -24,10 +26,25 @@ from transformers import AutoImageProcessor, LlavaForConditionalGeneration ...@@ -24,10 +26,25 @@ from transformers import AutoImageProcessor, LlavaForConditionalGeneration
from utils.protocol import EncodeRequest, EncodeResponse from utils.protocol import EncodeRequest, EncodeResponse
from utils.vllm import parse_vllm_args from utils.vllm import parse_vllm_args
from dynamo.sdk import endpoint, service from dynamo.sdk import async_on_start, endpoint, service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
import cupy as array_module
if not array_module.cuda.is_available():
raise ImportError("CUDA is not available.")
DEVICE = "cuda"
logger.info("Using cupy for array operations (GPU mode).")
except ImportError as e:
logger.warning(f"Failed to import cupy, falling back to numpy: {e}.")
import numpy as array_module
DEVICE = "cpu"
CACHE_SIZE_MAXIMUM = 8
@service( @service(
dynamo={ dynamo={
...@@ -36,7 +53,7 @@ logger = logging.getLogger(__name__) ...@@ -36,7 +53,7 @@ logger = logging.getLogger(__name__)
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
) )
class EncodeWorker: class VllmEncodeWorker:
def __init__(self) -> None: def __init__(self) -> None:
class_name = self.__class__.__name__ class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "") self.engine_args = parse_vllm_args(class_name, "")
...@@ -50,9 +67,50 @@ class EncodeWorker: ...@@ -50,9 +67,50 @@ class EncodeWorker:
self.MODEL_ID, device_map="auto", torch_dtype=torch.float16 self.MODEL_ID, device_map="auto", torch_dtype=torch.float16
).eval() ).eval()
self._image_cache: dict[str, Image.Image] = {}
self._cache_queue: Queue[str] = Queue(maxsize=CACHE_SIZE_MAXIMUM)
@endpoint() @endpoint()
async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]: async def encode(self, request: EncodeRequest) -> AsyncIterator[EncodeResponse]:
image = self.open_image(request.image_url) logger.debug(
f"Received encode request: {{ id: {request.request_id}, image_url: '{request.image_url}' }}."
)
request_id = request.request_id
image_url = request.image_url.lower()
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
# 2. Process the image using the image processor.
# 3. Run the image through the vision model's vision tower.
# 4. Run the results of the vision tower through the multi-modal projector.
# 5. Create a descriptor for the embeddings.
# 6. Create a write operation using the serialized request and the descriptor.
# 7. Await for the write operation to complete.
# 8. Yield the encode response.
# Either retrieve the image from the cache or download it and then cache it.
if request.image_url in self._image_cache:
image = self._image_cache[image_url]
logger.debug(
f"Image found in cache for request: {{ id: {request_id}, image_url: '{image_url}' }}."
)
else:
image = self.open_image(image_url)
logger.debug(
f"Downloading/opening image for request: {{ id: {request_id}, image_url: '{image_url}' }}."
)
# Cache the image for future use, and evict the oldest image if the cache is full.
if self._cache_queue.full():
oldest_image_url = self._cache_queue.get()
del self._image_cache[oldest_image_url]
self._image_cache[request.image_url] = image
self._cache_queue.put(request.image_url)
logger.debug(
f"Processing image for request: {{ id: {request_id}, image_url: '{image_url}' }}"
)
image_embeds = self.image_processor(images=image, return_tensors="pt") image_embeds = self.image_processor(images=image, return_tensors="pt")
with torch.no_grad(): with torch.no_grad():
...@@ -60,22 +118,56 @@ class EncodeWorker: ...@@ -60,22 +118,56 @@ class EncodeWorker:
vision_outputs = self.vision_model.vision_tower( vision_outputs = self.vision_model.vision_tower(
image_embeds["pixel_values"].to(self.vision_model.device) image_embeds["pixel_values"].to(self.vision_model.device)
) )
logger.debug("Vision model completed.")
embeddings = vision_outputs.last_hidden_state
embeddings = self.vision_model.multi_modal_projector(embeddings)
logger.debug(
f"Embeddings: {{ shape: {embeddings.shape}, dtype: {embeddings.dtype}, device: {embeddings.device}, ptr: {embeddings.data_ptr()}, elements: {{ count: {embeddings.numel()}, size: {embeddings.element_size()} }} }}."
)
if request.serialized_request is None:
logger.error(
f"Request serialized_request is None for request: {{ id: {request_id}, image_url: '{image_url}' }}."
)
# Create a descriptor for the embeddings, this will register the memory with the connector (and the NIXL runtime).
descriptor = connect.Descriptor(embeddings)
# Create a write operation using the serialized request and the descriptor.
# This will begin the RDMA transfer of the embeddings to the remote worker.
write_op = await self._connector.begin_write(
descriptor,
request.serialized_request,
)
# Await for the write operation to complete.
# This will block until the data has been written to the remote worker or an error occurs.
await write_op.wait_for_completion()
image_features = vision_outputs.last_hidden_state
image_features = self.vision_model.multi_modal_projector(image_features)
yield EncodeResponse( yield EncodeResponse(
image_features=image_features.tolist() request_id=request.request_id,
).model_dump_json() ).model_dump_json()
@async_on_start()
async def on_start(self):
logger.info("Startup started.")
# Create and initialize a dynamo connector for this worker.
# We'll needs this to move data between this worker and remote workers efficiently.
self._connector = connect.Connector()
await self._connector.initialize()
logger.info("Startup completed.")
def open_image(self, image: str) -> Image.Image: def open_image(self, image: str) -> Image.Image:
# TODO: Have a seperate field for url and non url - and avoid auto detection # TODO: Have a seperate field for url and non url - and avoid auto detection
try: try:
# Acquire the image and convert it to the format (RGB) the image processor model expects.
if image.startswith("http") or image.startswith("https"): if image.startswith("http") or image.startswith("https"):
response = requests.get(image) response = requests.get(image)
image_data = Image.open(BytesIO(response.content)).convert("RGB") image_data = Image.open(BytesIO(response.content)).convert("RGB")
else: else:
image_data = Image.open(image).convert("RGB") image_data = Image.open(image).convert("RGB")
return image_data
except Exception as e: except Exception as e:
logger.error(f"Error opening image: {e}") logger.error(f"Error opening image: {e}")
raise e raise e
return image_data
...@@ -20,8 +20,9 @@ import os ...@@ -20,8 +20,9 @@ import os
import signal import signal
import sys import sys
import connect
import torch import torch
from components.encode_worker import EncodeWorker from components.encode_worker import VllmEncodeWorker
from pydantic import BaseModel from pydantic import BaseModel
from utils.logging import check_required_workers from utils.logging import check_required_workers
from utils.nixl import NixlMetadataStore from utils.nixl import NixlMetadataStore
...@@ -38,6 +39,11 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic ...@@ -38,6 +39,11 @@ from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, servic
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Constants for the shape and dtype of the embeddings tensor.
EMBEDDINGS_SHAPE = (1, 577, 4096)
EMBEDDINGS_DTYPE = torch.float16
EMBEDDINGS_DEVICE = "cuda"
class RequestType(BaseModel): class RequestType(BaseModel):
text: str text: str
...@@ -50,8 +56,8 @@ class RequestType(BaseModel): ...@@ -50,8 +56,8 @@ class RequestType(BaseModel):
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
) )
class PrefillWorker: class VllmPrefillWorker:
encode_worker = depends(EncodeWorker) encode_worker = depends(VllmEncodeWorker)
def __init__(self): def __init__(self):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
...@@ -95,7 +101,7 @@ class PrefillWorker: ...@@ -95,7 +101,7 @@ class PrefillWorker:
raise RuntimeError("Failed to initialize engine client") raise RuntimeError("Failed to initialize engine client")
runtime = dynamo_context["runtime"] runtime = dynamo_context["runtime"]
enc_comp_ns, enc_comp_name = EncodeWorker.dynamo_address() # type: ignore enc_comp_ns, enc_comp_name = VllmEncodeWorker.dynamo_address() # type: ignore
self.encode_worker_client = ( self.encode_worker_client = (
await runtime.namespace(enc_comp_ns) await runtime.namespace(enc_comp_ns)
.component(enc_comp_name) .component(enc_comp_name)
...@@ -103,6 +109,20 @@ class PrefillWorker: ...@@ -103,6 +109,20 @@ class PrefillWorker:
.client() .client()
) )
self._connector = connect.Connector(runtime=runtime, namespace=enc_comp_ns)
await self._connector.initialize()
# Create a longer-lived buffer for receiving the image embeddings.
embeddings = torch.empty(
EMBEDDINGS_SHAPE,
dtype=EMBEDDINGS_DTYPE,
device=EMBEDDINGS_DEVICE,
)
descriptor = connect.Descriptor(embeddings)
# Register the descriptor w/ NIXL (this is optional, if not done here the connect subsytem will take care of this automatically).
descriptor.register_memory(self._connector)
self._embeddings_descriptor = (embeddings, descriptor)
await check_required_workers(self.encode_worker_client, self.min_workers) await check_required_workers(self.encode_worker_client, self.min_workers)
metadata = self.engine_client.nixl_metadata metadata = self.engine_client.nixl_metadata
...@@ -119,19 +139,19 @@ class PrefillWorker: ...@@ -119,19 +139,19 @@ class PrefillWorker:
sys.exit(1) sys.exit(1)
task.add_done_callback(prefill_queue_handler_cb) task.add_done_callback(prefill_queue_handler_cb)
logger.info("PrefillWorker initialized") logger.info("Initialization complete.")
def shutdown_vllm_engine(self, signum, frame): def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop""" """Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down") logger.info(f"Shutdown started, signal {signum} received.")
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
self.engine_client.close() self.engine_client.close()
logger.info("PrefillWorker shutdown complete")
except Exception as e: except Exception as e:
logger.error(f"Error during shutdown: {e}") logger.error(f"Error during shutdown: {e}")
finally: finally:
loop.stop() loop.stop()
logger.info("Shutdown complete.")
async def prefill_queue_handler(self): async def prefill_queue_handler(self):
logger.info("Prefill queue handler entered") logger.info("Prefill queue handler entered")
...@@ -166,62 +186,88 @@ class PrefillWorker: ...@@ -166,62 +186,88 @@ class PrefillWorker:
if request.multimodal_data_source["image_url"] is None: if request.multimodal_data_source["image_url"] is None:
raise ValueError("No image url provided for prefill request") raise ValueError("No image url provided for prefill request")
encode_generator = await self.encode_worker_client.round_robin( request_id = request.request_id
EncodeRequest( engine_id = request.engine_id
image_url=request.multimodal_data_source["image_url"], image_url = request.multimodal_data_source["image_url"]
).model_dump_json()
logger.info(
f"Received prefill request {{ id: {request_id}, engine_id: {engine_id}, image_url: '{image_url}' }}."
) )
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(encode_response.data()) # Extract the pre-allocated, reusable image embeddings tensor and its descriptor.
image_features = torch.tensor( # Doing this avoids unnessesary memory de/registration with NIXL.
encode_output.image_features, device="cpu", dtype=torch.float16 embeddings, descriptor = self._embeddings_descriptor
# Create a new writable operation from the descriptor.
with self._connector.create_writable(descriptor) as writable:
# Extract serialized metadata about the operation from the writable operation,
# and use it to create a new EncodeRequest.
encode_generator = await self.encode_worker_client.round_robin(
EncodeRequest(
request_id=request_id,
image_url=image_url,
serialized_request=writable.to_serialized(),
).model_dump_json()
) )
async for encode_response in encode_generator:
encode_output = EncodeResponse.model_validate_json(
encode_response.data(),
)
logger.debug(
f"Received response: {{ id: {encode_output.request_id} }}."
)
sampling_params = request.sampling_params # Wait for the write operation to complete.
sampling_params.max_tokens = 1 # This will block until the write operation is complete.
sampling_params.min_tokens = 1 # This await should be a no-op since we've already received a response from the encode worker.
await writable.wait_for_completion()
# At this point, the `embeddings` tensor is filled with the image embeddings from the remote encode worker.
remote_prefill_params = RemotePrefillParams( sampling_params = request.sampling_params
is_remote_decode=True, sampling_params.max_tokens = 1
decode_block_ids=request.block_ids, sampling_params.min_tokens = 1
decode_engine_id=request.engine_id,
decode_computed_block_ids=request.computed_block_ids,
)
# TODO check if metadata has changed remote_prefill_params = RemotePrefillParams(
# and reload - currently only loading once is_remote_decode=True,
if request.engine_id not in self._loaded_metadata: decode_block_ids=request.block_ids,
remote_metadata = await self._metadata_store.get(request.engine_id) decode_engine_id=engine_id,
await self.engine_client.add_remote_nixl_metadata(remote_metadata) decode_computed_block_ids=request.computed_block_ids,
logger.info( )
f"Loaded nixl metadata from engine {request.engine_id} into "
f"engine {self.engine_client.nixl_metadata.engine_id}" # TODO check if metadata has changed
# and reload - currently only loading once
if engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
logger.info(
f"Loaded nixl metadata from engine {engine_id} into "
f"engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(engine_id)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID = 32000
embedding_size = embeddings.shape[1]
padding_size = embedding_size - 1
image_token_index = request.prompt_token_ids.index(IMAGE_TOKEN_ID)
dummy_token_index = image_token_index + 1
prompt_token_ids = (
request.prompt_token_ids[:dummy_token_index]
+ request.prompt_token_ids[dummy_token_index + padding_size :]
) )
self._loaded_metadata.add(request.engine_id)
# To make sure the decode worker can pre-allocate the memory with the correct size for the prefill worker to transfer the kv cache,
# some placeholder dummy tokens were inserted based on the embedding size in the worker.py.
# The structure of the prompt is "\nUSER: <image> <dummy_tokens>\n<user_prompt>\nASSISTANT:", need to remove the dummy tokens after the image token.
IMAGE_TOKEN_ID = 32000
embedding_size = image_features.shape[1]
padding_size = embedding_size - 1
image_token_index = request.prompt_token_ids.index(IMAGE_TOKEN_ID)
dummy_token_index = image_token_index + 1
prompt_token_ids = (
request.prompt_token_ids[:dummy_token_index]
+ request.prompt_token_ids[dummy_token_index + padding_size :]
)
async for _ in self.engine_client.generate( async for _ in self.engine_client.generate(
request_id=request.request_id, request_id=request_id,
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
multi_modal_data={"image": image_features}, multi_modal_data={"image": embeddings},
), ),
sampling_params=sampling_params, sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params, remote_prefill_params=remote_prefill_params,
): ):
yield yield
@endpoint() @endpoint()
async def mock(self, req: RequestType): async def mock(self, req: RequestType):
......
...@@ -19,7 +19,7 @@ import uuid ...@@ -19,7 +19,7 @@ import uuid
from enum import Enum from enum import Enum
from typing import AsyncIterator, Tuple, Union from typing import AsyncIterator, Tuple, Union
from components.worker import VllmWorker from components.decode_worker import VllmDecodeWorker
from transformers import AutoTokenizer from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.logging import check_required_workers from utils.logging import check_required_workers
...@@ -53,7 +53,7 @@ class Processor(ProcessMixIn): ...@@ -53,7 +53,7 @@ class Processor(ProcessMixIn):
vLLM pre and post processing vLLM pre and post processing
""" """
worker = depends(VllmWorker) worker = depends(VllmDecodeWorker)
def __init__(self): def __init__(self):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
...@@ -83,7 +83,7 @@ class Processor(ProcessMixIn): ...@@ -83,7 +83,7 @@ class Processor(ProcessMixIn):
@async_on_start @async_on_start
async def async_init(self): async def async_init(self):
runtime = dynamo_context["runtime"] runtime = dynamo_context["runtime"]
comp_ns, comp_name = VllmWorker.dynamo_address() # type: ignore comp_ns, comp_name = VllmDecodeWorker.dynamo_address() # type: ignore
self.worker_client = ( self.worker_client = (
await runtime.namespace(comp_ns) await runtime.namespace(comp_ns)
.component(comp_name) .component(comp_name)
......
...@@ -21,7 +21,7 @@ Processor: ...@@ -21,7 +21,7 @@ Processor:
router: round-robin router: round-robin
common-configs: [model, block-size, max-model-len] common-configs: [model, block-size, max-model-len]
VllmWorker: VllmDecodeWorker:
enforce-eager: true enforce-eager: true
max-num-batched-tokens: 16384 max-num-batched-tokens: 16384
enable-prefix-caching: true enable-prefix-caching: true
...@@ -33,7 +33,7 @@ VllmWorker: ...@@ -33,7 +33,7 @@ VllmWorker:
gpu: 1 gpu: 1
common-configs: [model, block-size, max-model-len] common-configs: [model, block-size, max-model-len]
EncodeWorker: VllmEncodeWorker:
tensor-parallel-size: 1 tensor-parallel-size: 1
router: random router: random
ServiceArgs: ServiceArgs:
......
...@@ -22,7 +22,7 @@ Processor: ...@@ -22,7 +22,7 @@ Processor:
router: round-robin router: round-robin
common-configs: [model, block-size] common-configs: [model, block-size]
VllmWorker: VllmDecodeWorker:
remote-prefill: true remote-prefill: true
conditional-disagg: true conditional-disagg: true
max-local-prefill-length: 10 max-local-prefill-length: 10
...@@ -33,7 +33,7 @@ VllmWorker: ...@@ -33,7 +33,7 @@ VllmWorker:
gpu: 1 gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config] common-configs: [model, block-size, max-model-len, kv-transfer-config]
PrefillWorker: VllmPrefillWorker:
max-num-batched-tokens: 16384 max-num-batched-tokens: 16384
ServiceArgs: ServiceArgs:
workers: 1 workers: 1
...@@ -41,7 +41,7 @@ PrefillWorker: ...@@ -41,7 +41,7 @@ PrefillWorker:
gpu: 1 gpu: 1
common-configs: [model, block-size, max-model-len, kv-transfer-config] common-configs: [model, block-size, max-model-len, kv-transfer-config]
EncodeWorker: VllmEncodeWorker:
tensor-parallel-size: 1 tensor-parallel-size: 1
router: random router: random
ServiceArgs: ServiceArgs:
......
This diff is collapsed.
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
...@@ -13,9 +14,9 @@ ...@@ -13,9 +14,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from components.encode_worker import EncodeWorker from components.decode_worker import VllmDecodeWorker
from components.encode_worker import VllmEncodeWorker
from components.frontend import Frontend from components.frontend import Frontend
from components.processor import Processor from components.processor import Processor
from components.worker import VllmWorker
Frontend.link(Processor).link(VllmWorker).link(EncodeWorker) Frontend.link(Processor).link(VllmDecodeWorker).link(VllmEncodeWorker)
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from components.encode_worker import EncodeWorker from components.decode_worker import VllmDecodeWorker
from components.encode_worker import VllmEncodeWorker
from components.frontend import Frontend from components.frontend import Frontend
from components.prefill_worker import PrefillWorker from components.prefill_worker import VllmPrefillWorker
from components.processor import Processor from components.processor import Processor
from components.worker import VllmWorker
Frontend.link(Processor).link(VllmWorker).link(PrefillWorker).link(EncodeWorker) Frontend.link(Processor).link(VllmDecodeWorker).link(VllmPrefillWorker).link(
VllmEncodeWorker
)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import json import json
from typing import Any, List, Optional from typing import Any, List, Optional
import connect
import msgspec import msgspec
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, field_validator
from pydantic_core import core_schema from pydantic_core import core_schema
...@@ -111,12 +112,13 @@ class EncodeRequest(BaseModel): ...@@ -111,12 +112,13 @@ class EncodeRequest(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
image_url: str image_url: str
request_id: str
serialized_request: Optional[connect.SerializedRequest] = None
class EncodeResponse(BaseModel): class EncodeResponse(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
image_features: List[List[List[float]]]
class MyRequestOutput(BaseModel): class MyRequestOutput(BaseModel):
...@@ -129,7 +131,6 @@ class MyRequestOutput(BaseModel): ...@@ -129,7 +131,6 @@ class MyRequestOutput(BaseModel):
""" """
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str request_id: str
prompt: Optional[str] = None prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None prompt_token_ids: Optional[List[int]] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment