"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "4ae77dfd42041dc2defe21f6ccf76aecb4478812"
Unverified Commit 233a1e9a authored by Hyunjae Woo's avatar Hyunjae Woo Committed by GitHub
Browse files

feat: Enable ModelExpress P2P weight transfer in Dynamo vLLM worker (#6186)


Signed-off-by: default avatarHyunjae Woo <hwoo@nvidia.com>
parent 5624d144
...@@ -85,6 +85,7 @@ def parse_args() -> Config: ...@@ -85,6 +85,7 @@ def parse_args() -> Config:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Dynamo vLLM worker configuration", description="Dynamo vLLM worker configuration",
formatter_class=argparse.RawTextHelpFormatter, formatter_class=argparse.RawTextHelpFormatter,
allow_abbrev=False,
) )
# Build argument parser # Build argument parser
...@@ -215,6 +216,14 @@ def update_dynamo_config_with_engine( ...@@ -215,6 +216,14 @@ def update_dynamo_config_with_engine(
) )
dynamo_config.connector = normalized # type: ignore[assignment] dynamo_config.connector = normalized # type: ignore[assignment]
# Validate ModelExpress P2P server URL
if getattr(engine_config, "load_format", None) in ("mx-source", "mx-target"):
if not dynamo_config.model_express_url:
raise ValueError(
f"--model-express-url or MODEL_EXPRESS_URL env var is required "
f"when using --load-format={engine_config.load_format}"
)
def update_engine_config_with_dynamo( def update_engine_config_with_dynamo(
dynamo_config: Config, engine_config: AsyncEngineArgs dynamo_config: Config, engine_config: AsyncEngineArgs
......
...@@ -189,6 +189,16 @@ class DynamoVllmArgGroup(ArgGroup): ...@@ -189,6 +189,16 @@ class DynamoVllmArgGroup(ArgGroup):
help="Path to vLLM-Omni stage configuration YAML file for --omni mode (optional).", help="Path to vLLM-Omni stage configuration YAML file for --omni mode (optional).",
) )
# ModelExpress P2P
add_argument(
g,
flag_name="--model-express-url",
env_var="MODEL_EXPRESS_URL",
default=None,
help="ModelExpress P2P server URL (e.g., http://mx-server:8080). "
"Required when using --load-format=mx-source or --load-format=mx-target.",
)
# @dataclass() # @dataclass()
class DynamoVllmConfig(ConfigBase): class DynamoVllmConfig(ConfigBase):
...@@ -221,6 +231,9 @@ class DynamoVllmConfig(ConfigBase): ...@@ -221,6 +231,9 @@ class DynamoVllmConfig(ConfigBase):
omni: bool omni: bool
stage_configs_path: Optional[str] = None stage_configs_path: Optional[str] = None
# ModelExpress P2P
model_express_url: Optional[str] = None
def validate(self) -> None: def validate(self) -> None:
"""Validate vLLM wrapper configuration.""" """Validate vLLM wrapper configuration."""
self._validate_prefill_decode_exclusive() self._validate_prefill_decode_exclusive()
......
...@@ -401,6 +401,22 @@ def setup_vllm_engine(config, stat_logger=None): ...@@ -401,6 +401,22 @@ def setup_vllm_engine(config, stat_logger=None):
if engine_args.load_format == "gms": if engine_args.load_format == "gms":
engine_args.worker_cls = "gpu_memory_service.integrations.vllm.worker.GMSWorker" engine_args.worker_cls = "gpu_memory_service.integrations.vllm.worker.GMSWorker"
if engine_args.load_format in ("mx-source", "mx-target"):
try:
from modelexpress import register_modelexpress_loaders
# Ensure the ModelExpress server URL env var is set for the model loader
if config.model_express_url:
os.environ["MODEL_EXPRESS_URL"] = config.model_express_url
register_modelexpress_loaders()
# Use wrapper worker to ensure loaders are registered in spawned worker processes
engine_args.worker_cls = "modelexpress.vllm_worker.ModelExpressWorker"
except ImportError as e:
raise ImportError(
f"ModelExpress package required for --load-format={engine_args.load_format}. "
"Install with: pip install modelexpress"
) from e
# Load default sampling params from `generation_config.json` # Load default sampling params from `generation_config.json`
default_sampling_params = ( default_sampling_params = (
engine_args.create_model_config().get_diff_sampling_param() engine_args.create_model_config().get_diff_sampling_param()
......
...@@ -71,3 +71,72 @@ def test_custom_jinja_template_env_var_expansion(monkeypatch, mock_vllm_cli): ...@@ -71,3 +71,72 @@ def test_custom_jinja_template_env_var_expansion(monkeypatch, mock_vllm_cli):
f"Expected custom_jinja_template value to be {JINJA_TEMPLATE_PATH}, " f"Expected custom_jinja_template value to be {JINJA_TEMPLATE_PATH}, "
f"got {config.custom_jinja_template}" f"got {config.custom_jinja_template}"
) )
@pytest.mark.parametrize("load_format", ["mx-source", "mx-target"])
def test_model_express_url_from_cli_arg(mock_vllm_cli, load_format):
"""Test that --model-express-url is stored when load format is mx-source/mx-target."""
mock_vllm_cli(
"--model",
"Qwen/Qwen3-0.6B",
"--load-format",
load_format,
"--model-express-url",
"http://mx-server:8080",
)
config = parse_args()
assert config.model_express_url == "http://mx-server:8080"
@pytest.mark.parametrize("load_format", ["mx-source", "mx-target"])
def test_model_express_url_from_env_var(monkeypatch, mock_vllm_cli, load_format):
"""Test that MODEL_EXPRESS_URL env var is used as fallback."""
monkeypatch.setenv("MODEL_EXPRESS_URL", "http://env-mx:9090")
mock_vllm_cli(
"--model",
"Qwen/Qwen3-0.6B",
"--load-format",
load_format,
)
config = parse_args()
assert config.model_express_url == "http://env-mx:9090"
@pytest.mark.parametrize("load_format", ["mx-source", "mx-target"])
def test_model_express_url_cli_overrides_env(monkeypatch, mock_vllm_cli, load_format):
"""Test that --model-express-url takes precedence over MODEL_EXPRESS_URL."""
monkeypatch.setenv("MODEL_EXPRESS_URL", "http://env-mx:9090")
mock_vllm_cli(
"--model",
"Qwen/Qwen3-0.6B",
"--load-format",
load_format,
"--model-express-url",
"http://cli-mx:8080",
)
config = parse_args()
assert config.model_express_url == "http://cli-mx:8080"
@pytest.mark.parametrize("load_format", ["mx-source", "mx-target"])
def test_model_express_url_missing_raises(monkeypatch, mock_vllm_cli, load_format):
"""Test that missing server URL raises ValueError for mx load formats."""
monkeypatch.delenv("MODEL_EXPRESS_URL", raising=False)
mock_vllm_cli(
"--model",
"Qwen/Qwen3-0.6B",
"--load-format",
load_format,
)
with pytest.raises(
ValueError,
match=re.escape(f"--load-format={load_format}"),
):
parse_args()
def test_model_express_url_none_for_default_load_format(mock_vllm_cli):
"""Test that model_express_url is None when load format is not mx-*."""
mock_vllm_cli("--model", "Qwen/Qwen3-0.6B")
config = parse_args()
assert config.model_express_url is None
...@@ -45,6 +45,8 @@ vllm: ...@@ -45,6 +45,8 @@ vllm:
enable_media_ffmpeg: "true" enable_media_ffmpeg: "true"
enable_gpu_memory_service: "true" enable_gpu_memory_service: "true"
enable_kvbm: "true" enable_kvbm: "true"
enable_modelexpress_p2p: "false"
modelexpress_ref: "3d73992ce6c10e52ddc54f7f12af35d27e173f15"
sglang: sglang:
base_image: nvcr.io/nvidia/cuda-dl-base base_image: nvcr.io/nvidia/cuda-dl-base
......
...@@ -85,6 +85,10 @@ ARG LMCACHE_REF={{ context.vllm.lmcache_ref }} ...@@ -85,6 +85,10 @@ ARG LMCACHE_REF={{ context.vllm.lmcache_ref }}
# If left blank, then we will fallback to vLLM defaults # If left blank, then we will fallback to vLLM defaults
ARG DEEPGEMM_REF="" ARG DEEPGEMM_REF=""
# ModelExpress for P2P weight transfer (optional)
ARG ENABLE_MODELEXPRESS_P2P={{ context.vllm.enable_modelexpress_p2p }}
ARG MODELEXPRESS_REF={{ context.vllm.modelexpress_ref }}
{%- endif -%} {%- endif -%}
{% if framework == "trtllm" %} {% if framework == "trtllm" %}
......
...@@ -214,6 +214,15 @@ RUN --mount=type=cache,target=/home/dynamo/.cache/uv,uid=1000,gid=0,mode=0775 \ ...@@ -214,6 +214,15 @@ RUN --mount=type=cache,target=/home/dynamo/.cache/uv,uid=1000,gid=0,mode=0775 \
# pip/uv bypasses umask when creating .egg-info files, but chmod -R is fast here (small directory) # pip/uv bypasses umask when creating .egg-info files, but chmod -R is fast here (small directory)
chmod -R g+w /workspace/benchmarks chmod -R g+w /workspace/benchmarks
# Install ModelExpress for P2P weight transfer (optional)
ARG ENABLE_MODELEXPRESS_P2P
ARG MODELEXPRESS_REF
RUN if [ "${ENABLE_MODELEXPRESS_P2P}" = "true" ]; then \
echo "Installing ModelExpress from ref: ${MODELEXPRESS_REF}" && \
uv pip install "modelexpress @ git+https://github.com/ai-dynamo/modelexpress.git@${MODELEXPRESS_REF}#subdirectory=modelexpress_client/python"; \
fi
# Install common and test dependencies. Cache uv downloads; uv handles its own locking for this cache. # Install common and test dependencies. Cache uv downloads; uv handles its own locking for this cache.
RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requirements.txt \ RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requirements.txt \
--mount=type=bind,source=./container/deps/requirements.test.txt,target=/tmp/requirements.test.txt \ --mount=type=bind,source=./container/deps/requirements.test.txt,target=/tmp/requirements.test.txt \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment