Commit 945004e7 authored by Rick Ho's avatar Rick Ho
Browse files

fix shadow

parent 226e0779
...@@ -44,10 +44,9 @@ void _reduce_grad( ...@@ -44,10 +44,9 @@ void _reduce_grad(
long expert_size) { long expert_size) {
auto smgr = getCudaStreamManager(t.device().index()); auto smgr = getCudaStreamManager(t.device().index());
auto torch_stream = c10::cuda::getCurrentCUDAStream().stream();
cudaEvent_t evt_stash; cudaEvent_t evt_stash;
cudaEventCreate(&evt_stash); cudaEventCreate(&evt_stash);
cudaEventRecord(evt_stash, torch_stream); cudaEventRecord(evt_stash, smgr->torchStream());
FMOE_SWE(smgr->stream(0), evt_stash); FMOE_SWE(smgr->stream(0), evt_stash);
cudaEventDestroy(evt_stash); cudaEventDestroy(evt_stash);
......
...@@ -37,7 +37,7 @@ class MoEForward(Function): ...@@ -37,7 +37,7 @@ class MoEForward(Function):
try: try:
# To skip torch autograd's version check. # To skip torch autograd's version check.
with torch.autograd.graph.saved_tensors_hooks(nothing, nothing): with torch.autograd.graph.saved_tensors_hooks(nothing, nothing):
y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64)) y0 = expert_fn(x, torch.tensor([x.shape[0]], dtype=torch.int64), expert_idx)
except Exception as e: except Exception as e:
# Ignore the error and fall back for compatibility to older # Ignore the error and fall back for compatibility to older
# versions of PyTorch # versions of PyTorch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment