Commit 2de0f11f authored by Jan Eric Lenssen's avatar Jan Eric Lenssen
Browse files

bugfixes adj gradient

parent 9c208e8e
...@@ -29,6 +29,7 @@ def spline_conv( ...@@ -29,6 +29,7 @@ def spline_conv(
output = input[col] output = input[col]
# Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out]. # Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
if output.is_cuda: if output.is_cuda:
output = SplineConvGPU(kernel_size, is_open_spline, K, degree, output = SplineConvGPU(kernel_size, is_open_spline, K, degree,
basis_kernel, basis_backward_kernel, basis_kernel, basis_backward_kernel,
......
...@@ -144,28 +144,29 @@ const ${Dtype}* amount, const long* index, int num_threads) { ...@@ -144,28 +144,29 @@ const ${Dtype}* amount, const long* index, int num_threads) {
// Calculate B-spline basis tensor product gradient // Calculate B-spline basis tensor product gradient
adj_g += g * f * w; adj_g += g * f * w;
} }
atomicAdd(&(grad_amount[e_idx*${k_max} +k_idx]), adj_g); atomicAdd(&(grad_amount[e_idx*${k_max} + k_idx]), adj_g);
} }
} }
} }
''' '''
def get_weighting_forward_kernel(M_in, M_out, k_max): def get_weighting_forward_kernel(M_in, M_out, k_max, dtype='float'):
cuda_tensor = torch.FloatTensor([1]).cuda() cuda_tensor = torch.FloatTensor([1]).cuda()
kernel = _edgewise_spline_weighting_forward_kernel kernel = _edgewise_spline_weighting_forward_kernel
with torch.cuda.device_of(cuda_tensor): with torch.cuda.device_of(cuda_tensor):
f_fw = load_kernel( f_fw = load_kernel(
'edgewise_spline_weighting_forward_kernel', 'edgewise_spline_weighting_forward_kernel',
kernel, kernel,
Dtype='float', Dtype=dtype,
M_in=M_in, M_in=M_in,
M_out=M_out, M_out=M_out,
k_max=k_max) k_max=k_max)
return f_fw return f_fw
def get_weighting_backward_kernel(M_in, M_out, k_max, K, bp_to_adj=False): def get_weighting_backward_kernel(M_in, M_out, k_max, K, bp_to_adj=False,
dtype='float'):
cuda_tensor = torch.FloatTensor([1]).cuda() cuda_tensor = torch.FloatTensor([1]).cuda()
if bp_to_adj: if bp_to_adj:
kernel = _edgewise_spline_weighting_backward_kernel_bp2adj kernel = _edgewise_spline_weighting_backward_kernel_bp2adj
...@@ -175,7 +176,7 @@ def get_weighting_backward_kernel(M_in, M_out, k_max, K, bp_to_adj=False): ...@@ -175,7 +176,7 @@ def get_weighting_backward_kernel(M_in, M_out, k_max, K, bp_to_adj=False):
f_bw = load_kernel( f_bw = load_kernel(
'edgewise_spline_weighting_backward_kernel', 'edgewise_spline_weighting_backward_kernel',
kernel, kernel,
Dtype='float', Dtype=dtype,
M_in=M_in, M_in=M_in,
M_out=M_out, M_out=M_out,
k_max=k_max, k_max=k_max,
...@@ -341,27 +342,42 @@ int num_threads) { ...@@ -341,27 +342,42 @@ int num_threads) {
${Dtype} grad_out = 0.0; ${Dtype} grad_out = 0.0;
int quotient = (int)pow(2.0,(double)d_idx); int quotient = (int)pow(2.0,(double)d_idx);
value = input[e_idx * ${dim} + d_idx];
value *= kernel_size[d_idx] - is_open_spline[d_idx];
frac = value - floor(value);
for (int k_idx = 0; k_idx < ${k_max}; k_idx++) { for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
k_idx_mod = (k_idx/quotient) % 2; k_idx_mod = (k_idx/quotient) % 2;
value = input[e_idx * ${dim} + d_idx];
value *= kernel_size[d_idx] - is_open_spline[d_idx];
frac = value - floor(value);
${Dtype} residual = (1 - k_idx_mod) * (frac - 1) + k_idx_mod * frac; ${Dtype} residual = (1 - k_idx_mod) * (frac - 1) + k_idx_mod * frac;
int a_idx = e_idx*${k_max} + k_idx; int a_idx = e_idx*${k_max} + k_idx;
grad_out += grad_amount[a_idx]*amount[a_idx]/residual; grad_out += grad_amount[a_idx]*amount[a_idx]/residual;
} }
grad_adj[e_idx*${dim} + d_idx] = grad_out; grad_adj[idx] = grad_out*(kernel_size[d_idx] - is_open_spline[d_idx]);
} }
} }
/*
${Dtype} a = -(1 - k_idx_mod) + k_idx_mod;
for (int d_it = 0; d_it < ${dim}; d_it++) {
if(d_it!=d_idx)
{
value = input[e_idx * ${dim} + d_it];
value *= kernel_size[d_it] - is_open_spline[d_it];
frac = value - floor(value);
a *= (1 - k_idx_mod) * (1 - frac) + k_idx_mod * frac;
}
}
grad_out += a*grad_amount[a_idx];
*/
''' '''
def get_basis_kernel(k_max, K, dim, degree): def get_basis_kernel(k_max, K, dim, degree, dtype='float'):
if degree == 3: if degree == 3:
_spline_kernel = _spline_kernel_cubic _spline_kernel = _spline_kernel_cubic
elif degree == 2: elif degree == 2:
...@@ -374,14 +390,14 @@ def get_basis_kernel(k_max, K, dim, degree): ...@@ -374,14 +390,14 @@ def get_basis_kernel(k_max, K, dim, degree):
f = load_kernel( f = load_kernel(
'spline_kernel', 'spline_kernel',
_spline_kernel, _spline_kernel,
Dtype='float', Dtype=dtype,
k_max=k_max, k_max=k_max,
dim=dim, dim=dim,
K=K) K=K)
return f return f
def get_basis_backward_kernel(k_max, K, dim, degree): def get_basis_backward_kernel(k_max, K, dim, degree, dtype='float'):
if degree == 3: if degree == 3:
raise NotImplementedError raise NotImplementedError
elif degree == 2: elif degree == 2:
...@@ -394,7 +410,7 @@ def get_basis_backward_kernel(k_max, K, dim, degree): ...@@ -394,7 +410,7 @@ def get_basis_backward_kernel(k_max, K, dim, degree):
f = load_kernel( f = load_kernel(
'spline_kernel', 'spline_kernel',
_spline_kernel, _spline_kernel,
Dtype='float', Dtype=dtype,
k_max=k_max, k_max=k_max,
dim=dim, dim=dim,
K=K) K=K)
...@@ -431,11 +447,10 @@ class SplineConvGPU(Function): ...@@ -431,11 +447,10 @@ class SplineConvGPU(Function):
self.save_for_backward(input, weight) self.save_for_backward(input, weight)
num_edges, dim = adj_values.size() num_edges, dim = adj_values.size()
k_max = 2 ** dim k_max = (self.degree+1) ** dim
amount = adj_values.new(num_edges, k_max) amount = adj_values.new(num_edges, k_max)
index = adj_values.new(num_edges, k_max).long() index = adj_values.new(num_edges, k_max).long()
num_threads = amount.numel() num_threads = amount.numel()
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
self.f_basis_fw( self.f_basis_fw(
block=(cuda_num_threads, 1, 1), block=(cuda_num_threads, 1, 1),
...@@ -452,8 +467,8 @@ class SplineConvGPU(Function): ...@@ -452,8 +467,8 @@ class SplineConvGPU(Function):
# Weight features # Weight features
output = input.new(input.size(0), self.M_out) output = input.new(input.size(0), self.M_out)
num_threads = output.numel()
num_threads = output.numel()
with torch.cuda.device_of(input): with torch.cuda.device_of(input):
self.f_weighting_fw( self.f_weighting_fw(
block=(cuda_num_threads, 1, 1), block=(cuda_num_threads, 1, 1),
...@@ -468,15 +483,12 @@ class SplineConvGPU(Function): ...@@ -468,15 +483,12 @@ class SplineConvGPU(Function):
], ],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
self.amount = amount self.amount = amount
self.index = index self.index = index
return output return output
def backward(self, grad_output): def backward(self, grad_output):
print('grad_output:',grad_output.min(), grad_output.max())
grad_input = grad_output.new(grad_output.size(0), self.M_in).fill_(0) grad_input = grad_output.new(grad_output.size(0), self.M_in).fill_(0)
grad_weight = grad_output.new(self.K, self.M_in, self.M_out).fill_(0) grad_weight = grad_output.new(self.K, self.M_in, self.M_out).fill_(0)
num_threads = grad_output.numel() num_threads = grad_output.numel()
...@@ -488,7 +500,6 @@ class SplineConvGPU(Function): ...@@ -488,7 +500,6 @@ class SplineConvGPU(Function):
index = self.index index = self.index
grad_amount = grad_output.new(amount.size(0), grad_amount = grad_output.new(amount.size(0),
amount.size(1)).fill_(0) amount.size(1)).fill_(0)
with torch.cuda.device_of(grad_output): with torch.cuda.device_of(grad_output):
self.f_weighting_bw( self.f_weighting_bw(
block=(cuda_num_threads, 1, 1), block=(cuda_num_threads, 1, 1),
...@@ -529,6 +540,7 @@ class SplineConvGPU(Function): ...@@ -529,6 +540,7 @@ class SplineConvGPU(Function):
#print('grad_weight:',grad_weight[:,:,-1].min(), grad_weight[:,:,-1].max()) #print('grad_weight:',grad_weight[:,:,-1].min(), grad_weight[:,:,-1].max())
#print('grad_amount:',grad_amount.min(), grad_amount.max()) #print('grad_amount:',grad_amount.min(), grad_amount.max())
#print('grad_adj:',grad_adj.min(), grad_adj.max()) #print('grad_adj:',grad_adj.min(), grad_adj.max())
return grad_input, grad_weight, grad_adj return grad_input, grad_weight, grad_adj
else: else:
......
...@@ -11,11 +11,12 @@ from .spline_conv_gpu import get_basis_kernel,get_basis_backward_kernel, \ ...@@ -11,11 +11,12 @@ from .spline_conv_gpu import get_basis_kernel,get_basis_backward_kernel, \
class SplineConvTest(unittest.TestCase): class SplineConvTest(unittest.TestCase):
'''
@unittest.skipIf(not torch.cuda.is_available(), 'no GPU') @unittest.skipIf(not torch.cuda.is_available(), 'no GPU')
def test_forward_gpu(self): def test_forward_gpu(self):
edges = torch.LongTensor([[0, 0, 0, 0], [1, 2, 3, 4]]) edges = torch.LongTensor([[0, 0, 0, 0], [1, 2, 3, 4]])
values = [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]] values = [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]]
values = torch.FloatTensor(values) values = torch.FloatTensor(values).double()
adj = {'indices': edges.cuda(), 'values': Variable(values.cuda()), adj = {'indices': edges.cuda(), 'values': Variable(values.cuda()),
'size': torch.Size([5, 5, 2])} 'size': torch.Size([5, 5, 2])}
...@@ -23,11 +24,12 @@ class SplineConvTest(unittest.TestCase): ...@@ -23,11 +24,12 @@ class SplineConvTest(unittest.TestCase):
kernel_size = torch.cuda.LongTensor([3, 4]) kernel_size = torch.cuda.LongTensor([3, 4])
is_open_spline = torch.cuda.LongTensor([1, 0]) is_open_spline = torch.cuda.LongTensor([1, 0])
input = torch.FloatTensor([[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]]) input = torch.FloatTensor([[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]]).double()
weight = torch.arange(0.5, 0.5 * 27, step=0.5).view(13, 2, 1) weight = torch.arange(0.5, 0.5 * 27, step=0.5).view(13, 2, 1).double()
input, weight = input.cuda(), weight.cuda() input, weight = input.cuda(), weight.cuda()
input, weight = Variable(input), Variable(weight) input, weight = Variable(input), Variable(weight)
row, col = adj['indices']
output = input[col]
K = 12 K = 12
in_features = 2 in_features = 2
out_features = 1 out_features = 1
...@@ -43,9 +45,29 @@ class SplineConvTest(unittest.TestCase): ...@@ -43,9 +45,29 @@ class SplineConvTest(unittest.TestCase):
basis_bw_k = get_basis_backward_kernel(k_max, K, dim, degree) basis_bw_k = get_basis_backward_kernel(k_max, K, dim, degree)
output = spline_conv( #output = spline_conv(
adj, input, weight, kernel_size, is_open_spline, K, fw_k, bw_k, # adj, input, weight, kernel_size, is_open_spline, K, fw_k, bw_k,
basis_fw_k, basis_bw_k,bp_to_adj=True) # basis_fw_k, basis_bw_k,bp_to_adj=True)
values = adj['values']
output = SplineConvGPU(kernel_size, is_open_spline, K, degree,
basis_fw_k, basis_bw_k, fw_k, bw_k, bp_to_adj=True)\
(output, weight, values)
zero = output.data.new(adj['size'][1], output.size(1)).fill_(0.0)
zero = Variable(zero) if not torch.is_tensor(output) else zero
r = row.view(-1, 1).expand(row.size(0), output.size(1))
output = zero.scatter_add_(0, Variable(r), output)
# Weighten root node features by multiplying with root weight.
output += torch.mm(input, weight[-1])
# Normalize output by degree.
ones = values.data.new(values.size(0)).fill_(1)
zero = values.data.new(output.size(0)).fill_(0)
degree = zero.scatter_add_(0, row, ones)
degree = torch.clamp(degree, min=1)
output = output / Variable(degree.view(-1, 1))
expected_output = [ expected_output = [
[(12.5 * 9 + 13 * 10 + 266) / 4], [(12.5 * 9 + 13 * 10 + 266) / 4],
...@@ -56,14 +78,16 @@ class SplineConvTest(unittest.TestCase): ...@@ -56,14 +78,16 @@ class SplineConvTest(unittest.TestCase):
] ]
assert_almost_equal(output.cpu().data.numpy(), expected_output, 1) assert_almost_equal(output.cpu().data.numpy(), expected_output, 1)
@unittest.skipIf(not torch.cuda.is_available(), 'no GPU') @unittest.skipIf(not torch.cuda.is_available(), 'no GPU')
def test_backward(self): def test_backward(self):
kernel_size = torch.cuda.LongTensor([3, 4]) kernel_size = torch.cuda.LongTensor([3, 4])
is_open_spline = torch.cuda.LongTensor([1, 0]) is_open_spline = torch.cuda.LongTensor([1, 1])
input = torch.randn(4, 2).double().cuda() input = torch.randn(4, 2).double().cuda()
weight = torch.randn(12, 2, 1).double().cuda() weight = torch.randn(12, 2, 1).double().cuda()
values = torch.randn(4, 2).double().cuda() values = torch.FloatTensor(4, 2).uniform_(0, 1).double().cuda()
print(values)
input = Variable(input, requires_grad=True) input = Variable(input, requires_grad=True)
weight = Variable(weight, requires_grad=True) weight = Variable(weight, requires_grad=True)
values = Variable(values, requires_grad=True) values = Variable(values, requires_grad=True)
...@@ -84,7 +108,45 @@ class SplineConvTest(unittest.TestCase): ...@@ -84,7 +108,45 @@ class SplineConvTest(unittest.TestCase):
op = SplineConvGPU(kernel_size, is_open_spline, K, degree, op = SplineConvGPU(kernel_size, is_open_spline, K, degree,
basis_fw_k, basis_bw_k, fw_k, bw_k, bp_to_adj=True) basis_fw_k, basis_bw_k, fw_k, bw_k, bp_to_adj=True)
print(op(input, weight, values))
#test = gradcheck(op, (input, weight, values), eps=1e-6, atol=1e-4)
#self.assertTrue(test)
'''
@unittest.skipIf(not torch.cuda.is_available(), 'no GPU')
def test_backward(self):
input = torch.randn(4, 2).double().cuda()
weight = torch.randn(9, 2, 1).double().cuda()
values = torch.FloatTensor(4, 2).uniform_(0, 1).double().cuda()
print(values)
input = Variable(input, requires_grad=True)
weight = Variable(weight, requires_grad=True)
values = Variable(values, requires_grad=True)
K = 9
in_features = 2
out_features = 1
degree = 1
dim = 2
k_max = (degree + 1) ** dim
kernel_size = torch.cuda.LongTensor([3, 3])
is_open_spline = torch.cuda.LongTensor([1, 0])
fw_k = get_weighting_forward_kernel(in_features, out_features, k_max,
dtype='double')
bw_k = get_weighting_backward_kernel(in_features, out_features, k_max,
K, True,dtype='double')
basis_fw_k = get_basis_kernel(k_max, K, dim, degree, dtype='double')
basis_bw_k = get_basis_backward_kernel(k_max, K, dim, degree,
dtype='double')
op = SplineConvGPU(kernel_size, is_open_spline, K, degree,
basis_fw_k, basis_bw_k, fw_k, bw_k,
bp_to_adj=True)
#print(op(input, weight, values))
test = gradcheck(op, (input, weight, values), eps=1e-6, atol=1e-4) test = gradcheck(op, (input, weight, values), eps=1e-6, atol=1e-4)
print(test)
self.assertTrue(test) self.assertTrue(test)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment