Commit f7377ac4 authored by rusty1s's avatar rusty1s
Browse files

added tests, caching problems fixed

parent c005e19d
import os
import shutil
import subprocess
import torch
from torch.utils.ffi import create_extension
if os.path.exists('build'):
shutil.rmtree('build')
headers = ['torch_spline_conv/src/cpu.h']
sources = ['torch_spline_conv/src/cpu.c']
include_dirs = ['torch_spline_conv/src']
......
......@@ -5,4 +5,4 @@ description-file = README.md
test=pytest
[tool:pytest]
addopts = --capture=no --cov
addopts = --capture=no
[
{
"degree": 1,
"pseudo": [0, 0.25, 0.5, 0.75, 1],
"kernel_size": [5],
"is_open_spline": [0],
"expected_basis": [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]],
"expected_index": [[1, 0], [2, 1], [3, 2], [4, 3], [0, 4]]
}
]
from os import path as osp
from itertools import product
import pytest
import json
import torch
from torch_spline_conv.functions.utils import spline_basis
from .utils import tensors, Tensor
f = open(osp.join(osp.dirname(__file__), 'basis.json'), 'r')
data = json.load(f)
f.close()
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data))))
def test_spline_basis_cpu(tensor, i):
degree = data[i].get('degree')
pseudo = Tensor(tensor, data[i]['pseudo'])
kernel_size = torch.LongTensor(data[i]['kernel_size'])
is_open_spline = torch.ByteTensor(data[i]['is_open_spline'])
K = kernel_size.prod()
expected_basis = Tensor(tensor, data[i]['expected_basis'])
expected_index = torch.ByteTensor(data[i]['expected_index'])
basis, index = spline_basis(degree, pseudo, kernel_size, is_open_spline, K)
print('basis', basis)
print('weight_index', index)
return
assert basis.tolist() == expected_basis.tolist()
assert index.tolist() == expected_index.tolist()
import torch_spline_conv as SplineConv
import torch
tensors = ['FloatTensor']
def Tensor(str, x):
tensor = getattr(torch, str)
return tensor(x)
......@@ -14,16 +14,17 @@ def get_func(name, tensor):
def spline_basis(degree, pseudo, kernel_size, is_open_spline, K):
s = (degree + 1)**kernel_size.size(0)
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
basis = pseudo.new(pseudo.size(0), s)
weight_index = kernel_size.new(pseudo.size(0), s)
degree = degrees.get(degree)
if degree is None:
raise NotImplementedError('Basis computation not implemented for '
'specified B-spline degree')
s = (degree + 1)**kernel_size.size(0)
basis = pseudo.new(pseudo.size(0), s)
weight_index = kernel_size.new(pseudo.size(0), s)
func = get_func('basis_{}', degree, pseudo)
func = get_func('basis_{}'.format(degree), pseudo)
func(basis, weight_index, pseudo, kernel_size, is_open_spline, K)
return basis, weight_index
......
......@@ -6,27 +6,27 @@ void spline_(basis_linear)(THTensor *basis, THLongTensor *weight_index, THTensor
int64_t k, s, S, d, D;
real value;
D = THTensor_(size)(pseudo, 1);
S = THLongTEnsor_size(weight_index, 1);
TH_TENSOR_DIM_APPLY3(real, basis, int64_t, weight_index, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EX_EXCEPT_DIM,
for (s = 0; s < S; s++) {
/* k = K; */
/* b = 1; i = 0; */
S = THLongTensor_size(weight_index, 1);
/* TH_TENSOR_DIM_APPLY3(real, basis, int64_t, weight_index, real, pseudo, 1, TH_TENSOR_DIM_APPLY3_SIZE_EX_EXCEPT_DIM, */
/* for (s = 0; s < S; s++) { */
/* /1* k = K; *1/ */
/* /1* b = 1; i = 0; *1/ */
for (d = 0; d < D; d++) {
/* k /= kernel_size[d]; */
/* for (d = 0; d < D; d++) { */
/* /1* k /= kernel_size[d]; *1/ */
/* value = *(pseudo_data + d * pseudo_stride) * (kernel_size[d] - is_open_spline[d]); */
/* /1* value = *(pseudo_data + d * pseudo_stride) * (kernel_size[d] - is_open_spline[d]); *1/ */
/* int bot = int64_t(value); */
/* int top = (bot + 1) % kernel_size[d]; */
/* bot %= kernel_size[d]; */
}
basis_data[s * basis_stride] = 1;
weight_index[s * weight_index_stride] = 2;
})
/* /1* int bot = int64_t(value); *1/ */
/* /1* int top = (bot + 1) % kernel_size[d]; *1/ */
/* /1* bot %= kernel_size[d]; *1/ */
/* } */
/* basis_data[s * basis_stride] = 1; */
/* weight_index_data[s * weight_index_stride] = 2; */
/* }) */
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment