Commit 03f4110b authored by rusty1s's avatar rusty1s
Browse files

gradcheck

parent dff93289
......@@ -2,18 +2,20 @@ from itertools import product
import pytest
import torch
from torch_spline_conv.basis import basis_forward
from torch.autograd import Variable, gradcheck
from torch_spline_conv.basis import spline_basis, SplineBasis
from torch_spline_conv.utils.ffi import implemented_degrees
from .tensor import tensors
tests = [{
'pseudo': [0, 0.0625, 0.25, 0.75, 0.9375, 1],
'pseudo': [[0], [0.0625], [0.25], [0.75], [0.9375], [1]],
'kernel_size': [5],
'is_open_spline': [1],
'basis': [[1, 0], [0.75, 0.25], [1, 0], [1, 0], [0.25, 0.75], [1, 0]],
'weight_index': [[0, 1], [0, 1], [1, 2], [3, 4], [3, 4], [4, 0]],
}, {
'pseudo': [0, 0.0625, 0.25, 0.75, 0.9375, 1],
'pseudo': [[0], [0.0625], [0.25], [0.75], [0.9375], [1]],
'kernel_size': [4],
'is_open_spline': [0],
'basis': [[1, 0], [0.75, 0.25], [1, 0], [1, 0], [0.25, 0.75], [1, 0]],
......@@ -28,27 +30,39 @@ tests = [{
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_basis_forward_cpu(tensor, i):
def test_spline_basis_cpu(tensor, i):
data = tests[i]
pseudo = getattr(torch, tensor)(data['pseudo'])
kernel_size = torch.LongTensor(data['kernel_size'])
is_open_spline = torch.ByteTensor(data['is_open_spline'])
basis, weight_index = basis_forward(1, pseudo, kernel_size, is_open_spline)
basis, weight_index = spline_basis(1, pseudo, kernel_size, is_open_spline)
assert basis.tolist() == data['basis']
assert weight_index.tolist() == data['weight_index']
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_basis_forward_gpu(tensor, i): # pragma: no cover
def test_spline_basis_gpu(tensor, i): # pragma: no cover
data = tests[i]
pseudo = getattr(torch.cuda, tensor)(data['pseudo'])
kernel_size = torch.cuda.LongTensor(data['kernel_size'])
is_open_spline = torch.cuda.ByteTensor(data['is_open_spline'])
basis, weight_index = basis_forward(1, pseudo, kernel_size, is_open_spline)
basis, weight_index = spline_basis(1, pseudo, kernel_size, is_open_spline)
assert basis.cpu().tolist() == data['basis']
assert weight_index.cpu().tolist() == data['weight_index']
def test_spline_basis_grad_cpu():
degree = 1
kernel_size = torch.LongTensor([5, 5, 5])
is_open_spline = torch.ByteTensor([1, 0, 1])
op = SplineBasis(degree, kernel_size, is_open_spline)
pseudo = torch.DoubleTensor(4, 3).uniform_(0, 1)
pseudo = Variable(pseudo, requires_grad=True)
assert gradcheck(op, (pseudo, ), eps=1e-6, atol=1e-4) is True
......@@ -6,7 +6,6 @@ from .utils.ffi import basis_backward as ffi_basis_backward
def basis_forward(degree, pseudo, kernel_size, is_open_spline):
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
num_nodes, S = pseudo.size(0), (degree + 1)**kernel_size.size(0)
basis = pseudo.new(num_nodes, S)
weight_index = kernel_size.new(num_nodes, S)
......@@ -17,28 +16,36 @@ def basis_forward(degree, pseudo, kernel_size, is_open_spline):
def basis_backward(degree, grad_basis, pseudo, kernel_size, is_open_spline):
grad_pseudo = pseudo.new(pseudo.size())
ffi_basis_backward(degree, grad_pseudo, pseudo, kernel_size,
ffi_basis_backward(degree, grad_pseudo, grad_basis, pseudo, kernel_size,
is_open_spline)
return grad_pseudo
class Basis(Function):
class SplineBasis(Function):
def __init__(self, degree, kernel_size, is_open_spline):
super(Basis, self).__init__()
super(SplineBasis, self).__init__()
self.degree = degree
self.kernel_size = kernel_size
self.is_open_spline = is_open_spline
def forward(self, pseudo):
self.save_for_backawrd(pseudo)
self.save_for_backward(pseudo)
return basis_forward(self.degree, pseudo, self.kernel_size,
self.is_open_spline)
def backward(self, grad_basis, grad_weight_index):
pass
pseudo, = self.saved_tensors
grad_pseudo = None
if self.needs_input_grad[0]:
grad_pseudo = basis_backward(self.degree, grad_basis, pseudo,
self.kernel_size, self.is_open_spline)
def basis(degree, pseudo, kernel_size, is_open_spline):
return grad_pseudo
def spline_basis(degree, pseudo, kernel_size, is_open_spline):
if torch.is_tensor(pseudo):
return basis_forward(degree, pseudo, kernel_size, is_open_spline)
else:
return Basis(degree, kernel_size, is_open_spline)(pseudo)
return SplineBasis(degree, kernel_size, is_open_spline)(pseudo)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment