Commit 140282d5 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fixes

parent bec558b1
......@@ -289,11 +289,6 @@ class SpatialBottleneckFunction(torch.autograd.Function):
out1_pad = torch.empty([N,C,Hs+2,W], dtype=out1.dtype, device='cuda', memory_format=memory_format)
stream1.wait_stream(torch.cuda.current_stream())
if spatial_method != 2: stream3.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream3):
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_out1_halo = out1_pad[:,:1,:,:]
......@@ -343,11 +338,11 @@ class SpatialBottleneckFunction(torch.autograd.Function):
out1_pad[:,:,1:Hs+1,:].copy_(out1)
elif spatial_method == 2:
# wait for halo transfer to finish before doing a full convolution of padded x
torch.cuda.current_stream().wait_stream(stream1)
if explicit_nhwc:
out1_pad[:,1:Hs+1,:,:].copy_(out1)
else:
out1_pad[:,:,1:Hs+1,:].copy_(out1)
torch.cuda.current_stream().wait_stream(stream1)
fast_bottleneck.forward_out2_pad(explicit_nhwc, stride_1x1, args, outputs, out1_pad)
elif spatial_method == 3:
fast_bottleneck.forward_out2_mask(explicit_nhwc, stride_1x1, args, outputs, thresholdTop, thresholdBottom)
......@@ -705,8 +700,6 @@ class SpatialBottleneck(torch.nn.Module):
s4, b4 = self.downsample[1].get_scale_bias(self.explicit_nhwc)
w_scale.append(s4)
w_bias.append(b4)
self.w_scale = w_scale
self.w_bias = w_bias
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, w_scale, w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
else:
out = spatial_bottleneck_function(*self.spatial_parallel_args, self.explicit_nhwc, self.stride, self.w_scale, self.w_bias, self.thresholdTop, self.thresholdBottom, x, *self.w_conv)
......
......@@ -153,23 +153,36 @@ __device__ void checked_signal(
const int v1, const int v2, const int v3, const int v4
)
{
if (blockIdx.x == 0) {
register int r1, r2, r3, r4;
if (threadIdx.x == 0) {
// wait for top neighbor to clear bottom signal (indicating ready for new input)
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
} while (r1 == v1 && r2 == v2 && r3 == v3 && r4 == v4);
// signal to top neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
} else if (threadIdx.x == 1) {
// wait for bottom neighbor to clear top signal (indicating ready for new input)
cg::this_grid().sync();
bool is_main_thread = (blockIdx.x == 0 && threadIdx.x == 0) ? true : false;
if (is_main_thread) {
// flush all writes to global memory
__threadfence_system();
// wait for top or bottom neighbor to clear signal
register int r1, r2, r3, r4;
bool top_zeroed=false, btm_zeroed=false, top_done=false, btm_done=false;
do {
do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
} while (r1 == v1 && r2 == v2 && r3 == v3 && r4 == v4);
// signal to bottom neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
}
if (!top_zeroed) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal1_flag) : "memory");
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) top_zeroed = true;
}
if (!btm_zeroed) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(signal2_flag) : "memory");
if (r1 != v1 || r2 != v2 || r3 != v3 || r4 != v4) btm_zeroed = true;
}
} while((top_zeroed == top_done) && (btm_zeroed == btm_done));
if (!top_done && top_zeroed) {
// signal to top neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
top_done = true;
}
if (!btm_done && btm_zeroed) {
// signal to bottom neighbor my output is ready
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
btm_done = true;
}
} while (!top_done || !btm_done);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment