Commit b3c1340d authored by Daniel Povey's avatar Daniel Povey
Browse files

Test backward; now seems to work.

parent 86e3a617
...@@ -58,7 +58,7 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input, ...@@ -58,7 +58,7 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W)) static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
src = src_input_a[src_h][src_w]; src = src_input_a[src_h][src_w];
scalar_t relu = src + dest + this_pos_add_a[kh][kw]; scalar_t relu = src + dest + this_pos_add_a[kh][kw];
if (relu > 0.0) if (relu >= 0.0)
sum += relu * this_pos_mul_a[kh][kw]; sum += relu * this_pos_mul_a[kh][kw];
} }
} }
...@@ -127,7 +127,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input, ...@@ -127,7 +127,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
for (int h = 0; h < H; h++) { for (int h = 0; h < H; h++) {
for (int w = 0; w < W; w++) { for (int w = 0; w < W; w++) {
scalar_t dest = input_a[n][c + C][h][w], scalar_t dest = input_a[n][c + C][h][w],
dest_grad = 0.0, // to be multiplied by this_output_grad later.. dest_grad = 0.0, // to be multiplied by this_grad_output later..
this_grad_output = grad_output_a[n][c][h][w]; this_grad_output = grad_output_a[n][c][h][w];
for (int kh = 0; kh < kH; kh++) { for (int kh = 0; kh < kH; kh++) {
int src_h = h + kh - kH / 2; int src_h = h + kh - kH / 2;
...@@ -140,7 +140,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input, ...@@ -140,7 +140,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
scalar_t relu = src + dest + pos_add_a[c][kh][kw]; scalar_t relu = src + dest + pos_add_a[c][kh][kw];
if (relu >= 0.0) { if (relu >= 0.0) {
scalar_t pos_mul_val = pos_mul_a[c][kh][kw]; scalar_t pos_mul_val = pos_mul_a[c][kh][kw];
dest_grad += pos_mul_val; // will later multiply by this_output_grad dest_grad += pos_mul_val; // will later multiply by this_grad_output
grad_pos_add_a[c][kh][kw] += this_grad_output * pos_mul_val; grad_pos_add_a[c][kh][kw] += this_grad_output * pos_mul_val;
grad_pos_mul_a[c][kh][kw] += this_grad_output * relu; grad_pos_mul_a[c][kh][kw] += this_grad_output * relu;
if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) && if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
...@@ -149,7 +149,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input, ...@@ -149,7 +149,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
} }
} }
} }
grad_input_a[n][c + C][h][w] += dest_grad * this_grad_output; grad_input_a[n][c + C][h][w] = dest_grad * this_grad_output;
} }
} }
} }
......
...@@ -348,7 +348,7 @@ void integrated_conv_kernel_backward( ...@@ -348,7 +348,7 @@ void integrated_conv_kernel_backward(
// where the 'h' and 'w' indexes are into the zero-padded input // where the 'h' and 'w' indexes are into the zero-padded input
// image. // image.
*dest_img_buf = src_img_buf + ppatch_size, // version of input image that relates to destinatioon position *dest_img_buf = src_img_buf + ppatch_size, // version of input image that relates to destinatioon position
*grad_output_buf = src_img_buf + ppatch_size, // output gradient for padded patch, indexed [h*ppatchW + w] *grad_output_buf = dest_img_buf + ppatch_size, // output gradient for padded patch, indexed [h*ppatchW + w]
*grad_pos_add_buf = grad_output_buf + ppatch_size, // total grad for pos_add for this thread block, indexed [kh*kW + kw] *grad_pos_add_buf = grad_output_buf + ppatch_size, // total grad for pos_add for this thread block, indexed [kh*kW + kw]
*grad_pos_mul_buf = grad_pos_add_buf + (kH * kW), // total grad for pos_mul for this thread block, indexed [kh*kW + kw] *grad_pos_mul_buf = grad_pos_add_buf + (kH * kW), // total grad for pos_mul for this thread block, indexed [kh*kW + kw]
*reduce_buf = grad_pos_mul_buf + (kH * kW); // buffer for reduction over threads, size == blockDim.x *reduce_buf = grad_pos_mul_buf + (kH * kW); // buffer for reduction over threads, size == blockDim.x
...@@ -360,13 +360,17 @@ void integrated_conv_kernel_backward( ...@@ -360,13 +360,17 @@ void integrated_conv_kernel_backward(
// Load parts of the kernel parameters pos_add and pos_mul into shared memory, // Load parts of the kernel parameters pos_add and pos_mul into shared memory,
// in pos_add_buf and pos_mul_buf; zero the corresponding gradient buffers. // in pos_add_buf and pos_mul_buf; zero the corresponding gradient buffers.
// We know that blockDim.x >= kH * kW, see threads_per_kernel_pos. // We know that blockDim.x >= kH * kW, see threads_per_kernel_pos.
if (threadIdx.x < kH * kW) {
int i = threadIdx.x; for (int i = threadIdx.x % (blockDim.x / 2); i < kH * kW; i += (blockDim.x / 2)) {
int kh = i / kW, kw = i % kW; int kh = i / kW, kw = i % kW;
pos_add_buf[i] = pos_add[c][kh][kw]; if (threadIdx.x < blockDim.x / 2) { // First half of threads take care of pos_add..
pos_mul_buf[i] = pos_mul[c][kh][kw]; pos_add_buf[i] = pos_add[c][kh][kw];
grad_pos_add_buf[i] = 0.0; grad_pos_add_buf[i] = 0.0;
grad_pos_mul_buf[i] = 0.0; } else { // Second half take care of pos_mul... there is no warp divergence
// because we make sure blockDim.x is a multiple of 64.
pos_mul_buf[i] = pos_mul[c][kh][kw];
grad_pos_mul_buf[i] = 0.0;
}
} }
// n is the index within the batch of images. Loop to make sure we cover all // n is the index within the batch of images. Loop to make sure we cover all
...@@ -391,7 +395,8 @@ void integrated_conv_kernel_backward( ...@@ -391,7 +395,8 @@ void integrated_conv_kernel_backward(
// Load the 'src' and 'dest' versions of the padded patch into // Load the 'src' and 'dest' versions of the padded patch into
// shared-memory buffers, and also the output gradient. // shared-memory buffers, and also the output gradient.
for (int i = threadIdx.x % (blockDim.x / 2); i < ppatch_size; i += (blockDim.x / 2)) { for (int i = threadIdx.x % (blockDim.x / 2);
i < ppatch_size; i += (blockDim.x / 2)) {
int h_in_ppatch = i / ppatchW, int h_in_ppatch = i / ppatchW,
w_in_ppatch = i % ppatchW; w_in_ppatch = i % ppatchW;
int h = patch_h_offset + h_in_ppatch - (kH / 2), // kH / 2 is offset due to padding int h = patch_h_offset + h_in_ppatch - (kH / 2), // kH / 2 is offset due to padding
...@@ -401,7 +406,7 @@ void integrated_conv_kernel_backward( ...@@ -401,7 +406,7 @@ void integrated_conv_kernel_backward(
// load `input` // load `input`
scalar_t src_val = scalar_t(0), scalar_t src_val = scalar_t(0),
dest_val = scalar_t(0); dest_val = scalar_t(0);
if ((unsigned int)h < (unsigned int)H && // h >= 0 && h < H. if ((unsigned int)h < (unsigned int)H && // h >= 0 && h < H
(unsigned int)w < (unsigned int)W) { // w >= 0 && w < W (unsigned int)w < (unsigned int)W) { // w >= 0 && w < W
int C = grad_output.size(1); int C = grad_output.size(1);
src_val = input[n][c][h][w]; src_val = input[n][c][h][w];
...@@ -429,7 +434,7 @@ void integrated_conv_kernel_backward( ...@@ -429,7 +434,7 @@ void integrated_conv_kernel_backward(
grad_input_dest_sum = 0.0; // grad for channel c + C, for our pixel grad_input_dest_sum = 0.0; // grad for channel c + C, for our pixel
// of `input` (contribution of this thread) // of `input` (contribution of this thread)
if (pos_in_patch < patch_size) { if (pos_in_patch < patch_size) {
// This block computes `grad_input_sum`. // This block computes `grad_input_src_sum` and `grad_input_dest_sum`
// The num-threads for the backward kernel may not be an exact multiple // The num-threads for the backward kernel may not be an exact multiple
// of patch_size, wo we need the if-guard. // of patch_size, wo we need the if-guard.
...@@ -456,7 +461,6 @@ void integrated_conv_kernel_backward( ...@@ -456,7 +461,6 @@ void integrated_conv_kernel_backward(
// This is actually more like cross-correlation, as we don't have a // This is actually more like cross-correlation, as we don't have a
// negative sign on the h and w indexes in the kernel. // negative sign on the h and w indexes in the kernel.
int src_h_in_ppatch = h_in_patch + h_in_kernel, int src_h_in_ppatch = h_in_patch + h_in_kernel,
src_w_in_ppatch = w_in_patch + w_in_kernel; src_w_in_ppatch = w_in_patch + w_in_kernel;
int src_pos_in_ppatch = src_h_in_ppatch * ppatchW + src_w_in_ppatch; int src_pos_in_ppatch = src_h_in_ppatch * ppatchW + src_w_in_ppatch;
...@@ -469,9 +473,11 @@ void integrated_conv_kernel_backward( ...@@ -469,9 +473,11 @@ void integrated_conv_kernel_backward(
scalar_t this_grad_output = grad_output_buf[pos_in_ppatch]; scalar_t this_grad_output = grad_output_buf[pos_in_ppatch];
grad_input_dest_sum += this_grad_output * pos_mul_val; grad_input_dest_sum += this_grad_output * pos_mul_val;
} }
// To compute a contribution to "this_input_src_grad", we need to consider the // To compute a contribution to "this_input_src_grad", we need to
// contribution to the destination pixel that it would have contributed to // consider the contribution to the destination pixel that it would
// with this same offset. // have contributed to with this same offset.
// We have to flip the offsets: instead of "+ h_in_kernel",
// we use (kH - 1) - h_in_kernel,.
int dest_h_in_ppatch = h_in_patch + (kH - 1) - h_in_kernel, int dest_h_in_ppatch = h_in_patch + (kH - 1) - h_in_kernel,
dest_w_in_ppatch = w_in_patch + (kW - 1) - w_in_kernel, dest_w_in_ppatch = w_in_patch + (kW - 1) - w_in_kernel,
dest_pos_in_ppatch = dest_h_in_ppatch * ppatchW + dest_w_in_ppatch; dest_pos_in_ppatch = dest_h_in_ppatch * ppatchW + dest_w_in_ppatch;
...@@ -485,6 +491,7 @@ void integrated_conv_kernel_backward( ...@@ -485,6 +491,7 @@ void integrated_conv_kernel_backward(
} }
// Aggregate `grad_input_src_sum` over threads, if needed; and write the // Aggregate `grad_input_src_sum` over threads, if needed; and write the
// result to `grad_input`. // result to `grad_input`.
// h and w are un-padded indexes into the entire image.
int h = patch_h_offset + pos_in_patch / patchW, int h = patch_h_offset + pos_in_patch / patchW,
w = patch_w_offset + pos_in_patch % patchW; w = patch_w_offset + pos_in_patch % patchW;
...@@ -514,7 +521,8 @@ void integrated_conv_kernel_backward( ...@@ -514,7 +521,8 @@ void integrated_conv_kernel_backward(
kw = pos_in_kernel % kW; kw = pos_in_kernel % kW;
// This group of (threads_per_kernel_pos) threads is responsible // This group of (threads_per_kernel_pos) threads is responsible
// for position (kh, kw) in the kernel; we iterate over the patch. // for position (kh, kw) in the kernel; we iterate over the patch
// (an un-padded patch of output).
scalar_t pos_add_val = pos_add_buf[pos_in_kernel], scalar_t pos_add_val = pos_add_buf[pos_in_kernel],
pos_mul_val = pos_mul_buf[pos_in_kernel]; pos_mul_val = pos_mul_buf[pos_in_kernel];
...@@ -524,15 +532,15 @@ void integrated_conv_kernel_backward( ...@@ -524,15 +532,15 @@ void integrated_conv_kernel_backward(
// and pos_mul; we let `pos_in_patch` correspond to the *output* // and pos_mul; we let `pos_in_patch` correspond to the *output*
// position, and work out the input position based on gthe kernel position. // position, and work out the input position based on gthe kernel position.
int h_in_patch = pos_in_patch / patchH, int h_in_patch = pos_in_patch / patchW,
w_in_patch = pos_in_patch / patchW; w_in_patch = pos_in_patch % patchW;
// pos_in_ppatch is the position in the padded patch corresponding to // pos_in_ppatch is the position in the padded patch corresponding to
// `pos_in_patch`. // `pos_in_patch`.
int pos_in_ppatch = (h_in_patch + kH / 2) * ppatchW + (w_in_patch + kW / 2); int pos_in_ppatch = (h_in_patch + kH / 2) * ppatchW + (w_in_patch + kW / 2);
scalar_t dest_val = dest_img_buf[pos_in_ppatch]; scalar_t dest_val = dest_img_buf[pos_in_ppatch];
int offset_pos_in_ppatch = (h_in_patch + kh) * ppatchW + (w_in_patch + kw); int src_pos_in_ppatch = (h_in_patch + kh) * ppatchW + (w_in_patch + kw);
scalar_t src_val = src_img_buf[offset_pos_in_ppatch]; scalar_t src_val = src_img_buf[src_pos_in_ppatch];
scalar_t relu = dest_val + src_val + pos_add_val; scalar_t relu = dest_val + src_val + pos_add_val;
if (relu >= 0.0) { if (relu >= 0.0) {
...@@ -546,13 +554,15 @@ void integrated_conv_kernel_backward( ...@@ -546,13 +554,15 @@ void integrated_conv_kernel_backward(
this_grad_pos_mul = tiled_warp_reduce_sum( this_grad_pos_mul = tiled_warp_reduce_sum(
threads_per_kernel_pos, reduce_buf, this_grad_pos_mul); threads_per_kernel_pos, reduce_buf, this_grad_pos_mul);
if (threadIdx.x % threads_per_kernel_pos == 0) { if (threadIdx.x % threads_per_kernel_pos == 0) {
grad_pos_add_buf[pos_in_kernel] = this_grad_pos_add; grad_pos_add_buf[pos_in_kernel] += this_grad_pos_add;
grad_pos_mul_buf[pos_in_kernel] = this_grad_pos_mul; grad_pos_mul_buf[pos_in_kernel] += this_grad_pos_mul;
} }
} }
} }
} }
__syncthreads(); // make sure all threads have written to grad_pos_add_buf and
// grad_pos_mul_buf.
int block = blockIdx.z * gridDim.y + blockIdx.y; int block = blockIdx.z * gridDim.y + blockIdx.y;
int kernel_pos = threadIdx.x; int kernel_pos = threadIdx.x;
......
...@@ -18,7 +18,7 @@ def test_integrated_conv_zeros(): ...@@ -18,7 +18,7 @@ def test_integrated_conv_zeros():
kH = 5 kH = 5
kW = 5 kW = 5
pos_add = torch.zeros(C, kH, kW, device=device, dtype=dtype) pos_add = torch.zeros(C, kH, kW, device=device, dtype=dtype)
pos_mul = torch.zeros(C, kH, kW, device=device, dtype=dtype) pos_mul = torch.ones(C, kH, kW, device=device, dtype=dtype)
input.requires_grad = True input.requires_grad = True
pos_add.requires_grad = True pos_add.requires_grad = True
pos_mul.requires_grad = True pos_mul.requires_grad = True
...@@ -45,20 +45,28 @@ def test_integrated_conv_compare(): ...@@ -45,20 +45,28 @@ def test_integrated_conv_compare():
print("dtype=", dtype) print("dtype=", dtype)
input = torch.randn(N, 2 * C, H, W, dtype=dtype) input = torch.randn(N, 2 * C, H, W, dtype=dtype)
device = torch.device('cuda:0') device = torch.device('cuda:0')
input_cuda = input.to(device) input_cuda = input.to(device).detach()
kH = 5 kH = 5
kW = 5 kW = 5
pos_add = torch.randn(C, kH, kW, dtype=dtype) pos_add = torch.randn(C, kH, kW, dtype=dtype)
pos_mul = torch.randn(C, kH, kW, dtype=dtype) pos_mul = torch.randn(C, kH, kW, dtype=dtype)
pos_add_cuda = pos_add.to(device) pos_add_cuda = pos_add.to(device).detach()
pos_mul_cuda = pos_mul.to(device) pos_mul_cuda = pos_mul.to(device).detach()
for x in [ pos_add, pos_mul, pos_add_cuda, pos_mul_cuda, input, input_cuda ]:
x.requires_grad = True
output = integrated_conv(input, pos_add, pos_mul) output = integrated_conv(input, pos_add, pos_mul)
output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda) output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda)
print("output = ", output) print("output = ", output)
print("output_cuda = ", output_cuda) print("output_cuda = ", output_cuda)
output_grad = torch.randn(*output.shape, dtype=dtype)
output.backward(gradient=output_grad)
output_cuda.backward(gradient=output_grad.to(device))
diff = (output - output_cuda.to(torch.device('cpu'))).abs().sum() diff = (output - output_cuda.to(torch.device('cpu'))).abs().sum()
abs = output.abs().sum() abs = output.abs().sum()
print("Diff = ", diff, ", abs = ", abs) print("Diff = ", diff, ", abs = ", abs)
...@@ -66,6 +74,21 @@ def test_integrated_conv_compare(): ...@@ -66,6 +74,21 @@ def test_integrated_conv_compare():
atol=1.0e-05) atol=1.0e-05)
for a,b,name in [ (pos_add, pos_add_cuda, 'pos_add'),
(pos_mul, pos_mul_cuda, 'pos_mul'),
(input, input_cuda, 'input') ]:
grad = a.grad
cuda_grad = b.grad.to(torch.device('cpu'))
diff_abs = (grad - cuda_grad).abs().sum().item()
sum_abs = (grad + cuda_grad).abs().sum().item()
print(f"Comparing grad of {name}: diff={diff_abs}, sum={sum_abs}")
if diff_abs > 1.0e-05 * sum_abs:
print(f"Error: too much difference in grad of {name}.")
print("grad = ", grad)
print("cuda_grad = ", cuda_grad)
def test_integrated_conv_rand_compare(): def test_integrated_conv_rand_compare():
for _ in range(30): for _ in range(30):
N = random.randint(1, 256) N = random.randint(1, 256)
...@@ -108,11 +131,80 @@ def test_integrated_conv_rand_compare(): ...@@ -108,11 +131,80 @@ def test_integrated_conv_rand_compare():
output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda) output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda)
diff = (output - output_cuda.to(torch.device('cpu'))).abs().sum() diff = (output - output_cuda.to(torch.device('cpu'))).abs().sum()
abs = output.abs().sum() sum_abs = output.abs().sum()
print("Diff = ", diff, ", abs = ", abs) print("Diff = ", diff, ", abs = ", sum_abs)
if not torch.allclose(output, output_cuda.to(torch.device('cpu')), if (diff / sum_abs).item() > 0.001:
atol=1.0e-05):
print("output = ", output) print("output = ", output)
print("output_cuda = ", output_cuda) print("output_cuda = ", output_cuda)
assert 0, "outputs differ" assert 0, "outputs differ"
def test_integrated_conv_rand_grad():
for _ in range(30):
N = random.randint(1, 256)
C = random.randint(1, 64)
H = random.randint(1, 128)
W = random.randint(1, 128)
while N * C * H * W > 65535:
if N >= C and N >= H and N >= W:
N = N // 2
elif C >= H and C >= W:
C = C // 2
elif H >= W:
H = H // 2
else:
W = W // 2
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
if device == torch.device('cuda:0') and not torch.cuda.is_available():
print("Warning: torch not available, not testing this part.")
continue
for dtype in [torch.float32, torch.float64]:
print("dtype=", dtype, ", device=", device)
input = torch.randn(N, 2 * C, H, W, dtype=dtype, device=device)
kH = random.randint(1, 10)
kW = random.randint(1, 10)
if kH % 2 == 0:
kH += 1
if kW % 2 == 0:
kW += 1
pos_add = torch.randn(C, kH, kW, dtype=dtype, device=device)
pos_mul = torch.randn(C, kH, kW, dtype=dtype, device=device)
input.requires_grad = True
pos_add.requires_grad = True
pos_mul.requires_grad = True
output = integrated_conv(input, pos_add, pos_mul)
output_grad = torch.randn(N, C, H, W, dtype=dtype, device=device)
output.backward(gradient=output_grad)
delta = 1.0e-05
pos_delta = delta * torch.randn(C, kH, kW, dtype=dtype, device=device)
pred_change = (pos_delta * pos_add.grad).sum().to('cpu').item()
change = (output_grad * (integrated_conv(input, pos_add + pos_delta, pos_mul) - output )).sum().to('cpu').item()
print(f"For pos_add: pred_change={pred_change}, change={change}")
#assert abs(pred_change - change) < 1.0e-04
pred_change = (pos_delta * pos_mul.grad).sum().to('cpu').item()
change = (output_grad * (integrated_conv(input, pos_add, pos_mul + pos_delta) - output )).sum().to('cpu').item()
print(f"For pos_mul: pred_change={pred_change}, change={change}")
#assert abs(pred_change - change) / abs(change) < 1.0e-04
input_delta = delta * torch.randn(N, 2*C, H, W, dtype=dtype, device=device)
pred_change = (input_delta * input.grad).sum().to('cpu').item()
change = (output_grad * (integrated_conv(input + input_delta, pos_add, pos_mul) - output )).sum().to('cpu').item()
print(f"For input: pred_change={pred_change}, change={change}")
#assert abs(pred_change - change) / abs(change) < 1.0e-04
if __name__ == "__main__":
test_integrated_conv_rand_grad()
test_integrated_conv_zeros()
test_integrated_conv_compare()
test_integrated_conv_rand_compare()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment