Unverified Commit 7a1f7fc5 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[Feature] Hybrid EP and TP (#8590)

parent 51c38163
...@@ -138,6 +138,7 @@ class BenchArgs: ...@@ -138,6 +138,7 @@ class BenchArgs:
def load_model(server_args, port_args, tp_rank): def load_model(server_args, port_args, tp_rank):
suppress_other_loggers() suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
model_config = ModelConfig.from_server_args(server_args) model_config = ModelConfig.from_server_args(server_args)
model_runner = ModelRunner( model_runner = ModelRunner(
...@@ -146,6 +147,8 @@ def load_model(server_args, port_args, tp_rank): ...@@ -146,6 +147,8 @@ def load_model(server_args, port_args, tp_rank):
gpu_id=tp_rank, gpu_id=tp_rank,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
moe_ep_rank=moe_ep_rank,
moe_ep_size=server_args.ep_size,
pp_rank=0, pp_rank=0,
pp_size=1, pp_size=1,
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
......
...@@ -354,6 +354,13 @@ class GroupCoordinator: ...@@ -354,6 +354,13 @@ class GroupCoordinator:
self.cpu_group, 1 << 22, 6 self.cpu_group, 1 << 22, 6
) )
def __repr__(self):
return (
f"ranks={self.ranks} rank={self.rank} local_rank={self.local_rank} use_pynccl={self.use_pynccl} "
f"device_group={self.device_group} cpu_group={self.cpu_group} unique_name={self.unique_name} "
f"world_size={self.world_size} rank_in_group={self.rank_in_group}"
)
@property @property
def first_rank(self): def first_rank(self):
"""Return the global rank of the first process in the group""" """Return the global rank of the first process in the group"""
...@@ -1141,6 +1148,20 @@ def get_tp_group() -> GroupCoordinator: ...@@ -1141,6 +1148,20 @@ def get_tp_group() -> GroupCoordinator:
return _TP return _TP
_MOE_EP: Optional[GroupCoordinator] = None
_MOE_TP: Optional[GroupCoordinator] = None
def get_moe_ep_group() -> GroupCoordinator:
assert _MOE_EP is not None, "expert model parallel group is not initialized"
return _MOE_EP
def get_moe_tp_group() -> GroupCoordinator:
assert _MOE_TP is not None, "expert model parallel group is not initialized"
return _MOE_TP
# kept for backward compatibility # kept for backward compatibility
get_tensor_model_parallel_group = get_tp_group get_tensor_model_parallel_group = get_tp_group
...@@ -1250,6 +1271,7 @@ def init_distributed_environment( ...@@ -1250,6 +1271,7 @@ def init_distributed_environment(
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
backend: Optional[str] = None, backend: Optional[str] = None,
duplicate_tp_group: bool = False, duplicate_tp_group: bool = False,
...@@ -1327,6 +1349,45 @@ def initialize_model_parallel( ...@@ -1327,6 +1349,45 @@ def initialize_model_parallel(
_TP.pynccl_comm.disabled = False _TP.pynccl_comm.disabled = False
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
moe_ep_size = expert_model_parallel_size
moe_tp_size = tensor_model_parallel_size // moe_ep_size
global _MOE_EP
assert _MOE_EP is None, "expert model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
for j in range(moe_tp_size):
st = i * tensor_model_parallel_size + j
en = (i + 1) * tensor_model_parallel_size + j
ranks = list(range(st, en, moe_tp_size))
group_ranks.append(ranks)
_MOE_EP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_custom_allreduce=False,
group_name="moe_ep",
)
global _MOE_TP
assert _MOE_TP is None, "expert model parallel group is already initialized"
group_ranks = []
for i in range(num_tensor_model_parallel_groups):
for j in range(moe_ep_size):
st = i * tensor_model_parallel_size + j * moe_tp_size
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
ranks = list(range(st, en))
group_ranks.append(ranks)
_MOE_TP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
backend,
use_custom_allreduce=False,
group_name="moe_tp",
)
# Build the pipeline model-parallel groups. # Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
global _PP global _PP
...@@ -1347,6 +1408,7 @@ def initialize_model_parallel( ...@@ -1347,6 +1408,7 @@ def initialize_model_parallel(
def ensure_model_parallel_initialized( def ensure_model_parallel_initialized(
tensor_model_parallel_size: int, tensor_model_parallel_size: int,
expert_model_parallel_size: int,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
backend: Optional[str] = None, backend: Optional[str] = None,
) -> None: ) -> None:
...@@ -1357,7 +1419,10 @@ def ensure_model_parallel_initialized( ...@@ -1357,7 +1419,10 @@ def ensure_model_parallel_initialized(
backend = backend or torch.distributed.get_backend(get_world_group().device_group) backend = backend or torch.distributed.get_backend(get_world_group().device_group)
if not model_parallel_is_initialized(): if not model_parallel_is_initialized():
initialize_model_parallel( initialize_model_parallel(
tensor_model_parallel_size, pipeline_model_parallel_size, backend tensor_model_parallel_size,
expert_model_parallel_size,
pipeline_model_parallel_size,
backend,
) )
return return
...@@ -1417,6 +1482,26 @@ def get_tensor_model_parallel_rank(): ...@@ -1417,6 +1482,26 @@ def get_tensor_model_parallel_rank():
return get_tp_group().rank_in_group return get_tp_group().rank_in_group
def get_moe_expert_parallel_world_size():
"""Return world size for the moe expert parallel group."""
return get_moe_ep_group().world_size
def get_moe_expert_parallel_rank():
"""Return my rank for the moe expert parallel group."""
return get_moe_ep_group().rank_in_group
def get_moe_tensor_parallel_world_size():
"""Return world size for the moe tensor parallel group."""
return get_moe_tp_group().world_size
def get_moe_tensor_parallel_rank():
"""Return my rank for the moe tensor parallel group."""
return get_moe_tp_group().rank_in_group
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none and destroy them.""" """Set the groups to none and destroy them."""
global _TP global _TP
......
...@@ -719,6 +719,7 @@ def _launch_subprocesses( ...@@ -719,6 +719,7 @@ def _launch_subprocesses(
+ ((pp_rank % pp_size_per_node) * tp_size_per_node) + ((pp_rank % pp_size_per_node) * tp_size_per_node)
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
) )
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
proc = mp.Process( proc = mp.Process(
target=run_scheduler_process, target=run_scheduler_process,
args=( args=(
...@@ -726,6 +727,7 @@ def _launch_subprocesses( ...@@ -726,6 +727,7 @@ def _launch_subprocesses(
port_args, port_args,
gpu_id, gpu_id,
tp_rank, tp_rank,
moe_ep_rank,
pp_rank, pp_rank,
None, None,
writer, writer,
......
...@@ -135,7 +135,7 @@ class EPMoE(FusedMoE): ...@@ -135,7 +135,7 @@ class EPMoE(FusedMoE):
enable_ep_moe=True, enable_ep_moe=True,
) )
self.start_expert_id = self.ep_rank * self.num_local_experts self.start_expert_id = self.moe_ep_rank * self.num_local_experts
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1 self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
......
...@@ -7,6 +7,10 @@ from typing import List, Optional, Tuple ...@@ -7,6 +7,10 @@ from typing import List, Optional, Tuple
import torch import torch
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_rank,
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
...@@ -88,10 +92,6 @@ class FusedMoE(torch.nn.Module): ...@@ -88,10 +92,6 @@ class FusedMoE(torch.nn.Module):
self.layer_id = layer_id self.layer_id = layer_id
self.top_k = top_k self.top_k = top_k
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.tp_size = (
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
self.tp_rank = get_tensor_model_parallel_rank()
self.num_experts = num_experts self.num_experts = num_experts
self.num_fused_shared_experts = num_fused_shared_experts self.num_fused_shared_experts = num_fused_shared_experts
self.expert_map_cpu = None self.expert_map_cpu = None
...@@ -103,30 +103,27 @@ class FusedMoE(torch.nn.Module): ...@@ -103,30 +103,27 @@ class FusedMoE(torch.nn.Module):
enable_ep_moe = False enable_ep_moe = False
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
self.moe_ep_size = get_moe_expert_parallel_world_size()
self.moe_ep_rank = get_moe_expert_parallel_rank()
self.moe_tp_size = get_moe_tensor_parallel_world_size()
self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size
if enable_ep_moe: if enable_ep_moe:
# TODO(ch-wan): support shared experts fusion # TODO(ch-wan): support shared experts fusion
self.ep_size = self.tp_size
self.ep_rank = self.tp_rank
self.tp_size = 1
self.tp_rank = 0
# Create a tensor of size num_experts filled with -1 # Create a tensor of size num_experts filled with -1
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32) self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
# Create a expert map for the local experts # Create a expert map for the local experts
assert num_experts % self.ep_size == 0
self.num_local_experts = num_experts // self.ep_size
self.expert_map_cpu[ self.expert_map_cpu[
self.ep_rank self.moe_ep_rank
* self.num_local_experts : (self.ep_rank + 1) * self.num_local_experts : (self.moe_ep_rank + 1)
* self.num_local_experts * self.num_local_experts
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu") ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda") self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
else:
self.ep_size = 1
self.ep_rank = 0
self.num_local_experts = num_experts
self.routed_scaling_factor = routed_scaling_factor self.routed_scaling_factor = routed_scaling_factor
assert intermediate_size % self.tp_size == 0 assert intermediate_size % self.moe_tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.activation = activation self.activation = activation
self.apply_router_weight_on_input = apply_router_weight_on_input self.apply_router_weight_on_input = apply_router_weight_on_input
...@@ -437,8 +434,7 @@ class FusedMoE(torch.nn.Module): ...@@ -437,8 +434,7 @@ class FusedMoE(torch.nn.Module):
expert_id: int, expert_id: int,
) -> None: ) -> None:
# TP rank is set to 0 if EP is enabled tp_rank = self.moe_tp_rank
tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
# compressed-tensors checkpoints with packed weights are stored flipped # compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format # TODO (mgoin): check self.quant_method.quant_config.quant_format
...@@ -630,17 +626,17 @@ class FusedMoE(torch.nn.Module): ...@@ -630,17 +626,17 @@ class FusedMoE(torch.nn.Module):
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
**( **(
dict( dict(
tp_rank=self.tp_rank, tp_rank=self.moe_tp_rank,
tp_size=self.tp_size, tp_size=self.moe_tp_size,
ep_rank=self.ep_rank, ep_rank=self.moe_ep_rank,
ep_size=self.ep_size, ep_size=self.moe_ep_size,
) )
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod" if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
else {} else {}
), ),
) )
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states return final_hidden_states
......
...@@ -222,6 +222,7 @@ class DataParallelController: ...@@ -222,6 +222,7 @@ class DataParallelController:
+ ((pp_rank % pp_size_per_node) * tp_size_per_node) + ((pp_rank % pp_size_per_node) * tp_size_per_node)
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
) )
moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
proc = mp.Process( proc = mp.Process(
target=run_scheduler_process, target=run_scheduler_process,
args=( args=(
...@@ -229,6 +230,7 @@ class DataParallelController: ...@@ -229,6 +230,7 @@ class DataParallelController:
rank_port_args, rank_port_args,
gpu_id, gpu_id,
tp_rank, tp_rank,
moe_ep_rank,
pp_rank, pp_rank,
dp_rank, dp_rank,
writer, writer,
......
...@@ -200,15 +200,18 @@ class Scheduler( ...@@ -200,15 +200,18 @@ class Scheduler(
port_args: PortArgs, port_args: PortArgs,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
moe_ep_rank: int,
pp_rank: int, pp_rank: int,
dp_rank: Optional[int], dp_rank: Optional[int],
): ):
# Parse args # Parse args
self.server_args = server_args self.server_args = server_args
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.moe_ep_rank = moe_ep_rank
self.pp_rank = pp_rank self.pp_rank = pp_rank
self.dp_rank = dp_rank self.dp_rank = dp_rank
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.moe_ep_size = server_args.ep_size
self.pp_size = server_args.pp_size self.pp_size = server_args.pp_size
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy self.schedule_policy = server_args.schedule_policy
...@@ -310,6 +313,7 @@ class Scheduler( ...@@ -310,6 +313,7 @@ class Scheduler(
server_args=server_args, server_args=server_args,
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
pp_rank=pp_rank, pp_rank=pp_rank,
dp_rank=dp_rank, dp_rank=dp_rank,
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
...@@ -322,6 +326,7 @@ class Scheduler( ...@@ -322,6 +326,7 @@ class Scheduler(
self.draft_worker = EAGLEWorker( self.draft_worker = EAGLEWorker(
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
server_args=server_args, server_args=server_args,
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
target_worker=self.tp_worker, target_worker=self.tp_worker,
...@@ -2358,6 +2363,7 @@ def run_scheduler_process( ...@@ -2358,6 +2363,7 @@ def run_scheduler_process(
port_args: PortArgs, port_args: PortArgs,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
moe_ep_rank: int,
pp_rank: int, pp_rank: int,
dp_rank: Optional[int], dp_rank: Optional[int],
pipe_writer, pipe_writer,
...@@ -2368,6 +2374,8 @@ def run_scheduler_process( ...@@ -2368,6 +2374,8 @@ def run_scheduler_process(
prefix += f" DP{dp_rank}" prefix += f" DP{dp_rank}"
if server_args.tp_size > 1: if server_args.tp_size > 1:
prefix += f" TP{tp_rank}" prefix += f" TP{tp_rank}"
if server_args.ep_size > 1:
prefix += f" EP{moe_ep_rank}"
if server_args.pp_size > 1: if server_args.pp_size > 1:
prefix += f" PP{pp_rank}" prefix += f" PP{pp_rank}"
...@@ -2391,7 +2399,9 @@ def run_scheduler_process( ...@@ -2391,7 +2399,9 @@ def run_scheduler_process(
# Create a scheduler and run the event loop # Create a scheduler and run the event loop
try: try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank) scheduler = Scheduler(
server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
)
pipe_writer.send( pipe_writer.send(
{ {
"status": "ready", "status": "ready",
......
...@@ -56,6 +56,7 @@ class TpModelWorker: ...@@ -56,6 +56,7 @@ class TpModelWorker:
server_args: ServerArgs, server_args: ServerArgs,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
moe_ep_rank: int,
pp_rank: int, pp_rank: int,
dp_rank: Optional[int], dp_rank: Optional[int],
nccl_port: int, nccl_port: int,
...@@ -66,6 +67,7 @@ class TpModelWorker: ...@@ -66,6 +67,7 @@ class TpModelWorker:
# Parse args # Parse args
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.moe_ep_rank = moe_ep_rank
self.pp_rank = pp_rank self.pp_rank = pp_rank
# Init model and tokenizer # Init model and tokenizer
...@@ -85,6 +87,8 @@ class TpModelWorker: ...@@ -85,6 +87,8 @@ class TpModelWorker:
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
tp_size=server_args.tp_size, tp_size=server_args.tp_size,
moe_ep_rank=moe_ep_rank,
moe_ep_size=server_args.ep_size,
pp_rank=pp_rank, pp_rank=pp_rank,
pp_size=server_args.pp_size, pp_size=server_args.pp_size,
nccl_port=nccl_port, nccl_port=nccl_port,
......
...@@ -58,13 +58,14 @@ class TpModelWorkerClient: ...@@ -58,13 +58,14 @@ class TpModelWorkerClient:
server_args: ServerArgs, server_args: ServerArgs,
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
moe_ep_rank: int,
pp_rank: int, pp_rank: int,
dp_rank: Optional[int], dp_rank: Optional[int],
nccl_port: int, nccl_port: int,
): ):
# Load the model # Load the model
self.worker = TpModelWorker( self.worker = TpModelWorker(
server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port server_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank, nccl_port
) )
self.max_running_requests = self.worker.max_running_requests self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device self.device = self.worker.device
......
...@@ -157,6 +157,8 @@ class ModelRunner: ...@@ -157,6 +157,8 @@ class ModelRunner:
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
tp_size: int, tp_size: int,
moe_ep_rank: int,
moe_ep_size: int,
pp_rank: int, pp_rank: int,
pp_size: int, pp_size: int,
nccl_port: int, nccl_port: int,
...@@ -175,6 +177,8 @@ class ModelRunner: ...@@ -175,6 +177,8 @@ class ModelRunner:
logger.addFilter(RankZeroFilter(tp_rank == 0)) logger.addFilter(RankZeroFilter(tp_rank == 0))
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = tp_size self.tp_size = tp_size
self.moe_ep_rank = moe_ep_rank
self.moe_ep_size = moe_ep_size
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
self.pp_rank = pp_rank self.pp_rank = pp_rank
self.pp_size = pp_size self.pp_size = pp_size
...@@ -549,6 +553,7 @@ class ModelRunner: ...@@ -549,6 +553,7 @@ class ModelRunner:
initialize_model_parallel( initialize_model_parallel(
tensor_model_parallel_size=self.tp_size, tensor_model_parallel_size=self.tp_size,
pipeline_model_parallel_size=self.pp_size, pipeline_model_parallel_size=self.pp_size,
expert_model_parallel_size=self.moe_ep_size,
duplicate_tp_group=self.server_args.enable_pdmux, duplicate_tp_group=self.server_args.enable_pdmux,
) )
initialize_dp_attention( initialize_dp_attention(
......
...@@ -270,14 +270,6 @@ class ServerArgs: ...@@ -270,14 +270,6 @@ class ServerArgs:
sm_group_num: int = 3 sm_group_num: int = 3
def __post_init__(self): def __post_init__(self):
# Expert parallelism
# We put it here first due to some internal ckpt conversation issues.
if self.enable_ep_moe:
self.ep_size = self.tp_size
logger.warning(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
# Set missing default values # Set missing default values
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
...@@ -1335,6 +1327,7 @@ class ServerArgs: ...@@ -1335,6 +1327,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--expert-parallel-size", "--expert-parallel-size",
"--ep-size", "--ep-size",
"--ep",
type=int, type=int,
default=ServerArgs.ep_size, default=ServerArgs.ep_size,
help="The expert parallelism size.", help="The expert parallelism size.",
......
...@@ -73,6 +73,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -73,6 +73,7 @@ class EAGLEWorker(TpModelWorker):
gpu_id: int, gpu_id: int,
tp_rank: int, tp_rank: int,
dp_rank: Optional[int], dp_rank: Optional[int],
moe_ep_rank: int,
nccl_port: int, nccl_port: int,
target_worker: TpModelWorker, target_worker: TpModelWorker,
): ):
...@@ -127,6 +128,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -127,6 +128,7 @@ class EAGLEWorker(TpModelWorker):
tp_rank=tp_rank, tp_rank=tp_rank,
pp_rank=0, # FIXME pp_rank=0, # FIXME
dp_rank=dp_rank, dp_rank=dp_rank,
moe_ep_rank=moe_ep_rank,
nccl_port=nccl_port, nccl_port=nccl_port,
is_draft_worker=True, is_draft_worker=True,
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment