Commit f7377ac4 authored by rusty1s's avatar rusty1s
Browse files

added tests, caching problems fixed

parent c005e19d
import os
import shutil
import subprocess import subprocess
import torch import torch
from torch.utils.ffi import create_extension from torch.utils.ffi import create_extension
if os.path.exists('build'):
shutil.rmtree('build')
headers = ['torch_spline_conv/src/cpu.h'] headers = ['torch_spline_conv/src/cpu.h']
sources = ['torch_spline_conv/src/cpu.c'] sources = ['torch_spline_conv/src/cpu.c']
include_dirs = ['torch_spline_conv/src'] include_dirs = ['torch_spline_conv/src']
......
...@@ -5,4 +5,4 @@ description-file = README.md ...@@ -5,4 +5,4 @@ description-file = README.md
test=pytest test=pytest
[tool:pytest] [tool:pytest]
addopts = --capture=no --cov addopts = --capture=no
[
{
"degree": 1,
"pseudo": [0, 0.25, 0.5, 0.75, 1],
"kernel_size": [5],
"is_open_spline": [0],
"expected_basis": [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]],
"expected_index": [[1, 0], [2, 1], [3, 2], [4, 3], [0, 4]]
}
]
from os import path as osp
from itertools import product
import pytest
import json
import torch
from torch_spline_conv.functions.utils import spline_basis
from .utils import tensors, Tensor
f = open(osp.join(osp.dirname(__file__), 'basis.json'), 'r')
data = json.load(f)
f.close()
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
def test_spline_basis_cpu(tensor, i):
degree = data[i].get('degree')
pseudo = Tensor(tensor, data[i]['pseudo'])
kernel_size = torch.LongTensor(data[i]['kernel_size'])
is_open_spline = torch.ByteTensor(data[i]['is_open_spline'])
K = kernel_size.prod()
expected_basis = Tensor(tensor, data[i]['expected_basis'])
expected_index = torch.ByteTensor(data[i]['expected_index'])
basis, index = spline_basis(degree, pseudo, kernel_size, is_open_spline, K)
print('basis', basis)
print('weight_index', index)
return
assert basis.tolist() == expected_basis.tolist()
assert index.tolist() == expected_index.tolist()
import torch_spline_conv as SplineConv
import torch
tensors = ['FloatTensor']
def Tensor(str, x):
tensor = getattr(torch, str)
return tensor(x)
...@@ -14,16 +14,17 @@ def get_func(name, tensor): ...@@ -14,16 +14,17 @@ def get_func(name, tensor):
def spline_basis(degree, pseudo, kernel_size, is_open_spline, K): def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
s = (degree + 1)**kernel_size.size(0)
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
basis = pseudo.new(pseudo.size(0), s)
weight_index = kernel_size.new(pseudo.size(0), s)
degree = degrees.get(degree) degree = degrees.get(degree)
if degree is None: if degree is None:
raise NotImplementedError('Basis computation not implemented for ' raise NotImplementedError('Basis computation not implemented for '
'specified B-spline degree') 'specified B-spline degree')
s = (degree + 1)**kernel_size.size(0) func = get_func('basis_{}'.format(degree), pseudo)
basis = pseudo.new(pseudo.size(0), s)
weight_index = kernel_size.new(pseudo.size(0), s)
func = get_func('basis_{}', degree, pseudo)
func(basis, weight_index, pseudo, kernel_size, is_open_spline, K) func(basis, weight_index, pseudo, kernel_size, is_open_spline, K)
return basis, weight_index return basis, weight_index
......
...@@ -6,27 +6,27 @@ void spline_(basis_linear)(THTensor *basis, THLongTensor *weight_index, THTensor ...@@ -6,27 +6,27 @@ void spline_(basis_linear)(THTensor *basis, THLongTensor *weight_index, THTensor
int64_t k, s, S, d, D; int64_t k, s, S, d, D;
real value; real value;
D = THTensor_(size)(pseudo, 1); D = THTensor_(size)(pseudo, 1);
S = THLongTEnsor_size(weight_index, 1); S = THLongTensor_size(weight_index, 1);
TH_TENSOR_DIM_APPLY3(real, basis, int64_t, weight_index, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EX_EXCEPT_DIM, /* TH_TENSOR_DIM_APPLY3(real, basis, int64_t, weight_index, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EX_EXCEPT_DIM, */
for (s = 0; s < S; s++) { /* for (s = 0; s < S; s++) { */
/* k = K; */ /* /1* k = K; *1/ */
/* b = 1; i = 0; */ /* /1* b = 1; i = 0; *1/ */
for (d = 0; d < D; d++) { /* for (d = 0; d < D; d++) { */
/* k /= kernel_size[d]; */ /* /1* k /= kernel_size[d]; *1/ */
/* value = *(pseudo_data + d * pseudo_stride) * (kernel_size[d] - is_open_spline[d]); */ /* /1* value = *(pseudo_data + d * pseudo_stride) * (kernel_size[d] - is_open_spline[d]); *1/ */
/* int bot = int64_t(value); */ /* /1* int bot = int64_t(value); *1/ */
/* int top = (bot + 1) % kernel_size[d]; */ /* /1* int top = (bot + 1) % kernel_size[d]; *1/ */
/* bot %= kernel_size[d]; */ /* /1* bot %= kernel_size[d]; *1/ */
} /* } */
basis_data[s * basis_stride] = 1; /* basis_data[s * basis_stride] = 1; */
weight_index[s * weight_index_stride] = 2; /* weight_index_data[s * weight_index_stride] = 2; */
}) /* }) */
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment