Commit 781766e4 authored by rusty1s's avatar rusty1s
Browse files

rename

parent 74199575
......@@ -34,7 +34,7 @@ pip install cffi torch-spline-conv
```python
from torch_spline_conv import SplineConv
output = SplineConv.apply(src,
out = SplineConv.apply(src,
edge_index,
pseudo,
weight,
......@@ -71,7 +71,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
### Returns
* **output** *(Tensor)* - Output node features of shape `(number_of_nodes x out_channels)`.
* **out** *(Tensor)* - out node features of shape `(number_of_nodes x out_channels)`.
### Example
......@@ -89,10 +89,10 @@ degree = 1 # B-spline degree of 1
root_weight = torch.Tensor(2, 4) # separately weight root nodes
bias = None # do not apply an additional bias
output = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, root_weight, bias)
print(output.size())
print(out.size())
torch.Size([4, 4]) # 4 nodes with 4 features each
```
......
......@@ -30,7 +30,7 @@ tests = [{
'is_open_spline': [1, 0],
'root_weight': [[12.5], [13]],
'bias': [1],
'output': [
'expected': [
[1 + 12.5 * 9 + 13 * 10 + (8.5 + 40.5 + 107.5 + 101.5) / 4],
[1 + 12.5 * 1 + 13 * 2],
[1 + 12.5 * 3 + 13 * 4],
......@@ -51,9 +51,9 @@ def test_spline_conv_forward(test, dtype, device):
root_weight = tensor(test['root_weight'], dtype, device)
bias = tensor(test['bias'], dtype, device)
output = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, root_weight, bias)
assert output.tolist() == test['output']
out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, root_weight, bias)
assert out.tolist() == test['expected']
@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices))
......
......@@ -13,7 +13,7 @@ tests = [{
'weight': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]],
'basis': [[0.5, 0, 0.5, 0], [0, 0, 0.5, 0.5]],
'weight_index': [[0, 1, 2, 3], [0, 1, 2, 3]],
'output': [
'expected': [
[0.5 * ((1 * (1 + 5)) + (2 * (2 + 6)))],
[0.5 * ((3 * (5 + 7)) + (4 * (6 + 8)))],
]
......@@ -27,8 +27,8 @@ def test_spline_weighting_forward(test, dtype, device):
basis = tensor(test['basis'], dtype, device)
weight_index = tensor(test['weight_index'], torch.long, device)
output = SplineWeighting.apply(src, weight, basis, weight_index)
assert output.tolist() == test['output']
out = SplineWeighting.apply(src, weight, basis, weight_index)
assert out.tolist() == test['expected']
@pytest.mark.parametrize('device', devices)
......
......@@ -56,22 +56,22 @@ class SplineConv(object):
# Weight each node.
data = SplineBasis.apply(pseudo, kernel_size, is_open_spline, degree)
output = SplineWeighting.apply(src[col], weight, *data)
out = SplineWeighting.apply(src[col], weight, *data)
# Convert e x m_out to n x m_out features.
row_expand = row.unsqueeze(-1).expand_as(output)
output = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, output)
row_expand = row.unsqueeze(-1).expand_as(out)
out = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, out)
# Normalize output by node degree.
deg = node_degree(row, n, output.dtype, output.device)
output /= deg.unsqueeze(-1).clamp(min=1)
# Normalize out by node degree.
deg = node_degree(row, n, out.dtype, out.device)
out /= deg.unsqueeze(-1).clamp(min=1)
# Weight root node separately (if wished).
if root_weight is not None:
output += torch.mm(src, root_weight)
out += torch.mm(src, root_weight)
# Add bias (if wished).
if bias is not None:
output += bias
out += bias
return output
return out
......@@ -33,16 +33,16 @@ def fw_weighting(self, src, weight, basis, weight_index):
func(self, src, weight, basis, weight_index)
def bw_weighting_src(self, grad_output, weight, basis, weight_index):
def bw_weighting_src(self, grad_out, weight, basis, weight_index):
func = get_func('weightingBackwardSrc', self)
func(self, grad_output, weight, basis, weight_index)
func(self, grad_out, weight, basis, weight_index)
def bw_weighting_weight(self, grad_output, src, basis, weight_index):
def bw_weighting_weight(self, grad_out, src, basis, weight_index):
func = get_func('weightingBackwardWeight', self)
func(self, grad_output, src, basis, weight_index)
func(self, grad_out, src, basis, weight_index)
def bw_weighting_basis(self, grad_output, src, weight, weight_index):
def bw_weighting_basis(self, grad_out, src, weight, weight_index):
func = get_func('weightingBackwardBasis', self)
func(self, grad_output, src, weight, weight_index)
func(self, grad_out, src, weight, weight_index)
......@@ -5,26 +5,26 @@ from .utils.ffi import bw_weighting_weight, bw_weighting_basis
def fw(src, weight, basis, weight_index):
output = src.new_empty((src.size(0), weight.size(2)))
fw_weighting(output, src, weight, basis, weight_index)
return output
out = src.new_empty((src.size(0), weight.size(2)))
fw_weighting(out, src, weight, basis, weight_index)
return out
def bw_src(grad_output, weight, basis, weight_index):
grad_src = grad_output.new_empty((grad_output.size(0), weight.size(1)))
bw_weighting_src(grad_src, grad_output, weight, basis, weight_index)
def bw_src(grad_out, weight, basis, weight_index):
grad_src = grad_out.new_empty((grad_out.size(0), weight.size(1)))
bw_weighting_src(grad_src, grad_out, weight, basis, weight_index)
return grad_src
def bw_weight(grad_output, src, basis, weight_index, K):
grad_weight = src.new_empty((K, src.size(1), grad_output.size(1)))
bw_weighting_weight(grad_weight, grad_output, src, basis, weight_index)
def bw_weight(grad_out, src, basis, weight_index, K):
grad_weight = src.new_empty((K, src.size(1), grad_out.size(1)))
bw_weighting_weight(grad_weight, grad_out, src, basis, weight_index)
return grad_weight
def bw_basis(grad_output, src, weight, weight_index):
def bw_basis(grad_out, src, weight, weight_index):
grad_basis = src.new_empty(weight_index.size())
bw_weighting_basis(grad_basis, grad_output, src, weight, weight_index)
bw_weighting_basis(grad_basis, grad_out, src, weight, weight_index)
return grad_basis
......@@ -35,18 +35,18 @@ class SplineWeighting(Function):
return fw(src, weight, basis, weight_index)
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
def backward(ctx, grad_out): # pragma: no cover
grad_src = grad_weight = grad_basis = None
src, weight, basis, weight_index = ctx.saved_tensors
if ctx.needs_input_grad[0]:
grad_src = bw_src(grad_output, weight, basis, weight_index)
grad_src = bw_src(grad_out, weight, basis, weight_index)
if ctx.needs_input_grad[1]:
K = weight.size(0)
grad_weight = bw_weight(grad_output, src, basis, weight_index, K)
grad_weight = bw_weight(grad_out, src, basis, weight_index, K)
if ctx.needs_input_grad[2]:
grad_basis = bw_basis(grad_output, src, weight, weight_index)
grad_basis = bw_basis(grad_out, src, weight, weight_index)
return grad_src, grad_weight, grad_basis, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment