Commit 56c1bd63 authored by Rick Ho's avatar Rick Ho
Browse files

fix no grad after all-gather bug

parent d83234b0
...@@ -143,3 +143,19 @@ class MOEGather(Function): ...@@ -143,3 +143,19 @@ class MOEGather(Function):
else: else:
global_grad_out_buf = grad_out_buf global_grad_out_buf = grad_out_buf
return global_grad_out_buf, None, None, None, None, None return global_grad_out_buf, None, None, None, None, None
class AllGather(Function):
@staticmethod
def forward(ctx, inp, rank, world_size, group):
tensor_list = [torch.empty_like(inp) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, inp, group=group)
torch.cuda.synchronize()
output = torch.cat(tensor_list, dim=0)
ctx.args = rank, inp.shape[0]
return output
@staticmethod
def backward(ctx, grad_out):
rank, dim0 = ctx.args
return grad_out[rank * dim0:(rank + 1) * dim0], None, None, None
...@@ -69,8 +69,6 @@ class FMoETransformerMLP(nn.Module): ...@@ -69,8 +69,6 @@ class FMoETransformerMLP(nn.Module):
d_model=1024, d_model=1024,
d_hidden=4096, d_hidden=4096,
world_size=1, world_size=1,
model_parallel_size=1,
model_parallel_rank=1,
mp_group=None, mp_group=None,
activation=torch.nn.functional.gelu, activation=torch.nn.functional.gelu,
top_k=2, top_k=2,
...@@ -81,9 +79,13 @@ class FMoETransformerMLP(nn.Module): ...@@ -81,9 +79,13 @@ class FMoETransformerMLP(nn.Module):
self.d_model = d_model self.d_model = d_model
self.d_hidden = d_hidden self.d_hidden = d_hidden
self.world_size = world_size self.world_size = world_size
self.model_parallel_size = model_parallel_size
self.model_parallel_rank = model_parallel_rank
self.mp_group = mp_group self.mp_group = mp_group
if mp_group is None:
self.mp_size = 1
self.mp_rank = 0
else:
self.mp_size = mp_group.size()
self.mp_rank = mp_group.rank()
self.activation = activation self.activation = activation
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
self.top_k = top_k self.top_k = top_k
...@@ -104,10 +106,10 @@ class FMoETransformerMLP(nn.Module): ...@@ -104,10 +106,10 @@ class FMoETransformerMLP(nn.Module):
original_shape = inp.shape original_shape = inp.shape
inp = inp.reshape(-1, self.d_model) inp = inp.reshape(-1, self.d_model)
if self.model_parallel_size > 1: if self.mp_size > 1:
B: int = inp.shape[0] B: int = inp.shape[0]
local_batch_size = B // self.model_parallel_size local_batch_size = B // self.mp_size
batch_start = local_batch_size * self.model_parallel_rank batch_start = local_batch_size * self.mp_rank
batch_end = min(batch_start + local_batch_size, B) batch_end = min(batch_start + local_batch_size, B)
inp = inp[batch_start:batch_end] inp = inp[batch_start:batch_end]
...@@ -138,11 +140,8 @@ class FMoETransformerMLP(nn.Module): ...@@ -138,11 +140,8 @@ class FMoETransformerMLP(nn.Module):
if not self.pre_lnorm: if not self.pre_lnorm:
output = self.layer_norm(output) output = self.layer_norm(output)
if self.model_parallel_size > 1: if self.mp_size > 1:
world_size = self.model_parallel_size output = AllGather.apply(output,
tensor_list = [torch.empty_like(output) for _ in range(world_size)] self.mp_rank, self.mp_size, self.mp_group)
torch.distributed.all_gather(tensor_list, output, group=self.mp_group)
output = torch.cat(tensor_list, dim=1)
return output.reshape(original_shape), self.bias return output.reshape(original_shape), self.bias
from .layers import FMoETransformerMLP from .layers import FMoETransformerMLP
from .distributed import DistributedGroupedDataParallel from .distributed import DistributedGroupedDataParallel
def create_moe_mlp(args, model_parallel_rank, group):
def create_moe_mlp(args, group):
assert ( assert (
args.seq_length * args.batch_size % args.model_parallel_size == 0 args.seq_length * args.batch_size % args.model_parallel_size == 0
), "Batch size x sequence length should be multiple of mp size" ), "Batch size x sequence length should be multiple of mp size"
...@@ -14,9 +15,7 @@ def create_moe_mlp(args, model_parallel_rank, group): ...@@ -14,9 +15,7 @@ def create_moe_mlp(args, model_parallel_rank, group):
d_model=args.hidden_size, d_model=args.hidden_size,
d_hidden=args.hidden_size * 4, d_hidden=args.hidden_size * 4,
world_size=world_size, world_size=world_size,
model_parallel_size=args.model_parallel_size, mp_group=group
model_parallel_rank=model_parallel_rank,
mp_group=group,
) )
for p in fmoe.gate.parameters(): for p in fmoe.gate.parameters():
setattr(p, 'shared', True) setattr(p, 'shared', True)
...@@ -38,9 +37,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True): ...@@ -38,9 +37,7 @@ def fmoefy(model, num_experts=None, distributed_experts=True):
args.distributed_experts = distributed_experts args.distributed_experts = distributed_experts
for l in model.language_model.transformer.layers: for l in model.language_model.transformer.layers:
l.mlp = create_moe_mlp(args, l.mlp = create_moe_mlp(args, mpu.get_model_parallel_group())
mpu.get_model_parallel_rank(),
mpu.get_model_parallel_group())
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment