Unverified Commit 447be242 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix OOM when updating expert locations (#6660)

parent 183d9f96
...@@ -27,21 +27,30 @@ from sglang.srt.managers.expert_location import ( ...@@ -27,21 +27,30 @@ from sglang.srt.managers.expert_location import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def update_expert_location( class ExpertLocationUpdater:
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]], def __init__(self):
new_expert_location_metadata: ExpertLocationMetadata, self._first_execution = True
nnodes: int,
rank: int, def update(
): self,
old_expert_location_metadata = get_global_expert_location_metadata() routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
_update_expert_weights( new_expert_location_metadata: ExpertLocationMetadata,
routed_experts_weights_of_layer, nnodes: int,
old_expert_location_metadata, rank: int,
new_expert_location_metadata, ):
nnodes, if self._first_execution:
rank, self._first_execution = False
) torch.cuda.empty_cache()
old_expert_location_metadata.update(new_expert_location_metadata)
old_expert_location_metadata = get_global_expert_location_metadata()
_update_expert_weights(
routed_experts_weights_of_layer,
old_expert_location_metadata,
new_expert_location_metadata,
nnodes,
rank,
)
old_expert_location_metadata.update(new_expert_location_metadata)
def _update_expert_weights( def _update_expert_weights(
......
...@@ -73,8 +73,8 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -73,8 +73,8 @@ from sglang.srt.mem_cache.memory_pool import (
TokenToKVPoolAllocator, TokenToKVPoolAllocator,
) )
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
from sglang.srt.model_executor import expert_location_updater
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import ( from sglang.srt.model_loader.loader import (
...@@ -267,6 +267,7 @@ class ModelRunner: ...@@ -267,6 +267,7 @@ class ModelRunner:
if self.server_args.enable_eplb and (not self.is_draft_worker) if self.server_args.enable_eplb and (not self.is_draft_worker)
else None else None
) )
self.expert_location_updater = ExpertLocationUpdater()
# Load the model # Load the model
self.sampler = Sampler() self.sampler = Sampler()
...@@ -600,7 +601,7 @@ class ModelRunner: ...@@ -600,7 +601,7 @@ class ModelRunner:
def update_expert_location( def update_expert_location(
self, new_expert_location_metadata: ExpertLocationMetadata self, new_expert_location_metadata: ExpertLocationMetadata
): ):
expert_location_updater.update_expert_location( self.expert_location_updater.update(
self.model.routed_experts_weights_of_layer, self.model.routed_experts_weights_of_layer,
new_expert_location_metadata, new_expert_location_metadata,
nnodes=self.server_args.nnodes, nnodes=self.server_args.nnodes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment