Commit 704092b1 authored by Sengxian's avatar Sengxian
Browse files

Fix input grad in mp group

parent 092c8d67
...@@ -192,3 +192,27 @@ class AllGather(Function): ...@@ -192,3 +192,27 @@ class AllGather(Function):
def backward(ctx, grad_out): def backward(ctx, grad_out):
rank, dim0 = ctx.args rank, dim0 = ctx.args
return grad_out[rank * dim0:(rank + 1) * dim0], None, None, None return grad_out[rank * dim0:(rank + 1) * dim0], None, None, None
class Slice(Function):
r'''
A wrapper for the Slice function to support auto-differentiation.
'''
@staticmethod
def forward(ctx, inp, rank, world_size, group):
B: int = inp.shape[0]
local_batch_size = B // world_size
batch_start = local_batch_size * rank
batch_end = min(batch_start + local_batch_size, B)
inp = inp[batch_start:batch_end]
ctx.args = world_size, group
return inp
@staticmethod
def backward(ctx, grad_out):
world_size, group = ctx.args
tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, grad_out, group=group)
torch.cuda.synchronize()
grad_out = torch.cat(tensor_list, dim=0)
return grad_out, None, None, None
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
from .functions import moe_prepare_forward from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear from .functions import MOEScatter, MOEGather, MOELinear
from .functions import AllGather from .functions import AllGather, Slice
from .gates import NaiveGate from .gates import NaiveGate
...@@ -179,11 +179,8 @@ class FMoE(nn.Module): ...@@ -179,11 +179,8 @@ class FMoE(nn.Module):
expert is multiplied to the experts' output tensors as a weight. expert is multiplied to the experts' output tensors as a weight.
''' '''
if self.mp_size > 1: if self.mp_size > 1:
B: int = inp.shape[0] inp = Slice.apply(inp,
local_batch_size = B // self.mp_size self.mp_rank, self.mp_size, self.mp_group)
batch_start = local_batch_size * self.mp_rank
batch_end = min(batch_start + local_batch_size, B)
inp = inp[batch_start:batch_end]
gate_top_k_idx, gate_score = self.gate(inp) gate_top_k_idx, gate_score = self.gate(inp)
# to: (BxLxtop_k) x d_model # to: (BxLxtop_k) x d_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment