Commit 7b0ec2c6 authored by rusty1s's avatar rusty1s
Browse files

update torch 0.4.0

parent 8d6a32c4
......@@ -40,7 +40,7 @@ output = SplineConv.apply(src,
weight,
kernel_size,
is_open_spline,
degree,
degree=1,
root_weight=None,
bias=None)
```
......@@ -65,7 +65,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
* **weight** *(Tensor)* - Trainable weight parameters of shape `(kernel_size x in_channels x out_channels)`.
* **kernel_size** *(LongTensor)* - Number of trainable weight parameters in each edge dimension.
* **is_open_spline** *(ByteTensor)* - Whether to use open or closed B-spline bases for each dimension.
* **degree** *(Scalar)* - B-spline basis degree.
* **degree** *(int, optional)* - B-spline basis degree. (default: `1`)
* **root_weight** *(Tensor, optional)* - Additional shared trainable parameters for each feature of the root node of shape `(in_channels x out_channels)`. (default: `None`)
* **bias** *(Tensor, optional)* - Optional bias of shape `(out_channels)`. (default: `None`)
......@@ -85,7 +85,7 @@ pseudo = torch.Tensor(6, 2) # two-dimensional edge attributes
weight = torch.Tensor(25, 2, 4) # 25 trainable parameters for in_channels x out_channels
kernel_size = torch.LongTensor([5, 5]) # 5 trainable parameters in each edge dimension
is_open_spline = torch.ByteTensor([1, 1]) # only use open B-splines
degree = torch.tensor(1) # B-spline degree of 1
degree = 1 # B-spline degree of 1
root_weight = torch.Tensor(2, 4) # separately weight root nodes
bias = None # do not apply an additional bias
......
import os.path as osp
import shutil
import subprocess
import torch
from torch.utils.ffi import create_extension
if osp.exists('build'):
shutil.rmtree('build')
files = ['Basis', 'Weighting']
headers = ['aten/TH/TH{}.h'.format(f) for f in files]
......
......@@ -2,9 +2,7 @@ from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
from torch_spline_conv.basis import SplineBasis
from torch_spline_conv.utils.ffi import implemented_degrees as degrees
from .utils import dtypes, devices, tensor
......@@ -31,24 +29,10 @@ tests = [{
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spline_basis_forward(test, dtype, device):
degree = torch.tensor(1)
pseudo = tensor(test['pseudo'], dtype, device)
kernel_size = tensor(test['kernel_size'], torch.long, device)
is_open_spline = tensor(test['is_open_spline'], torch.uint8, device)
basis, weight_index = SplineBasis.apply(degree, pseudo, kernel_size,
is_open_spline)
basis, weight_idx = SplineBasis.apply(pseudo, kernel_size, is_open_spline)
assert basis.tolist() == test['basis']
assert weight_index.tolist() == test['weight_index']
@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices))
def test_spline_basis_backward(degree, device):
degree = torch.tensor(degree)
pseudo = torch.rand((4, 3), dtype=torch.double, device=device)
pseudo.requires_grad_()
kernel_size = tensor([5, 5, 5], torch.long, device)
is_open_spline = tensor([1, 0, 1], torch.uint8, device)
data = (degree, pseudo, kernel_size, is_open_spline)
# assert gradcheck(SplineBasis.apply, data, eps=1e-6, atol=1e-4) is True
assert weight_idx.tolist() == test['weight_index']
......@@ -3,7 +3,7 @@ from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
from torch_spline_conv import spline_conv
from torch_spline_conv import SplineConv
from torch_spline_conv.utils.ffi import implemented_degrees as degrees
from .utils import dtypes, devices, tensor
......@@ -48,36 +48,30 @@ def test_spline_conv_forward(test, dtype, device):
weight = tensor(test['weight'], dtype, device)
kernel_size = tensor(test['kernel_size'], torch.long, device)
is_open_spline = tensor(test['is_open_spline'], torch.uint8, device)
degree = torch.tensor(1)
root_weight = tensor(test['root_weight'], dtype, device)
bias = tensor(test['bias'], dtype, device)
output = spline_conv(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, root_weight, bias)
output = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, root_weight, bias)
assert output.tolist() == test['output']
@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices))
def test_spline_basis_backward(degree, device):
pass
# src = torch.DoubleTensor(3, 2).uniform_(-1, 1)
# edge_index = torch.LongTensor([[0, 1, 1, 2], [1, 0, 2, 1]])
# pseudo = torch.DoubleTensor(4, 3).uniform_(0, 1)
# weight = torch.DoubleTensor(125, 2, 4).uniform_(-1, 1)
# kernel_size = torch.LongTensor([5, 5, 5])
# is_open_spline = torch.ByteTensor([1, 0, 1])
# root_weight = torch.DoubleTensor(2, 4).uniform_(-1, 1)
# bias = torch.DoubleTensor(4).uniform_(-1, 1)
src = torch.rand((3, 2), dtype=torch.double, device=device)
src.requires_grad_()
edge_index = tensor([[0, 1, 1, 2], [1, 0, 2, 1]], torch.long, device)
pseudo = torch.rand((4, 3), dtype=torch.double, device=device)
pseudo.requires_grad_()
weight = torch.rand((125, 2, 4), dtype=torch.double, device=device)
weight.requires_grad_()
kernel_size = tensor([5, 5, 5], torch.long, device)
is_open_spline = tensor([1, 0, 1], torch.uint8, device)
root_weight = torch.rand((2, 4), dtype=torch.double, device=device)
root_weight.requires_grad_()
bias = torch.rand((4), dtype=torch.double, device=device)
bias.requires_grad_()
# src = Variable(src, requires_grad=True)
# pseudo = Variable(pseudo, requires_grad=True)
# weight = Variable(weight, requires_grad=True)
# root_weight = Variable(root_weight, requires_grad=True)
# bias = Variable(bias, requires_grad=True)
# def op(src, pseudo, weight, root_weight, bias):
# return spline_conv(src, edge_index, pseudo, weight, kernel_size,
# is_open_spline, degree, root_weight, bias)
# data = (src, pseudo, weight, root_weight, bias)
# assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True
data = (src, edge_index, pseudo, weight, kernel_size, is_open_spline,
degree, root_weight, bias)
assert gradcheck(SplineConv.apply, data, eps=1e-6, atol=1e-4) is True
......@@ -33,19 +33,17 @@ def test_spline_weighting_forward(test, dtype, device):
@pytest.mark.parametrize('device', devices)
def test_spline_basis_backward(device):
degree = torch.tensor(1)
pseudo = torch.rand((4, 2), dtype=torch.double, device=device)
pseudo.requires_grad_()
kernel_size = tensor([5, 5], torch.long, device)
is_open_spline = tensor([1, 1], torch.uint8, device)
basis, weight_index = SplineBasis.apply(degree, pseudo, kernel_size,
is_open_spline)
basis, weight_idx = SplineBasis.apply(pseudo, kernel_size, is_open_spline)
src = torch.rand((4, 2), dtype=torch.double, device=device)
src.requires_grad_()
weight = torch.rand((25, 2, 4), dtype=torch.double, device=device)
weight.requires_grad_()
data = (src, weight, basis, weight_index)
data = (src, weight, basis, weight_idx)
assert gradcheck(SplineWeighting.apply, data, eps=1e-6, atol=1e-4) is True
......@@ -20,17 +20,20 @@ def bw(degree, grad_basis, pseudo, kernel_size, is_open_spline):
class SplineBasis(Function):
@staticmethod
def forward(ctx, degree, pseudo, kernel_size, is_open_spline):
ctx.save_for_backward(degree, pseudo, kernel_size, is_open_spline)
return fw(degree.item(), pseudo, kernel_size, is_open_spline)
def forward(ctx, pseudo, kernel_size, is_open_spline, degree=1):
ctx.save_for_backward(pseudo)
ctx.kernel_size = kernel_size
ctx.is_open_spline = is_open_spline
ctx.degree = degree
return fw(degree, pseudo, kernel_size, is_open_spline)
@staticmethod
def backward(ctx, grad_basis, grad_weight_index):
degree, pseudo, kernel_size, is_open_spline = ctx.saved_tensors
pseudo, = ctx.saved_tensors
grad_pseudo = None
if ctx.needs_input_grad[1]:
grad_pseudo = bw(degree.item(), grad_basis, pseudo, kernel_size,
is_open_spline)
if ctx.needs_input_grad[0]:
grad_pseudo = bw(ctx.degree, grad_basis, pseudo, ctx.kernel_size,
ctx.is_open_spline)
return None, grad_pseudo, None, None
return grad_pseudo, None, None, None
import torch
from torch.autograd import Function
from .basis import SplineBasis
from .weighting import SplineWeighting
......@@ -7,7 +6,7 @@ from .weighting import SplineWeighting
from .utils.degree import degree as node_degree
class SplineConv(Function):
class SplineConv(object):
"""Applies the spline-based convolution operator :math:`(f \star g)(i) =
\frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
......@@ -28,7 +27,7 @@ class SplineConv(Function):
parameters in each edge dimension.
is_open_spline (:class:`ByteTensor`): Whether to use open or closed
B-spline bases for each dimension.
degree (:class:`Scalar`): B-spline basis degree.
degree (int, optional): B-spline basis degree. (default: :obj:`1`)
root_weight (:class:`Tensor`, optional): Additional shared trainable
parameters for each feature of the root node of shape
(in_channels x out_channels). (default: :obj:`None`)
......@@ -39,16 +38,15 @@ class SplineConv(Function):
"""
@staticmethod
def forward(ctx,
src,
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
degree,
root_weight=None,
bias=None):
def apply(src,
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
degree=1,
root_weight=None,
bias=None):
src = src.unsqueeze(-1) if src.dim() == 1 else src
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
......@@ -57,8 +55,8 @@ class SplineConv(Function):
n, m_out = src.size(0), weight.size(2)
# Weight each node.
b, wi = SplineBasis.apply(degree, pseudo, kernel_size, is_open_spline)
output = SplineWeighting.apply(src[col], weight, b, wi)
data = SplineBasis.apply(pseudo, kernel_size, is_open_spline, degree)
output = SplineWeighting.apply(src[col], weight, *data)
# Convert e x m_out to n x m_out features.
row_expand = row.unsqueeze(-1).expand_as(output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment