"doc/vscode:/vscode.git/clone" did not exist on "233ac55e2a32c2ba261496aa1d598ffbc687e4ef"
Commit b3c1340d authored by Daniel Povey's avatar Daniel Povey
Browse files

Test backward; now seems to work.

parent 86e3a617
......@@ -58,7 +58,7 @@ torch::Tensor integrated_conv_cpu(torch::Tensor input,
static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
src = src_input_a[src_h][src_w];
scalar_t relu = src + dest + this_pos_add_a[kh][kw];
if (relu > 0.0)
if (relu >= 0.0)
sum += relu * this_pos_mul_a[kh][kw];
}
}
......@@ -127,7 +127,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
for (int h = 0; h < H; h++) {
for (int w = 0; w < W; w++) {
scalar_t dest = input_a[n][c + C][h][w],
dest_grad = 0.0, // to be multiplied by this_output_grad later..
dest_grad = 0.0, // to be multiplied by this_grad_output later..
this_grad_output = grad_output_a[n][c][h][w];
for (int kh = 0; kh < kH; kh++) {
int src_h = h + kh - kH / 2;
......@@ -140,7 +140,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
scalar_t relu = src + dest + pos_add_a[c][kh][kw];
if (relu >= 0.0) {
scalar_t pos_mul_val = pos_mul_a[c][kh][kw];
dest_grad += pos_mul_val; // will later multiply by this_output_grad
dest_grad += pos_mul_val; // will later multiply by this_grad_output
grad_pos_add_a[c][kh][kw] += this_grad_output * pos_mul_val;
grad_pos_mul_a[c][kh][kw] += this_grad_output * relu;
if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
......@@ -149,7 +149,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
}
}
}
grad_input_a[n][c + C][h][w] += dest_grad * this_grad_output;
grad_input_a[n][c + C][h][w] = dest_grad * this_grad_output;
}
}
}
......
......@@ -348,7 +348,7 @@ void integrated_conv_kernel_backward(
// where the 'h' and 'w' indexes are into the zero-padded input
// image.
*dest_img_buf = src_img_buf + ppatch_size, // version of input image that relates to destinatioon position
*grad_output_buf = src_img_buf + ppatch_size, // output gradient for padded patch, indexed [h*ppatchW + w]
*grad_output_buf = dest_img_buf + ppatch_size, // output gradient for padded patch, indexed [h*ppatchW + w]
*grad_pos_add_buf = grad_output_buf + ppatch_size, // total grad for pos_add for this thread block, indexed [kh*kW + kw]
*grad_pos_mul_buf = grad_pos_add_buf + (kH * kW), // total grad for pos_mul for this thread block, indexed [kh*kW + kw]
*reduce_buf = grad_pos_mul_buf + (kH * kW); // buffer for reduction over threads, size == blockDim.x
......@@ -360,14 +360,18 @@ void integrated_conv_kernel_backward(
// Load parts of the kernel parameters pos_add and pos_mul into shared memory,
// in pos_add_buf and pos_mul_buf; zero the corresponding gradient buffers.
// We know that blockDim.x >= kH * kW, see threads_per_kernel_pos.
if (threadIdx.x < kH * kW) {
int i = threadIdx.x;
for (int i = threadIdx.x % (blockDim.x / 2); i < kH * kW; i += (blockDim.x / 2)) {
int kh = i / kW, kw = i % kW;
if (threadIdx.x < blockDim.x / 2) { // First half of threads take care of pos_add..
pos_add_buf[i] = pos_add[c][kh][kw];
pos_mul_buf[i] = pos_mul[c][kh][kw];
grad_pos_add_buf[i] = 0.0;
} else { // Second half take care of pos_mul... there is no warp divergence
// because we make sure blockDim.x is a multiple of 64.
pos_mul_buf[i] = pos_mul[c][kh][kw];
grad_pos_mul_buf[i] = 0.0;
}
}
// n is the index within the batch of images. Loop to make sure we cover all
// images in the batch. input.size(0) is the batch size N. All threads in
......@@ -391,7 +395,8 @@ void integrated_conv_kernel_backward(
// Load the 'src' and 'dest' versions of the padded patch into
// shared-memory buffers, and also the output gradient.
for (int i = threadIdx.x % (blockDim.x / 2); i < ppatch_size; i += (blockDim.x / 2)) {
for (int i = threadIdx.x % (blockDim.x / 2);
i < ppatch_size; i += (blockDim.x / 2)) {
int h_in_ppatch = i / ppatchW,
w_in_ppatch = i % ppatchW;
int h = patch_h_offset + h_in_ppatch - (kH / 2), // kH / 2 is offset due to padding
......@@ -401,7 +406,7 @@ void integrated_conv_kernel_backward(
// load `input`
scalar_t src_val = scalar_t(0),
dest_val = scalar_t(0);
if ((unsigned int)h < (unsigned int)H && // h >= 0 && h < H.
if ((unsigned int)h < (unsigned int)H && // h >= 0 && h < H
(unsigned int)w < (unsigned int)W) { // w >= 0 && w < W
int C = grad_output.size(1);
src_val = input[n][c][h][w];
......@@ -429,7 +434,7 @@ void integrated_conv_kernel_backward(
grad_input_dest_sum = 0.0; // grad for channel c + C, for our pixel
// of `input` (contribution of this thread)
if (pos_in_patch < patch_size) {
// This block computes `grad_input_sum`.
// This block computes `grad_input_src_sum` and `grad_input_dest_sum`
// The num-threads for the backward kernel may not be an exact multiple
// of patch_size, wo we need the if-guard.
......@@ -456,7 +461,6 @@ void integrated_conv_kernel_backward(
// This is actually more like cross-correlation, as we don't have a
// negative sign on the h and w indexes in the kernel.
int src_h_in_ppatch = h_in_patch + h_in_kernel,
src_w_in_ppatch = w_in_patch + w_in_kernel;
int src_pos_in_ppatch = src_h_in_ppatch * ppatchW + src_w_in_ppatch;
......@@ -469,9 +473,11 @@ void integrated_conv_kernel_backward(
scalar_t this_grad_output = grad_output_buf[pos_in_ppatch];
grad_input_dest_sum += this_grad_output * pos_mul_val;
}
// To compute a contribution to "this_input_src_grad", we need to consider the
// contribution to the destination pixel that it would have contributed to
// with this same offset.
// To compute a contribution to "this_input_src_grad", we need to
// consider the contribution to the destination pixel that it would
// have contributed to with this same offset.
// We have to flip the offsets: instead of "+ h_in_kernel",
// we use (kH - 1) - h_in_kernel,.
int dest_h_in_ppatch = h_in_patch + (kH - 1) - h_in_kernel,
dest_w_in_ppatch = w_in_patch + (kW - 1) - w_in_kernel,
dest_pos_in_ppatch = dest_h_in_ppatch * ppatchW + dest_w_in_ppatch;
......@@ -485,6 +491,7 @@ void integrated_conv_kernel_backward(
}
// Aggregate `grad_input_src_sum` over threads, if needed; and write the
// result to `grad_input`.
// h and w are un-padded indexes into the entire image.
int h = patch_h_offset + pos_in_patch / patchW,
w = patch_w_offset + pos_in_patch % patchW;
......@@ -514,7 +521,8 @@ void integrated_conv_kernel_backward(
kw = pos_in_kernel % kW;
// This group of (threads_per_kernel_pos) threads is responsible
// for position (kh, kw) in the kernel; we iterate over the patch.
// for position (kh, kw) in the kernel; we iterate over the patch
// (an un-padded patch of output).
scalar_t pos_add_val = pos_add_buf[pos_in_kernel],
pos_mul_val = pos_mul_buf[pos_in_kernel];
......@@ -524,15 +532,15 @@ void integrated_conv_kernel_backward(
// and pos_mul; we let `pos_in_patch` correspond to the *output*
// position, and work out the input position based on gthe kernel position.
int h_in_patch = pos_in_patch / patchH,
w_in_patch = pos_in_patch / patchW;
int h_in_patch = pos_in_patch / patchW,
w_in_patch = pos_in_patch % patchW;
// pos_in_ppatch is the position in the padded patch corresponding to
// `pos_in_patch`.
int pos_in_ppatch = (h_in_patch + kH / 2) * ppatchW + (w_in_patch + kW / 2);
scalar_t dest_val = dest_img_buf[pos_in_ppatch];
int offset_pos_in_ppatch = (h_in_patch + kh) * ppatchW + (w_in_patch + kw);
scalar_t src_val = src_img_buf[offset_pos_in_ppatch];
int src_pos_in_ppatch = (h_in_patch + kh) * ppatchW + (w_in_patch + kw);
scalar_t src_val = src_img_buf[src_pos_in_ppatch];
scalar_t relu = dest_val + src_val + pos_add_val;
if (relu >= 0.0) {
......@@ -546,13 +554,15 @@ void integrated_conv_kernel_backward(
this_grad_pos_mul = tiled_warp_reduce_sum(
threads_per_kernel_pos, reduce_buf, this_grad_pos_mul);
if (threadIdx.x % threads_per_kernel_pos == 0) {
grad_pos_add_buf[pos_in_kernel] = this_grad_pos_add;
grad_pos_mul_buf[pos_in_kernel] = this_grad_pos_mul;
grad_pos_add_buf[pos_in_kernel] += this_grad_pos_add;
grad_pos_mul_buf[pos_in_kernel] += this_grad_pos_mul;
}
}
}
}
__syncthreads(); // make sure all threads have written to grad_pos_add_buf and
// grad_pos_mul_buf.
int block = blockIdx.z * gridDim.y + blockIdx.y;
int kernel_pos = threadIdx.x;
......
......@@ -18,7 +18,7 @@ def test_integrated_conv_zeros():
kH = 5
kW = 5
pos_add = torch.zeros(C, kH, kW, device=device, dtype=dtype)
pos_mul = torch.zeros(C, kH, kW, device=device, dtype=dtype)
pos_mul = torch.ones(C, kH, kW, device=device, dtype=dtype)
input.requires_grad = True
pos_add.requires_grad = True
pos_mul.requires_grad = True
......@@ -45,20 +45,28 @@ def test_integrated_conv_compare():
print("dtype=", dtype)
input = torch.randn(N, 2 * C, H, W, dtype=dtype)
device = torch.device('cuda:0')
input_cuda = input.to(device)
input_cuda = input.to(device).detach()
kH = 5
kW = 5
pos_add = torch.randn(C, kH, kW, dtype=dtype)
pos_mul = torch.randn(C, kH, kW, dtype=dtype)
pos_add_cuda = pos_add.to(device)
pos_mul_cuda = pos_mul.to(device)
pos_add_cuda = pos_add.to(device).detach()
pos_mul_cuda = pos_mul.to(device).detach()
for x in [ pos_add, pos_mul, pos_add_cuda, pos_mul_cuda, input, input_cuda ]:
x.requires_grad = True
output = integrated_conv(input, pos_add, pos_mul)
output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda)
print("output = ", output)
print("output_cuda = ", output_cuda)
output_grad = torch.randn(*output.shape, dtype=dtype)
output.backward(gradient=output_grad)
output_cuda.backward(gradient=output_grad.to(device))
diff = (output - output_cuda.to(torch.device('cpu'))).abs().sum()
abs = output.abs().sum()
print("Diff = ", diff, ", abs = ", abs)
......@@ -66,6 +74,21 @@ def test_integrated_conv_compare():
atol=1.0e-05)
for a,b,name in [ (pos_add, pos_add_cuda, 'pos_add'),
(pos_mul, pos_mul_cuda, 'pos_mul'),
(input, input_cuda, 'input') ]:
grad = a.grad
cuda_grad = b.grad.to(torch.device('cpu'))
diff_abs = (grad - cuda_grad).abs().sum().item()
sum_abs = (grad + cuda_grad).abs().sum().item()
print(f"Comparing grad of {name}: diff={diff_abs}, sum={sum_abs}")
if diff_abs > 1.0e-05 * sum_abs:
print(f"Error: too much difference in grad of {name}.")
print("grad = ", grad)
print("cuda_grad = ", cuda_grad)
def test_integrated_conv_rand_compare():
for _ in range(30):
N = random.randint(1, 256)
......@@ -108,11 +131,80 @@ def test_integrated_conv_rand_compare():
output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda)
diff = (output - output_cuda.to(torch.device('cpu'))).abs().sum()
abs = output.abs().sum()
print("Diff = ", diff, ", abs = ", abs)
sum_abs = output.abs().sum()
print("Diff = ", diff, ", abs = ", sum_abs)
if not torch.allclose(output, output_cuda.to(torch.device('cpu')),
atol=1.0e-05):
if (diff / sum_abs).item() > 0.001:
print("output = ", output)
print("output_cuda = ", output_cuda)
assert 0, "outputs differ"
def test_integrated_conv_rand_grad():
for _ in range(30):
N = random.randint(1, 256)
C = random.randint(1, 64)
H = random.randint(1, 128)
W = random.randint(1, 128)
while N * C * H * W > 65535:
if N >= C and N >= H and N >= W:
N = N // 2
elif C >= H and C >= W:
C = C // 2
elif H >= W:
H = H // 2
else:
W = W // 2
for device in [ torch.device('cpu'), torch.device('cuda:0') ]:
if device == torch.device('cuda:0') and not torch.cuda.is_available():
print("Warning: torch not available, not testing this part.")
continue
for dtype in [torch.float32, torch.float64]:
print("dtype=", dtype, ", device=", device)
input = torch.randn(N, 2 * C, H, W, dtype=dtype, device=device)
kH = random.randint(1, 10)
kW = random.randint(1, 10)
if kH % 2 == 0:
kH += 1
if kW % 2 == 0:
kW += 1
pos_add = torch.randn(C, kH, kW, dtype=dtype, device=device)
pos_mul = torch.randn(C, kH, kW, dtype=dtype, device=device)
input.requires_grad = True
pos_add.requires_grad = True
pos_mul.requires_grad = True
output = integrated_conv(input, pos_add, pos_mul)
output_grad = torch.randn(N, C, H, W, dtype=dtype, device=device)
output.backward(gradient=output_grad)
delta = 1.0e-05
pos_delta = delta * torch.randn(C, kH, kW, dtype=dtype, device=device)
pred_change = (pos_delta * pos_add.grad).sum().to('cpu').item()
change = (output_grad * (integrated_conv(input, pos_add + pos_delta, pos_mul) - output )).sum().to('cpu').item()
print(f"For pos_add: pred_change={pred_change}, change={change}")
#assert abs(pred_change - change) < 1.0e-04
pred_change = (pos_delta * pos_mul.grad).sum().to('cpu').item()
change = (output_grad * (integrated_conv(input, pos_add, pos_mul + pos_delta) - output )).sum().to('cpu').item()
print(f"For pos_mul: pred_change={pred_change}, change={change}")
#assert abs(pred_change - change) / abs(change) < 1.0e-04
input_delta = delta * torch.randn(N, 2*C, H, W, dtype=dtype, device=device)
pred_change = (input_delta * input.grad).sum().to('cpu').item()
change = (output_grad * (integrated_conv(input + input_delta, pos_add, pos_mul) - output )).sum().to('cpu').item()
print(f"For input: pred_change={pred_change}, change={change}")
#assert abs(pred_change - change) / abs(change) < 1.0e-04
if __name__ == "__main__":
test_integrated_conv_rand_grad()
test_integrated_conv_zeros()
test_integrated_conv_compare()
test_integrated_conv_rand_compare()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment