Unverified Commit c1858b7e authored by Aaron Hao's avatar Aaron Hao Committed by GitHub
Browse files

[Feat][RL][1/2] Native Weight Syncing API: NCCL (#31943)


Signed-off-by: default avatarahao-anyscale <ahao@anyscale.com>
Signed-off-by: default avatarAaron Hao <ahao@anyscale.com>
Co-authored-by: default avatarSumanthRH <sumanthrh99@gmail.com>
parent 82914d2a
......@@ -6,6 +6,10 @@ from collections.abc import AsyncGenerator, Iterable, Mapping
from typing import Any
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.inputs.data import PromptType, StreamingInput
from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput
......@@ -191,3 +195,13 @@ class EngineClient(ABC):
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
"""Get supported tasks"""
raise NotImplementedError
async def init_weight_transfer_engine(
self, init_request: WeightTransferInitRequest
) -> None:
"""Initialize weight transfer for RL training."""
raise NotImplementedError
async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
"""Batched weight update for RL training."""
raise NotImplementedError
......@@ -34,6 +34,10 @@ from vllm.config.model import (
RunnerOption,
TokenizerMode,
)
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
......@@ -360,6 +364,23 @@ class LLM:
def get_tokenizer(self) -> TokenizerLike:
return self.llm_engine.get_tokenizer()
def get_world_size(self, include_dp: bool = True) -> int:
"""Get the world size from the parallel config.
Args:
include_dp: If True (default), returns the world size including
data parallelism (TP * PP * DP). If False, returns the world
size without data parallelism (TP * PP).
Returns:
The world size (tensor_parallel_size * pipeline_parallel_size),
optionally multiplied by data_parallel_size if include_dp is True.
"""
parallel_config = self.llm_engine.vllm_config.parallel_config
if include_dp:
return parallel_config.world_size_across_dp
return parallel_config.world_size
def reset_mm_cache(self) -> None:
self.input_processor.clear_mm_cache()
self.llm_engine.reset_mm_cache()
......@@ -1903,6 +1924,38 @@ class LLM:
# its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id))
def init_weight_transfer_engine(
self, request: WeightTransferInitRequest | dict
) -> None:
"""
Initialize weight transfer for RL training.
Args:
request: Weight transfer initialization request with backend-specific info
"""
init_info_dict = (
request["init_info"] if isinstance(request, dict) else request.init_info
)
self.llm_engine.collective_rpc(
"init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
)
def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
"""
Update the weights of the model.
Args:
request: Weight update request with backend-specific update info
"""
update_info_dict = (
request["update_info"] if isinstance(request, dict) else request.update_info
)
self.llm_engine.collective_rpc(
"update_weights", kwargs={"update_info": update_info_dict}
)
def __repr__(self) -> str:
"""Return a transformers-style hierarchical view of the model."""
# Cache the result to avoid repeated collective_rpc calls
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from http import HTTPStatus
from fastapi import APIRouter, FastAPI, Query, Request
from fastapi import APIRouter, FastAPI, HTTPException, Query, Request
from fastapi.responses import JSONResponse
import vllm.envs as envs
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
......@@ -98,5 +103,63 @@ async def is_paused(raw_request: Request) -> JSONResponse:
return JSONResponse(content={"is_paused": paused})
@router.post("/init_weight_transfer_engine")
async def init_weight_transfer_engine(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
init_info = body.get("init_info")
if init_info is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'init_info' in request body",
)
await engine_client(raw_request).init_weight_transfer_engine(
WeightTransferInitRequest(init_info=init_info)
)
return JSONResponse(content={"message": "Weight transfer initialized"})
@router.post("/update_weights")
async def update_weights(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
update_info = body.get("update_info")
if update_info is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'update_info' in request body",
)
await engine_client(raw_request).update_weights(
request=WeightTransferUpdateRequest(update_info=update_info)
)
return JSONResponse(content={"message": "Weights updated"})
@router.get("/get_world_size")
async def get_world_size(
raw_request: Request,
include_dp: bool = Query(True),
):
"""Get the world size from the parallel config.
Args:
include_dp: If True (default), returns the world size including
data parallelism (TP * PP * DP). If False, returns the world
size without data parallelism (TP * PP).
"""
parallel_config = engine_client(raw_request).vllm_config.parallel_config
if include_dp:
world_size = parallel_config.world_size_across_dp
else:
world_size = parallel_config.world_size
return JSONResponse(content={"world_size": world_size})
def attach_router(app: FastAPI):
if not envs.VLLM_SERVER_DEV_MODE:
return
app.include_router(router)
......@@ -649,6 +649,9 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
)
# Activations not quantized for marlin.
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8.
......@@ -908,6 +911,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
......@@ -1241,6 +1247,9 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
layer.w2_input_scale,
)
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
......
......@@ -216,6 +216,11 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
# Materialize layer tensors onto device
materialize_layer(layer)
# Reset FP8 online quantization flag so process_weights_after_loading
# will run again during reload
if hasattr(layer, "_already_called_process_weights_after_loading"):
delattr(layer, "_already_called_process_weights_after_loading")
# Unwrap layerwise loading wrappers
for param in get_layer_tensors(layer).values():
param.weight_loader = _get_original_loader(param)
......
......@@ -14,6 +14,10 @@ import torch
import vllm.envs as envs
from vllm import TokensPrompt
from vllm.config import VllmConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient
from vllm.inputs import PromptType, StreamingInput
......@@ -1011,3 +1015,44 @@ class AsyncLLM(EngineClient):
@property
def dead_error(self) -> BaseException:
return EngineDeadError()
async def init_weight_transfer_engine(
self, request: WeightTransferInitRequest
) -> None:
"""
Initialize weight transfer for RL training.
Args:
request: Weight transfer initialization request with backend-specific info
"""
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
)
if isinstance(request, WeightTransferInitRequest):
init_info_dict = request.init_info
else:
raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}")
await self.collective_rpc(
"init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
)
async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
"""
Batched weight update for RL training.
Args:
request: Weight update request with backend-specific update info
"""
if isinstance(request, WeightTransferUpdateRequest):
update_info_dict = request.update_info
else:
raise TypeError(
f"Expected WeightTransferUpdateRequest, got {type(request)}"
)
await self.collective_rpc(
"update_weights", kwargs={"update_info": update_info_dict}
)
......@@ -33,6 +33,7 @@ from vllm.distributed.parallel_state import (
get_pp_group,
get_tp_group,
)
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.models.interfaces import is_mixture_of_experts
......@@ -89,6 +90,16 @@ class Worker(WorkerBase):
# Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# Weight transfer engine (initialized on-demand)
self.weight_transfer_engine = (
WeightTransferEngineFactory.create_engine(
self.vllm_config.weight_transfer_config,
self.vllm_config.parallel_config,
)
if self.vllm_config.weight_transfer_config is not None
else None
)
# Torch/CUDA profiler. Enabled and configured through profiler_config.
self.profiler: Any | None = None
profiler_config = vllm_config.profiler_config
......@@ -932,6 +943,69 @@ class Worker(WorkerBase):
tensorizer_config=tensorizer_config,
)
def init_weight_transfer_engine(self, init_info: dict) -> None:
"""
Initialize weight transfer mechanism.
For NCCL backend, this creates a process group with the trainer.
Args:
init_info: Dictionary containing backend-specific initialization info
"""
if self.weight_transfer_engine is None:
raise RuntimeError(
"Weight transfer not configured. "
"Please set weight_transfer_config to enable weight transfer."
)
# Parse dict into backend-specific typed dataclass
typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
self.weight_transfer_engine.init_transfer_engine(typed_init_info)
def update_weights(self, update_info: dict) -> None:
"""
Batched weight update from the trainer.
Args:
update_info: Dictionary containing backend-specific update info
"""
if self.weight_transfer_engine is None:
raise RuntimeError(
"Weight transfer not configured. "
"Please set weight_transfer_config to enable weight transfer."
)
# Parse dict into backend-specific typed dataclass
typed_update_info = self.weight_transfer_engine.parse_update_info(update_info)
model = self.model_runner.model
if typed_update_info.is_checkpoint_format:
from vllm.model_executor.model_loader.reload import (
finalize_layerwise_reload,
initialize_layerwise_reload,
)
# Use layerwise reload pattern for checkpoint format weights
with torch.device(self.device):
initialize_layerwise_reload(model)
self.weight_transfer_engine.receive_weights(
typed_update_info,
load_weights=model.load_weights,
)
finalize_layerwise_reload(model, self.model_config)
else:
# Weights are already in kernel format, copy directly
def load_weights_direct(
weights: list[tuple[str, torch.Tensor]],
) -> None:
for name, weight in weights:
param = model.get_parameter(name)
param.copy_(weight)
self.weight_transfer_engine.receive_weights(
typed_update_info,
load_weights=load_weights_direct,
)
def shutdown(self) -> None:
# has_kv_transfer_group can be None during interpreter shutdown.
if ensure_kv_transfer_shutdown is not None:
......@@ -939,6 +1013,9 @@ class Worker(WorkerBase):
if self.profiler is not None:
self.profiler.shutdown()
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
weight_transfer_engine.shutdown()
def init_worker_distributed_environment(
vllm_config: VllmConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment