Commit e6632e3f authored by rusty1s's avatar rusty1s
Browse files

to function

parent 6729e5b9
......@@ -32,10 +32,11 @@ pip install cffi torch-spline-conv
## Usage
```python
from torch_spline_conv import spline_conv
from torch_spline_conv import SplineConv
output = spline_conv(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree=1, root_weight=None, bias=None)
output = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree=1, root_weight=None,
bias=None)
```
Applies the spline-based convolution operator
......@@ -70,7 +71,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
```python
import torch
from torch_spline_conv import spline_conv
from torch_spline_conv import SplineConv
src = torch.Tensor(4, 2) # 4 nodes with 2 features each
edge_index = torch.LongTensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) # 6 edges
......@@ -82,8 +83,8 @@ degree = torch.tensor(1) # B-spline degree of 1
root_weight = torch.Tensor(2, 4) # separately weight root nodes
bias = None # do not apply an additional bias
output = spline_conv(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, root_weight, bias)
output = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, root_weight, bias)
print(output.size())
torch.Size([4, 4]) # 4 nodes with 4 features each
......
from .conv import spline_conv
from .conv import SplineConv
__version__ = '0.1.0'
__all__ = ['spline_conv', '__version__']
__all__ = ['SplineConv', '__version__']
import torch
from torch.autograd import Function
from .basis import SplineBasis
from .weighting import SplineWeighting
......@@ -6,15 +7,7 @@ from .weighting import SplineWeighting
from .utils.degree import degree as node_degree
def spline_conv(src,
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
degree,
root_weight=None,
bias=None):
class SplineConv(Function):
"""Applies the spline-based convolution operator :math:`(f \star g)(i) =
\frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
......@@ -45,31 +38,42 @@ def spline_conv(src,
:rtype: :class:`Tensor`
"""
src = src.unsqueeze(-1) if src.dim() == 1 else src
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
@staticmethod
def forward(ctx,
src,
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
degree,
root_weight=None,
bias=None):
src = src.unsqueeze(-1) if src.dim() == 1 else src
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
row, col = edge_index
n, m_out = src.size(0), weight.size(2)
row, col = edge_index
n, m_out = src.size(0), weight.size(2)
# Weight each node.
basis, weight_index = SplineBasis.apply(degree, pseudo, kernel_size,
is_open_spline)
output = SplineWeighting.apply(src[col], weight, basis, weight_index)
# Weight each node.
b, wi = SplineBasis.apply(degree, pseudo, kernel_size, is_open_spline)
output = SplineWeighting.apply(src[col], weight, b, wi)
# Perform the real convolution => Convert e x m_out to n x m_out features.
row_expand = row.unsqueeze(-1).expand_as(output)
output = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, output)
# Convert e x m_out to n x m_out features.
row_expand = row.unsqueeze(-1).expand_as(output)
output = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, output)
# Normalize output by node degree.
deg = node_degree(row, n, out=src.new_empty(()))
output /= deg.unsqueeze(-1).clamp(min=1)
# Normalize output by node degree.
deg = node_degree(row, n, out=src.new_empty(()))
output /= deg.unsqueeze(-1).clamp(min=1)
# Weight root node separately (if wished).
if root_weight is not None:
output += torch.mm(src, root_weight)
# Weight root node separately (if wished).
if root_weight is not None:
output += torch.mm(src, root_weight)
# Add bias (if wished).
if bias is not None:
output += bias
# Add bias (if wished).
if bias is not None:
output += bias
return output
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment