Commit 7b0ec2c6 authored by rusty1s's avatar rusty1s
Browse files

update torch 0.4.0

parent 8d6a32c4
...@@ -40,7 +40,7 @@ output = SplineConv.apply(src, ...@@ -40,7 +40,7 @@ output = SplineConv.apply(src,
weight, weight,
kernel_size, kernel_size,
is_open_spline, is_open_spline,
degree, degree=1,
root_weight=None, root_weight=None,
bias=None) bias=None)
``` ```
...@@ -65,7 +65,7 @@ The kernel function is defined over the weighted B-spline tensor product basis, ...@@ -65,7 +65,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
* **weight** *(Tensor)* - Trainable weight parameters of shape `(kernel_size x in_channels x out_channels)`. * **weight** *(Tensor)* - Trainable weight parameters of shape `(kernel_size x in_channels x out_channels)`.
* **kernel_size** *(LongTensor)* - Number of trainable weight parameters in each edge dimension. * **kernel_size** *(LongTensor)* - Number of trainable weight parameters in each edge dimension.
* **is_open_spline** *(ByteTensor)* - Whether to use open or closed B-spline bases for each dimension. * **is_open_spline** *(ByteTensor)* - Whether to use open or closed B-spline bases for each dimension.
* **degree** *(Scalar)* - B-spline basis degree. * **degree** *(int, optional)* - B-spline basis degree. (default: `1`)
* **root_weight** *(Tensor, optional)* - Additional shared trainable parameters for each feature of the root node of shape `(in_channels x out_channels)`. (default: `None`) * **root_weight** *(Tensor, optional)* - Additional shared trainable parameters for each feature of the root node of shape `(in_channels x out_channels)`. (default: `None`)
* **bias** *(Tensor, optional)* - Optional bias of shape `(out_channels)`. (default: `None`) * **bias** *(Tensor, optional)* - Optional bias of shape `(out_channels)`. (default: `None`)
...@@ -85,7 +85,7 @@ pseudo = torch.Tensor(6, 2) # two-dimensional edge attributes ...@@ -85,7 +85,7 @@ pseudo = torch.Tensor(6, 2) # two-dimensional edge attributes
weight = torch.Tensor(25, 2, 4) # 25 trainable parameters for in_channels x out_channels weight = torch.Tensor(25, 2, 4) # 25 trainable parameters for in_channels x out_channels
kernel_size = torch.LongTensor([5, 5]) # 5 trainable parameters in each edge dimension kernel_size = torch.LongTensor([5, 5]) # 5 trainable parameters in each edge dimension
is_open_spline = torch.ByteTensor([1, 1]) # only use open B-splines is_open_spline = torch.ByteTensor([1, 1]) # only use open B-splines
degree = torch.tensor(1) # B-spline degree of 1 degree = 1 # B-spline degree of 1
root_weight = torch.Tensor(2, 4) # separately weight root nodes root_weight = torch.Tensor(2, 4) # separately weight root nodes
bias = None # do not apply an additional bias bias = None # do not apply an additional bias
......
import os.path as osp import os.path as osp
import shutil
import subprocess import subprocess
import torch import torch
from torch.utils.ffi import create_extension from torch.utils.ffi import create_extension
if osp.exists('build'):
shutil.rmtree('build')
files = ['Basis', 'Weighting'] files = ['Basis', 'Weighting']
headers = ['aten/TH/TH{}.h'.format(f) for f in files] headers = ['aten/TH/TH{}.h'.format(f) for f in files]
......
...@@ -2,9 +2,7 @@ from itertools import product ...@@ -2,9 +2,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck
from torch_spline_conv.basis import SplineBasis from torch_spline_conv.basis import SplineBasis
from torch_spline_conv.utils.ffi import implemented_degrees as degrees
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
...@@ -31,24 +29,10 @@ tests = [{ ...@@ -31,24 +29,10 @@ tests = [{
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spline_basis_forward(test, dtype, device): def test_spline_basis_forward(test, dtype, device):
degree = torch.tensor(1)
pseudo = tensor(test['pseudo'], dtype, device) pseudo = tensor(test['pseudo'], dtype, device)
kernel_size = tensor(test['kernel_size'], torch.long, device) kernel_size = tensor(test['kernel_size'], torch.long, device)
is_open_spline = tensor(test['is_open_spline'], torch.uint8, device) is_open_spline = tensor(test['is_open_spline'], torch.uint8, device)
basis, weight_index = SplineBasis.apply(degree, pseudo, kernel_size, basis, weight_idx = SplineBasis.apply(pseudo, kernel_size, is_open_spline)
is_open_spline)
assert basis.tolist() == test['basis'] assert basis.tolist() == test['basis']
assert weight_index.tolist() == test['weight_index'] assert weight_idx.tolist() == test['weight_index']
@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices))
def test_spline_basis_backward(degree, device):
degree = torch.tensor(degree)
pseudo = torch.rand((4, 3), dtype=torch.double, device=device)
pseudo.requires_grad_()
kernel_size = tensor([5, 5, 5], torch.long, device)
is_open_spline = tensor([1, 0, 1], torch.uint8, device)
data = (degree, pseudo, kernel_size, is_open_spline)
# assert gradcheck(SplineBasis.apply, data, eps=1e-6, atol=1e-4) is True
...@@ -3,7 +3,7 @@ from itertools import product ...@@ -3,7 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_spline_conv import spline_conv from torch_spline_conv import SplineConv
from torch_spline_conv.utils.ffi import implemented_degrees as degrees from torch_spline_conv.utils.ffi import implemented_degrees as degrees
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
...@@ -48,36 +48,30 @@ def test_spline_conv_forward(test, dtype, device): ...@@ -48,36 +48,30 @@ def test_spline_conv_forward(test, dtype, device):
weight = tensor(test['weight'], dtype, device) weight = tensor(test['weight'], dtype, device)
kernel_size = tensor(test['kernel_size'], torch.long, device) kernel_size = tensor(test['kernel_size'], torch.long, device)
is_open_spline = tensor(test['is_open_spline'], torch.uint8, device) is_open_spline = tensor(test['is_open_spline'], torch.uint8, device)
degree = torch.tensor(1)
root_weight = tensor(test['root_weight'], dtype, device) root_weight = tensor(test['root_weight'], dtype, device)
bias = tensor(test['bias'], dtype, device) bias = tensor(test['bias'], dtype, device)
output = spline_conv(src, edge_index, pseudo, weight, kernel_size, output = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, root_weight, bias) is_open_spline, 1, root_weight, bias)
assert output.tolist() == test['output'] assert output.tolist() == test['output']
@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices)) @pytest.mark.parametrize('degree,device', product(degrees.keys(), devices))
def test_spline_basis_backward(degree, device): def test_spline_basis_backward(degree, device):
pass src = torch.rand((3, 2), dtype=torch.double, device=device)
# src = torch.DoubleTensor(3, 2).uniform_(-1, 1) src.requires_grad_()
# edge_index = torch.LongTensor([[0, 1, 1, 2], [1, 0, 2, 1]]) edge_index = tensor([[0, 1, 1, 2], [1, 0, 2, 1]], torch.long, device)
# pseudo = torch.DoubleTensor(4, 3).uniform_(0, 1) pseudo = torch.rand((4, 3), dtype=torch.double, device=device)
# weight = torch.DoubleTensor(125, 2, 4).uniform_(-1, 1) pseudo.requires_grad_()
# kernel_size = torch.LongTensor([5, 5, 5]) weight = torch.rand((125, 2, 4), dtype=torch.double, device=device)
# is_open_spline = torch.ByteTensor([1, 0, 1]) weight.requires_grad_()
# root_weight = torch.DoubleTensor(2, 4).uniform_(-1, 1) kernel_size = tensor([5, 5, 5], torch.long, device)
# bias = torch.DoubleTensor(4).uniform_(-1, 1) is_open_spline = tensor([1, 0, 1], torch.uint8, device)
root_weight = torch.rand((2, 4), dtype=torch.double, device=device)
root_weight.requires_grad_()
bias = torch.rand((4), dtype=torch.double, device=device)
bias.requires_grad_()
# src = Variable(src, requires_grad=True) data = (src, edge_index, pseudo, weight, kernel_size, is_open_spline,
# pseudo = Variable(pseudo, requires_grad=True) degree, root_weight, bias)
# weight = Variable(weight, requires_grad=True) assert gradcheck(SplineConv.apply, data, eps=1e-6, atol=1e-4) is True
# root_weight = Variable(root_weight, requires_grad=True)
# bias = Variable(bias, requires_grad=True)
# def op(src, pseudo, weight, root_weight, bias):
# return spline_conv(src, edge_index, pseudo, weight, kernel_size,
# is_open_spline, degree, root_weight, bias)
# data = (src, pseudo, weight, root_weight, bias)
# assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True
...@@ -33,19 +33,17 @@ def test_spline_weighting_forward(test, dtype, device): ...@@ -33,19 +33,17 @@ def test_spline_weighting_forward(test, dtype, device):
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_spline_basis_backward(device): def test_spline_basis_backward(device):
degree = torch.tensor(1)
pseudo = torch.rand((4, 2), dtype=torch.double, device=device) pseudo = torch.rand((4, 2), dtype=torch.double, device=device)
pseudo.requires_grad_() pseudo.requires_grad_()
kernel_size = tensor([5, 5], torch.long, device) kernel_size = tensor([5, 5], torch.long, device)
is_open_spline = tensor([1, 1], torch.uint8, device) is_open_spline = tensor([1, 1], torch.uint8, device)
basis, weight_index = SplineBasis.apply(degree, pseudo, kernel_size, basis, weight_idx = SplineBasis.apply(pseudo, kernel_size, is_open_spline)
is_open_spline)
src = torch.rand((4, 2), dtype=torch.double, device=device) src = torch.rand((4, 2), dtype=torch.double, device=device)
src.requires_grad_() src.requires_grad_()
weight = torch.rand((25, 2, 4), dtype=torch.double, device=device) weight = torch.rand((25, 2, 4), dtype=torch.double, device=device)
weight.requires_grad_() weight.requires_grad_()
data = (src, weight, basis, weight_index) data = (src, weight, basis, weight_idx)
assert gradcheck(SplineWeighting.apply, data, eps=1e-6, atol=1e-4) is True assert gradcheck(SplineWeighting.apply, data, eps=1e-6, atol=1e-4) is True
...@@ -20,17 +20,20 @@ def bw(degree, grad_basis, pseudo, kernel_size, is_open_spline): ...@@ -20,17 +20,20 @@ def bw(degree, grad_basis, pseudo, kernel_size, is_open_spline):
class SplineBasis(Function): class SplineBasis(Function):
@staticmethod @staticmethod
def forward(ctx, degree, pseudo, kernel_size, is_open_spline): def forward(ctx, pseudo, kernel_size, is_open_spline, degree=1):
ctx.save_for_backward(degree, pseudo, kernel_size, is_open_spline) ctx.save_for_backward(pseudo)
return fw(degree.item(), pseudo, kernel_size, is_open_spline) ctx.kernel_size = kernel_size
ctx.is_open_spline = is_open_spline
ctx.degree = degree
return fw(degree, pseudo, kernel_size, is_open_spline)
@staticmethod @staticmethod
def backward(ctx, grad_basis, grad_weight_index): def backward(ctx, grad_basis, grad_weight_index):
degree, pseudo, kernel_size, is_open_spline = ctx.saved_tensors pseudo, = ctx.saved_tensors
grad_pseudo = None grad_pseudo = None
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[0]:
grad_pseudo = bw(degree.item(), grad_basis, pseudo, kernel_size, grad_pseudo = bw(ctx.degree, grad_basis, pseudo, ctx.kernel_size,
is_open_spline) ctx.is_open_spline)
return None, grad_pseudo, None, None return grad_pseudo, None, None, None
import torch import torch
from torch.autograd import Function
from .basis import SplineBasis from .basis import SplineBasis
from .weighting import SplineWeighting from .weighting import SplineWeighting
...@@ -7,7 +6,7 @@ from .weighting import SplineWeighting ...@@ -7,7 +6,7 @@ from .weighting import SplineWeighting
from .utils.degree import degree as node_degree from .utils.degree import degree as node_degree
class SplineConv(Function): class SplineConv(object):
"""Applies the spline-based convolution operator :math:`(f \star g)(i) = """Applies the spline-based convolution operator :math:`(f \star g)(i) =
\frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)} \frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph. f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
...@@ -28,7 +27,7 @@ class SplineConv(Function): ...@@ -28,7 +27,7 @@ class SplineConv(Function):
parameters in each edge dimension. parameters in each edge dimension.
is_open_spline (:class:`ByteTensor`): Whether to use open or closed is_open_spline (:class:`ByteTensor`): Whether to use open or closed
B-spline bases for each dimension. B-spline bases for each dimension.
degree (:class:`Scalar`): B-spline basis degree. degree (int, optional): B-spline basis degree. (default: :obj:`1`)
root_weight (:class:`Tensor`, optional): Additional shared trainable root_weight (:class:`Tensor`, optional): Additional shared trainable
parameters for each feature of the root node of shape parameters for each feature of the root node of shape
(in_channels x out_channels). (default: :obj:`None`) (in_channels x out_channels). (default: :obj:`None`)
...@@ -39,16 +38,15 @@ class SplineConv(Function): ...@@ -39,16 +38,15 @@ class SplineConv(Function):
""" """
@staticmethod @staticmethod
def forward(ctx, def apply(src,
src, edge_index,
edge_index, pseudo,
pseudo, weight,
weight, kernel_size,
kernel_size, is_open_spline,
is_open_spline, degree=1,
degree, root_weight=None,
root_weight=None, bias=None):
bias=None):
src = src.unsqueeze(-1) if src.dim() == 1 else src src = src.unsqueeze(-1) if src.dim() == 1 else src
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
...@@ -57,8 +55,8 @@ class SplineConv(Function): ...@@ -57,8 +55,8 @@ class SplineConv(Function):
n, m_out = src.size(0), weight.size(2) n, m_out = src.size(0), weight.size(2)
# Weight each node. # Weight each node.
b, wi = SplineBasis.apply(degree, pseudo, kernel_size, is_open_spline) data = SplineBasis.apply(pseudo, kernel_size, is_open_spline, degree)
output = SplineWeighting.apply(src[col], weight, b, wi) output = SplineWeighting.apply(src[col], weight, *data)
# Convert e x m_out to n x m_out features. # Convert e x m_out to n x m_out features.
row_expand = row.unsqueeze(-1).expand_as(output) row_expand = row.unsqueeze(-1).expand_as(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment