Unverified Commit cd7b1988 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #14 from rusty1s/tracing

prepare tracing
parents d3169766 32224979
import torch
def degree(index, num_nodes=None, dtype=None, device=None):
num_nodes = index.max().item() + 1 if num_nodes is None else num_nodes
out = torch.zeros((num_nodes), dtype=dtype, device=device)
return out.scatter_add_(0, index, out.new_ones((index.size(0))))
import torch import torch
import torch_spline_conv.weighting_cpu
if torch.cuda.is_available():
import torch_spline_conv.weighting_cuda
@torch.jit.script
def get_func(name, tensor): def spline_weighting(x: torch.Tensor, weight: torch.Tensor,
if tensor.is_cuda: basis: torch.Tensor,
return getattr(torch_spline_conv.weighting_cuda, name) weight_index: torch.Tensor) -> torch.Tensor:
else: return torch.ops.torch_spline_conv.spline_weighting(
return getattr(torch_spline_conv.weighting_cpu, name) x, weight, basis, weight_index)
class SplineWeighting(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, basis, weight_index):
ctx.weight_index = weight_index
ctx.save_for_backward(x, weight, basis)
op = get_func('weighting_fw', x)
out = op(x, weight, basis, weight_index)
return out
@staticmethod
def backward(ctx, grad_out):
x, weight, basis = ctx.saved_tensors
grad_x = grad_weight = grad_basis = None
if ctx.needs_input_grad[0]:
op = get_func('weighting_bw_x', x)
grad_x = op(grad_out, weight, basis, ctx.weight_index)
if ctx.needs_input_grad[1]:
op = get_func('weighting_bw_w', x)
grad_weight = op(grad_out, x, basis, ctx.weight_index,
weight.size(0))
if ctx.needs_input_grad[2]:
op = get_func('weighting_bw_b', x)
grad_basis = op(grad_out, x, weight, ctx.weight_index)
return grad_x, grad_weight, grad_basis, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment