Commit b5ac9f33 authored by rusty1s's avatar rusty1s
Browse files

weigthing forward (cpu+gpu)

parent d48533ea
from itertools import product
import pytest
import torch
from torch_spline_conv.weighting import spline_weighting
from .tensor import tensors
tests = [{
'src': [[1, 2], [3, 4]],
'weight': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]],
'basis': [[0.5, 0, 0.5, 0], [0, 0, 0.5, 0.5]],
'weight_index': [[0, 1, 2, 3], [0, 1, 2, 3]],
'output': [
[0.5 * ((1 * (1 + 5)) + (2 * (2 + 6)))],
[0.5 * ((3 * (5 + 7)) + (4 * (6 + 8)))],
]
}]
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_spline_basis_forward_cpu(tensor, i):
data = tests[i]
src = getattr(torch, tensor)(data['src'])
weight = getattr(torch, tensor)(data['weight'])
basis = getattr(torch, tensor)(data['basis'])
weight_index = torch.LongTensor(data['weight_index'])
output = spline_weighting(src, weight, basis, weight_index)
assert output.tolist() == data['output']
@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA')
@pytest.mark.parametrize('tensor,i', product(tensors, range(len(tests))))
def test_spline_basis_forward_gpu(tensor, i):
data = tests[i]
src = getattr(torch.cuda, tensor)(data['src'])
weight = getattr(torch.cuda, tensor)(data['weight'])
basis = getattr(torch.cuda, tensor)(data['basis'])
weight_index = torch.cuda.LongTensor(data['weight_index'])
output = spline_weighting(src, weight, basis, weight_index)
assert output.cpu().tolist() == data['output']
from .spline_conv import spline_conv
from .conv import spline_conv
__version__ = '0.1.0'
......
import torch
from torch.autograd import Function
from .utils.ffi import basis_forward as ffi_basis_forward
from .utils.ffi import basis_backward as ffi_basis_backward
from .utils.ffi import basis_forward as basis_fw
from .utils.ffi import basis_backward as basis_bw
def basis_forward(degree, pseudo, kernel_size, is_open_spline):
num_nodes, S = pseudo.size(0), (degree + 1)**kernel_size.size(0)
basis = pseudo.new(num_nodes, S)
weight_index = kernel_size.new(num_nodes, S)
ffi_basis_forward(degree, basis, weight_index, pseudo, kernel_size,
is_open_spline)
basis_fw(degree, basis, weight_index, pseudo, kernel_size, is_open_spline)
return basis, weight_index
def basis_backward(degree, grad_basis, pseudo, kernel_size, is_open_spline):
grad_pseudo = pseudo.new(pseudo.size())
ffi_basis_backward(degree, grad_pseudo, grad_basis, pseudo, kernel_size,
is_open_spline)
basis_bw(degree, grad_pseudo, grad_basis, pseudo, kernel_size,
is_open_spline)
return grad_pseudo
......@@ -34,8 +33,8 @@ class SplineBasis(Function):
self.is_open_spline)
def backward(self, grad_basis, grad_weight_index):
pseudo, = self.saved_tensors
grad_pseudo = None
pseudo, = self.saved_tensors
if self.needs_input_grad[0]:
grad_pseudo = basis_backward(self.degree, grad_basis, pseudo,
......
......@@ -28,3 +28,23 @@ def basis_backward(degree, self, grad_basis, pseudo, kernel_size,
name = '{}BasisBackward'.format(get_degree_str(degree))
func = get_func(name, self.is_cuda, self)
func(self, grad_basis, pseudo, kernel_size, is_open_spline)
def weighting_forward(self, src, weight, basis, weight_index):
func = get_func('weightingForward', self.is_cuda, self)
func(self, src, weight, basis, weight_index)
def weighting_backward_src(self, grad_output, weight, basis, weight_index):
func = get_func('weightingBackwardSrc', self.is_cuda, self)
func(self, grad_output, weight, basis, weight_index)
def weighting_backward_weight(self, grad_output, src, basis, weight_index):
func = get_func('weightingBackwardWeight', self.is_cuda, self)
func(self, grad_output, src, basis, weight_index)
def weighting_backward_basis(self, grad_output, src, weight, weight_index):
func = get_func('weightingBackwardBasis', self.is_cuda, self)
func(self, grad_output, src, weight, weight_index)
import torch
from torch.autograd import Function
from .utils.ffi import weighting_forward as weighting_fw
from .utils.ffi import weighting_backward_src as weighting_bw_src
from .utils.ffi import weighting_backward_weight as weighting_bw_weight
from .utils.ffi import weighting_backward_basis as weighting_bw_basis
def weighting_forward(src, weight, basis, weight_index):
output = src.new(src.size(0), weight.size(2))
weighting_fw(output, src, weight, basis, weight_index)
return output
def weighting_backward_src(grad_output, weight, basis, weight_index):
grad_src = grad_output.new(grad_output.size(0), weight.size(1))
weight = weight.transpose(1, 2).contiguous() # Coalesced memory access.
weighting_bw_src(grad_src, grad_output, weight, basis, weight_index)
return grad_src
def weighting_backward_weight(grad_output, src, basis, weight_index, K):
grad_weight = src.new(K, src.size(1), grad_output.size(1))
weighting_bw_weight(grad_weight, grad_output, src, basis, weight_index)
return grad_weight
def weighting_backward_basis(grad_output, src, weight, weight_index):
grad_basis = src.new(weight_index.size())
weighting_bw_basis(grad_basis, grad_output, src, weight, weight_index)
return grad_basis
class SplineWeighting(Function):
def __init__(self, weight_index):
super(SplineWeighting, self).__init__()
self.weight_index = weight_index
def forward(self, src, weight, basis):
self.save_for_backward(src, weight, basis)
return weighting_forward(src, weight, basis, self.weight_index)
def backward(self, grad_output):
grad_src = grad_weight = grad_basis = None
src, weight, basis = self.saved_tensors
if self.needs_input_grad[0]:
grad_src = weighting_backward_src(grad_output, weight, basis,
self.weight_index)
if self.needs_input_grad[1]:
grad_weight = weighting_backward_weight(grad_output, src, basis,
self.weight_index)
if self.needs_input_grad[2]:
grad_basis = weighting_backward_basis(grad_output, src, weight,
self.weight_index)
return grad_src, grad_weight, grad_basis
def spline_weighting(src, weight, basis, weight_index):
if torch.is_tensor(src):
return weighting_forward(src, weight, basis, weight_index)
else:
return SplineWeighting(weight_index)(src, weight, basis)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment