Unverified Commit fd28640d authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Add `update_weights_from_tensor` (#2631)

parent 7863e436
...@@ -220,3 +220,5 @@ work_dirs/ ...@@ -220,3 +220,5 @@ work_dirs/
*.app *.app
compile_commands.json compile_commands.json
*.iml
...@@ -21,6 +21,8 @@ from dataclasses import dataclass ...@@ -21,6 +21,8 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -407,6 +409,18 @@ class UpdateWeightsFromDistributedReqOutput: ...@@ -407,6 +409,18 @@ class UpdateWeightsFromDistributedReqOutput:
message: str message: str
@dataclass
class UpdateWeightsFromTensorReqInput:
name: str
tensor: torch.Tensor
@dataclass
class UpdateWeightsFromTensorReqOutput:
success: bool
message: str
@dataclass @dataclass
class InitWeightsUpdateGroupReqInput: class InitWeightsUpdateGroupReqInput:
# The master address # The master address
......
...@@ -22,7 +22,7 @@ import warnings ...@@ -22,7 +22,7 @@ import warnings
from collections import deque from collections import deque
from concurrent import futures from concurrent import futures
from types import SimpleNamespace from types import SimpleNamespace
from typing import Callable, Dict, List, Optional, Tuple from typing import Dict, List, Optional
import psutil import psutil
import setproctitle import setproctitle
...@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -52,6 +52,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
) )
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
FINISH_ABORT, FINISH_ABORT,
...@@ -478,6 +480,11 @@ class Scheduler: ...@@ -478,6 +480,11 @@ class Scheduler:
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromDistributedReqOutput(success, message) UpdateWeightsFromDistributedReqOutput(success, message)
) )
elif isinstance(recv_req, UpdateWeightsFromTensorReqInput):
success, message = self.update_weights_from_tensor(recv_req)
self.send_to_tokenizer.send_pyobj(
UpdateWeightsFromTensorReqOutput(success, message)
)
elif isinstance(recv_req, GetWeightsByNameReqInput): elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req) parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
...@@ -1458,6 +1465,17 @@ class Scheduler: ...@@ -1458,6 +1465,17 @@ class Scheduler:
logger.error(message) logger.error(message)
return success, message return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
"""Update the online model parameter from tensors."""
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req) parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter return parameter
......
...@@ -59,6 +59,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -59,6 +59,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
) )
from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.metrics.collector import TokenizerMetricsCollector
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -179,6 +181,9 @@ class TokenizerManager: ...@@ -179,6 +181,9 @@ class TokenizerManager:
self.update_weights_from_distributed_communicator = _Communicator( self.update_weights_from_distributed_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.update_weights_from_tensor_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_weights_by_name_communicator = _Communicator( self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
...@@ -515,6 +520,22 @@ class TokenizerManager: ...@@ -515,6 +520,22 @@ class TokenizerManager:
result = (await self.update_weights_from_distributed_communicator(obj))[0] result = (await self.update_weights_from_distributed_communicator(obj))[0]
return result.success, result.message return result.success, result.message
async def update_weights_from_tensor(
self,
obj: UpdateWeightsFromTensorReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
async def get_weights_by_name( async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
): ):
...@@ -708,6 +729,11 @@ class TokenizerManager: ...@@ -708,6 +729,11 @@ class TokenizerManager:
self.server_args.dp_size == 1 self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed" ), "dp_size must be 1 for update weights from distributed"
self.update_weights_from_distributed_communicator.handle_recv(recv_obj) self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, GetWeightsByNameReqOutput): elif isinstance(recv_obj, GetWeightsByNameReqOutput):
self.get_weights_by_name_communicator.handle_recv(recv_obj) self.get_weights_by_name_communicator.handle_recv(recv_obj)
else: else:
......
...@@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -24,6 +24,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -188,6 +189,12 @@ class TpModelWorker: ...@@ -188,6 +189,12 @@ class TpModelWorker:
) )
return success, message return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
success, message = self.model_runner.update_weights_from_tensor(
recv_req.name, recv_req.tensor
)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name( parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size recv_req.name, recv_req.truncate_size
......
...@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -28,6 +28,7 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
...@@ -225,6 +226,10 @@ class TpModelWorkerClient: ...@@ -225,6 +226,10 @@ class TpModelWorkerClient:
success, message = self.worker.update_weights_from_distributed(recv_req) success, message = self.worker.update_weights_from_distributed(recv_req)
return success, message return success, message
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
success, message = self.worker.update_weights_from_tensor(recv_req)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req) return self.worker.get_weights_by_name(recv_req)
......
...@@ -429,6 +429,10 @@ class ModelRunner: ...@@ -429,6 +429,10 @@ class ModelRunner:
logger.error(error_msg) logger.error(error_msg)
return False, error_msg return False, error_msg
def update_weights_from_tensor(self, name, tensor: torch.Tensor):
self.model.load_weights([(name, tensor)])
return True, "Success" # TODO error handling
def get_weights_by_name( def get_weights_by_name(
self, name: str, truncate_size: int = 100 self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
......
...@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -57,6 +57,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput, OpenSessionReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
...@@ -109,6 +110,7 @@ app.add_middleware( ...@@ -109,6 +110,7 @@ app.add_middleware(
tokenizer_manager: TokenizerManager = None tokenizer_manager: TokenizerManager = None
scheduler_info: Dict = None scheduler_info: Dict = None
##### Native API endpoints ##### ##### Native API endpoints #####
...@@ -866,6 +868,14 @@ class Engine: ...@@ -866,6 +868,14 @@ class Engine:
tokenizer_manager.update_weights_from_distributed(obj, None) tokenizer_manager.update_weights_from_distributed(obj, None)
) )
def update_weights_from_tensor(self, name, tensor):
"""Update weights from distributed source."""
obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
tokenizer_manager.update_weights_from_tensor(obj, None)
)
def get_weights_by_name(self, name, truncate_size=100): def get_weights_by_name(self, name, truncate_size=100):
"""Get weights by parameter name.""" """Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
......
...@@ -40,6 +40,7 @@ suites = { ...@@ -40,6 +40,7 @@ suites = {
"test_triton_attention_kernels.py", "test_triton_attention_kernels.py",
"test_triton_attention_backend.py", "test_triton_attention_backend.py",
"test_update_weights_from_disk.py", "test_update_weights_from_disk.py",
"test_update_weights_from_tensor.py",
"test_vision_chunked_prefill.py", "test_vision_chunked_prefill.py",
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_session_control.py", "test_session_control.py",
......
import unittest
import torch
import sglang as sgl
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
class TestReleaseGPUOccupation(unittest.TestCase):
def test_release_and_resume_occupation(self):
engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
param_name = "model.layers.2.self_attn.k_proj.weight"
def _check_param(expect_values):
actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5]
assert torch.allclose(
actual_values, torch.tensor(expect_values), atol=0.001
), f"{actual_values=}"
_check_param([0.0571, -0.0114, 0.0444, 0.0215, -0.0149])
new_tensor = torch.full((3072, 2048), 1.5)
engine.update_weights_from_tensor(param_name, new_tensor)
_check_param([1.5] * 5)
engine.shutdown()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment