Commit 80a45518 authored by rusty1s's avatar rusty1s
Browse files

normalize arg

parent 50011f1e
......@@ -41,6 +41,7 @@ out = SplineConv.apply(src,
kernel_size,
is_open_spline,
degree=1,
norm=True,
root_weight=None,
bias=None)
```
......@@ -66,6 +67,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
* **kernel_size** *(LongTensor)* - Number of trainable weight parameters in each edge dimension.
* **is_open_spline** *(ByteTensor)* - Whether to use open or closed B-spline bases for each dimension.
* **degree** *(int, optional)* - B-spline basis degree. (default: `1`)
* **norm** *(bool, optional)*: Whether to normalize output by node degree. (default: `True`)
* **root_weight** *(Tensor, optional)* - Additional shared trainable parameters for each feature of the root node of shape `(in_channels x out_channels)`. (default: `None`)
* **bias** *(Tensor, optional)* - Optional bias of shape `(out_channels)`. (default: `None`)
......@@ -86,11 +88,12 @@ weight = torch.rand((25, 2, 4), dtype=torch.float) # 25 parameters for in_chann
kernel_size = torch.tensor([5, 5]) # 5 parameters in each edge dimension
is_open_spline = torch.tensor([1, 1], dtype=torch.uint8) # only use open B-splines
degree = 1 # B-spline degree of 1
norm = True # Normalize output by node degree.
root_weight = torch.rand((2, 4), dtype=torch.float) # separately weight root nodes
bias = None # do not apply an additional bias
out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, root_weight, bias)
is_open_spline, degree, norm root_weight, bias)
print(out.size())
torch.Size([4, 4]) # 4 nodes with 4 features each
......
......@@ -2,7 +2,7 @@ from os import path as osp
from setuptools import setup, find_packages
__version__ = '1.0.2'
__version__ = '1.0.3'
url = 'https://github.com/rusty1s/pytorch_spline_conv'
install_requires = ['cffi']
......
......@@ -31,11 +31,11 @@ tests = [{
'root_weight': [[12.5], [13]],
'bias': [1],
'expected': [
1 + (12.5 * 9 + 13 * 10 + 8.5 + 40.5 + 107.5 + 101.5) / 5,
1 + 12.5 * 1 + 13 * 2,
1 + 12.5 * 3 + 13 * 4,
1 + 12.5 * 5 + 13 * 6,
1 + 12.5 * 7 + 13 * 8,
[1 + 12.5 * 9 + 13 * 10 + (8.5 + 40.5 + 107.5 + 101.5) / 4],
[1 + 12.5 * 1 + 13 * 2],
[1 + 12.5 * 3 + 13 * 4],
[1 + 12.5 * 5 + 13 * 6],
[1 + 12.5 * 7 + 13 * 8],
]
}]
......@@ -52,9 +52,8 @@ def test_spline_conv_forward(test, dtype, device):
bias = tensor(test['bias'], dtype, device)
out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, root_weight, bias)
assert list(out.size()) == [5, 1]
assert pytest.approx(out.view(-1).tolist()) == test['expected']
is_open_spline, 1, True, root_weight, bias)
assert out.tolist() == test['expected']
@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices))
......@@ -74,5 +73,5 @@ def test_spline_basis_backward(degree, device):
bias.requires_grad_()
data = (src, edge_index, pseudo, weight, kernel_size, is_open_spline,
degree, root_weight, bias)
degree, True, root_weight, bias)
assert gradcheck(SplineConv.apply, data, eps=1e-6, atol=1e-4) is True
......@@ -2,6 +2,6 @@ from .basis import SplineBasis
from .weighting import SplineWeighting
from .conv import SplineConv
__version__ = '1.0.2'
__version__ = '1.0.3'
__all__ = ['SplineBasis', 'SplineWeighting', 'SplineConv', '__version__']
......@@ -28,6 +28,8 @@ class SplineConv(object):
is_open_spline (:class:`ByteTensor`): Whether to use open or closed
B-spline bases for each dimension.
degree (int, optional): B-spline basis degree. (default: :obj:`1`)
norm (bool, optional): Whether to normalize output by node degree.
(default: :obj:`True`)
root_weight (:class:`Tensor`, optional): Additional shared trainable
parameters for each feature of the root node of shape
(in_channels x out_channels). (default: :obj:`None`)
......@@ -45,6 +47,7 @@ class SplineConv(object):
kernel_size,
is_open_spline,
degree=1,
norm=True,
root_weight=None,
bias=None):
......@@ -62,15 +65,14 @@ class SplineConv(object):
row_expand = row.unsqueeze(-1).expand_as(out)
out = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, out)
deg = node_degree(row, n, out.dtype, out.device)
# Normalize out by node degree (if wished).
if norm:
deg = node_degree(row, n, out.dtype, out.device)
out = out / deg.unsqueeze(-1).clamp(min=1)
# Weight root node separately (if wished).
if root_weight is not None:
out += torch.mm(src, root_weight)
deg += 1
# Normalize out by node degree.
out /= deg.unsqueeze(-1).clamp(min=1)
# Add bias (if wished).
if bias is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment