Commit c057dcd7 authored by rusty1s's avatar rusty1s
Browse files

typo

parent 9ea57e71
......@@ -68,21 +68,21 @@ The kernel function *g* is defined over the weighted B-spline tensor product bas
import torch
from torch_spline_conv import spline_conv
src = torch.Tensor(4, 2) # 4 nodes with 2 features
src = torch.Tensor(4, 2) # 4 nodes with 2 features each
edge_index = torch.LongTensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) # 6 edges
pseudo = torch.Tensor(6, 2) # two-dimensional edge attributes
weight = torch.Tensor(25, 2, 4) # 25 trainable parameters for each in_channels x out_channels combination
kernel_size = torch.LongTensor([5, 5]) # 5 trainable parameters in each edge dimension
is_open_spline = torch.ByteTensor([1, 1]) # only use open B-splines
degree = 1 # B-spline degree of 1
root_weight = torch.Tensor(2, 4) # Weight root nodes separatly
root_weight = torch.Tensor(2, 4) # Weight root nodes separately
bias = None # No additional bias
output = spline_conv(src, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, root_weight, bias)
print(output.size())
torch.Size([4, 4]) # 4 nodes with 4 features
torch.Size([4, 4]) # 4 nodes with 4 features each
```
## Cite
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment