Commit 05dd9c69 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent a5d51c01
...@@ -330,6 +330,7 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -330,6 +330,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
# the first kernel of _forward_rest can launch. # the first kernel of _forward_rest can launch.
# At least we can overlap the two halo correction kernels. # At least we can overlap the two halo correction kernels.
if spatial_group_rank < spatial_group_size-1: if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1) # wait for halo transfers to finish
stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream2): with torch.cuda.stream(stream2):
w1by3 = args[2][:,2:3,:,:].clone() w1by3 = args[2][:,2:3,:,:].clone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment