Unverified Commit 904655c5 authored by Hank Han's avatar Hank Han Committed by GitHub
Browse files

[2/N] Added the core structure of elastic EP and the eplb algorithm with faulty rank (#10606)


Co-authored-by: default avatarXun Sun <UNIDY2002@outlook.com>
Co-authored-by: default avatarShangming Cai <csmthu@gmail.com>
parent e028af69
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
from sglang.srt.managers.schedule_batch import ServerArgs
from sglang.srt.utils import is_cpu, is_cuda
@dataclass
class ElasticEPState:
active_ranks: Optional[torch.Tensor]
last_active_ranks: Optional[torch.Tensor]
active_ranks_cpu: Optional[torch.Tensor]
def is_active_equal_last(self) -> bool:
return torch.equal(self.active_ranks, self.last_active_ranks)
def sync_active_to_cpu(self):
if self.active_ranks is not None:
self.active_ranks_cpu = self.active_ranks.detach().cpu().clone()
def snapshot_active_to_last(self):
if self.active_ranks is not None:
self.last_active_ranks = self.active_ranks.clone()
class ElasticEPStateManager:
_instance: Optional[ElasticEPState] = None
@classmethod
def instance(cls) -> ElasticEPState:
return cls._instance
@classmethod
def init(cls, server_args: ServerArgs):
if cls._instance is not None:
return cls._instance
if server_args.elastic_ep_backend is not None:
cls._instance = cls._build_state(ep_size=None, device=None)
return cls._instance
@staticmethod
def _select_device() -> torch.device:
if is_cuda():
return torch.device("cuda")
elif is_cpu():
return torch.device("cpu")
else:
raise NotImplementedError("Only CUDA and CPU support elastic ep now.")
@classmethod
def _build_state(
cls, *, ep_size: Optional[int] = None, device: Optional[torch.device] = None
) -> ElasticEPState:
active = cls.healthy_rank_state(ep_size=ep_size, device=device)
return ElasticEPState(
active_ranks=active,
last_active_ranks=active.clone(),
active_ranks_cpu=active.detach().cpu().clone(),
)
@classmethod
def healthy_rank_state(
cls, *, ep_size: Optional[int] = None, device: Optional[torch.device] = None
) -> torch.Tensor:
size = ep_size if ep_size is not None else torch.distributed.get_world_size()
dev = device if device is not None else cls._select_device()
return torch.ones(size, dtype=torch.int32, device=dev)
......@@ -3,7 +3,8 @@ from typing import Optional
import torch
from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec, elasticity_aware
class EplbAlgorithm(Enum):
......@@ -11,6 +12,7 @@ class EplbAlgorithm(Enum):
deepseek_hierarchical = auto()
deepseek_vec = auto()
deepseek_vec_hierarchical = auto()
elasticity_aware = auto()
# TODO may have more algorithm later
......@@ -45,6 +47,21 @@ def rebalance_experts(
enable_hierarchical=algorithm == EplbAlgorithm.deepseek_vec_hierarchical,
)
if algorithm == EplbAlgorithm.elasticity_aware:
return elasticity_aware.rebalance_experts(
weight=tokens_per_expert.sum(dim=0),
num_replicas=num_physical_experts,
num_groups=num_groups,
num_nodes=num_nodes,
num_gpus=num_physical_experts // num_local_physical_experts,
enable_hierarchical=True,
active_ranks=(
ElasticEPStateManager.instance().active_ranks
if ElasticEPStateManager.instance() is not None
else ElasticEPStateManager.healthy_rank_state()
),
)
raise NotImplementedError
......
from typing import Tuple
import torch
from sglang.srt.eplb.eplb_algorithms.deepseek import rebalance_experts_hierarchical
def rebalance_experts(
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
enable_hierarchical: bool,
active_ranks: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all logical experts
num_replicas: number of physical experts, must be a multiple of `num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [layers, num_replicas], the expert index of each replica
logical_to_physical_map: [layers, num_logical_experts, X], the replica indices for each expert
expert_count: [layers, num_logical_experts], number of physical replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
weight = weight.float().cpu()
num_active_ranks = active_ranks.sum().item()
num_local_experts = num_replicas // num_gpus
if num_active_ranks < num_gpus:
# Must fall back to global load-balance policy
# and fix some params
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight,
num_local_experts * num_active_ranks,
1,
1,
num_active_ranks,
)
elif enable_hierarchical:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus
)
maxlogcnt = logcnt.max().item()
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(
num_local_experts * num_active_ranks,
dtype=torch.int64,
device=log2phy.device,
).expand(num_layers, -1),
)
if num_active_ranks < num_gpus:
phy2log_slices = list(
phy2log.view(num_layers, num_active_ranks, -1).unbind(dim=1)
)
active_ranks_list = active_ranks.tolist()
for idx, active_rank in enumerate(active_ranks_list):
if not active_rank:
phy2log_slices.insert(idx, torch.zeros_like(phy2log_slices[0]))
log2phy = torch.where(
log2phy >= idx * num_local_experts,
log2phy + num_local_experts,
log2phy,
)
phy2log = torch.stack(phy2log_slices, dim=1).contiguous().view(num_layers, -1)
return phy2log, log2phy, logcnt
......@@ -4,6 +4,7 @@ import logging
from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import get_is_extend_in_batch
from sglang.srt.layers.moe.token_dispatcher.base import (
......@@ -63,14 +64,6 @@ class MooncakeCombineInput(NamedTuple):
assert isinstance(MooncakeCombineInput, CombineInput)
_ACTIVE_RANKS: Optional[torch.Tensor] = None
def get_ep_active_ranks() -> torch.Tensor:
assert _ACTIVE_RANKS is not None, "_ACTIVE_RANKS is not initialized"
return _ACTIVE_RANKS
class EPBuffer:
_buffer = None
_hidden_size: Optional[int] = None
......@@ -153,12 +146,7 @@ class _MooncakeEPDispatcherImpl:
self.first_execution = True
self.timeout_us = 10000000
global _ACTIVE_RANKS
if _ACTIVE_RANKS is None:
_ACTIVE_RANKS = torch.ones(
(self.num_experts,), dtype=torch.int32, device="cuda"
)
self.active_ranks = _ACTIVE_RANKS
self.active_ranks = ElasticEPStateManager.instance().active_ranks
self.handle = None
......
......@@ -24,7 +24,7 @@ import threading
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
......@@ -51,6 +51,7 @@ from sglang.srt.distributed import (
set_symm_mem_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
from sglang.srt.eplb.eplb_manager import EPLBManager
from sglang.srt.eplb.expert_distribution import (
ExpertDistributionRecorder,
......@@ -379,6 +380,11 @@ class ModelRunner:
)
self.expert_location_updater = ExpertLocationUpdater()
(
ElasticEPStateManager.init(self.server_args)
if self.server_args.elastic_ep_backend
else None
)
# Load the model
self.sampler = Sampler()
self.load_model()
......@@ -956,16 +962,33 @@ class ModelRunner:
new_expert_location_metadata: ExpertLocationMetadata,
update_layer_ids: List[int],
):
self.expert_location_updater.update(
self.model.routed_experts_weights_of_layer,
new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=self.server_args.nnodes,
rank=self.tp_rank,
)
if ElasticEPStateManager.instance() is not None:
# TODO: refactor the weights update when elastic ep
old_expert_location_metadata = get_global_expert_location_metadata()
assert old_expert_location_metadata is not None
old_expert_location_metadata.update(
new_expert_location_metadata,
update_layer_ids=update_layer_ids,
)
self.update_weights_from_disk(
self.server_args.model_path,
self.server_args.load_format,
lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name,
)
else:
self.expert_location_updater.update(
self.model.routed_experts_weights_of_layer,
new_expert_location_metadata,
update_layer_ids=update_layer_ids,
nnodes=self.server_args.nnodes,
rank=self.tp_rank,
)
def update_weights_from_disk(
self, model_path: str, load_format: str
self,
model_path: str,
load_format: str,
weight_name_filter: Optional[Callable[[str], bool]] = None,
) -> tuple[bool, str]:
"""Update engine weights in-place from the disk."""
logger.info(
......@@ -987,6 +1010,11 @@ class ModelRunner:
iter = loader._get_weights_iterator(
DefaultModelLoader.Source.init_new(config, self.model)
)
if weight_name_filter is not None:
iter = (
(name, weight) for name, weight in iter if weight_name_filter(name)
)
return iter
def model_load_weights(model, iter):
......
......@@ -600,6 +600,9 @@ class ServerArgs:
# Handle any other necessary validations.
self._handle_other_validations()
# Handle elastic expert parallelism.
self._handle_elastic_ep()
def _handle_deprecated_args(self):
# handle deprecated tool call parsers
deprecated_tool_call_parsers = {"qwen25": "qwen", "glm45": "glm"}
......@@ -1225,6 +1228,15 @@ class ServerArgs:
if self.enable_eplb:
assert self.ep_size > 1
def _handle_elastic_ep(self):
if self.elastic_ep_backend is not None:
if self.enable_eplb:
if self.eplb_algorithm == "auto":
self.eplb_algorithm = "elasticity_aware"
assert (
self.eplb_algorithm == "elasticity_aware"
), "Elastic EP requires eplb_algorithm to be set to 'auto' or 'elasticity_aware'."
def _handle_expert_distribution_metrics(self):
if self.enable_expert_distribution_metrics and (
self.expert_distribution_recorder_mode is None
......
......@@ -3,6 +3,7 @@ from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import get_rdma_devices_args
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
......@@ -11,166 +12,12 @@ from sglang.test.test_utils import (
popen_launch_server,
)
class TestPureDP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--elastic-ep-backend",
"mooncake",
"--mooncake-ib-device",
"mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"low_latency",
"--chunked-prefill-size",
"512",
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"512",
"--mem-fraction-static",
"0.5",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
class TestHybridDPTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"2",
"--elastic-ep-backend",
"mooncake",
"--mooncake-ib-device",
"mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"low_latency",
"--chunked-prefill-size",
"512",
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"256",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
ib_devices = get_rdma_devices_args()
class TestTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--elastic-ep-backend",
"mooncake",
"--mooncake-ib-device",
"mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"low_latency",
"--chunked-prefill-size",
"512",
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"128",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
extra_args = []
self.assertGreater(metrics["accuracy"], 0.60)
class TestNoGatherdBuffer(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
......@@ -183,16 +30,10 @@ class TestNoGatherdBuffer(CustomTestCase):
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-dense-tp-size",
"1",
"--enable-dp-lm-head",
"--elastic-ep-backend",
"mooncake",
"--mooncake-ib-device",
"mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
ib_devices,
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
......@@ -200,9 +41,12 @@ class TestNoGatherdBuffer(CustomTestCase):
"--chunked-prefill-size",
"512",
"--cuda-graph-max-bs",
"32",
"128",
"--max-running-requests",
"512",
"--mem-fraction-static",
"0.5",
*cls.extra_args,
],
)
......@@ -226,60 +70,73 @@ class TestNoGatherdBuffer(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.60)
class TestTBO(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-dense-tp-size",
"1",
"--elastic-ep-backend",
"mooncake",
"--mooncake-ib-device",
"mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3,mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
"--moe-a2a-backend",
"deepep",
"--deepep-mode",
"low_latency",
"--chunked-prefill-size",
"512",
"--enable-two-batch-overlap",
"--cuda-graph-max-bs",
"128",
"--max-running-requests",
"512",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
class TestPureDP(TestTP):
extra_args = [
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
]
class TestHybridDPTP(TestTP):
extra_args = [
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"2",
]
class TestNoGatherdBuffer(TestTP):
extra_args = [
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-dense-tp-size",
"1",
]
class TestTBO(TestTP):
extra_args = [
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-dense-tp-size",
"1",
"--enable-two-batch-overlap",
]
class TestMooncakeWitchEPLB(TestTP):
extra_args = [
"--tp",
"4",
"--enable-dp-attention",
"--dp",
"4",
"--moe-dense-tp-size",
"1",
"--enable-two-batch-overlap",
"--enable-eplb",
"--ep-num-redundant-experts",
"4",
"--eplb-rebalance-num-iterations",
"50",
"--expert-distribution-recorder-buffer-size",
"50",
"--expert-distribution-recorder-mode",
"stat",
"--ep-dispatch-algorithm",
"static",
]
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment