"vscode:/vscode.git/clone" did not exist on "56e707bfccb62ada836d21e431d6db0d10dd73a1"
Commit 57f5a26e authored by rusty1s's avatar rusty1s
Browse files

coverage fix

parent cba43ac9
...@@ -12,34 +12,27 @@ def get_func(name, tensor): ...@@ -12,34 +12,27 @@ def get_func(name, tensor):
return getattr(module, name) return getattr(module, name)
def fw(pseudo, kernel_size, is_open_spline, degree):
op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
basis, weight_index = op(pseudo, kernel_size, is_open_spline)
return basis, weight_index
def bw(grad_basis, pseudo, kernel_size, is_open_spline, degree):
op = get_func('{}_bw'.format(implemented_degrees[degree]), pseudo)
grad_pseudo = op(grad_basis, pseudo, kernel_size, is_open_spline)
return grad_pseudo
class SplineBasis(torch.autograd.Function): class SplineBasis(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, pseudo, kernel_size, is_open_spline, degree): def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
ctx.save_for_backward(pseudo) ctx.save_for_backward(pseudo)
ctx.kernel_size = kernel_size ctx.kernel_size, ctx.is_open_spline = kernel_size, is_open_spline
ctx.is_open_spline = is_open_spline
ctx.degree = degree ctx.degree = degree
return fw(pseudo, kernel_size, is_open_spline, degree)
op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
basis, weight_index = op(pseudo, kernel_size, is_open_spline)
return basis, weight_index
@staticmethod @staticmethod
def backward(ctx, grad_basis, grad_weight_index): def backward(ctx, grad_basis, grad_weight_index):
pseudo, = ctx.saved_tensors pseudo, = ctx.saved_tensors
kernel_size, is_open_spline = ctx.kernel_size, ctx.is_open_spline
degree = ctx.degree
grad_pseudo = None grad_pseudo = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_pseudo = bw(grad_basis, pseudo, ctx.kernel_size, op = get_func('{}_bw'.format(implemented_degrees[degree]), pseudo)
ctx.is_open_spline, ctx.degree) grad_pseudo = op(grad_basis, pseudo, kernel_size, is_open_spline)
return grad_pseudo, None, None, None return grad_pseudo, None, None, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment