Commit 3d2bc25c authored by rusty1s's avatar rusty1s
Browse files

bugfixes

parent 2041aef6
......@@ -31,21 +31,24 @@ def test_spline_conv_cpu(tensor):
]
expected_output = [
[1 + 12.5 * 9 + 13 * 10 + sum(edgewise_output) / 4],
[1 + 12.5 * 1 + 13 * 2],
[1 + 12.5 * 3 + 13 * 4],
[1 + 12.5 * 5 + 13 * 6],
[1 + 12.5 * 7 + 13 * 8],
(1 + 12.5 * 9 + 13 * 10 + sum(edgewise_output)) / 5,
1 + 12.5 * 1 + 13 * 2,
1 + 12.5 * 3 + 13 * 4,
1 + 12.5 * 5 + 13 * 6,
1 + 12.5 * 7 + 13 * 8,
]
assert output.tolist() == expected_output
output = [pytest.approx(x, 0.01) for x in output.view(-1).tolist()]
assert output == expected_output
x, weight, pseudo = Variable(x), Variable(weight), Variable(pseudo)
root_weight, bias = Variable(root_weight), Variable(bias)
output = spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, root_weight, 1, bias)
assert output.data.tolist() == expected_output
output = [pytest.approx(x, 0.01) for x in output.data.view(-1).tolist()]
assert output == expected_output
def test_spline_weighting_backward_cpu():
......@@ -57,7 +60,7 @@ def test_spline_weighting_backward_cpu():
x = torch.DoubleTensor(16, 2).uniform_(-1, 1)
x = Variable(x, requires_grad=True)
pseudo = torch.DoubleTensor(16, 3).uniform_(0, 1)
pseudo = Variable(torch.DoubleTensor(pseudo), requires_grad=True)
pseudo = Variable(pseudo, requires_grad=True)
weight = torch.DoubleTensor(25, 2, 4).uniform_(-1, 1)
weight = Variable(weight, requires_grad=True)
......@@ -88,3 +91,19 @@ def test_spline_conv_gpu(tensor):
output = spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, root_weight, 1, bias)
assert output.cpu().tolist() == expected_output.tolist()
def test_spline_weighting_backward_gpu():
for degree in implemented_degrees.keys():
kernel_size = torch.cuda.LongTensor([5, 5, 5])
is_open_spline = torch.cuda.ByteTensor([1, 0, 1])
op = SplineWeighting(kernel_size, is_open_spline, degree)
x = torch.cuda.DoubleTensor(16, 2).uniform_(-1, 1)
x = Variable(x, requires_grad=True)
pseudo = torch.cuda.DoubleTensor(16, 3).uniform_(0, 1)
pseudo = Variable(pseudo, requires_grad=False) # TODO
weight = torch.cuda.DoubleTensor(25, 2, 4).uniform_(-1, 1)
weight = Variable(weight, requires_grad=True)
assert gradcheck(op, (x, pseudo, weight), eps=1e-6, atol=1e-4) is True
......@@ -46,10 +46,9 @@ def spline_weighting_backward_input(grad_output, weight, basis,
grad_input = grad_output.new(grad_output.size(0), weight.size(1))
func = get_func('weighting_backward_input', grad_output)
# Transpose for coalesced memory access.
# Transpose for coalesced memory access on GPU.
weight = weight.transpose(1, 2).contiguous()
func(grad_input, grad_output, weight, basis, weight_index)
weight = weight.transpose(1, 2).contiguous()
return grad_input
......
......@@ -19,15 +19,18 @@ def spline_conv(x,
output = basic_spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree)
# Normalize output by node degree.
# Compute degree.
degree = x.new() if torch.is_tensor(x) else x.data.new()
degree = node_degree(edge_index, x.size(0), out=degree)
degree = degree.unsqueeze(-1).clamp_(min=1)
output /= degree if torch.is_tensor(x) else Var(degree)
# Weight root node separately (if wished).
if root_weight is not None:
output += torch.mm(x, root_weight)
degree += 1
# Normalize output by node degree.
degree = degree.unsqueeze(-1).clamp_(min=1)
output /= degree if torch.is_tensor(x) else Var(degree)
# Add bias (if wished).
if bias is not None:
......
......@@ -17,13 +17,13 @@ __global__ void weightingForwardKernel(TensorInfo<Real> output, TensorInfo<Real>
KERNEL_LOOP(i, n) {
int64_t edgeOffset = i / output.size[1], inputOffset = edgeOffset * input.stride[0];
int64_t s, S = basis.size[1], m_in, M_in = input.size[1], m_out = i % output.size[1], M_out = output.size[1], weightOffset;
Real value = 0;
Real value = 0; Real b;
for (s = 0; s < S; s++) {
b = basis.data[edgeOffset * S + s];
weightOffset = weightIndex.data[edgeOffset * S + s] * M_in * M_out + m_out;
for (m_in = 0; m_in < M_in; m_in++) {
value += weight.data[weightOffset + m_in * M_out] * input.data[inputOffset + m_in * input.stride[1]];
value += weight.data[weightOffset + m_in * M_out] * input.data[inputOffset + m_in * input.stride[1]] * b;
}
value *= basis.data[edgeOffset * S + s];
}
output.data[i] = value;
}
......@@ -34,13 +34,13 @@ __global__ void weightingBackwardInputKernel(TensorInfo<Real> gradInput, TensorI
KERNEL_LOOP(i, n) {
int64_t edgeOffset = i / gradInput.size[1], gradOutputOffset = edgeOffset * gradOutput.stride[0];
int64_t s, S = basis.size[1], m_in = i % gradInput.size[1], M_in = gradInput.size[1], m_out, M_out = gradOutput.size[1], weightOffset;
Real value = 0;
Real value = 0; Real b;
for (s = 0; s < S; s++) {
b = basis.data[edgeOffset * S + s];
weightOffset = weightIndex.data[edgeOffset * S + s] * M_in * M_out + m_in;
for (m_out = 0; m_out < M_out; m_out++) {
value += weight.data[weightOffset + M_in * m_out] * gradOutput.data[gradOutputOffset + m_out];
value += weight.data[weightOffset + M_in * m_out] * gradOutput.data[gradOutputOffset + m_out] * b;
}
value *= basis.data[edgeOffset * S + s];
}
gradInput.data[i] = value;
}
......
......@@ -78,14 +78,14 @@ void spline_(weighting_forward)(THTensor *output, THTensor *input, THTensor *wei
void spline_(weighting_backward_input)(THTensor *grad_input, THTensor *grad_output, THTensor *weight, THTensor *basis, THLongTensor *weight_index) {
real *weight_data = weight->storage->data + weight->storageOffset; real b;
SPLINE_WEIGHTING(grad_input, grad_output, basis, weight_index, THTensor_(size)(weight, 1), THTensor_(size)(weight, 2), THLongTensor_size(weight_index, 1),
SPLINE_WEIGHTING(grad_input, grad_output, basis, weight_index, THTensor_(size)(weight, 2), THTensor_(size)(weight, 1), THLongTensor_size(weight_index, 1),
for (m_in = 0; m_in < M_in; m_in++) {
value = 0;
for (s = 0; s < S; s++) {
b = *(basis_data + s * basis_stride);
w_idx = *(weight_index_data + s * weight_index_stride);
for (m_out = 0; m_out < M_out; m_out++) {
value += b * *(grad_output_data + m_out * grad_output_stride) * *(weight_data + w_idx * M_in * M_out + m_in * M_out + m_out);
value += b * *(grad_output_data + m_out * grad_output_stride) * *(weight_data + w_idx * M_in * M_out + m_out * M_in + m_in);
}
}
grad_input_data[m_in * grad_input_stride] = value;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment