Unverified Commit c1858b7e authored by Aaron Hao's avatar Aaron Hao Committed by GitHub
Browse files

[Feat][RL][1/2] Native Weight Syncing API: NCCL (#31943)


Signed-off-by: default avatarahao-anyscale <ahao@anyscale.com>
Signed-off-by: default avatarAaron Hao <ahao@anyscale.com>
Co-authored-by: default avatarSumanthRH <sumanthrh99@gmail.com>
parent 82914d2a
...@@ -6,6 +6,10 @@ from collections.abc import AsyncGenerator, Iterable, Mapping ...@@ -6,6 +6,10 @@ from collections.abc import AsyncGenerator, Iterable, Mapping
from typing import Any from typing import Any
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.inputs.data import PromptType, StreamingInput from vllm.inputs.data import PromptType, StreamingInput
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import PoolingRequestOutput, RequestOutput from vllm.outputs import PoolingRequestOutput, RequestOutput
...@@ -191,3 +195,13 @@ class EngineClient(ABC): ...@@ -191,3 +195,13 @@ class EngineClient(ABC):
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
"""Get supported tasks""" """Get supported tasks"""
raise NotImplementedError raise NotImplementedError
async def init_weight_transfer_engine(
self, init_request: WeightTransferInitRequest
) -> None:
"""Initialize weight transfer for RL training."""
raise NotImplementedError
async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
"""Batched weight update for RL training."""
raise NotImplementedError
...@@ -34,6 +34,10 @@ from vllm.config.model import ( ...@@ -34,6 +34,10 @@ from vllm.config.model import (
RunnerOption, RunnerOption,
TokenizerMode, TokenizerMode,
) )
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.chat_utils import ( from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam, ChatCompletionMessageParam,
...@@ -360,6 +364,23 @@ class LLM: ...@@ -360,6 +364,23 @@ class LLM:
def get_tokenizer(self) -> TokenizerLike: def get_tokenizer(self) -> TokenizerLike:
return self.llm_engine.get_tokenizer() return self.llm_engine.get_tokenizer()
def get_world_size(self, include_dp: bool = True) -> int:
"""Get the world size from the parallel config.
Args:
include_dp: If True (default), returns the world size including
data parallelism (TP * PP * DP). If False, returns the world
size without data parallelism (TP * PP).
Returns:
The world size (tensor_parallel_size * pipeline_parallel_size),
optionally multiplied by data_parallel_size if include_dp is True.
"""
parallel_config = self.llm_engine.vllm_config.parallel_config
if include_dp:
return parallel_config.world_size_across_dp
return parallel_config.world_size
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
self.input_processor.clear_mm_cache() self.input_processor.clear_mm_cache()
self.llm_engine.reset_mm_cache() self.llm_engine.reset_mm_cache()
...@@ -1903,6 +1924,38 @@ class LLM: ...@@ -1903,6 +1924,38 @@ class LLM:
# its previous requests. # its previous requests.
return sorted(outputs, key=lambda x: int(x.request_id)) return sorted(outputs, key=lambda x: int(x.request_id))
def init_weight_transfer_engine(
self, request: WeightTransferInitRequest | dict
) -> None:
"""
Initialize weight transfer for RL training.
Args:
request: Weight transfer initialization request with backend-specific info
"""
init_info_dict = (
request["init_info"] if isinstance(request, dict) else request.init_info
)
self.llm_engine.collective_rpc(
"init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
)
def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
"""
Update the weights of the model.
Args:
request: Weight update request with backend-specific update info
"""
update_info_dict = (
request["update_info"] if isinstance(request, dict) else request.update_info
)
self.llm_engine.collective_rpc(
"update_weights", kwargs={"update_info": update_info_dict}
)
def __repr__(self) -> str: def __repr__(self) -> str:
"""Return a transformers-style hierarchical view of the model.""" """Return a transformers-style hierarchical view of the model."""
# Cache the result to avoid repeated collective_rpc calls # Cache the result to avoid repeated collective_rpc calls
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from http import HTTPStatus from http import HTTPStatus
from fastapi import APIRouter, FastAPI, Query, Request from fastapi import APIRouter, FastAPI, HTTPException, Query, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
import vllm.envs as envs
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -98,5 +103,63 @@ async def is_paused(raw_request: Request) -> JSONResponse: ...@@ -98,5 +103,63 @@ async def is_paused(raw_request: Request) -> JSONResponse:
return JSONResponse(content={"is_paused": paused}) return JSONResponse(content={"is_paused": paused})
@router.post("/init_weight_transfer_engine")
async def init_weight_transfer_engine(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
init_info = body.get("init_info")
if init_info is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'init_info' in request body",
)
await engine_client(raw_request).init_weight_transfer_engine(
WeightTransferInitRequest(init_info=init_info)
)
return JSONResponse(content={"message": "Weight transfer initialized"})
@router.post("/update_weights")
async def update_weights(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
update_info = body.get("update_info")
if update_info is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'update_info' in request body",
)
await engine_client(raw_request).update_weights(
request=WeightTransferUpdateRequest(update_info=update_info)
)
return JSONResponse(content={"message": "Weights updated"})
@router.get("/get_world_size")
async def get_world_size(
raw_request: Request,
include_dp: bool = Query(True),
):
"""Get the world size from the parallel config.
Args:
include_dp: If True (default), returns the world size including
data parallelism (TP * PP * DP). If False, returns the world
size without data parallelism (TP * PP).
"""
parallel_config = engine_client(raw_request).vllm_config.parallel_config
if include_dp:
world_size = parallel_config.world_size_across_dp
else:
world_size = parallel_config.world_size
return JSONResponse(content={"world_size": world_size})
def attach_router(app: FastAPI): def attach_router(app: FastAPI):
if not envs.VLLM_SERVER_DEV_MODE:
return
app.include_router(router) app.include_router(router)
...@@ -649,6 +649,9 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod): ...@@ -649,6 +649,9 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
) )
# Activations not quantized for marlin. # Activations not quantized for marlin.
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
class Fp8MoEMethod(FusedMoEMethodBase): class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8. """MoE method for FP8.
...@@ -908,6 +911,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -908,6 +911,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
) )
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
def maybe_make_prepare_finalize( def maybe_make_prepare_finalize(
self, self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
...@@ -1241,6 +1247,9 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1241,6 +1247,9 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
layer.w2_input_scale, layer.w2_input_scale,
) )
# Prevent duplicate processing (e.g., during weight reload)
layer._already_called_process_weights_after_loading = True
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):
""" """
......
...@@ -216,6 +216,11 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): ...@@ -216,6 +216,11 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
# Materialize layer tensors onto device # Materialize layer tensors onto device
materialize_layer(layer) materialize_layer(layer)
# Reset FP8 online quantization flag so process_weights_after_loading
# will run again during reload
if hasattr(layer, "_already_called_process_weights_after_loading"):
delattr(layer, "_already_called_process_weights_after_loading")
# Unwrap layerwise loading wrappers # Unwrap layerwise loading wrappers
for param in get_layer_tensors(layer).values(): for param in get_layer_tensors(layer).values():
param.weight_loader = _get_original_loader(param) param.weight_loader = _get_original_loader(param)
......
...@@ -14,6 +14,10 @@ import torch ...@@ -14,6 +14,10 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import TokensPrompt from vllm import TokensPrompt
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
WeightTransferUpdateRequest,
)
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.inputs import PromptType, StreamingInput from vllm.inputs import PromptType, StreamingInput
...@@ -1011,3 +1015,44 @@ class AsyncLLM(EngineClient): ...@@ -1011,3 +1015,44 @@ class AsyncLLM(EngineClient):
@property @property
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
return EngineDeadError() return EngineDeadError()
async def init_weight_transfer_engine(
self, request: WeightTransferInitRequest
) -> None:
"""
Initialize weight transfer for RL training.
Args:
request: Weight transfer initialization request with backend-specific info
"""
from vllm.distributed.weight_transfer.base import (
WeightTransferInitRequest,
)
if isinstance(request, WeightTransferInitRequest):
init_info_dict = request.init_info
else:
raise TypeError(f"Expected WeightTransferInitRequest, got {type(request)}")
await self.collective_rpc(
"init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
)
async def update_weights(self, request: WeightTransferUpdateRequest) -> None:
"""
Batched weight update for RL training.
Args:
request: Weight update request with backend-specific update info
"""
if isinstance(request, WeightTransferUpdateRequest):
update_info_dict = request.update_info
else:
raise TypeError(
f"Expected WeightTransferUpdateRequest, got {type(request)}"
)
await self.collective_rpc(
"update_weights", kwargs={"update_info": update_info_dict}
)
...@@ -33,6 +33,7 @@ from vllm.distributed.parallel_state import ( ...@@ -33,6 +33,7 @@ from vllm.distributed.parallel_state import (
get_pp_group, get_pp_group,
get_tp_group, get_tp_group,
) )
from vllm.distributed.weight_transfer import WeightTransferEngineFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.models.interfaces import is_mixture_of_experts from vllm.model_executor.models.interfaces import is_mixture_of_experts
...@@ -89,6 +90,16 @@ class Worker(WorkerBase): ...@@ -89,6 +90,16 @@ class Worker(WorkerBase):
# Buffers saved before sleep # Buffers saved before sleep
self._sleep_saved_buffers: dict[str, torch.Tensor] = {} self._sleep_saved_buffers: dict[str, torch.Tensor] = {}
# Weight transfer engine (initialized on-demand)
self.weight_transfer_engine = (
WeightTransferEngineFactory.create_engine(
self.vllm_config.weight_transfer_config,
self.vllm_config.parallel_config,
)
if self.vllm_config.weight_transfer_config is not None
else None
)
# Torch/CUDA profiler. Enabled and configured through profiler_config. # Torch/CUDA profiler. Enabled and configured through profiler_config.
self.profiler: Any | None = None self.profiler: Any | None = None
profiler_config = vllm_config.profiler_config profiler_config = vllm_config.profiler_config
...@@ -932,6 +943,69 @@ class Worker(WorkerBase): ...@@ -932,6 +943,69 @@ class Worker(WorkerBase):
tensorizer_config=tensorizer_config, tensorizer_config=tensorizer_config,
) )
def init_weight_transfer_engine(self, init_info: dict) -> None:
"""
Initialize weight transfer mechanism.
For NCCL backend, this creates a process group with the trainer.
Args:
init_info: Dictionary containing backend-specific initialization info
"""
if self.weight_transfer_engine is None:
raise RuntimeError(
"Weight transfer not configured. "
"Please set weight_transfer_config to enable weight transfer."
)
# Parse dict into backend-specific typed dataclass
typed_init_info = self.weight_transfer_engine.parse_init_info(init_info)
self.weight_transfer_engine.init_transfer_engine(typed_init_info)
def update_weights(self, update_info: dict) -> None:
"""
Batched weight update from the trainer.
Args:
update_info: Dictionary containing backend-specific update info
"""
if self.weight_transfer_engine is None:
raise RuntimeError(
"Weight transfer not configured. "
"Please set weight_transfer_config to enable weight transfer."
)
# Parse dict into backend-specific typed dataclass
typed_update_info = self.weight_transfer_engine.parse_update_info(update_info)
model = self.model_runner.model
if typed_update_info.is_checkpoint_format:
from vllm.model_executor.model_loader.reload import (
finalize_layerwise_reload,
initialize_layerwise_reload,
)
# Use layerwise reload pattern for checkpoint format weights
with torch.device(self.device):
initialize_layerwise_reload(model)
self.weight_transfer_engine.receive_weights(
typed_update_info,
load_weights=model.load_weights,
)
finalize_layerwise_reload(model, self.model_config)
else:
# Weights are already in kernel format, copy directly
def load_weights_direct(
weights: list[tuple[str, torch.Tensor]],
) -> None:
for name, weight in weights:
param = model.get_parameter(name)
param.copy_(weight)
self.weight_transfer_engine.receive_weights(
typed_update_info,
load_weights=load_weights_direct,
)
def shutdown(self) -> None: def shutdown(self) -> None:
# has_kv_transfer_group can be None during interpreter shutdown. # has_kv_transfer_group can be None during interpreter shutdown.
if ensure_kv_transfer_shutdown is not None: if ensure_kv_transfer_shutdown is not None:
...@@ -939,6 +1013,9 @@ class Worker(WorkerBase): ...@@ -939,6 +1013,9 @@ class Worker(WorkerBase):
if self.profiler is not None: if self.profiler is not None:
self.profiler.shutdown() self.profiler.shutdown()
if weight_transfer_engine := getattr(self, "weight_transfer_engine", None):
weight_transfer_engine.shutdown()
def init_worker_distributed_environment( def init_worker_distributed_environment(
vllm_config: VllmConfig, vllm_config: VllmConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment