Commit 80a45518 authored by rusty1s's avatar rusty1s
Browse files

normalize arg

parent 50011f1e
...@@ -41,6 +41,7 @@ out = SplineConv.apply(src, ...@@ -41,6 +41,7 @@ out = SplineConv.apply(src,
kernel_size, kernel_size,
is_open_spline, is_open_spline,
degree=1, degree=1,
norm=True,
root_weight=None, root_weight=None,
bias=None) bias=None)
``` ```
...@@ -66,6 +67,7 @@ The kernel function is defined over the weighted B-spline tensor product basis, ...@@ -66,6 +67,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
* **kernel_size** *(LongTensor)* - Number of trainable weight parameters in each edge dimension. * **kernel_size** *(LongTensor)* - Number of trainable weight parameters in each edge dimension.
* **is_open_spline** *(ByteTensor)* - Whether to use open or closed B-spline bases for each dimension. * **is_open_spline** *(ByteTensor)* - Whether to use open or closed B-spline bases for each dimension.
* **degree** *(int, optional)* - B-spline basis degree. (default: `1`) * **degree** *(int, optional)* - B-spline basis degree. (default: `1`)
* **norm** *(bool, optional)*: Whether to normalize output by node degree. (default: `True`)
* **root_weight** *(Tensor, optional)* - Additional shared trainable parameters for each feature of the root node of shape `(in_channels x out_channels)`. (default: `None`) * **root_weight** *(Tensor, optional)* - Additional shared trainable parameters for each feature of the root node of shape `(in_channels x out_channels)`. (default: `None`)
* **bias** *(Tensor, optional)* - Optional bias of shape `(out_channels)`. (default: `None`) * **bias** *(Tensor, optional)* - Optional bias of shape `(out_channels)`. (default: `None`)
...@@ -86,11 +88,12 @@ weight = torch.rand((25, 2, 4), dtype=torch.float) # 25 parameters for in_chann ...@@ -86,11 +88,12 @@ weight = torch.rand((25, 2, 4), dtype=torch.float) # 25 parameters for in_chann
kernel_size = torch.tensor([5, 5]) # 5 parameters in each edge dimension kernel_size = torch.tensor([5, 5]) # 5 parameters in each edge dimension
is_open_spline = torch.tensor([1, 1], dtype=torch.uint8) # only use open B-splines is_open_spline = torch.tensor([1, 1], dtype=torch.uint8) # only use open B-splines
degree = 1 # B-spline degree of 1 degree = 1 # B-spline degree of 1
norm = True # Normalize output by node degree.
root_weight = torch.rand((2, 4), dtype=torch.float) # separately weight root nodes root_weight = torch.rand((2, 4), dtype=torch.float) # separately weight root nodes
bias = None # do not apply an additional bias bias = None # do not apply an additional bias
out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size, out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, root_weight, bias) is_open_spline, degree, norm root_weight, bias)
print(out.size()) print(out.size())
torch.Size([4, 4]) # 4 nodes with 4 features each torch.Size([4, 4]) # 4 nodes with 4 features each
......
...@@ -2,7 +2,7 @@ from os import path as osp ...@@ -2,7 +2,7 @@ from os import path as osp
from setuptools import setup, find_packages from setuptools import setup, find_packages
__version__ = '1.0.2' __version__ = '1.0.3'
url = 'https://github.com/rusty1s/pytorch_spline_conv' url = 'https://github.com/rusty1s/pytorch_spline_conv'
install_requires = ['cffi'] install_requires = ['cffi']
......
...@@ -31,11 +31,11 @@ tests = [{ ...@@ -31,11 +31,11 @@ tests = [{
'root_weight': [[12.5], [13]], 'root_weight': [[12.5], [13]],
'bias': [1], 'bias': [1],
'expected': [ 'expected': [
1 + (12.5 * 9 + 13 * 10 + 8.5 + 40.5 + 107.5 + 101.5) / 5, [1 + 12.5 * 9 + 13 * 10 + (8.5 + 40.5 + 107.5 + 101.5) / 4],
1 + 12.5 * 1 + 13 * 2, [1 + 12.5 * 1 + 13 * 2],
1 + 12.5 * 3 + 13 * 4, [1 + 12.5 * 3 + 13 * 4],
1 + 12.5 * 5 + 13 * 6, [1 + 12.5 * 5 + 13 * 6],
1 + 12.5 * 7 + 13 * 8, [1 + 12.5 * 7 + 13 * 8],
] ]
}] }]
...@@ -52,9 +52,8 @@ def test_spline_conv_forward(test, dtype, device): ...@@ -52,9 +52,8 @@ def test_spline_conv_forward(test, dtype, device):
bias = tensor(test['bias'], dtype, device) bias = tensor(test['bias'], dtype, device)
out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size, out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, root_weight, bias) is_open_spline, 1, True, root_weight, bias)
assert list(out.size()) == [5, 1] assert out.tolist() == test['expected']
assert pytest.approx(out.view(-1).tolist()) == test['expected']
@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices)) @pytest.mark.parametrize('degree,device', product(degrees.keys(), devices))
...@@ -74,5 +73,5 @@ def test_spline_basis_backward(degree, device): ...@@ -74,5 +73,5 @@ def test_spline_basis_backward(degree, device):
bias.requires_grad_() bias.requires_grad_()
data = (src, edge_index, pseudo, weight, kernel_size, is_open_spline, data = (src, edge_index, pseudo, weight, kernel_size, is_open_spline,
degree, root_weight, bias) degree, True, root_weight, bias)
assert gradcheck(SplineConv.apply, data, eps=1e-6, atol=1e-4) is True assert gradcheck(SplineConv.apply, data, eps=1e-6, atol=1e-4) is True
...@@ -2,6 +2,6 @@ from .basis import SplineBasis ...@@ -2,6 +2,6 @@ from .basis import SplineBasis
from .weighting import SplineWeighting from .weighting import SplineWeighting
from .conv import SplineConv from .conv import SplineConv
__version__ = '1.0.2' __version__ = '1.0.3'
__all__ = ['SplineBasis', 'SplineWeighting', 'SplineConv', '__version__'] __all__ = ['SplineBasis', 'SplineWeighting', 'SplineConv', '__version__']
...@@ -28,6 +28,8 @@ class SplineConv(object): ...@@ -28,6 +28,8 @@ class SplineConv(object):
is_open_spline (:class:`ByteTensor`): Whether to use open or closed is_open_spline (:class:`ByteTensor`): Whether to use open or closed
B-spline bases for each dimension. B-spline bases for each dimension.
degree (int, optional): B-spline basis degree. (default: :obj:`1`) degree (int, optional): B-spline basis degree. (default: :obj:`1`)
norm (bool, optional): Whether to normalize output by node degree.
(default: :obj:`True`)
root_weight (:class:`Tensor`, optional): Additional shared trainable root_weight (:class:`Tensor`, optional): Additional shared trainable
parameters for each feature of the root node of shape parameters for each feature of the root node of shape
(in_channels x out_channels). (default: :obj:`None`) (in_channels x out_channels). (default: :obj:`None`)
...@@ -45,6 +47,7 @@ class SplineConv(object): ...@@ -45,6 +47,7 @@ class SplineConv(object):
kernel_size, kernel_size,
is_open_spline, is_open_spline,
degree=1, degree=1,
norm=True,
root_weight=None, root_weight=None,
bias=None): bias=None):
...@@ -62,15 +65,14 @@ class SplineConv(object): ...@@ -62,15 +65,14 @@ class SplineConv(object):
row_expand = row.unsqueeze(-1).expand_as(out) row_expand = row.unsqueeze(-1).expand_as(out)
out = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, out) out = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, out)
deg = node_degree(row, n, out.dtype, out.device) # Normalize out by node degree (if wished).
if norm:
deg = node_degree(row, n, out.dtype, out.device)
out = out / deg.unsqueeze(-1).clamp(min=1)
# Weight root node separately (if wished). # Weight root node separately (if wished).
if root_weight is not None: if root_weight is not None:
out += torch.mm(src, root_weight) out += torch.mm(src, root_weight)
deg += 1
# Normalize out by node degree.
out /= deg.unsqueeze(-1).clamp(min=1)
# Add bias (if wished). # Add bias (if wished).
if bias is not None: if bias is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment