# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import json import logging from dataclasses import dataclass from pathlib import Path from typing import List, Optional import torch import torch.distributed import torch.nn.functional as F from sglang.srt.configs.model_config import ModelConfig from sglang.srt.model_loader import get_model_architecture from sglang.srt.server_args import ServerArgs logger = logging.getLogger(__name__) @dataclass class ExpertLocationMetadata: physical_to_logical_map: torch.Tensor # (layers, num_physical_experts) logical_to_all_physical_map: torch.Tensor # (layers, num_logical_experts, X) logical_to_all_physical_map_num_valid: torch.Tensor # (layers, num_logical_experts) logical_to_rank_dispatch_physical_map: torch.Tensor # (layers, num_logical_experts) # -------------------------------- properties ------------------------------------ @property def num_layers(self) -> int: return self.physical_to_logical_map.shape[0] @property def num_physical_experts(self) -> int: return self.physical_to_logical_map.shape[1] @property def num_local_physical_experts(self) -> int: ans, remainder = divmod(self.num_physical_experts, self.ep_size) assert remainder == 0 return ans @property def num_logical_experts(self) -> int: return self.logical_to_all_physical_map.shape[1] @property def ep_size(self): # TODO change when EP size != world size return torch.distributed.get_world_size() def __post_init__(self): num_layers_0, num_physical_experts_0 = self.physical_to_logical_map.shape num_layers_1, num_logical_experts_0, num_physical_experts_1 = ( self.logical_to_all_physical_map.shape ) num_layers_2, num_logical_experts_1 = ( self.logical_to_all_physical_map_num_valid.shape ) num_layers_3, num_logical_experts_2 = ( self.logical_to_rank_dispatch_physical_map.shape ) assert num_layers_0 == num_layers_1 == num_layers_2 == num_layers_3 assert num_logical_experts_0 == num_logical_experts_1 == num_logical_experts_2 assert num_physical_experts_0 == num_physical_experts_1 # -------------------------------- construction ------------------------------------ @staticmethod def init_trivial(server_args: ServerArgs, model_config: ModelConfig): """Trivial location - logical expert i corresponds to physical expert i""" common = ExpertLocationMetadata._init_common(server_args, model_config) num_physical_experts = common["num_physical_experts"] model_config_for_expert_location = common["model_config_for_expert_location"] num_layers = model_config_for_expert_location.num_layers num_logical_experts = model_config_for_expert_location.num_logical_experts physical_to_logical_map = ( torch.arange(0, num_physical_experts).repeat(num_layers, 1) % num_logical_experts ) return ExpertLocationMetadata.init_by_mapping( server_args, model_config, physical_to_logical_map=physical_to_logical_map, ) @staticmethod def init_by_mapping( server_args: ServerArgs, model_config: ModelConfig, physical_to_logical_map, ): if not isinstance(physical_to_logical_map, torch.Tensor): physical_to_logical_map = torch.tensor(physical_to_logical_map) physical_to_logical_map = physical_to_logical_map.to(server_args.device) common = ExpertLocationMetadata._init_common(server_args, model_config) model_config_for_expert_location = common["model_config_for_expert_location"] logical_to_all_physical_map = _compute_logical_to_all_physical_map( physical_to_logical_map, num_logical_experts=model_config_for_expert_location.num_logical_experts, ) return ExpertLocationMetadata._init_raw( ep_size=common["ep_size"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, ) @staticmethod def init_by_eplb( server_args: ServerArgs, model_config: ModelConfig, logical_count: torch.Tensor ): if not isinstance(logical_count, torch.Tensor): logical_count = torch.tensor(logical_count) if len(logical_count.shape) == 2: logical_count = logical_count.unsqueeze(0) logical_count = logical_count.to(server_args.device) common = ExpertLocationMetadata._init_common(server_args, model_config) model_config_for_expert_location = common["model_config_for_expert_location"] num_physical_experts = common["num_physical_experts"] phase = server_args.disaggregation_mode if phase == "null": phase = "decode" physical_to_logical_map, logical_to_all_physical_map, expert_count = ( deepseek_eplb.rebalance_experts( tokens_per_expert=logical_count, num_physical_experts=num_physical_experts, num_local_physical_experts=num_physical_experts // common["ep_size"], num_groups=model_config_for_expert_location.num_groups, num_nodes=server_args.nnodes, phase=phase, ) ) return ExpertLocationMetadata._init_raw( ep_size=common["ep_size"], physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map, ) @staticmethod def _init_common(server_args: ServerArgs, model_config: ModelConfig): model_config_for_expert_location = ( ModelConfigForExpertLocation.from_model_config(model_config) ) num_physical_experts = ( model_config_for_expert_location.num_logical_experts + server_args.ep_num_redundant_experts ) ep_size = server_args.ep_size assert num_physical_experts % ep_size == 0 num_local_physical_experts = num_physical_experts // ep_size return dict( model_config_for_expert_location=model_config_for_expert_location, num_physical_experts=num_physical_experts, num_local_physical_experts=num_local_physical_experts, ep_size=ep_size, ) @staticmethod def _init_raw( ep_size: int, physical_to_logical_map: torch.Tensor, logical_to_all_physical_map: torch.Tensor, ): _, num_physical_experts = physical_to_logical_map.shape logical_to_all_physical_map_padded = F.pad( logical_to_all_physical_map, (0, num_physical_experts - logical_to_all_physical_map.shape[-1]), value=-1, ) logical_to_all_physical_map_num_valid = torch.count_nonzero( logical_to_all_physical_map != -1, dim=-1 ) return ExpertLocationMetadata( physical_to_logical_map=physical_to_logical_map, logical_to_all_physical_map=logical_to_all_physical_map_padded, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, logical_to_rank_dispatch_physical_map=compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map=logical_to_all_physical_map, logical_to_all_physical_map_num_valid=logical_to_all_physical_map_num_valid, num_gpus=ep_size, num_physical_experts=num_physical_experts, ep_rank=torch.distributed.get_rank(), ), ) # -------------------------------- usage ------------------------------------ def logical_to_all_physical( self, layer_id: int, logical_expert_id: int ) -> List[int]: return [ physical_expert_id for physical_expert_id in self.logical_to_all_physical_map[ layer_id, logical_expert_id ].tolist() if physical_expert_id != -1 ] _global_expert_location_metadata: Optional[ExpertLocationMetadata] = None def get_global_expert_location_metadata(): return _global_expert_location_metadata def set_global_expert_location_metadata(value): global _global_expert_location_metadata assert _global_expert_location_metadata is None _global_expert_location_metadata = value def _compute_logical_to_all_physical_map( physical_to_logical_map: torch.Tensor, num_logical_experts: int ): # This is rarely called, so we use for loops for maximum clarity num_layers, num_physical_experts = physical_to_logical_map.shape logical_to_all_physical_map = [ [[] for _ in range(num_logical_experts)] for _ in range(num_layers) ] for layer_id in range(num_layers): for physical_expert_id in range(num_physical_experts): logical_expert_id = physical_to_logical_map[ layer_id, physical_expert_id ].item() logical_to_all_physical_map[layer_id][logical_expert_id].append( physical_expert_id ) logical_to_all_physical_map = _pad_nested_array( logical_to_all_physical_map, pad_value=-1 ) return torch.tensor( logical_to_all_physical_map, device=physical_to_logical_map.device ) def _pad_nested_array(arr, pad_value): max_len = max(len(inner) for outer in arr for inner in outer) padded = [ [inner + [pad_value] * (max_len - len(inner)) for inner in outer] for outer in arr ] return padded # TODO use more sophisticated approaches def compute_logical_to_rank_dispatch_physical_map( logical_to_all_physical_map: torch.Tensor, logical_to_all_physical_map_num_valid: torch.Tensor, num_gpus: int, num_physical_experts: int, ep_rank: int, base_seed: int = 42, ): device = logical_to_all_physical_map.device num_local_physical_experts = num_physical_experts // num_gpus num_layers, num_logical_experts, _ = logical_to_all_physical_map.shape g = torch.Generator(device=device) g.manual_seed(base_seed + ep_rank) output_shape = (num_layers, num_logical_experts) chosen_index = ( torch.randint( 0, 65536, output_shape, dtype=torch.int32, device=device, generator=g ) % logical_to_all_physical_map_num_valid ) logical_to_rank_dispatch_physical_map = torch.gather( logical_to_all_physical_map, dim=2, index=chosen_index.unsqueeze(-1) ).squeeze(-1) assert logical_to_rank_dispatch_physical_map.shape == output_shape for index in range(logical_to_all_physical_map_num_valid.max().item()): partial_logical_to_all_physical_map = logical_to_all_physical_map[:, :, index] is_valid = partial_logical_to_all_physical_map != -1 is_same_gpu = ( partial_logical_to_all_physical_map // num_local_physical_experts ) == ep_rank logical_to_rank_dispatch_physical_map = torch.where( is_valid & is_same_gpu, partial_logical_to_all_physical_map, logical_to_rank_dispatch_physical_map, ) assert torch.all(logical_to_rank_dispatch_physical_map != -1) return logical_to_rank_dispatch_physical_map @dataclass class ModelConfigForExpertLocation: num_layers: int num_logical_experts: int num_groups: Optional[int] = None @staticmethod def init_dummy(): return ModelConfigForExpertLocation(num_layers=1, num_logical_experts=1) @staticmethod def from_model_config(model_config: ModelConfig): model_class, _ = get_model_architecture(model_config) if hasattr(model_class, "get_model_config_for_expert_location"): return model_class.get_model_config_for_expert_location( model_config.hf_config ) else: return ModelConfigForExpertLocation.init_dummy() def compute_initial_expert_location_metadata( server_args: ServerArgs, model_config: ModelConfig ) -> ExpertLocationMetadata: data = server_args.init_expert_location if data == "trivial": logger.info("init_expert_location from trivial") return ExpertLocationMetadata.init_trivial(server_args, model_config) # TODO unify with the utils function if data.endswith(".pt"): data_dict = torch.load(data, weights_only=True) elif data.endswith(".json"): data_dict = json.loads(Path(data).read_text()) else: data_dict = json.loads(data) if "physical_to_logical_map" in data_dict: logger.info( "init_expert_location from init_by_mapping using ServerArgs.init_expert_location" ) return ExpertLocationMetadata.init_by_mapping( server_args, model_config, **data_dict ) elif "logical_count" in data_dict: logger.info( "init_expert_location from init_by_eplb using ServerArgs.init_expert_location" ) return ExpertLocationMetadata.init_by_eplb( server_args, model_config, logical_count=data_dict["logical_count"] ) else: raise NotImplementedError( f"Unknown init_expert_location format ({list(data_dict.keys())=})" )