Unverified Commit c471d39e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Support loading weights when physical experts are different from logical experts (#6386)

parent d0443275
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from torch.nn import Module from torch.nn import Module
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
try: try:
...@@ -425,6 +426,28 @@ class EPMoE(torch.nn.Module): ...@@ -425,6 +426,28 @@ class EPMoE(torch.nn.Module):
weight_name: str, weight_name: str,
shard_id: str, shard_id: str,
expert_id: int, expert_id: int,
) -> None:
physical_expert_ids = (
get_global_expert_location_metadata().logical_to_all_physical(
self.layer_id, expert_id
)
)
for physical_expert_id in physical_expert_ids:
self._weight_loader_physical(
param=param,
loaded_weight=loaded_weight,
weight_name=weight_name,
shard_id=shard_id,
expert_id=physical_expert_id,
)
def _weight_loader_physical(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
) -> None: ) -> None:
if expert_id < self.start_expert_id or expert_id > self.end_expert_id: if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
return return
......
...@@ -15,7 +15,7 @@ import json ...@@ -15,7 +15,7 @@ import json
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Optional from typing import List, Optional
import torch import torch
import torch.distributed import torch.distributed
...@@ -163,6 +163,19 @@ class ExpertLocationMetadata: ...@@ -163,6 +163,19 @@ class ExpertLocationMetadata:
logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid,
) )
# -------------------------------- usage ------------------------------------
def logical_to_all_physical(
self, layer_id: int, logical_expert_id: int
) -> List[int]:
return [
physical_expert_id
for physical_expert_id in self.logical_to_all_physical_map[
layer_id, logical_expert_id
].tolist()
if physical_expert_id != -1
]
_global_expert_location_metadata: Optional[ExpertLocationMetadata] = None _global_expert_location_metadata: Optional[ExpertLocationMetadata] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment