Unverified Commit 983bfcf3 authored by Chayenne's avatar Chayenne Committed by GitHub
Browse files

Online weight updates from torch.distributed (#2279)

parent 28bc60dc
...@@ -27,6 +27,7 @@ concurrency: ...@@ -27,6 +27,7 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
unit-test-frontend: unit-test-frontend:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner runs-on: 1-gpu-runner
...@@ -98,6 +99,11 @@ jobs: ...@@ -98,6 +99,11 @@ jobs:
python3 test_mla_fp8.py python3 test_mla_fp8.py
python3 test_dp_attention.py python3 test_dp_attention.py
- name: Test update weights from distributed
timeout-minutes: 10
run: |
cd test/srt
python3 test_update_weights_from_distributed.py
performance-test-1-gpu-part-1: performance-test-1-gpu-part-1:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -245,6 +251,7 @@ jobs: ...@@ -245,6 +251,7 @@ jobs:
cd test/srt cd test/srt
python3 test_moe_eval_accuracy_large.py python3 test_moe_eval_accuracy_large.py
finish: finish:
needs: [ needs: [
unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu, unit-test-frontend, unit-test-backend-1-gpu, unit-test-backend-2-gpu,
......
...@@ -365,6 +365,41 @@ class UpdateWeightFromDiskReqOutput: ...@@ -365,6 +365,41 @@ class UpdateWeightFromDiskReqOutput:
message: str message: str
@dataclass
class UpdateWeightsFromDistributedReqInput:
name: str
dtype: str
shape: List[int]
@dataclass
class UpdateWeightsFromDistributedReqOutput:
success: bool
message: str
@dataclass
class InitWeightsUpdateGroupReqInput:
# The master address
master_address: str
# The master port
master_port: int
# The rank offset
rank_offset: int
# The world size
world_size: int
# The group name
group_name: str = "weight_update_group"
# The backend
backend: str = "nccl"
@dataclass
class InitWeightsUpdateGroupReqOutput:
success: bool
message: str
@dataclass @dataclass
class GetWeightsByNameReqInput: class GetWeightsByNameReqInput:
name: str name: str
......
...@@ -40,6 +40,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -40,6 +40,8 @@ from sglang.srt.managers.io_struct import (
FlushCacheReq, FlushCacheReq,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
GetWeightsByNameReqOutput, GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
...@@ -47,6 +49,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -47,6 +49,8 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
) )
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
FINISH_ABORT, FINISH_ABORT,
...@@ -516,6 +520,19 @@ class Scheduler: ...@@ -516,6 +520,19 @@ class Scheduler:
elif isinstance(recv_req, GetWeightsByNameReqInput): elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req) parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
success, message = self.init_weights_update_group(recv_req)
self.send_to_tokenizer.send_pyobj(
InitWeightsUpdateGroupReqOutput(success, message)
)
elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput):
success, message = self.update_weights_from_distributed(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromDistributedReqOutput(success, message)
)
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, ProfileReq): elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE: if recv_req == ProfileReq.START_PROFILE:
self.start_profile() self.start_profile()
...@@ -1378,6 +1395,23 @@ class Scheduler: ...@@ -1378,6 +1395,23 @@ class Scheduler:
logger.error(message) logger.error(message)
return success, message return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
"""Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req) parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter return parameter
......
...@@ -48,6 +48,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -48,6 +48,8 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput, GenerateReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
GetWeightsByNameReqOutput, GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
...@@ -55,6 +57,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -55,6 +57,8 @@ from sglang.srt.managers.io_struct import (
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
) )
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -456,6 +460,48 @@ class TokenizerManager: ...@@ -456,6 +460,48 @@ class TokenizerManager:
else: else:
return False, "Another update is in progress. Please try again later." return False, "Another update is in progress. Please try again later."
async def init_weights_update_group(
self,
obj: InitWeightsUpdateGroupReqInput,
request: Optional[fastapi.Request] = None,
) -> bool:
if self.to_create_loop:
self.create_handle_loop()
self.send_to_scheduler.send_pyobj(obj)
self.init_weights_update_group_result = asyncio.Future()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
result = await self.init_weights_update_group_result
return result.success, result.message
async def update_weights_from_distributed(
self,
obj: UpdateWeightsFromDistributedReqInput,
request: Optional[fastapi.Request] = None,
):
if self.to_create_loop:
self.create_handle_loop()
if not self.model_update_lock.locked():
async with self.model_update_lock:
self.send_to_scheduler.send_pyobj(obj)
self.parameter_update_result = asyncio.Future()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
result = await self.parameter_update_result
return result.success, result.message
else:
logger.error(
f"Another parameter update is in progress in tokenizer manager"
)
return (
False,
"Another parameter update is in progress. Please try again later.",
)
async def get_weights_by_name( async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
): ):
...@@ -546,7 +592,9 @@ class TokenizerManager: ...@@ -546,7 +592,9 @@ class TokenizerManager:
BatchEmbeddingOut, BatchEmbeddingOut,
BatchTokenIDOut, BatchTokenIDOut,
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqOutput,
GetWeightsByNameReqOutput, GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj() ] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightFromDiskReqOutput): if isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
...@@ -558,6 +606,12 @@ class TokenizerManager: ...@@ -558,6 +606,12 @@ class TokenizerManager:
if len(self.model_update_tmp) == self.server_args.dp_size: if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp) self.model_update_result.set_result(self.model_update_tmp)
continue continue
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.parameter_update_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, GetWeightsByNameReqOutput): elif isinstance(recv_obj, GetWeightsByNameReqOutput):
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
self.get_weights_by_name_result.set_result(recv_obj) self.get_weights_by_name_result.set_result(recv_obj)
...@@ -568,6 +622,12 @@ class TokenizerManager: ...@@ -568,6 +622,12 @@ class TokenizerManager:
self.get_weights_by_name_tmp self.get_weights_by_name_tmp
) )
continue continue
elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj)
continue
elif isinstance(recv_obj, OpenSessionReqOutput): elif isinstance(recv_obj, OpenSessionReqOutput):
self.session_futures[recv_obj.session_id].set_result( self.session_futures[recv_obj.session_id].set_result(
recv_obj.session_id recv_obj.session_id
......
...@@ -21,7 +21,9 @@ from sglang.srt.configs.model_config import ModelConfig ...@@ -21,7 +21,9 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
) )
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -164,6 +166,25 @@ class TpModelWorker: ...@@ -164,6 +166,25 @@ class TpModelWorker:
) )
return success, message return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.model_runner.init_weights_update_group(
recv_req.master_address,
recv_req.master_port,
recv_req.rank_offset,
recv_req.world_size,
recv_req.group_name,
recv_req.backend,
)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.model_runner.update_weights_from_distributed(
recv_req.name, recv_req.dtype, recv_req.shape
)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name( parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size recv_req.name, recv_req.truncate_size
......
...@@ -25,7 +25,9 @@ import torch ...@@ -25,7 +25,9 @@ import torch
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
) )
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
...@@ -211,6 +213,16 @@ class TpModelWorkerClient: ...@@ -211,6 +213,16 @@ class TpModelWorkerClient:
success, message = self.worker.update_weights_from_disk(recv_req) success, message = self.worker.update_weights_from_disk(recv_req)
return success, message return success, message
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
success, message = self.worker.init_weights_update_group(recv_req)
return success, message
def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput
):
success, message = self.worker.update_weights_from_distributed(recv_req)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req) return self.worker.get_weights_by_name(recv_req)
......
...@@ -20,10 +20,13 @@ import inspect ...@@ -20,10 +20,13 @@ import inspect
import json import json
import logging import logging
import pkgutil import pkgutil
import time
from functools import lru_cache from functools import lru_cache
from typing import Optional, Type from tokenize import tabsize
from typing import Any, Optional, Type, Union
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig from vllm.config import ModelConfig as VllmModelConfig
...@@ -59,6 +62,7 @@ from sglang.srt.utils import ( ...@@ -59,6 +62,7 @@ from sglang.srt.utils import (
crash_on_warnings, crash_on_warnings,
enable_show_time_cost, enable_show_time_cost,
get_available_gpu_memory, get_available_gpu_memory,
init_custom_process_group,
is_hip, is_hip,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
monkey_patch_vllm_model_config, monkey_patch_vllm_model_config,
...@@ -404,6 +408,86 @@ class ModelRunner: ...@@ -404,6 +408,86 @@ class ModelRunner:
logger.info("Update weights end.") logger.info("Update weights end.")
return True, "Succeeded to update model weights." return True, "Succeeded to update model weights."
def init_weights_update_group(
self,
master_address,
master_port,
rank_offset,
world_size,
group_name,
backend="nccl",
):
"""Initialize the Torch process group for model parameter updates.
`_model_update_group` is used in the RLHF workflow, where rank
0 is the actor model in the training engine, and the other ranks are
the inference engine, which is used for rollout.
In the RLHF workflow, the training engine updates the model
weights/parameters online, and broadcasts them to the inference
engine through the `_model_update_group` process group.
"""
assert (
torch.distributed.is_initialized()
), "Default torch process group must be initialized"
assert group_name != "", "Group name cannot be empty"
rank = rank_offset + self.tp_rank
logger.info(
f"init custom process group: master_address={master_address}, master_port={master_port}, "
f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
)
try:
self._model_update_group = init_custom_process_group(
backend=backend,
init_method=f"tcp://{master_address}:{master_port}",
world_size=world_size,
rank=rank,
group_name=group_name,
)
dist.barrier(group=self._model_update_group, device_ids=[rank])
return True, "Succeeded to initialize custom process group."
except Exception as e:
message = f"Failed to initialize custom process group: {e}."
logger.error(message)
return False, message
def update_weights_from_distributed(self, name, dtype, shape):
"""
Update specific parameter in the model weights online
through `_model_update_group` process group.
Args:
name: the name of the parameter to be updated.
dtype: the data type of the parameter to be updated.
shape: the shape of the parameter to be updated.
"""
target_dtype = (
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
)
current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
assert (
self._model_update_group is not None
), "model update group must be initialized"
try:
weights = torch.empty(shape, dtype=target_dtype, device=self.device)
torch.distributed.broadcast(weights, src=0, group=self._model_update_group)
self.model.load_weights([(name, weights)])
return True, f"Succeeded to update parameter {name} online."
except Exception as e:
error_msg = (
f"Failed to update parameter online: {e}. "
f"The full weights of the ModelRunner are partially updated. "
f"Please discard the whole weights."
)
logger.error(error_msg)
return False, error_msg
def get_weights_by_name( def get_weights_by_name(
self, name: str, truncate_size: int = 100 self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
......
...@@ -307,6 +307,8 @@ class LlamaForCausalLM(nn.Module): ...@@ -307,6 +307,8 @@ class LlamaForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"] self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
# Llama 3.2 1B Insturct set tie_word_embeddings to True
# Llama 3.1 8B Insturct set tie_word_embeddings to False
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens self.lm_head = self.model.embed_tokens
else: else:
......
...@@ -53,8 +53,10 @@ from sglang.srt.managers.io_struct import ( ...@@ -53,8 +53,10 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
GenerateReqInput, GenerateReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
OpenSessionReqInput, OpenSessionReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
) )
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
...@@ -80,6 +82,7 @@ from sglang.srt.utils import ( ...@@ -80,6 +82,7 @@ from sglang.srt.utils import (
assert_pkg_version, assert_pkg_version,
configure_logger, configure_logger,
delete_directory, delete_directory,
init_custom_process_group,
is_port_available, is_port_available,
kill_process_tree, kill_process_tree,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
...@@ -211,6 +214,34 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R ...@@ -211,6 +214,34 @@ async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: R
) )
@app.post("/init_weights_update_group")
async def init_weights_update_group(
obj: InitWeightsUpdateGroupReqInput, request: Request
):
"""Initialize the parameter update group."""
success, message = await tokenizer_manager.init_weights_update_group(obj, request)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(content, status_code=200)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.post("/update_weights_from_distributed")
async def update_weights_from_distributed(
obj: UpdateWeightsFromDistributedReqInput, request: Request
):
"""Update model parameter from distributed online."""
success, message = await tokenizer_manager.update_weights_from_distributed(
obj, request
)
content = {"success": success, "message": message}
if success:
return ORJSONResponse(content, status_code=200)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.api_route("/get_weights_by_name", methods=["GET", "POST"]) @app.api_route("/get_weights_by_name", methods=["GET", "POST"])
async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
"""Get model parameter by name.""" """Get model parameter by name."""
...@@ -288,18 +319,6 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -288,18 +319,6 @@ async def generate_request(obj: GenerateReqInput, request: Request):
) )
@time_func_latency
async def get_weights_by_name_request(obj: GetWeightsByNameReqInput, request: Request):
"""Handle a get parameter by name request."""
try:
ret = await tokenizer_manager.get_weights_by_name(obj, request)
return ret
except ValueError as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/encode", methods=["POST", "PUT"]) @app.api_route("/encode", methods=["POST", "PUT"])
@time_func_latency @time_func_latency
async def encode_request(obj: EmbeddingReqInput, request: Request): async def encode_request(obj: EmbeddingReqInput, request: Request):
...@@ -970,7 +989,51 @@ class Engine: ...@@ -970,7 +989,51 @@ class Engine:
async def get_server_info(self): async def get_server_info(self):
return await _get_server_info() return await _get_server_info()
def init_weights_update_group(
self,
master_address: str,
master_port: int,
rank_offset: int,
world_size: int,
group_name: str,
backend: str = "nccl",
):
"""Initialize parameter update group."""
obj = InitWeightsUpdateGroupReqInput(
master_address=master_address,
master_port=master_port,
rank_offset=rank_offset,
world_size=world_size,
group_name=group_name,
backend=backend,
)
async def _init_group():
return await tokenizer_manager.init_weights_update_group(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_init_group())
def update_weights_from_distributed(self, name, dtype, shape):
"""Update weights from distributed source."""
obj = UpdateWeightsFromDistributedReqInput(
name=name,
dtype=dtype,
shape=shape,
)
async def _update_weights():
return await tokenizer_manager.update_weights_from_distributed(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_update_weights())
def get_weights_by_name(self, name, truncate_size=100): def get_weights_by_name(self, name, truncate_size=100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
async def _get_weights():
return await tokenizer_manager.get_weights_by_name(obj, None)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(get_weights_by_name_request(obj, None)) return loop.run_until_complete(_get_weights())
...@@ -39,6 +39,7 @@ import numpy as np ...@@ -39,6 +39,7 @@ import numpy as np
import psutil import psutil
import requests import requests
import torch import torch
import torch.distributed
import torch.distributed as dist import torch.distributed as dist
import triton import triton
import zmq import zmq
...@@ -962,6 +963,78 @@ def get_nvgpu_memory_capacity(): ...@@ -962,6 +963,78 @@ def get_nvgpu_memory_capacity():
) )
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
def init_custom_process_group(
backend=None,
init_method=None,
timeout=None,
world_size=-1,
rank=-1,
store=None,
group_name=None,
pg_options=None,
):
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
_new_process_group_helper,
_world,
default_pg_timeout,
rendezvous,
)
assert (store is None) or (
init_method is None
), "Cannot specify both init_method and store."
if store is not None:
assert world_size > 0, "world_size must be positive if using store"
assert rank >= 0, "rank must be non-negative if using store"
elif init_method is None:
init_method = "env://"
if backend:
backend = Backend(backend)
else:
backend = Backend("undefined")
if timeout is None:
timeout = default_pg_timeout
# backward compatible API
if store is None:
rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
store, rank, world_size = next(rendezvous_iterator)
store.set_timeout(timeout)
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
store = PrefixStore(group_name, store)
# NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
# https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
# We need to determine the appropriate parameter name based on PyTorch version
pg_options_param_name = (
"backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
)
pg, _ = _new_process_group_helper(
world_size,
rank,
[],
backend,
store,
group_name=group_name,
**{pg_options_param_name: pg_options},
timeout=timeout,
)
_world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
return pg
def crash_on_warnings(): def crash_on_warnings():
# Crash on warning if we are running CI tests # Crash on warning if we are running CI tests
return get_bool_env_var("SGLANG_IS_IN_CI") return get_bool_env_var("SGLANG_IS_IN_CI")
......
...@@ -8,47 +8,46 @@ from transformers import AutoModelForCausalLM ...@@ -8,47 +8,46 @@ from transformers import AutoModelForCausalLM
import sglang as sgl import sglang as sgl
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
is_in_ci,
popen_launch_server, popen_launch_server,
) )
from sglang.utils import terminate_process from sglang.utils import terminate_process
def _process_return(ret):
if isinstance(ret, list) and len(ret) == 2:
print(f"running assert_allclose on data parallel")
np.testing.assert_allclose(ret[0], ret[1])
return np.array(ret[0])
return np.array(ret)
class TestGetWeightsByName(unittest.TestCase): class TestGetWeightsByName(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.hf_model = AutoModelForCausalLM.from_pretrained(
cls.model, torch_dtype="bfloat16"
).to("cuda:0")
@classmethod def init_hf_model(self, model_name, tie_word_embeddings):
def tearDownClass(cls): self.hf_model = AutoModelForCausalLM.from_pretrained(
del cls.hf_model model_name, torch_dtype="bfloat16", tie_word_embeddings=tie_word_embeddings
gc.collect() ).to("cuda:0")
torch.cuda.empty_cache()
def init_backend(self, backend, dp, tp): def init_backend(self, backend, dp, tp, model_name):
self.engine = None
self.process = None
self.backend = backend self.backend = backend
self.dp = dp self.dp = dp
self.tp = tp self.tp = tp
if backend == "Engine": if backend == "Engine":
self.engine = sgl.Engine( self.engine = sgl.Engine(
model_path=self.model, model_path=model_name,
random_seed=42, random_seed=42,
tp_size=self.tp, tp_size=tp,
dp_size=self.dp, dp_size=dp,
mem_fraction_static=0.85,
) )
else: else:
self.process = popen_launch_server( self.process = popen_launch_server(
self.model, model_name,
self.base_url, DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=( other_args=(
"--tp-size", "--tp-size",
...@@ -58,12 +57,50 @@ class TestGetWeightsByName(unittest.TestCase): ...@@ -58,12 +57,50 @@ class TestGetWeightsByName(unittest.TestCase):
), ),
) )
def close_engine_and_server(self): def clean_up(self):
if self.engine: del self.hf_model
gc.collect()
torch.cuda.empty_cache()
if self.backend == "Engine":
self.engine.shutdown() self.engine.shutdown()
if self.process: else:
terminate_process(self.process) terminate_process(self.process)
def assert_tie_word_embeddings(self, truncate_size):
print(f"assert_tie_word_embeddings")
if self.backend == "Engine":
backend_ret = _process_return(
self.engine.get_weights_by_name("lm_head.weight", truncate_size)
)
else:
backend_ret = _process_return(
requests.get(
f"{DEFAULT_URL_FOR_TEST}/get_weights_by_name",
json={"name": "lm_head.weight", "truncate_size": truncate_size},
).json()
)
print(f"assert_tie_word_embeddings of hf and backend")
assert np.allclose(
self.hf_model.get_parameter("model.embed_tokens.weight")
.cpu()
.detach()
.float()
.numpy()[:truncate_size],
backend_ret,
)
assert np.allclose(
self.hf_model.get_parameter("lm_head.weight")
.cpu()
.detach()
.float()
.numpy()[:truncate_size],
self.hf_model.get_parameter("model.embed_tokens.weight")
.cpu()
.detach()
.float()
.numpy()[:truncate_size],
)
def assert_weights_all_close(self, param_name, truncate_size): def assert_weights_all_close(self, param_name, truncate_size):
print( print(
f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}" f"param_name: {param_name}, backend: {self.backend}, dp: {self.dp}, tp: {self.tp}"
...@@ -73,34 +110,38 @@ class TestGetWeightsByName(unittest.TestCase): ...@@ -73,34 +110,38 @@ class TestGetWeightsByName(unittest.TestCase):
if self.backend == "Engine": if self.backend == "Engine":
engine_ret = self.engine.get_weights_by_name(param_name, truncate_size) engine_ret = self.engine.get_weights_by_name(param_name, truncate_size)
engine_ret = self._process_return(engine_ret) engine_ret = _process_return(engine_ret)
np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(engine_ret, param_np, rtol=1e-5, atol=1e-5)
if self.backend == "Runtime": if self.backend == "Runtime":
runtime_ret = requests.get( runtime_ret = requests.get(
f"{self.base_url}/get_weights_by_name", f"{DEFAULT_URL_FOR_TEST}/get_weights_by_name",
json={"name": param_name, "truncate_size": truncate_size}, json={"name": param_name, "truncate_size": truncate_size},
).json() ).json()
runtime_ret = self._process_return(runtime_ret) runtime_ret = _process_return(runtime_ret)
np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(runtime_ret, param_np, rtol=1e-5, atol=1e-5)
@staticmethod def test_get_weights_by_name(self):
def _process_return(ret): if is_in_ci():
if isinstance(ret, list) and len(ret) == 2: test_suits = [
print("running assert_allclose on data parallel") ("Engine", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
np.testing.assert_allclose(ret[0], ret[1]) ]
return np.array(ret[0]) else:
return np.array(ret) test_suits = [
("Runtime", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
def test_get_parameters_by_name(self): ("Engine", 1, 1, DEFAULT_MODEL_NAME_FOR_TEST),
test_suits = [("Engine", 1, 1), ("Runtime", 1, 1)] ]
if torch.cuda.device_count() >= 2:
if torch.cuda.device_count() >= 2: test_suits.append(("Engine", 1, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST))
test_suits.append(("Engine", 1, 2)) test_suits.append(("Runtime", 2, 1, DEFAULT_MODEL_NAME_FOR_TEST))
test_suits.append(("Runtime", 2, 1))
if torch.cuda.device_count() >= 4:
if torch.cuda.device_count() >= 4: test_suits.extend(
test_suits.extend([("Engine", 2, 2), ("Runtime", 2, 2)]) [
("Engine", 2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
("Runtime", 2, 2, DEFAULT_MODEL_NAME_FOR_TEST),
]
)
parameters = [ parameters = [
"model.embed_tokens.weight", "model.embed_tokens.weight",
...@@ -117,11 +158,24 @@ class TestGetWeightsByName(unittest.TestCase): ...@@ -117,11 +158,24 @@ class TestGetWeightsByName(unittest.TestCase):
"lm_head.weight", "lm_head.weight",
] ]
truncate_size = 100
for test_suit in test_suits: for test_suit in test_suits:
if test_suit[-1] == DEFAULT_MODEL_NAME_FOR_TEST:
tie_word_embeddings = False
else:
tie_word_embeddings = True
self.init_hf_model(test_suit[-1], tie_word_embeddings)
self.init_backend(*test_suit) self.init_backend(*test_suit)
for param_name in parameters: for param_name in parameters:
self.assert_weights_all_close(param_name, 100) self.assert_weights_all_close(param_name, truncate_size)
self.close_engine_and_server()
if tie_word_embeddings:
self.assert_tie_word_embeddings(truncate_size)
self.clean_up()
if __name__ == "__main__": if __name__ == "__main__":
......
"""Test distributed weight updates.
This test suite simulates a distributed training environment to ensure
correct weight synchronization. On rank 0, the instruct model represents
pre-training weights, and the base model represents post-training weights.
The base model's weights are broadcasted to other ranks using the online
weight update API.
On other ranks, an engine is initialized with the instruct model, and its
parameters are verified against the Hugging Face model. After updating
weights from the distributed system, post-training weights are loaded
and verified again to ensure consistency and accuracy across the
distributed setup.
"""
import gc
import os
import time
import unittest
import numpy as np
import requests
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from transformers import AutoModelForCausalLM
import sglang as sgl
from sglang.srt.utils import init_custom_process_group
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
is_in_ci,
popen_launch_server,
)
from sglang.utils import terminate_process
mp.set_start_method("spawn", force=True)
def verify_params_close(params1, params2, error_msg):
"""Verify if two parameter arrays are close enough."""
try:
assert np.allclose(np.array(params1), np.array(params2)), error_msg
except Exception as e:
print(f"Parameters not close for {error_msg}")
print("Params1:", np.array(params1))
print("Params2:", np.array(params2))
raise e
def verify_params_not_close(params1, params2, error_msg):
"""Verify if two parameter arrays are different enough."""
assert not np.allclose(np.array(params1), np.array(params2)), error_msg
def init_process(
rank,
world_size,
param_queue,
truncate_size,
state_dict_key_to_shape,
tp_size,
model_name,
backend,
checking_parameters,
tie_word_embeddings,
):
torch.cuda.set_device(rank)
if rank == 0:
init_process_hf(
rank,
world_size,
param_queue,
truncate_size,
model_name,
checking_parameters,
tie_word_embeddings,
state_dict_key_to_shape,
)
elif rank in [1, 2]:
init_process_sgl(
rank,
world_size,
param_queue,
truncate_size,
model_name,
checking_parameters,
tie_word_embeddings,
state_dict_key_to_shape,
backend,
tp_size,
)
def init_process_hf(
rank,
world_size,
param_queue,
truncate_size,
model_name,
checking_parameters,
tie_word_embeddings,
state_dict_key_to_shape,
):
# These two environment variables are very important
# to avoid unexpected behaviors of CUDA and NCCL.
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
# Load model and get parameters
hf_instruct_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="bfloat16",
tie_word_embeddings=tie_word_embeddings,
).to("cuda:0")
base_model_name = model_name.replace("-Instruct", "")
hf_base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype="bfloat16",
tie_word_embeddings=tie_word_embeddings,
).to("cuda:0")
hf_instruct_params = []
hf_base_params = []
print(f"get parameter in hf instruct model and base model")
for parameter_name in checking_parameters:
hf_instruct_params.append(
hf_instruct_model.get_parameter(parameter_name)[:truncate_size]
.cpu()
.detach()
.float()
.numpy()
.tolist()
)
hf_base_params.append(
hf_base_model.get_parameter(parameter_name)[:truncate_size]
.cpu()
.detach()
.float()
.numpy()
.tolist()
)
param_queue.put(("hf_instruct_params", hf_instruct_params))
param_queue.put(("hf_base_params", hf_base_params))
# Init weight update group for rank 0 (the training engine in RLHF).
print(f"rank {rank} world_size: {world_size} init custom process group")
group = init_custom_process_group(
backend="nccl",
init_method="tcp://localhost:65500",
world_size=world_size,
rank=rank,
group_name="test_parameter_update_group",
)
dist.barrier(group=group, device_ids=[rank])
torch.cuda.synchronize()
time_begin_broadcast = time.time()
# The last parameter is lm_head.weight, which is tied
# with embed_tokens.weight. Actually, we only need
# to broadcast embed_tokens.weight once.
broadcast_parameters = list(state_dict_key_to_shape.keys())
if tie_word_embeddings:
broadcast_parameters.remove("lm_head.weight")
# Broadcast all the weights from the training
# engine to other ranks (inference engine).
for parameter_name in broadcast_parameters:
torch.distributed.broadcast(
hf_base_model.get_parameter(parameter_name),
src=0,
group=group,
)
torch.cuda.synchronize()
time_end_broadcast = time.time()
# Measure the latency of broadcasting/weights update.
broadcast_time = time_end_broadcast - time_begin_broadcast
print(f"rank {rank} broadcast parameter time: {broadcast_time:.3f}s")
param_queue.put(("broadcast_time", broadcast_time))
# Delete the huggingface models to free up memory.
del hf_instruct_model
del hf_base_model
gc.collect()
torch.cuda.empty_cache()
def init_process_sgl(
rank,
world_size,
param_queue,
truncate_size,
model_name,
checking_parameters,
tie_word_embeddings,
state_dict_key_to_shape,
backend,
tp_size,
):
torch.cuda.set_device(rank)
torch.cuda.synchronize()
base_gpu_id = 1 if rank == 1 else 1 + tp_size
if backend == "Engine":
engine = sgl.Engine(
model_path=model_name,
random_seed=42,
base_gpu_id=base_gpu_id,
tp_size=tp_size,
)
else:
if rank == 1:
url = DEFAULT_URL_FOR_TEST
else:
url = DEFAULT_URL_FOR_TEST.replace("2157", "2159")
process = popen_launch_server(
model_name,
url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--base-gpu-id",
str(base_gpu_id),
"--tp-size",
str(tp_size),
),
)
torch.cuda.synchronize()
if backend == "Engine":
print(f"rank {rank} init engine")
else:
print(f"rank {rank} init server on url: {url}")
# Get weights of instruct model, i.e. pre-training weights.
instruct_params = []
for parameter_name in checking_parameters:
instruct_params.append(
engine.get_weights_by_name(parameter_name, truncate_size)
if backend == "Engine"
else requests.get(
f"{url}/get_weights_by_name",
json={"name": parameter_name, "truncate_size": truncate_size},
).json()
)
param_queue.put((f"sgl_dp_{rank}_instruct_params", instruct_params))
# Init weight update group with the training engine.
if backend == "Engine":
engine.init_weights_update_group(
master_address="localhost",
master_port="65500",
rank_offset=base_gpu_id,
world_size=world_size,
group_name="test_parameter_update_group",
backend="nccl",
)
else:
requests.post(
f"{url}/init_weights_update_group",
json={
"master_address": "localhost",
"master_port": "65500",
"rank_offset": base_gpu_id,
"world_size": world_size,
"group_name": "test_parameter_update_group",
"backend": "nccl",
},
)
torch.cuda.synchronize()
time_begin_update = time.time()
# The last parameter is lm_head.weight, which is tied
# with embed_tokens.weight. Actually, we only need
# to update embed_tokens.weight once.
tie_word_embeddings = (
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
)
update_parameters = list(state_dict_key_to_shape.keys())
if tie_word_embeddings:
update_parameters.remove("lm_head.weight")
# Get weights from the training engine and update the inference engine.
for parameter_name in update_parameters:
if backend == "Engine":
engine.update_weights_from_distributed(
parameter_name,
dtype=torch.bfloat16,
shape=state_dict_key_to_shape[parameter_name],
)
else:
requests.post(
f"{url}/update_weights_from_distributed",
json={
"name": parameter_name,
"dtype": "bfloat16",
"shape": state_dict_key_to_shape[parameter_name],
},
)
torch.cuda.synchronize()
time_end_update = time.time()
# Measure the latency of broadcast/weights update.
update_time = time_end_update - time_begin_update
print(
f"fully update model_name {model_name} rank {rank} parameter from distributed time: {update_time:.3f}s"
)
param_queue.put((f"update_sgl_dp_{rank}_time", update_time))
# Get the weights of post-training model after weights update for correctness check.
base_params = []
for parameter_name in checking_parameters:
if backend == "Engine":
base_params.append(
engine.get_weights_by_name(parameter_name, truncate_size)
)
else:
base_params.append(
requests.get(
f"{url}/get_weights_by_name",
json={
"name": parameter_name,
"truncate_size": truncate_size,
},
).json()
)
param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
# Shutdown the engine or terminate the server process.
if backend == "Engine":
engine.shutdown()
else:
terminate_process(process)
def assert_tied_weights(params_list, message, should_be_tied):
for params in params_list:
if should_be_tied:
assert np.allclose(params[0], params[-1]), message
else:
assert not np.allclose(params[0], params[-1]), message
def test_update_weights_from_distributed(
tp_size,
dp_size,
model_name,
backend,
state_dict_key_to_shape,
truncate_size,
checking_parameters,
):
tie_word_embeddings = (
True if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else False
)
print(
f"Testing model: {model_name} tp_size: {tp_size}, dp_size: {dp_size} backend: {backend}"
)
param_queue = mp.Queue()
results = {}
context = mp.spawn(
init_process,
args=(
1 + tp_size * dp_size,
param_queue,
truncate_size,
state_dict_key_to_shape,
tp_size,
model_name,
backend,
checking_parameters,
tie_word_embeddings,
),
nprocs=1 + dp_size,
join=False,
)
while len(results) < 3 * (1 + dp_size):
try:
key, value = param_queue.get(timeout=5)
results[key] = value
except Exception as e:
if all(not p.is_alive() for p in context.processes):
break
context.join()
if len(results) != 3 * (1 + dp_size):
raise RuntimeError(
f"Expected {3 * (1 + dp_size)} parameters but got {len(results)}"
)
params = {
"hf_instruct": results.get("hf_instruct_params"),
"hf_base": results.get("hf_base_params"),
"sgl_dp_1_instruct": results.get("sgl_dp_1_instruct_params"),
"sgl_dp_1_base": results.get("sgl_dp_1_base_params"),
"broadcast_time": results.get("broadcast_time"),
"update_sgl_dp_1_time": results.get("update_sgl_dp_1_time"),
}
if dp_size == 2:
dp2_params = {
"sgl_dp_2_instruct": results.get("sgl_dp_2_instruct_params"),
"sgl_dp_2_base": results.get("sgl_dp_2_base_params"),
"update_sgl_dp_2_time": results.get("update_sgl_dp_2_time"),
}
assert all(v is not None for v in dp2_params.values())
params.update(dp2_params)
# Check the correctness of weights update by verifying
# the weights of instruct model and base model.
for i in range(len(params["hf_instruct"])):
verify_params_close(
params["hf_instruct"][i],
params["sgl_dp_1_instruct"][i],
f"sgl_dp_1_instruct_params rank {i}",
)
verify_params_close(
params["hf_base"][i],
params["sgl_dp_1_base"][i],
f"sgl_dp_1_base_params rank {i}",
)
verify_params_not_close(
params["hf_instruct"][i],
params["hf_base"][i],
f"hf_instruct_params rank {i}",
)
if dp_size == 2:
verify_params_close(
params["hf_base"][i],
params["sgl_dp_2_base"][i],
f"sgl_dp_2_base_params rank {i}",
)
verify_params_close(
params["hf_instruct"][i],
params["sgl_dp_2_instruct"][i],
f"sgl_dp_2_instruct_params rank {i}",
)
assert len(params["hf_instruct"]) == len(
params["hf_base"]
), "hf_instruct_params and hf_base_params have different lengths"
# Check if the weights of lm_head are tied with embed_tokens.
params_to_check = [
(
params["hf_instruct"],
"lm_head.weight is not tied with embed_tokens.weight",
),
(
params["hf_base"],
"lm_head.weight is not tied with embed_tokens.weight",
),
(
params["sgl_dp_1_instruct"],
"lm_head.weight is not tied with embed_tokens.weight",
),
(
params["sgl_dp_1_base"],
"lm_head.weight is not tied with embed_tokens.weight",
),
]
if dp_size == 2:
params_to_check.extend(
[
(
params["sgl_dp_2_instruct"],
"lm_head.weight is not tied with embed_tokens.weight",
),
(
params["sgl_dp_2_base"],
"lm_head.weight is not tied with embed_tokens.weight",
),
]
)
assert_tied_weights(
[params for params, _ in params_to_check],
(
"lm_head.weight is not tied with embed_tokens.weight"
if tie_word_embeddings
else "lm_head.weight is tied with embed_tokens.weight"
),
tie_word_embeddings,
)
# Time limit for broadcast and update on CI is 3 / 6
# On local H100, it's 1 / 2
time_limit = 3 if model_name == DEFAULT_SMALL_MODEL_NAME_FOR_TEST else 6
assert (
params["broadcast_time"] < time_limit
), f"broadcast_time exceeds time limit {time_limit}s"
assert (
params["update_sgl_dp_1_time"] < time_limit
), f"update_sgl_dp_one_time exceeds time limit {time_limit}s"
if dp_size == 2:
assert (
params["update_sgl_dp_2_time"] < time_limit
), f"update_sgl_dp_two_time exceeds time limit {time_limit}s"
# Delete the context and close the parameter queue.
del context
param_queue.close()
param_queue.join_thread()
gc.collect()
torch.cuda.empty_cache()
class TestUpdateWeightsFromDistributed(unittest.TestCase):
def test_update_weights_from_distributed(self):
assert torch.cuda.device_count() >= 2, "At least 2 GPUs are required"
# test_suits : tp, dp, model_name, backend
if is_in_ci():
test_suits = [
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
]
else:
test_suits = [
(1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
(1, 1, DEFAULT_MODEL_NAME_FOR_TEST, "Sever"),
]
if torch.cuda.device_count() >= 4:
test_suits.extend(
[
(2, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
(1, 2, DEFAULT_MODEL_NAME_FOR_TEST, "Server"),
]
)
if torch.cuda.device_count() >= 5:
test_suits.extend(
[
(2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, "Engine"),
(2, 2, DEFAULT_MODEL_NAME_FOR_TEST, "Server"),
]
)
model_state_dict_shapes = {}
test_models = [test_suit[2] for test_suit in test_suits]
for model_name in test_models:
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="bfloat16"
).to("cuda:0")
state_dict = model.state_dict()
state_dict_keys = list(state_dict.keys())
model_state_dict_shapes[model_name] = {
key: state_dict[key].shape for key in state_dict_keys
}
del model
gc.collect()
torch.cuda.empty_cache()
truncate_size = 10
checking_parameters = [
"model.embed_tokens.weight",
"model.layers.0.input_layernorm.weight",
"model.layers.1.self_attn.q_proj.weight",
"model.layers.2.self_attn.k_proj.weight",
"model.layers.3.self_attn.v_proj.weight",
"model.layers.4.self_attn.o_proj.weight",
"model.layers.5.mlp.gate_proj.weight",
"model.layers.6.mlp.up_proj.weight",
"model.layers.7.mlp.down_proj.weight",
"model.layers.8.post_attention_layernorm.weight",
"model.norm.weight",
"lm_head.weight",
]
for tp_size, dp_size, model_name, backend in test_suits:
test_update_weights_from_distributed(
tp_size,
dp_size,
model_name,
backend,
model_state_dict_shapes[model_name],
truncate_size,
checking_parameters,
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment