"vscode:/vscode.git/clone" did not exist on "ca675ed41d58d63ae41d7835c5dcb07677a54ad1"
Commit 781766e4 authored by rusty1s's avatar rusty1s
Browse files

rename

parent 74199575
...@@ -34,7 +34,7 @@ pip install cffi torch-spline-conv ...@@ -34,7 +34,7 @@ pip install cffi torch-spline-conv
```python ```python
from torch_spline_conv import SplineConv from torch_spline_conv import SplineConv
output = SplineConv.apply(src, out = SplineConv.apply(src,
edge_index, edge_index,
pseudo, pseudo,
weight, weight,
...@@ -71,7 +71,7 @@ The kernel function is defined over the weighted B-spline tensor product basis, ...@@ -71,7 +71,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
### Returns ### Returns
* **output** *(Tensor)* - Output node features of shape `(number_of_nodes x out_channels)`. * **out** *(Tensor)* - out node features of shape `(number_of_nodes x out_channels)`.
### Example ### Example
...@@ -89,10 +89,10 @@ degree = 1 # B-spline degree of 1 ...@@ -89,10 +89,10 @@ degree = 1 # B-spline degree of 1
root_weight = torch.Tensor(2, 4) # separately weight root nodes root_weight = torch.Tensor(2, 4) # separately weight root nodes
bias = None # do not apply an additional bias bias = None # do not apply an additional bias
output = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size, out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, root_weight, bias) is_open_spline, degree, root_weight, bias)
print(output.size()) print(out.size())
torch.Size([4, 4]) # 4 nodes with 4 features each torch.Size([4, 4]) # 4 nodes with 4 features each
``` ```
......
...@@ -30,7 +30,7 @@ tests = [{ ...@@ -30,7 +30,7 @@ tests = [{
'is_open_spline': [1, 0], 'is_open_spline': [1, 0],
'root_weight': [[12.5], [13]], 'root_weight': [[12.5], [13]],
'bias': [1], 'bias': [1],
'output': [ 'expected': [
[1 + 12.5 * 9 + 13 * 10 + (8.5 + 40.5 + 107.5 + 101.5) / 4], [1 + 12.5 * 9 + 13 * 10 + (8.5 + 40.5 + 107.5 + 101.5) / 4],
[1 + 12.5 * 1 + 13 * 2], [1 + 12.5 * 1 + 13 * 2],
[1 + 12.5 * 3 + 13 * 4], [1 + 12.5 * 3 + 13 * 4],
...@@ -51,9 +51,9 @@ def test_spline_conv_forward(test, dtype, device): ...@@ -51,9 +51,9 @@ def test_spline_conv_forward(test, dtype, device):
root_weight = tensor(test['root_weight'], dtype, device) root_weight = tensor(test['root_weight'], dtype, device)
bias = tensor(test['bias'], dtype, device) bias = tensor(test['bias'], dtype, device)
output = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size, out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, root_weight, bias) is_open_spline, 1, root_weight, bias)
assert output.tolist() == test['output'] assert out.tolist() == test['expected']
@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices)) @pytest.mark.parametrize('degree,device', product(degrees.keys(), devices))
......
...@@ -13,7 +13,7 @@ tests = [{ ...@@ -13,7 +13,7 @@ tests = [{
'weight': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]], 'weight': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]],
'basis': [[0.5, 0, 0.5, 0], [0, 0, 0.5, 0.5]], 'basis': [[0.5, 0, 0.5, 0], [0, 0, 0.5, 0.5]],
'weight_index': [[0, 1, 2, 3], [0, 1, 2, 3]], 'weight_index': [[0, 1, 2, 3], [0, 1, 2, 3]],
'output': [ 'expected': [
[0.5 * ((1 * (1 + 5)) + (2 * (2 + 6)))], [0.5 * ((1 * (1 + 5)) + (2 * (2 + 6)))],
[0.5 * ((3 * (5 + 7)) + (4 * (6 + 8)))], [0.5 * ((3 * (5 + 7)) + (4 * (6 + 8)))],
] ]
...@@ -27,8 +27,8 @@ def test_spline_weighting_forward(test, dtype, device): ...@@ -27,8 +27,8 @@ def test_spline_weighting_forward(test, dtype, device):
basis = tensor(test['basis'], dtype, device) basis = tensor(test['basis'], dtype, device)
weight_index = tensor(test['weight_index'], torch.long, device) weight_index = tensor(test['weight_index'], torch.long, device)
output = SplineWeighting.apply(src, weight, basis, weight_index) out = SplineWeighting.apply(src, weight, basis, weight_index)
assert output.tolist() == test['output'] assert out.tolist() == test['expected']
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
......
...@@ -56,22 +56,22 @@ class SplineConv(object): ...@@ -56,22 +56,22 @@ class SplineConv(object):
# Weight each node. # Weight each node.
data = SplineBasis.apply(pseudo, kernel_size, is_open_spline, degree) data = SplineBasis.apply(pseudo, kernel_size, is_open_spline, degree)
output = SplineWeighting.apply(src[col], weight, *data) out = SplineWeighting.apply(src[col], weight, *data)
# Convert e x m_out to n x m_out features. # Convert e x m_out to n x m_out features.
row_expand = row.unsqueeze(-1).expand_as(output) row_expand = row.unsqueeze(-1).expand_as(out)
output = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, output) out = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, out)
# Normalize output by node degree. # Normalize out by node degree.
deg = node_degree(row, n, output.dtype, output.device) deg = node_degree(row, n, out.dtype, out.device)
output /= deg.unsqueeze(-1).clamp(min=1) out /= deg.unsqueeze(-1).clamp(min=1)
# Weight root node separately (if wished). # Weight root node separately (if wished).
if root_weight is not None: if root_weight is not None:
output += torch.mm(src, root_weight) out += torch.mm(src, root_weight)
# Add bias (if wished). # Add bias (if wished).
if bias is not None: if bias is not None:
output += bias out += bias
return output return out
...@@ -33,16 +33,16 @@ def fw_weighting(self, src, weight, basis, weight_index): ...@@ -33,16 +33,16 @@ def fw_weighting(self, src, weight, basis, weight_index):
func(self, src, weight, basis, weight_index) func(self, src, weight, basis, weight_index)
def bw_weighting_src(self, grad_output, weight, basis, weight_index): def bw_weighting_src(self, grad_out, weight, basis, weight_index):
func = get_func('weightingBackwardSrc', self) func = get_func('weightingBackwardSrc', self)
func(self, grad_output, weight, basis, weight_index) func(self, grad_out, weight, basis, weight_index)
def bw_weighting_weight(self, grad_output, src, basis, weight_index): def bw_weighting_weight(self, grad_out, src, basis, weight_index):
func = get_func('weightingBackwardWeight', self) func = get_func('weightingBackwardWeight', self)
func(self, grad_output, src, basis, weight_index) func(self, grad_out, src, basis, weight_index)
def bw_weighting_basis(self, grad_output, src, weight, weight_index): def bw_weighting_basis(self, grad_out, src, weight, weight_index):
func = get_func('weightingBackwardBasis', self) func = get_func('weightingBackwardBasis', self)
func(self, grad_output, src, weight, weight_index) func(self, grad_out, src, weight, weight_index)
...@@ -5,26 +5,26 @@ from .utils.ffi import bw_weighting_weight, bw_weighting_basis ...@@ -5,26 +5,26 @@ from .utils.ffi import bw_weighting_weight, bw_weighting_basis
def fw(src, weight, basis, weight_index): def fw(src, weight, basis, weight_index):
output = src.new_empty((src.size(0), weight.size(2))) out = src.new_empty((src.size(0), weight.size(2)))
fw_weighting(output, src, weight, basis, weight_index) fw_weighting(out, src, weight, basis, weight_index)
return output return out
def bw_src(grad_output, weight, basis, weight_index): def bw_src(grad_out, weight, basis, weight_index):
grad_src = grad_output.new_empty((grad_output.size(0), weight.size(1))) grad_src = grad_out.new_empty((grad_out.size(0), weight.size(1)))
bw_weighting_src(grad_src, grad_output, weight, basis, weight_index) bw_weighting_src(grad_src, grad_out, weight, basis, weight_index)
return grad_src return grad_src
def bw_weight(grad_output, src, basis, weight_index, K): def bw_weight(grad_out, src, basis, weight_index, K):
grad_weight = src.new_empty((K, src.size(1), grad_output.size(1))) grad_weight = src.new_empty((K, src.size(1), grad_out.size(1)))
bw_weighting_weight(grad_weight, grad_output, src, basis, weight_index) bw_weighting_weight(grad_weight, grad_out, src, basis, weight_index)
return grad_weight return grad_weight
def bw_basis(grad_output, src, weight, weight_index): def bw_basis(grad_out, src, weight, weight_index):
grad_basis = src.new_empty(weight_index.size()) grad_basis = src.new_empty(weight_index.size())
bw_weighting_basis(grad_basis, grad_output, src, weight, weight_index) bw_weighting_basis(grad_basis, grad_out, src, weight, weight_index)
return grad_basis return grad_basis
...@@ -35,18 +35,18 @@ class SplineWeighting(Function): ...@@ -35,18 +35,18 @@ class SplineWeighting(Function):
return fw(src, weight, basis, weight_index) return fw(src, weight, basis, weight_index)
@staticmethod @staticmethod
def backward(ctx, grad_output): # pragma: no cover def backward(ctx, grad_out): # pragma: no cover
grad_src = grad_weight = grad_basis = None grad_src = grad_weight = grad_basis = None
src, weight, basis, weight_index = ctx.saved_tensors src, weight, basis, weight_index = ctx.saved_tensors
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_src = bw_src(grad_output, weight, basis, weight_index) grad_src = bw_src(grad_out, weight, basis, weight_index)
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
K = weight.size(0) K = weight.size(0)
grad_weight = bw_weight(grad_output, src, basis, weight_index, K) grad_weight = bw_weight(grad_out, src, basis, weight_index, K)
if ctx.needs_input_grad[2]: if ctx.needs_input_grad[2]:
grad_basis = bw_basis(grad_output, src, weight, weight_index) grad_basis = bw_basis(grad_out, src, weight, weight_index)
return grad_src, grad_weight, grad_basis, None return grad_src, grad_weight, grad_basis, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment