Commit e6632e3f authored by rusty1s's avatar rusty1s
Browse files

to function

parent 6729e5b9
...@@ -32,10 +32,11 @@ pip install cffi torch-spline-conv ...@@ -32,10 +32,11 @@ pip install cffi torch-spline-conv
## Usage ## Usage
```python ```python
from torch_spline_conv import spline_conv from torch_spline_conv import SplineConv
output = spline_conv(src, edge_index, pseudo, weight, kernel_size, output = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree=1, root_weight=None, bias=None) is_open_spline, degree=1, root_weight=None,
bias=None)
``` ```
Applies the spline-based convolution operator Applies the spline-based convolution operator
...@@ -70,7 +71,7 @@ The kernel function is defined over the weighted B-spline tensor product basis, ...@@ -70,7 +71,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
```python ```python
import torch import torch
from torch_spline_conv import spline_conv from torch_spline_conv import SplineConv
src = torch.Tensor(4, 2) # 4 nodes with 2 features each src = torch.Tensor(4, 2) # 4 nodes with 2 features each
edge_index = torch.LongTensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) # 6 edges edge_index = torch.LongTensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) # 6 edges
...@@ -82,8 +83,8 @@ degree = torch.tensor(1) # B-spline degree of 1 ...@@ -82,8 +83,8 @@ degree = torch.tensor(1) # B-spline degree of 1
root_weight = torch.Tensor(2, 4) # separately weight root nodes root_weight = torch.Tensor(2, 4) # separately weight root nodes
bias = None # do not apply an additional bias bias = None # do not apply an additional bias
output = spline_conv(src, edge_index, pseudo, weight, kernel_size, output = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, root_weight, bias) is_open_spline, degree, root_weight, bias)
print(output.size()) print(output.size())
torch.Size([4, 4]) # 4 nodes with 4 features each torch.Size([4, 4]) # 4 nodes with 4 features each
......
from .conv import spline_conv from .conv import SplineConv
__version__ = '0.1.0' __version__ = '0.1.0'
__all__ = ['spline_conv', '__version__'] __all__ = ['SplineConv', '__version__']
import torch import torch
from torch.autograd import Function
from .basis import SplineBasis from .basis import SplineBasis
from .weighting import SplineWeighting from .weighting import SplineWeighting
...@@ -6,15 +7,7 @@ from .weighting import SplineWeighting ...@@ -6,15 +7,7 @@ from .weighting import SplineWeighting
from .utils.degree import degree as node_degree from .utils.degree import degree as node_degree
def spline_conv(src, class SplineConv(Function):
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
degree,
root_weight=None,
bias=None):
"""Applies the spline-based convolution operator :math:`(f \star g)(i) = """Applies the spline-based convolution operator :math:`(f \star g)(i) =
\frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)} \frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph. f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
...@@ -45,31 +38,42 @@ def spline_conv(src, ...@@ -45,31 +38,42 @@ def spline_conv(src,
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
src = src.unsqueeze(-1) if src.dim() == 1 else src @staticmethod
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo def forward(ctx,
src,
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
degree,
root_weight=None,
bias=None):
src = src.unsqueeze(-1) if src.dim() == 1 else src
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
row, col = edge_index row, col = edge_index
n, m_out = src.size(0), weight.size(2) n, m_out = src.size(0), weight.size(2)
# Weight each node. # Weight each node.
basis, weight_index = SplineBasis.apply(degree, pseudo, kernel_size, b, wi = SplineBasis.apply(degree, pseudo, kernel_size, is_open_spline)
is_open_spline) output = SplineWeighting.apply(src[col], weight, b, wi)
output = SplineWeighting.apply(src[col], weight, basis, weight_index)
# Perform the real convolution => Convert e x m_out to n x m_out features. # Convert e x m_out to n x m_out features.
row_expand = row.unsqueeze(-1).expand_as(output) row_expand = row.unsqueeze(-1).expand_as(output)
output = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, output) output = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, output)
# Normalize output by node degree. # Normalize output by node degree.
deg = node_degree(row, n, out=src.new_empty(())) deg = node_degree(row, n, out=src.new_empty(()))
output /= deg.unsqueeze(-1).clamp(min=1) output /= deg.unsqueeze(-1).clamp(min=1)
# Weight root node separately (if wished). # Weight root node separately (if wished).
if root_weight is not None: if root_weight is not None:
output += torch.mm(src, root_weight) output += torch.mm(src, root_weight)
# Add bias (if wished). # Add bias (if wished).
if bias is not None: if bias is not None:
output += bias output += bias
return output return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment