import torch @torch.jit.script def spline_weighting(x: torch.Tensor, weight: torch.Tensor, basis: torch.Tensor, weight_index: torch.Tensor) -> torch.Tensor: return torch.ops.spline_conv.spline_weighting(x, weight, basis, weight_index) # class SplineWeighting(torch.autograd.Function): # @staticmethod # def forward(ctx, x, weight, basis, weight_index): # ctx.weight_index = weight_index # ctx.save_for_backward(x, weight, basis) # op = get_func('weighting_fw', x) # out = op(x, weight, basis, weight_index) # return out # @staticmethod # def backward(ctx, grad_out): # x, weight, basis = ctx.saved_tensors # grad_x = grad_weight = grad_basis = None # if ctx.needs_input_grad[0]: # op = get_func('weighting_bw_x', x) # grad_x = op(grad_out, weight, basis, ctx.weight_index) # if ctx.needs_input_grad[1]: # op = get_func('weighting_bw_w', x) # grad_weight = op(grad_out, x, basis, ctx.weight_index, # weight.size(0)) # if ctx.needs_input_grad[2]: # op = get_func('weighting_bw_b', x) # grad_basis = op(grad_out, x, weight, ctx.weight_index) # return grad_x, grad_weight, grad_basis, None