Unverified Commit 9183c23e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Speed up `update_weights_from_tensor` (#2695)

parent 148254d4
...@@ -426,8 +426,7 @@ class UpdateWeightsFromDistributedReqOutput: ...@@ -426,8 +426,7 @@ class UpdateWeightsFromDistributedReqOutput:
@dataclass @dataclass
class UpdateWeightsFromTensorReqInput: class UpdateWeightsFromTensorReqInput:
name: str serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
tensor: torch.Tensor
@dataclass @dataclass
......
...@@ -30,7 +30,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a ...@@ -30,7 +30,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import broadcast_pyobj, set_random_seed from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -197,7 +197,7 @@ class TpModelWorker: ...@@ -197,7 +197,7 @@ class TpModelWorker:
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
success, message = self.model_runner.update_weights_from_tensor( success, message = self.model_runner.update_weights_from_tensor(
recv_req.name, recv_req.tensor MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
) )
return success, message return success, message
......
...@@ -17,7 +17,7 @@ import gc ...@@ -17,7 +17,7 @@ import gc
import json import json
import logging import logging
import time import time
from typing import Optional from typing import List, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -428,9 +428,9 @@ class ModelRunner: ...@@ -428,9 +428,9 @@ class ModelRunner:
logger.error(error_msg) logger.error(error_msg)
return False, error_msg return False, error_msg
def update_weights_from_tensor(self, name, tensor: torch.Tensor): def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
self.model.load_weights([(name, tensor)]) self.model.load_weights(named_tensors)
return True, "Success" # TODO error handling return True, "Success"
def get_weights_by_name( def get_weights_by_name(
self, name: str, truncate_size: int = 100 self, name: str, truncate_size: int = 100
......
...@@ -27,7 +27,9 @@ import signal ...@@ -27,7 +27,9 @@ import signal
import threading import threading
import time import time
from http import HTTPStatus from http import HTTPStatus
from typing import AsyncIterator, Dict, List, Optional, Union from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
import torch
# Fix a bug of Python threading # Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None) setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
...@@ -78,6 +80,7 @@ from sglang.srt.openai_api.adapter import ( ...@@ -78,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
from sglang.srt.openai_api.protocol import ModelCard, ModelList from sglang.srt.openai_api.protocol import ModelCard, ModelList
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
MultiprocessingSerializer,
add_api_key_middleware, add_api_key_middleware,
add_prometheus_middleware, add_prometheus_middleware,
assert_pkg_version, assert_pkg_version,
...@@ -874,9 +877,11 @@ class Engine: ...@@ -874,9 +877,11 @@ class Engine:
tokenizer_manager.update_weights_from_distributed(obj, None) tokenizer_manager.update_weights_from_distributed(obj, None)
) )
def update_weights_from_tensor(self, name, tensor): def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
"""Update weights from distributed source.""" """Update weights from distributed source."""
obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor) obj = UpdateWeightsFromTensorReqInput(
serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete( return loop.run_until_complete(
tokenizer_manager.update_weights_from_tensor(obj, None) tokenizer_manager.update_weights_from_tensor(obj, None)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import base64 import base64
import dataclasses import dataclasses
import io
import ipaddress import ipaddress
import itertools import itertools
import json import json
...@@ -34,6 +35,7 @@ import warnings ...@@ -34,6 +35,7 @@ import warnings
from functools import lru_cache from functools import lru_cache
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from io import BytesIO from io import BytesIO
from multiprocessing.reduction import ForkingPickler
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
import numpy as np import numpy as np
...@@ -60,7 +62,6 @@ from triton.runtime.cache import ( ...@@ -60,7 +62,6 @@ from triton.runtime.cache import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
show_time_cost = False show_time_cost = False
time_infos = {} time_infos = {}
...@@ -1206,7 +1207,6 @@ def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> ...@@ -1206,7 +1207,6 @@ def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) ->
# https://github.com/pytorch/pytorch/blob/ # https://github.com/pytorch/pytorch/blob/
# c1cd946818442aca8c7f812b16d187ce1586c3bc/ # c1cd946818442aca8c7f812b16d187ce1586c3bc/
# torch/cuda/__init__.py#L831C1-L831C17 # torch/cuda/__init__.py#L831C1-L831C17
import torch.cuda
import torch.version import torch.version
if not torch.cuda._is_compiled(): if not torch.cuda._is_compiled():
...@@ -1335,3 +1335,16 @@ def parse_tool_response(text, tools, **kwargs): ...@@ -1335,3 +1335,16 @@ def parse_tool_response(text, tools, **kwargs):
for call_info in call_info_list for call_info in call_info_list
] ]
return text, call_info_list return text, call_info_list
class MultiprocessingSerializer:
@staticmethod
def serialize(obj):
buf = io.BytesIO()
ForkingPickler(buf).dump(obj)
buf.seek(0)
return buf.read()
@staticmethod
def deserialize(data):
return ForkingPickler.loads(data)
import time
import unittest import unittest
import torch import torch
...@@ -6,27 +7,32 @@ import sglang as sgl ...@@ -6,27 +7,32 @@ import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class TestReleaseGPUOccupation(unittest.TestCase): class TestUpdateWeightsFromTensor(unittest.TestCase):
def test_release_and_resume_occupation(self): def test_update_weights_from_tensor(self):
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
param_name = "model.layers.2.self_attn.k_proj.weight" param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 16)]
def _check_param(expect_values): _check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110])
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
assert torch.allclose(
actual_values, torch.tensor(expect_values), atol=0.001
), f"{actual_values=}"
_check_param([0.0571, -0.0114, 0.0444, 0.0215, -0.0149]) new_tensor = torch.full((16384, 2048), 1.5)
new_tensor = torch.full((3072, 2048), 1.5) time_start = time.time()
engine.update_weights_from_tensor(param_name, new_tensor) engine.update_weights_from_tensor([(x, new_tensor) for x in param_names])
print(f"Time delta: {time.time() - time_start:.03f}")
_check_param([1.5] * 5) for param_name in param_names[:3]:
_check_param(engine, param_name, [1.5] * 5)
engine.shutdown() engine.shutdown()
def _check_param(engine, param_name, expect_values):
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
assert torch.allclose(
actual_values, torch.tensor(expect_values), atol=0.002
), f"{actual_values=}"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment