Unverified Commit 54a46a26 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Remove `tp_worker.worker` (#11548)

parent 7c94eaee
...@@ -468,9 +468,7 @@ class Scheduler( ...@@ -468,9 +468,7 @@ class Scheduler(
# Hybrid memory pool # Hybrid memory pool
self.is_hybrid = self.tp_worker.is_hybrid self.is_hybrid = self.tp_worker.is_hybrid
self.is_hybrid_gdn = ( self.is_hybrid_gdn = self.tp_worker.model_runner.hybrid_gdn_config is not None
self.tp_worker.worker.model_runner.hybrid_gdn_config is not None
)
if self.is_hybrid: if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size self.sliding_window_size = self.tp_worker.sliding_window_size
...@@ -1882,7 +1880,7 @@ class Scheduler( ...@@ -1882,7 +1880,7 @@ class Scheduler(
chunked_req_to_exclude.add(self.chunked_req) chunked_req_to_exclude.add(self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True) self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
# chunked request keeps its rid but will get a new req_pool_idx # chunked request keeps its rid but will get a new req_pool_idx
if self.tp_worker.worker.model_runner.mambaish_config is not None: if self.tp_worker.model_runner.mambaish_config is not None:
self.req_to_token_pool.free( self.req_to_token_pool.free(
self.chunked_req.req_pool_idx, free_mamba_cache=False self.chunked_req.req_pool_idx, free_mamba_cache=False
) )
...@@ -2686,9 +2684,7 @@ class Scheduler( ...@@ -2686,9 +2684,7 @@ class Scheduler(
ret = vars(get_global_server_args()) ret = vars(get_global_server_args())
ret["last_gen_throughput"] = self.last_gen_throughput ret["last_gen_throughput"] = self.last_gen_throughput
ret["memory_usage"] = { ret["memory_usage"] = {
"weight": round( "weight": round(self.tp_worker.model_runner.weight_load_mem_usage, 2),
self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
),
"kvcache": round( "kvcache": round(
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2 self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
), ),
...@@ -2696,7 +2692,7 @@ class Scheduler( ...@@ -2696,7 +2692,7 @@ class Scheduler(
} }
ret["memory_usage"]["graph"] = round( ret["memory_usage"]["graph"] = round(
self.tp_worker.worker.model_runner.graph_mem_usage, 2 self.tp_worker.model_runner.graph_mem_usage, 2
) )
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0: if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
......
from __future__ import annotations
import logging import logging
from typing import Tuple from typing import TYPE_CHECKING, Tuple
import torch import torch
...@@ -23,6 +25,9 @@ from sglang.srt.managers.io_struct import ( ...@@ -23,6 +25,9 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqOutput, UpdateWeightsFromTensorReqOutput,
) )
if TYPE_CHECKING:
from sglang.srt.managers.scheduler import Scheduler
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -79,7 +84,9 @@ class SchedulerUpdateWeightsMixin: ...@@ -79,7 +84,9 @@ class SchedulerUpdateWeightsMixin:
parameter = self.tp_worker.get_weights_by_name(recv_req) parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter) return GetWeightsByNameReqOutput(parameter)
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput): def release_memory_occupation(
self: Scheduler, recv_req: ReleaseMemoryOccupationReqInput
):
tags = recv_req.tags tags = recv_req.tags
if tags is None or len(tags) == 0: if tags is None or len(tags) == 0:
...@@ -94,14 +101,16 @@ class SchedulerUpdateWeightsMixin: ...@@ -94,14 +101,16 @@ class SchedulerUpdateWeightsMixin:
if GPU_MEMORY_TYPE_WEIGHTS in tags: if GPU_MEMORY_TYPE_WEIGHTS in tags:
self.stashed_model_static_state = _export_static_state( self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model self.tp_worker.model_runner.model
) )
torch.distributed.barrier(self.tp_cpu_group) torch.distributed.barrier(self.tp_cpu_group)
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS) self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
return ReleaseMemoryOccupationReqOutput() return ReleaseMemoryOccupationReqOutput()
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput): def resume_memory_occupation(
self: Scheduler, recv_req: ResumeMemoryOccupationReqInput
):
tags = recv_req.tags tags = recv_req.tags
if tags is None or len(tags) == 0: if tags is None or len(tags) == 0:
...@@ -114,7 +123,7 @@ class SchedulerUpdateWeightsMixin: ...@@ -114,7 +123,7 @@ class SchedulerUpdateWeightsMixin:
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS) self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
torch.distributed.barrier(self.tp_cpu_group) torch.distributed.barrier(self.tp_cpu_group)
_import_static_state( _import_static_state(
self.tp_worker.worker.model_runner.model, self.tp_worker.model_runner.model,
self.stashed_model_static_state, self.stashed_model_static_state,
) )
del self.stashed_model_static_state del self.stashed_model_static_state
...@@ -124,24 +133,20 @@ class SchedulerUpdateWeightsMixin: ...@@ -124,24 +133,20 @@ class SchedulerUpdateWeightsMixin:
return ResumeMemoryOccupationReqOutput() return ResumeMemoryOccupationReqOutput()
def save_remote_model(self, params): def save_remote_model(self: Scheduler, params):
url = params["url"] url = params["url"]
worker = self.tp_worker.worker self.tp_worker.model_runner.save_remote_model(url)
worker.model_runner.save_remote_model(url)
if self.draft_worker is not None: if self.draft_worker is not None:
draft_url = params.get("draft_url", None) draft_url = params.get("draft_url", None)
assert ( assert (
draft_url is not None draft_url is not None
), "draft_url must be provided when draft model is enabled" ), "draft_url must be provided when draft model is enabled"
draft_worker = self.draft_worker.worker self.draft_worker.model_runner.save_remote_model(draft_url)
draft_worker.model_runner.save_remote_model(draft_url)
def save_sharded_model(self, params):
worker = self.tp_worker.worker
worker.model_runner.save_sharded_model( def save_sharded_model(self: Scheduler, params):
self.tp_worker.model_runner.save_sharded_model(
path=params["path"], path=params["path"],
pattern=params["pattern"], pattern=params["pattern"],
max_size=params["max_size"], max_size=params["max_size"],
......
...@@ -168,9 +168,6 @@ class TpModelWorker: ...@@ -168,9 +168,6 @@ class TpModelWorker:
)[0] )[0]
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
# A reference make this class has the same member as TpModelWorkerClient
self.worker = self
self.hicache_layer_transfer_counter = None self.hicache_layer_transfer_counter = None
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter): def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment