Commit dc3db673 authored by Rick Ho's avatar Rick Ho
Browse files

fix replica condition and minor optimizations

parent a8ecd3d7
...@@ -99,12 +99,15 @@ class FMoETransformerMLP(nn.Module): ...@@ -99,12 +99,15 @@ class FMoETransformerMLP(nn.Module):
) )
def forward(self, inp: torch.Tensor): def forward(self, inp: torch.Tensor):
if self.num_expert != 1: original_shape = inp.shape
B: int = inp.shape[1] inp = inp.reshape(-1, self.d_model)
if self.model_parallel_size > 1:
B: int = inp.shape[0]
local_batch_size = B // self.model_parallel_size local_batch_size = B // self.model_parallel_size
batch_start = local_batch_size * self.model_parallel_rank batch_start = local_batch_size * self.model_parallel_rank
batch_end = min(batch_start + local_batch_size, B) batch_end = min(batch_start + local_batch_size, B)
inp = inp[:, batch_start:batch_end, :].contiguous() inp = inp[batch_start:batch_end]
residual = inp residual = inp
if self.pre_lnorm: if self.pre_lnorm:
...@@ -112,9 +115,9 @@ class FMoETransformerMLP(nn.Module): ...@@ -112,9 +115,9 @@ class FMoETransformerMLP(nn.Module):
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
inp = inp.view(-1, self.d_model).repeat_interleave( # to: (BxLxtop_k) x d_model
repeats=self.top_k, dim=0 inp = inp.repeat_interleave(repeats=self.top_k, dim=0)
) # (BxLxtop_k) x d_model
x = _fmoe_full_forward( x = _fmoe_full_forward(
inp, inp,
gate_top_k_idx, gate_top_k_idx,
...@@ -124,26 +127,20 @@ class FMoETransformerMLP(nn.Module): ...@@ -124,26 +127,20 @@ class FMoETransformerMLP(nn.Module):
self.world_size, self.world_size,
) )
core_out = x.view(-1, self.top_k, self.d_model) # (BxL) x top_k x d_model # to: (BxL) x top_k x d_model
core_out = torch.bmm(gate_score, core_out) # (BxL) x 1 x d_model core_out = x.view(-1, self.top_k, self.d_model)
core_out = core_out.view(residual.size(0), residual.size(1), self.d_model) # to: (BxL) x 1 x d_model
core_out = torch.bmm(gate_score, core_out)
output = core_out + residual output = core_out + residual
if not self.pre_lnorm: if not self.pre_lnorm:
output = self.layer_norm(output) output = self.layer_norm(output)
if self.num_expert != 1: if self.model_parallel_size > 1:
world_size = self.model_parallel_size world_size = self.model_parallel_size
if world_size == 1:
return output, self.bias
rank = self.model_parallel_rank
tensor_list = [torch.empty_like(output) for _ in range(world_size)] tensor_list = [torch.empty_like(output) for _ in range(world_size)]
tensor_list[rank] = output
torch.distributed.all_gather(tensor_list, output, group=self.group)
# Note: torch.cat already creates a contiguous tensor. torch.distributed.all_gather(tensor_list, output, group=self.group)
output = torch.cat(tensor_list, dim=1).contiguous() output = torch.cat(tensor_list, dim=1)
return output, self.bias return output.reshape(original_shape), self.bias
...@@ -3,7 +3,7 @@ from .layers import FMoETransformerMLP ...@@ -3,7 +3,7 @@ from .layers import FMoETransformerMLP
def create_moe_mlp(args, model_parallel_rank, group): def create_moe_mlp(args, model_parallel_rank, group):
assert ( assert (
args.num_experts % args.model_parallel_size == 0 args.seq_length * args.batch_size % args.model_parallel_size == 0
), "Num experts should be multiple of mp size" ), "Num experts should be multiple of mp size"
num_experts = args.num_experts // args.model_parallel_size num_experts = args.num_experts // args.model_parallel_size
fmoe = FMoETransformerMLP( fmoe = FMoETransformerMLP(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment