Unverified Commit 9bff03f2 authored by milesial's avatar milesial Committed by GitHub
Browse files

feat: vLLM backend with frontend media decoding (#5781)


Signed-off-by: default avatarAlexandre Milesi <milesial@users.noreply.github.com>
parent f70dd663
...@@ -2354,6 +2354,7 @@ dependencies = [ ...@@ -2354,6 +2354,7 @@ dependencies = [
"erased-serde", "erased-serde",
"etcd-client", "etcd-client",
"ffmpeg-next", "ffmpeg-next",
"flate2",
"futures", "futures",
"futures-util", "futures-util",
"galil-seiferas", "galil-seiferas",
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import time
import uuid
from typing import Any, Dict, Tuple
import numpy as np
import torch
from dynamo import nixl_connect
from dynamo.nixl_connect import OperationKind, RdmaMetadata, SerializedDescriptor
logger = logging.getLogger(__name__)
async def read_decoded_media_via_nixl(
connector: nixl_connect.Connector,
decoded_meta: Dict[str, Any],
return_metadata: bool = False,
) -> np.ndarray | Tuple[np.ndarray, Dict[str, Any]]:
"""
Read pre-decoded media data via NIXL RDMA transfer, into a CPU numpy array.
Args:
connector: Initialized NIXL connector for RDMA operations.
decoded_meta: Metadata dict from the frontent, containing nixl_metadata, shape, dtype, nixl_descriptor, and metadata.
Returns:
np.ndarray containing the transferred media data.
Dict[str, Any] containing the media metadata.
"""
rdma_metadata = decoded_meta["nixl_metadata"]
descriptor = decoded_meta["nixl_descriptor"]
remote_device = (
"cpu"
if descriptor.get("mem_type", "dram").lower() == "dram"
else f"cuda:{descriptor.get('device_id', 0)}"
)
rdma_metadata = RdmaMetadata(
descriptors=[
SerializedDescriptor(
device=remote_device,
ptr=descriptor["addr"],
size=descriptor["size"],
)
],
nixl_metadata=rdma_metadata,
notification_key=str(uuid.uuid4()),
operation_kind=int(OperationKind.READ),
)
# Create empty tensor to receive RDMA data
shape = decoded_meta["shape"]
dtype_str = decoded_meta.get("dtype", "uint8").lower()
alloc_start = time.perf_counter()
tensor = torch.empty(shape, dtype=getattr(torch, dtype_str))
alloc_end = time.perf_counter()
local_descriptor = nixl_connect.Descriptor(tensor)
read_start = time.perf_counter()
read_op = await connector.begin_read(rdma_metadata, local_descriptor)
await read_op.wait_for_completion()
read_end = time.perf_counter()
logger.debug(
f"Loaded media via NIXL RDMA: shape={shape}, "
f"read_time={read_end - read_start:.4f}s, "
f"alloc_time={alloc_end - alloc_start:.6f}s"
)
array = tensor.numpy() # zero-copy
if return_metadata:
return array, decoded_meta.get("metadata")
else:
return array
...@@ -71,6 +71,7 @@ class Config: ...@@ -71,6 +71,7 @@ class Config:
enable_multimodal: bool = False enable_multimodal: bool = False
multimodal_encode_prefill_worker: bool = False multimodal_encode_prefill_worker: bool = False
mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:" mm_prompt_template: str = "USER: <image>\n<prompt> ASSISTANT:"
frontend_decoding: bool = False
# vLLM-native encoder worker (ECConnector mode) # vLLM-native encoder worker (ECConnector mode)
vllm_native_encoder_worker: bool = False vllm_native_encoder_worker: bool = False
...@@ -217,6 +218,15 @@ def parse_args() -> Config: ...@@ -217,6 +218,15 @@ def parse_args() -> Config:
"'USER: <image> please describe the image ASSISTANT:'." "'USER: <image> please describe the image ASSISTANT:'."
), ),
) )
parser.add_argument(
"--frontend-decoding",
action="store_true",
help=(
"Enable frontend decoding of multimodal images. "
"When enabled, images are decoded in the Rust frontend and transferred to the backend via NIXL RDMA. "
"Without this flag, images are decoded in the Python backend (default behavior)."
),
)
parser.add_argument( parser.add_argument(
"--vllm-native-encoder-worker", "--vllm-native-encoder-worker",
action="store_true", action="store_true",
...@@ -402,6 +412,7 @@ def parse_args() -> Config: ...@@ -402,6 +412,7 @@ def parse_args() -> Config:
config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker config.multimodal_encode_prefill_worker = args.multimodal_encode_prefill_worker
config.enable_multimodal = args.enable_multimodal config.enable_multimodal = args.enable_multimodal
config.mm_prompt_template = args.mm_prompt_template config.mm_prompt_template = args.mm_prompt_template
config.frontend_decoding = args.frontend_decoding
config.vllm_native_encoder_worker = args.vllm_native_encoder_worker config.vllm_native_encoder_worker = args.vllm_native_encoder_worker
config.ec_connector_backend = args.ec_connector_backend config.ec_connector_backend = args.ec_connector_backend
config.ec_storage_path = args.ec_storage_path config.ec_storage_path = args.ec_storage_path
......
...@@ -21,7 +21,9 @@ from vllm.outputs import RequestOutput ...@@ -21,7 +21,9 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams, StructuredOutputsParams from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
import dynamo.nixl_connect as nixl_connect
from dynamo.common.utils.input_params import InputParamManager from dynamo.common.utils.input_params import InputParamManager
from dynamo.common.utils.media_nixl import read_decoded_media_via_nixl
from dynamo.common.utils.otel_tracing import build_trace_headers from dynamo.common.utils.otel_tracing import build_trace_headers
from dynamo.llm import ( from dynamo.llm import (
ModelInput, ModelInput,
...@@ -244,6 +246,7 @@ class BaseWorkerHandler(ABC): ...@@ -244,6 +246,7 @@ class BaseWorkerHandler(ABC):
config=None, config=None,
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None, shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False,
): ):
self.runtime = runtime self.runtime = runtime
self.component = component self.component = component
...@@ -257,6 +260,10 @@ class BaseWorkerHandler(ABC): ...@@ -257,6 +260,10 @@ class BaseWorkerHandler(ABC):
self.temp_dirs: list[tempfile.TemporaryDirectory] = [] self.temp_dirs: list[tempfile.TemporaryDirectory] = []
self.model_max_len = model_max_len self.model_max_len = model_max_len
self.enable_multimodal = enable_multimodal self.enable_multimodal = enable_multimodal
self.enable_frontend_decoding = enable_frontend_decoding
# NIXL connector for frontend decoding - lazy initialized
self._nixl_connector = None
self._nixl_connector_lock = asyncio.Lock()
# LoRA tracking # LoRA tracking
self.lora_id_for_name: dict[str, int] = {} self.lora_id_for_name: dict[str, int] = {}
self.lora_name_to_path: dict[str, str] = {} self.lora_name_to_path: dict[str, str] = {}
...@@ -879,6 +886,10 @@ class BaseWorkerHandler(ABC): ...@@ -879,6 +886,10 @@ class BaseWorkerHandler(ABC):
""" """
Load a batch of images from multimodal data items. Load a batch of images from multimodal data items.
Supports two paths:
1. Url variant: Download and decode image from URL (default)
2. Decoded variant: Read pre-decoded image via NIXL RDMA (requires --frontend-decoding)
Args: Args:
image_mm_items: List of multimodal data items for images image_mm_items: List of multimodal data items for images
Returns: Returns:
...@@ -887,25 +898,41 @@ class BaseWorkerHandler(ABC): ...@@ -887,25 +898,41 @@ class BaseWorkerHandler(ABC):
Exception: If any image fails to load Exception: If any image fails to load
""" """
image_futures = [] image_futures = []
for item in image_mm_items: for item in image_mm_items:
if isinstance(item, dict) and URL_VARIANT_KEY in item: if isinstance(item, dict) and URL_VARIANT_KEY in item:
# URL path: download and decode in Python backend
url = item[URL_VARIANT_KEY] url = item[URL_VARIANT_KEY]
image_futures.append(self.image_loader.load_image(url)) image_futures.append(self.image_loader.load_image(url))
logger.debug(f"Preparing to load image from URL: {url[:80]}...") logger.debug(f"Preparing to load image from URL: {url[:80]}...")
elif isinstance(item, dict) and DECODED_VARIANT_KEY in item: elif isinstance(item, dict) and DECODED_VARIANT_KEY in item:
logger.warning( if self.enable_frontend_decoding:
"Decoded multimodal data not yet supported in standard worker" async with self._nixl_connector_lock:
) if self._nixl_connector is None:
self._nixl_connector = nixl_connect.Connector()
await self._nixl_connector.initialize()
metadata = item[DECODED_VARIANT_KEY]
image_futures.append(
read_decoded_media_via_nixl(self._nixl_connector, metadata)
)
else:
logger.error(
"Received Decoded multimodal data but --frontend-decoding not enabled. "
"Use --frontend-decoding flag to enable NIXL RDMA image transfer."
)
raise ValueError("Could not load decoded media from frontend")
# Process images in parallel
results = await asyncio.gather(*image_futures, return_exceptions=True) results = await asyncio.gather(*image_futures, return_exceptions=True)
loaded_images = [] loaded_images = []
collective_exceptions = "" collective_exceptions = ""
for i, result in enumerate(results): for media_item, result in zip(image_mm_items, results):
if isinstance(result, Exception): if isinstance(result, Exception):
url = image_mm_items[i].get(URL_VARIANT_KEY, "unknown") source = media_item.get(URL_VARIANT_KEY, "decoded")
logger.error(f"Failed to load image from {url[:80]}...: {result}") logger.error(f"Failed to load image from {source[:80]}...: {result}")
collective_exceptions += ( collective_exceptions += (
f"Failed to load image from {url[:80]}...: {result}\n" f"Failed to load image from {source[:80]}...: {result}\n"
) )
continue continue
loaded_images.append(result) loaded_images.append(result)
...@@ -1238,6 +1265,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1238,6 +1265,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
config=None, config=None,
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None, shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False,
): ):
super().__init__( super().__init__(
runtime, runtime,
...@@ -1250,6 +1278,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1250,6 +1278,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
config, config,
use_vllm_tokenizer, use_vllm_tokenizer,
shutdown_event, shutdown_event,
enable_frontend_decoding,
) )
async def generate(self, request, context): async def generate(self, request, context):
...@@ -1451,6 +1480,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1451,6 +1480,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
config=None, config=None,
use_vllm_tokenizer: bool = False, use_vllm_tokenizer: bool = False,
shutdown_event: asyncio.Event | None = None, shutdown_event: asyncio.Event | None = None,
enable_frontend_decoding: bool = False,
): ):
super().__init__( super().__init__(
runtime, runtime,
...@@ -1463,6 +1493,7 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1463,6 +1493,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
config, config,
use_vllm_tokenizer, use_vllm_tokenizer,
shutdown_event, shutdown_event,
enable_frontend_decoding,
) )
async def generate(self, request, context): async def generate(self, request, context):
......
...@@ -27,6 +27,17 @@ from dynamo.llm import ( ...@@ -27,6 +27,17 @@ from dynamo.llm import (
fetch_llm, fetch_llm,
register_llm, register_llm,
) )
# Optional imports for frontend decoding support
try:
from dynamo.llm import MediaDecoder, MediaFetcher
MEDIA_DECODER_AVAILABLE = True
except ImportError:
MediaDecoder = None
MediaFetcher = None
MEDIA_DECODER_AVAILABLE = False
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.multimodal_handlers import ( from dynamo.vllm.multimodal_handlers import (
...@@ -407,6 +418,23 @@ async def register_vllm_model( ...@@ -407,6 +418,23 @@ async def register_vllm_model(
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1) data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
runtime_config.data_parallel_size = data_parallel_size runtime_config.data_parallel_size = data_parallel_size
# Configure media decoder for frontend image decoding when enabled
# This enables frontend to decode images and transfer via NIXL RDMA
media_decoder = None
media_fetcher = None
if config.frontend_decoding:
if not MEDIA_DECODER_AVAILABLE:
raise RuntimeError(
"--frontend-decoding requires MediaDecoder support. "
"Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
)
media_decoder = MediaDecoder()
media_decoder.enable_image({"limits": {"max_alloc": 128 * 1024 * 1024}})
# media_decoder.enable_video({})
media_fetcher = MediaFetcher()
media_fetcher.timeout_ms(30000)
await register_llm( await register_llm(
model_input, model_input,
model_type, model_type,
...@@ -417,6 +445,8 @@ async def register_vllm_model( ...@@ -417,6 +445,8 @@ async def register_vllm_model(
migration_limit=migration_limit, migration_limit=migration_limit,
runtime_config=runtime_config, runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template, custom_template_path=config.custom_jinja_template,
media_decoder=media_decoder,
media_fetcher=media_fetcher,
) )
...@@ -449,6 +479,7 @@ async def init_prefill( ...@@ -449,6 +479,7 @@ async def init_prefill(
config=config, config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer, use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
enable_frontend_decoding=config.frontend_decoding,
) )
handler.add_temp_dir(prometheus_temp_dir) handler.add_temp_dir(prometheus_temp_dir)
...@@ -577,6 +608,7 @@ async def init( ...@@ -577,6 +608,7 @@ async def init(
config=config, config=config,
use_vllm_tokenizer=config.use_vllm_tokenizer, use_vllm_tokenizer=config.use_vllm_tokenizer,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
enable_frontend_decoding=config.frontend_decoding,
) )
handler.add_temp_dir(prometheus_temp_dir) handler.add_temp_dir(prometheus_temp_dir)
......
...@@ -1662,6 +1662,7 @@ dependencies = [ ...@@ -1662,6 +1662,7 @@ dependencies = [
"erased-serde", "erased-serde",
"etcd-client", "etcd-client",
"ffmpeg-next", "ffmpeg-next",
"flate2",
"futures", "futures",
"futures-util", "futures-util",
"galil-seiferas", "galil-seiferas",
......
...@@ -441,10 +441,13 @@ class ActiveOperation(AbstractOperation): ...@@ -441,10 +441,13 @@ class ActiveOperation(AbstractOperation):
self._status = OperationStatus.CANCELLED self._status = OperationStatus.CANCELLED
self._xfer_hndl = None self._xfer_hndl = None
async def _wait_for_completion_(self) -> None: async def _wait_for_completion_(
self, min_poll_ms=5, max_poll_ms=100, backoff_factor=1.5
) -> None:
# Loop until the operation is no longer in progress (or "initialized"), # Loop until the operation is no longer in progress (or "initialized"),
# yielding control to the event loop to allow other operations to run. # yielding control to the event loop to allow other operations to run.
iteration_count = 0 iteration_count = 0
sleep_time = min_poll_ms
while True: while True:
if iteration_count & 10 == 0: if iteration_count & 10 == 0:
logger.debug( logger.debug(
...@@ -452,10 +455,9 @@ class ActiveOperation(AbstractOperation): ...@@ -452,10 +455,9 @@ class ActiveOperation(AbstractOperation):
) )
match self.status: match self.status:
# "in progress" or "initialized" means the operation is ongoing. # "in progress" or "initialized" means the operation is ongoing.
case OperationStatus.INITIALIZED: case OperationStatus.INITIALIZED | OperationStatus.IN_PROGRESS:
await asyncio.sleep(0.1) await asyncio.sleep(sleep_time / 1000)
case OperationStatus.IN_PROGRESS: sleep_time = min(sleep_time * backoff_factor, max_poll_ms)
await asyncio.sleep(0.1)
# Any other state indicates completion or error. # Any other state indicates completion or error.
case _: case _:
return return
...@@ -1371,16 +1373,18 @@ class PassiveOperation(AbstractOperation): ...@@ -1371,16 +1373,18 @@ class PassiveOperation(AbstractOperation):
f")" f")"
) )
async def _wait_for_completion_(self) -> None: async def _wait_for_completion_(
self, min_poll_ms=5, max_poll_ms=100, backoff_factor=1.5
) -> None:
# Loop until the operation is no longer in progress (or "initialized"), # Loop until the operation is no longer in progress (or "initialized"),
# yielding control to the event loop to allow other operations to run. # yielding control to the event loop to allow other operations to run.
sleep_time = min_poll_ms
while True: while True:
match self.status: match self.status:
# "in progress" or "initialized" means the operation is ongoing. # "in progress" or "initialized" means the operation is ongoing.
case OperationStatus.INITIALIZED: case OperationStatus.INITIALIZED | OperationStatus.IN_PROGRESS:
await asyncio.sleep(0.1) await asyncio.sleep(sleep_time / 1000)
case OperationStatus.IN_PROGRESS: sleep_time = min(sleep_time * backoff_factor, max_poll_ms)
await asyncio.sleep(0.1)
# Any other state indicates completion or error. # Any other state indicates completion or error.
case _: case _:
return return
......
...@@ -25,7 +25,7 @@ block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:nix", "dep:aligned-vec", "de ...@@ -25,7 +25,7 @@ block-manager = ["dep:nixl-sys", "dep:cudarc", "dep:nix", "dep:aligned-vec", "de
block-manager-bench = ["block-manager", "testing-full", "dep:clap", "dep:indicatif"] block-manager-bench = ["block-manager", "testing-full", "dep:clap", "dep:indicatif"]
cuda = ["dep:cudarc"] cuda = ["dep:cudarc"]
integration = ["dynamo-runtime/integration"] integration = ["dynamo-runtime/integration"]
media-nixl = ["dep:nixl-sys", "dep:dynamo-memory"] media-nixl = ["dep:nixl-sys", "dep:dynamo-memory", "dep:flate2"]
media-ffmpeg = ["dep:video-rs", "dep:ffmpeg-next", "dep:memfile", "media-nixl"] media-ffmpeg = ["dep:video-rs", "dep:ffmpeg-next", "dep:memfile", "media-nixl"]
kv-router-stress = ["dep:clap", "dep:indicatif"] kv-router-stress = ["dep:clap", "dep:indicatif"]
...@@ -110,6 +110,9 @@ nixl-sys = { version = "=0.9.0", optional = true } ...@@ -110,6 +110,9 @@ nixl-sys = { version = "=0.9.0", optional = true }
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
nix = { version = "0.26", optional = true } nix = { version = "0.26", optional = true }
# media-nixl (zlib compression for NIXL metadata)
flate2 = { version = "1", optional = true }
# block_manager_bench # block_manager_bench
clap = { version = "4.5.49", features = ["derive"], optional = true } clap = { version = "4.5.49", features = ["derive"], optional = true }
indicatif = { version = "0.18.0", optional = true } indicatif = { version = "0.18.0", optional = true }
......
...@@ -10,6 +10,8 @@ use { ...@@ -10,6 +10,8 @@ use {
base64::{Engine as _, engine::general_purpose}, base64::{Engine as _, engine::general_purpose},
dynamo_memory::SystemStorage, dynamo_memory::SystemStorage,
dynamo_memory::nixl::{self, NixlAgent, NixlDescriptor, RegisteredView}, dynamo_memory::nixl::{self, NixlAgent, NixlDescriptor, RegisteredView},
flate2::{Compression, write::ZlibEncoder},
std::io::Write,
std::sync::Arc, std::sync::Arc,
}; };
...@@ -108,7 +110,8 @@ impl<D: Dimension> TryFrom<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData { ...@@ -108,7 +110,8 @@ impl<D: Dimension> TryFrom<ArrayBase<OwnedRepr<u8>, D>> for DecodedMediaData {
} }
// Get NIXL metadata for a descriptor // Get NIXL metadata for a descriptor
// Avoids cross-request leak possibility and reduces metadata size // Returns zlib-compressed, base64-encoded metadata in format: "b64:<compressed_base64>"
// This format matches what Python nixl_connect expects for RdmaMetadata.nixl_metadata
// TODO: pre-allocate a fixed NIXL-registered RAM pool so metadata can be cached on the target? // TODO: pre-allocate a fixed NIXL-registered RAM pool so metadata can be cached on the target?
#[cfg(feature = "media-nixl")] #[cfg(feature = "media-nixl")]
pub fn get_nixl_metadata(agent: &NixlAgent, _storage: &SystemStorage) -> Result<String> { pub fn get_nixl_metadata(agent: &NixlAgent, _storage: &SystemStorage) -> Result<String> {
...@@ -118,7 +121,12 @@ pub fn get_nixl_metadata(agent: &NixlAgent, _storage: &SystemStorage) -> Result< ...@@ -118,7 +121,12 @@ pub fn get_nixl_metadata(agent: &NixlAgent, _storage: &SystemStorage) -> Result<
// reg_desc_list.add_storage_desc(storage)?; // reg_desc_list.add_storage_desc(storage)?;
// let nixl_partial_md = agent.raw_agent().get_local_partial_md(&reg_desc_list, None)?; // let nixl_partial_md = agent.raw_agent().get_local_partial_md(&reg_desc_list, None)?;
let b64_encoded = general_purpose::STANDARD.encode(&nixl_md); // Compress with zlib (level 6, matching Python's default)
let mut encoder = ZlibEncoder::new(Vec::new(), Compression::new(6));
encoder.write_all(&nixl_md)?;
let compressed = encoder.finish()?;
let b64_encoded = general_purpose::STANDARD.encode(&compressed);
Ok(format!("b64:{}", b64_encoded)) Ok(format!("b64:{}", b64_encoded))
} }
......
...@@ -304,6 +304,37 @@ vllm_configs = { ...@@ -304,6 +304,37 @@ vllm_configs = {
) )
], ],
), ),
"multimodal_agg_frontend_decoding": VLLMConfig(
name="multimodal_agg_frontend_decoding",
directory=vllm_dir,
script_name="agg_multimodal.sh",
marks=[pytest.mark.gpu_1, pytest.mark.pre_merge],
model="Qwen/Qwen2-VL-2B-Instruct",
# Pass --frontend-decoding to enable Rust frontend image decoding + NIXL RDMA transfer
script_args=[
"--model",
"Qwen/Qwen2-VL-2B-Instruct",
"--frontend-decoding",
],
request_payloads=[
chat_payload(
[
{
"type": "text",
"text": "What colors are in the following image? Respond only with the colors.",
},
{
"type": "image_url",
"image_url": {"url": MULTIMODAL_IMG_URL},
},
],
repeat_count=1,
expected_response=["green"],
temperature=0.0,
max_tokens=100,
)
],
),
"multimodal_agg_llava_epd": VLLMConfig( "multimodal_agg_llava_epd": VLLMConfig(
name="multimodal_agg_llava_epd", name="multimodal_agg_llava_epd",
directory=vllm_dir, directory=vllm_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment