Commit 01ae2d72 authored by Sengxian's avatar Sengxian
Browse files

Optimize redundancy communication

parent fdbac1df
......@@ -24,6 +24,7 @@ class FMoELinear(nn.Module):
class FMoENaiveGate(nn.Module):
def __init__(self, d_model, num_expert, world_size, top_k=2):
super(FMoENaiveGate, self).__init__()
# print(f"gate: {num_expert * world_size}")
self.gate = nn.Linear(d_model, num_expert * world_size)
self.top_k = top_k
......@@ -69,16 +70,21 @@ class FMoETransformerMLP(nn.Module):
d_model=1024,
d_hidden=4096,
world_size=1,
model_parallel_size=1,
model_parallel_rank=1,
group=None,
activation=torch.nn.functional.gelu,
top_k=2,
pre_lnorm=False,
model_parallel_rank=-1,
):
super(FMoETransformerMLP, self).__init__()
self.num_expert = num_expert
self.d_model = d_model
self.d_hidden = d_hidden
self.world_size = world_size
self.model_parallel_size = model_parallel_size
self.model_parallel_rank = model_parallel_rank
self.group = group
self.activation = activation
self.pre_lnorm = pre_lnorm
self.top_k = top_k
......@@ -86,15 +92,24 @@ class FMoETransformerMLP(nn.Module):
self.htoh4 = FMoELinear(num_expert, d_model, d_hidden)
self.h4toh = FMoELinear(num_expert, d_hidden, d_model)
# print(f"FMoETransformerMLP world_size: {world_size} num_expert: {num_expert}")
self.gate = FMoENaiveGate(d_model, num_expert, world_size, top_k)
self.layer_norm = nn.LayerNorm(d_model)
self.bias = torch.nn.parameter.Parameter(
torch.zeros(d_model, dtype=torch.float32)
)
self.model_parallel_rank = model_parallel_rank
def forward(self, inp):
def forward(self, inp: torch.Tensor):
if self.num_expert != 1:
B: int = inp.shape[1]
local_batch_size = B // self.model_parallel_size
batch_start = local_batch_size * self.model_parallel_rank
batch_end = min(batch_start + local_batch_size, B)
inp = inp[:, batch_start:batch_end, :].contiguous()
# print(inp.shape)
# print(f"mp_rank: {self.model_parallel_rank}, [{batch_start}, {batch_end})")
residual = inp
if self.pre_lnorm:
inp = self.layer_norm(inp)
......@@ -121,4 +136,19 @@ class FMoETransformerMLP(nn.Module):
if not self.pre_lnorm:
output = self.layer_norm(output)
if self.num_expert != 1:
world_size = self.model_parallel_size
if world_size == 1:
return output, self.bias
rank = self.model_parallel_rank
tensor_list = [torch.empty_like(output) for _ in range(world_size)]
tensor_list[rank] = output
torch.distributed.all_gather(tensor_list, output, group=self.group)
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=1).contiguous()
return output, self.bias
from .layers import FMoETransformerMLP
def create_moe_mlp(args):
def create_moe_mlp(args, model_parallel_rank, group):
assert (
args.num_experts % args.model_parallel_size == 0
), "Num experts should be multiple of mp size"
......@@ -10,7 +10,9 @@ def create_moe_mlp(args):
num_experts,
d_model=args.hidden_size,
d_hidden=args.hidden_size * 4,
world_size=args.model_parallel_size,
model_parallel_rank=args.model_parallel_rank,
world_size=args.world_size,
model_parallel_size=args.model_parallel_size,
model_parallel_rank=model_parallel_rank,
group=group,
)
return fmoe
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment