Unverified Commit 54a46a26 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Remove `tp_worker.worker` (#11548)

parent 7c94eaee
......@@ -468,9 +468,7 @@ class Scheduler(
# Hybrid memory pool
self.is_hybrid = self.tp_worker.is_hybrid
self.is_hybrid_gdn = (
self.tp_worker.worker.model_runner.hybrid_gdn_config is not None
)
self.is_hybrid_gdn = self.tp_worker.model_runner.hybrid_gdn_config is not None
if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size
......@@ -1882,7 +1880,7 @@ class Scheduler(
chunked_req_to_exclude.add(self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
# chunked request keeps its rid but will get a new req_pool_idx
if self.tp_worker.worker.model_runner.mambaish_config is not None:
if self.tp_worker.model_runner.mambaish_config is not None:
self.req_to_token_pool.free(
self.chunked_req.req_pool_idx, free_mamba_cache=False
)
......@@ -2686,9 +2684,7 @@ class Scheduler(
ret = vars(get_global_server_args())
ret["last_gen_throughput"] = self.last_gen_throughput
ret["memory_usage"] = {
"weight": round(
self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
),
"weight": round(self.tp_worker.model_runner.weight_load_mem_usage, 2),
"kvcache": round(
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
),
......@@ -2696,7 +2692,7 @@ class Scheduler(
}
ret["memory_usage"]["graph"] = round(
self.tp_worker.worker.model_runner.graph_mem_usage, 2
self.tp_worker.model_runner.graph_mem_usage, 2
)
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
......
from __future__ import annotations
import logging
from typing import Tuple
from typing import TYPE_CHECKING, Tuple
import torch
......@@ -23,6 +25,9 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqOutput,
)
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import Scheduler
logger = logging.getLogger(__name__)
......@@ -79,7 +84,9 @@ class SchedulerUpdateWeightsMixin:
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
def release_memory_occupation(
self: Scheduler, recv_req: ReleaseMemoryOccupationReqInput
):
tags = recv_req.tags
if tags is None or len(tags) == 0:
......@@ -94,14 +101,16 @@ class SchedulerUpdateWeightsMixin:
if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
self.tp_worker.model_runner.model
)
torch.distributed.barrier(self.tp_cpu_group)
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
def resume_memory_occupation(
self: Scheduler, recv_req: ResumeMemoryOccupationReqInput
):
tags = recv_req.tags
if tags is None or len(tags) == 0:
......@@ -114,7 +123,7 @@ class SchedulerUpdateWeightsMixin:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
torch.distributed.barrier(self.tp_cpu_group)
_import_static_state(
self.tp_worker.worker.model_runner.model,
self.tp_worker.model_runner.model,
self.stashed_model_static_state,
)
del self.stashed_model_static_state
......@@ -124,24 +133,20 @@ class SchedulerUpdateWeightsMixin:
return ResumeMemoryOccupationReqOutput()
def save_remote_model(self, params):
def save_remote_model(self: Scheduler, params):
url = params["url"]
worker = self.tp_worker.worker
worker.model_runner.save_remote_model(url)
self.tp_worker.model_runner.save_remote_model(url)
if self.draft_worker is not None:
draft_url = params.get("draft_url", None)
assert (
draft_url is not None
), "draft_url must be provided when draft model is enabled"
draft_worker = self.draft_worker.worker
draft_worker.model_runner.save_remote_model(draft_url)
def save_sharded_model(self, params):
worker = self.tp_worker.worker
self.draft_worker.model_runner.save_remote_model(draft_url)
worker.model_runner.save_sharded_model(
def save_sharded_model(self: Scheduler, params):
self.tp_worker.model_runner.save_sharded_model(
path=params["path"],
pattern=params["pattern"],
max_size=params["max_size"],
......
......@@ -168,9 +168,6 @@ class TpModelWorker:
)[0]
set_random_seed(self.random_seed)
# A reference make this class has the same member as TpModelWorkerClient
self.worker = self
self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment