Unverified Commit 447be242 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix OOM when updating expert locations (#6660)

parent 183d9f96
......@@ -27,21 +27,30 @@ from sglang.srt.managers.expert_location import (
logger = logging.getLogger(__name__)
def update_expert_location(
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
new_expert_location_metadata: ExpertLocationMetadata,
nnodes: int,
rank: int,
):
old_expert_location_metadata = get_global_expert_location_metadata()
_update_expert_weights(
routed_experts_weights_of_layer,
old_expert_location_metadata,
new_expert_location_metadata,
nnodes,
rank,
)
old_expert_location_metadata.update(new_expert_location_metadata)
class ExpertLocationUpdater:
def __init__(self):
self._first_execution = True
def update(
self,
routed_experts_weights_of_layer: Dict[int, List[torch.Tensor]],
new_expert_location_metadata: ExpertLocationMetadata,
nnodes: int,
rank: int,
):
if self._first_execution:
self._first_execution = False
torch.cuda.empty_cache()
old_expert_location_metadata = get_global_expert_location_metadata()
_update_expert_weights(
routed_experts_weights_of_layer,
old_expert_location_metadata,
new_expert_location_metadata,
nnodes,
rank,
)
old_expert_location_metadata.update(new_expert_location_metadata)
def _update_expert_weights(
......
......@@ -73,8 +73,8 @@ from sglang.srt.mem_cache.memory_pool import (
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
from sglang.srt.model_executor import expert_location_updater
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import (
......@@ -267,6 +267,7 @@ class ModelRunner:
if self.server_args.enable_eplb and (not self.is_draft_worker)
else None
)
self.expert_location_updater = ExpertLocationUpdater()
# Load the model
self.sampler = Sampler()
......@@ -600,7 +601,7 @@ class ModelRunner:
def update_expert_location(
self, new_expert_location_metadata: ExpertLocationMetadata
):
expert_location_updater.update_expert_location(
self.expert_location_updater.update(
self.model.routed_experts_weights_of_layer,
new_expert_location_metadata,
nnodes=self.server_args.nnodes,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment