"vscode:/vscode.git/clone" did not exist on "97ef6ff8b3ad0b485173a99a7b1960536bc22437"
Commit a5d51c01 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent 8b6f8fc1
...@@ -510,7 +510,7 @@ class SpatialBottleneckFunction(torch.autograd.Function): ...@@ -510,7 +510,7 @@ class SpatialBottleneckFunction(torch.autograd.Function):
w1by3 = w[:,:1,:,:].clone() w1by3 = w[:,:1,:,:].clone()
ctx.stream2.wait_stream(ctx.stream1) # wait for halo transfers to finish ctx.stream2.wait_stream(ctx.stream1) # wait for halo transfers to finish
ctx.stream2.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel ctx.stream2.wait_stream(torch.cuda.current_stream()) # wait for backward_grad_out1_mask to finish before launching halo correction kernel
with torch.cuda.stream(ctx.stream1): with torch.cuda.stream(ctx.stream2):
btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, btm_halo, btm_relu_halo, btm_grad_out1.clone()) btm_grad_out1_halo = fast_bottleneck.backward_grad_out1_halo_corr(ctx.explicit_nhwc, ctx.stride_1x1, t_list, w1by3, grads, btm_halo, btm_relu_halo, btm_grad_out1.clone())
btm_grad_out1.copy_(btm_grad_out1_halo) btm_grad_out1.copy_(btm_grad_out1_halo)
if ctx.spatial_group_rank > 0: if ctx.spatial_group_rank > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment