Unverified Commit 5083a736 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #108 from laekov/faster-bug

Fix type mismatch, shape mismatch and lack of condition in FasterMoE's expert shadowing
parents 665b99bf 4682c1d0
...@@ -29,6 +29,8 @@ def stash_expert_params(e, params): ...@@ -29,6 +29,8 @@ def stash_expert_params(e, params):
def pop_expert_params(e): def pop_expert_params(e):
if not hasattr(e, 'expert_param_stash'): if not hasattr(e, 'expert_param_stash'):
return return
if not e.expert_param_stash:
return
for n, p in e.named_parameters(): for n, p in e.named_parameters():
with torch.no_grad(): with torch.no_grad():
p.copy_(e.expert_param_stash[n]) p.copy_(e.expert_param_stash[n])
...@@ -53,6 +55,6 @@ def set_grads(e, grads): ...@@ -53,6 +55,6 @@ def set_grads(e, grads):
seg = grads[offset:offset + p.numel()] seg = grads[offset:offset + p.numel()]
offset += p.numel() offset += p.numel()
if p.grad is None: if p.grad is None:
p.grad = seg.clone() p.grad = seg.clone().reshape(p.shape)
else: else:
p.grad += seg.reshape(p.shape) p.grad += seg.reshape(p.shape)
...@@ -35,7 +35,7 @@ class MoEForward(Function): ...@@ -35,7 +35,7 @@ class MoEForward(Function):
x.requires_grad = True x.requires_grad = True
# To skip torch autograd's version check. # To skip torch autograd's version check.
with torch.autograd.graph.saved_tensors_hooks(nothing, nothing): with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
y0 = expert_fn(x, [x.shape[0]]) y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64))
ctx.gibs[idx] = x ctx.gibs[idx] = x
ctx.gobs[idx] = y0 ctx.gobs[idx] = y0
y.copy_(y0) y.copy_(y0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment