Commit 704092b1 authored by Sengxian's avatar Sengxian
Browse files

Fix input grad in mp group

parent 092c8d67
......@@ -192,3 +192,27 @@ class AllGather(Function):
def backward(ctx, grad_out):
rank, dim0 = ctx.args
return grad_out[rank * dim0:(rank + 1) * dim0], None, None, None
class Slice(Function):
r'''
A wrapper for the Slice function to support auto-differentiation.
'''
@staticmethod
def forward(ctx, inp, rank, world_size, group):
B: int = inp.shape[0]
local_batch_size = B // world_size
batch_start = local_batch_size * rank
batch_end = min(batch_start + local_batch_size, B)
inp = inp[batch_start:batch_end]
ctx.args = world_size, group
return inp
@staticmethod
def backward(ctx, grad_out):
world_size, group = ctx.args
tensor_list = [torch.empty_like(grad_out) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, grad_out, group=group)
torch.cuda.synchronize()
grad_out = torch.cat(tensor_list, dim=0)
return grad_out, None, None, None
......@@ -8,7 +8,7 @@ import numpy as np
from .functions import moe_prepare_forward
from .functions import MOEScatter, MOEGather, MOELinear
from .functions import AllGather
from .functions import AllGather, Slice
from .gates import NaiveGate
......@@ -179,11 +179,8 @@ class FMoE(nn.Module):
expert is multiplied to the experts' output tensors as a weight.
'''
if self.mp_size > 1:
B: int = inp.shape[0]
local_batch_size = B // self.mp_size
batch_start = local_batch_size * self.mp_rank
batch_end = min(batch_start + local_batch_size, B)
inp = inp[batch_start:batch_end]
inp = Slice.apply(inp,
self.mp_rank, self.mp_size, self.mp_group)
gate_top_k_idx, gate_score = self.gate(inp)
# to: (BxLxtop_k) x d_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment