Unverified Commit 3c08f173 authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[hotfix]: modify create_ep_hierarchical_group and add test (#5032)

* feat: modify create_ep_hierarchical_group args

* test: add ep tests

* fix: remove get_process_group_ranks

* fix: fix src_rank
parent 97cd0cd5
...@@ -150,7 +150,8 @@ class HierarchicalAllToAll(torch.autograd.Function): ...@@ -150,7 +150,8 @@ class HierarchicalAllToAll(torch.autograd.Function):
def forward( def forward(
ctx: Any, ctx: Any,
inputs: Tensor, inputs: Tensor,
groups: Tuple[ProcessGroup], groups: Tuple[ProcessGroup, ProcessGroup],
src_rank: int
) -> Tensor: ) -> Tensor:
""" """
Returns: Returns:
...@@ -159,12 +160,12 @@ class HierarchicalAllToAll(torch.autograd.Function): ...@@ -159,12 +160,12 @@ class HierarchicalAllToAll(torch.autograd.Function):
# TODO: we can reduce comm volume by removing empty capacity # TODO: we can reduce comm volume by removing empty capacity
if ctx is not None: if ctx is not None:
ctx.comm_grps = groups ctx.comm_grps = groups
ctx.src_rank = src_rank
intra_node_group, inter_node_group = groups intra_node_group, inter_node_group = groups
local_world_size = dist.get_world_size(intra_node_group) local_world_size = dist.get_world_size(intra_node_group)
num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1 num_group = dist.get_world_size(inter_node_group) if inter_node_group is not None else 1
world_size = local_world_size * num_group world_size = local_world_size * num_group
src_rank = dist.get_process_group_ranks(intra_node_group)[0]
outputs = torch.empty_like(inputs) outputs = torch.empty_like(inputs)
if dist.get_rank() == src_rank: if dist.get_rank() == src_rank:
...@@ -196,9 +197,10 @@ class HierarchicalAllToAll(torch.autograd.Function): ...@@ -196,9 +197,10 @@ class HierarchicalAllToAll(torch.autograd.Function):
return outputs return outputs
@staticmethod @staticmethod
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None]: def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
return ( return (
HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps), HierarchicalAllToAll.forward(None, grad_outputs[0], ctx.comm_grps, ctx.src_rank),
None,
None, None,
) )
......
...@@ -13,7 +13,7 @@ from colossalai.moe.load_balance import LoadBalancer ...@@ -13,7 +13,7 @@ from colossalai.moe.load_balance import LoadBalancer
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.moe.routers import MoeRouter, get_router_cls from colossalai.moe.routers import MoeRouter, get_router_cls
from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator
from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_size from colossalai.tensor.moe_tensor.api import get_dp_group, get_ep_group, get_ep_group_ranks, get_ep_size
class SparseMLP(nn.Module): class SparseMLP(nn.Module):
...@@ -105,8 +105,11 @@ class SparseMLP(nn.Module): ...@@ -105,8 +105,11 @@ class SparseMLP(nn.Module):
if self.expert_parallel is not None: if self.expert_parallel is not None:
self.ep_group = get_ep_group(self.experts) self.ep_group = get_ep_group(self.experts)
self.ep_size = get_ep_size(self.experts) self.ep_size = get_ep_size(self.experts)
self.ep_hierarchical_group = create_ep_hierarchical_group( self.ep_hierarchical_group = None
self.ep_group) if enable_hierarchical_comm else None if enable_hierarchical_comm:
self.ep_intra_src_rank, *self.ep_hierarchical_group = create_ep_hierarchical_group(
get_ep_group_ranks(self.experts)
)
self.dp_group = get_dp_group(self.experts) self.dp_group = get_dp_group(self.experts)
else: else:
self.ep_group = None self.ep_group = None
...@@ -225,10 +228,10 @@ class SparseMLP(nn.Module): ...@@ -225,10 +228,10 @@ class SparseMLP(nn.Module):
""" """
if not overlap or dist.get_world_size(self.ep_group) == 1: if not overlap or dist.get_world_size(self.ep_group) == 1:
if self.ep_hierarchical_group is not None: if self.ep_hierarchical_group is not None:
expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group) expert_input = HierarchicalAllToAll.apply(dispatch_data, self.ep_hierarchical_group, self.ep_intra_src_rank)
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size) expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.hidden_size)
expert_output = self.experts(expert_input) expert_output = self.experts(expert_input)
expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group) expert_output = HierarchicalAllToAll.apply(expert_output, self.ep_hierarchical_group, self.ep_intra_src_rank)
return expert_output return expert_output
else: else:
expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0] expert_input = AllToAll.apply(dispatch_data, self.ep_group, False)[0]
......
...@@ -179,15 +179,15 @@ def set_moe_args(config: Any, args: dict): ...@@ -179,15 +179,15 @@ def set_moe_args(config: Any, args: dict):
def create_ep_hierarchical_group( def create_ep_hierarchical_group(
ep_group: dist.ProcessGroup, ep_group_ranks: List[int],
nproc_per_node: Optional[int] = None, nproc_per_node: Optional[int] = None,
) -> Tuple[Optional[dist.ProcessGroup], ) -> Tuple[int, dist.ProcessGroup, Optional[dist.ProcessGroup]]:
Optional[dist.ProcessGroup]]:
""" """
e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4 e.g., If ep_group = [1, 2, 5, 6], and nproc_per_node = 4
Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None Then, ep_intra_group = [1, 2] & [5, 6], ep_inter_group = [1, 5] & None
""" """
assert dist.is_initialized(), "Please initialize torch.distributed first." assert dist.is_initialized(), "Please initialize torch.distributed first."
rank = dist.get_rank()
if nproc_per_node is None: if nproc_per_node is None:
nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE") nproc_per_node = os.environ.get("LOCAL_WORLD_SIZE")
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
...@@ -197,24 +197,23 @@ def create_ep_hierarchical_group( ...@@ -197,24 +197,23 @@ def create_ep_hierarchical_group(
"nproc_per_node should be a divisor of world_size." "nproc_per_node should be a divisor of world_size."
num_node = dist.get_world_size() // nproc_per_node num_node = dist.get_world_size() // nproc_per_node
rank = dist.get_rank() intra_src_rank = None
ep_ranks = dist.get_process_group_ranks(ep_group)
ep_intra_node_group = None ep_intra_node_group = None
for i in range(num_node): for i in range(num_node):
ep_intra_ranks = [ ep_intra_ranks = [
i * nproc_per_node + j i * nproc_per_node + j
for j in range(nproc_per_node) for j in range(nproc_per_node)
if j in ep_ranks if j in ep_group_ranks
] ]
group = dist.new_group(ep_intra_ranks) group = dist.new_group(ep_intra_ranks)
if rank in ep_intra_ranks: if rank in ep_intra_ranks:
assert ep_intra_node_group is None assert ep_intra_node_group is None
ep_intra_node_group = group ep_intra_node_group = group
intra_src_rank = ep_intra_ranks[0]
ep_inter_node_group = None ep_inter_node_group = None
ep_inter_ranks = [ ep_inter_ranks = [
ep_ranks[0] + i * nproc_per_node ep_group_ranks[0] + i * nproc_per_node
for i in range(num_node) for i in range(num_node)
] ]
if len(ep_inter_ranks) > 1: if len(ep_inter_ranks) > 1:
...@@ -222,4 +221,4 @@ def create_ep_hierarchical_group( ...@@ -222,4 +221,4 @@ def create_ep_hierarchical_group(
if rank in ep_inter_ranks: if rank in ep_inter_ranks:
ep_inter_node_group = group ep_inter_node_group = group
return ep_intra_node_group, ep_inter_node_group return intra_src_rank, ep_intra_node_group, ep_inter_node_group
from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -124,7 +126,7 @@ def get_dp_rank(tensor: torch.Tensor) -> int: ...@@ -124,7 +126,7 @@ def get_dp_rank(tensor: torch.Tensor) -> int:
return dist.get_rank(get_dp_group(tensor)) return dist.get_rank(get_dp_group(tensor))
def get_ep_group_ranks(tensor: torch.Tensor) -> int: def get_ep_group_ranks(tensor: torch.Tensor) -> List[int]:
""" """
Get the expert parallel group ranks of the given tensor. Get the expert parallel group ranks of the given tensor.
...@@ -137,7 +139,7 @@ def get_ep_group_ranks(tensor: torch.Tensor) -> int: ...@@ -137,7 +139,7 @@ def get_ep_group_ranks(tensor: torch.Tensor) -> int:
return tensor.moe_info.ep_group_ranks return tensor.moe_info.ep_group_ranks
def get_dp_group_ranks(tensor: torch.Tensor) -> int: def get_dp_group_ranks(tensor: torch.Tensor) -> List[int]:
""" """
Get the data parallel group ranks of the given tensor. Get the data parallel group ranks of the given tensor.
......
import os import os
import warnings import warnings
from typing import Dict
import pytest import pytest
import torch import torch
...@@ -123,7 +124,7 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_ ...@@ -123,7 +124,7 @@ def sync_local_from_ep(local_model: SparseMLP, ep_model: SparseMLP, assert_grad_
local_param.data.copy_(all_param.data) local_param.data.copy_(all_param.data)
def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, seed: int): def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size: int, dim: int, config: Dict):
assert batch_size % world_size == 0 assert batch_size % world_size == 0
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=dict(), rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
...@@ -133,8 +134,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size ...@@ -133,8 +134,9 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2) local_model = SparseMLP(num_experts=num_experts, hidden_size=dim, intermediate_size=dim * 2)
MOE_MANAGER.__init__() MOE_MANAGER.__init__()
MOE_MANAGER.setup(parallel="EP") MOE_MANAGER.setup(parallel="EP")
os.environ["LOCAL_WORLD_SIZE"] = str(world_size) enable_hierarchical_comm = config.get("enable_hierarchical_comm", False)
enable_hierarchical_comm = torch.__version__ >= "1.13.1" if enable_hierarchical_comm:
os.environ["LOCAL_WORLD_SIZE"] = str(world_size)
ep_model = SparseMLP( ep_model = SparseMLP(
num_experts=num_experts, num_experts=num_experts,
hidden_size=dim, hidden_size=dim,
...@@ -161,7 +163,6 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size ...@@ -161,7 +163,6 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
tp_grad_handler = MoeGradientHandler(tp_model) tp_grad_handler = MoeGradientHandler(tp_model)
rank = dist.get_rank() rank = dist.get_rank()
torch.cuda.manual_seed(seed)
input_data = torch.randn(batch_size, dim, device=get_current_device()) input_data = torch.randn(batch_size, dim, device=get_current_device())
micro_batch_size = batch_size // world_size micro_batch_size = batch_size // world_size
index = rank * micro_batch_size index = rank * micro_batch_size
...@@ -218,11 +219,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size ...@@ -218,11 +219,14 @@ def run_test(rank: int, world_size: int, port: int, num_experts: int, batch_size
@pytest.mark.parametrize("num_experts", [4, 64]) @pytest.mark.parametrize("num_experts", [4, 64])
@pytest.mark.parametrize("batch_size", [16]) @pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("dim", [64]) @pytest.mark.parametrize("dim", [64])
@pytest.mark.parametrize("seed", [42, 127]) @pytest.mark.parametrize("config", [
{"enable_hierarchical_comm": False},
{"enable_hierarchical_comm": True},
])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, seed: int): def test_moe_ep_tp(num_experts: int, batch_size: int, dim: int, config: Dict):
spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, seed=seed) spawn(run_test, 2, num_experts=num_experts, batch_size=batch_size, dim=dim, config=config)
if __name__ == '__main__': if __name__ == '__main__':
test_moe_ep_tp(num_experts=8, batch_size=32, dim=32, seed=42) test_moe_ep_tp(num_experts=8, batch_size=32, dim=32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment