Commit ac26fc19 authored by rusty1s's avatar rusty1s
Browse files

prepare tracing

parent d3169766
import torch
from typing import Optional
from .basis import SplineBasis
from .weighting import SplineWeighting
import torch
from .utils.degree import degree as node_degree
from .basis import spline_basis
from .weighting import spline_weighting
class SplineConv(object):
@torch.jit.script
def spline_conv(x: torch.Tensor, edge_index: torch.Tensor,
pseudo: torch.Tensor, weight: torch.Tensor,
kernel_size: torch.Tensor, is_open_spline: torch.Tensor,
degree: int = 1, norm: bool = True,
root_weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""Applies the spline-based convolution operator :math:`(f \star g)(i) =
\frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
......@@ -38,37 +44,34 @@ class SplineConv(object):
:rtype: :class:`Tensor`
"""
@staticmethod
def apply(x, edge_index, pseudo, weight, kernel_size, is_open_spline,
degree=1, norm=True, root_weight=None, bias=None):
x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
row, col = edge_index
N, E, M_out = x.size(0), row.size(0), weight.size(2)
row, col = edge_index
n, m_out = x.size(0), weight.size(2)
# Weight each node.
basis, weight_index = spline_basis(pseudo, kernel_size, is_open_spline,
degree)
# Weight each node.
basis, weight_index = SplineBasis.apply(pseudo, kernel_size,
is_open_spline, degree)
weight_index = weight_index.detach()
out = SplineWeighting.apply(x[col], weight, basis, weight_index)
out = spline_weighting(x[col], weight, basis, weight_index)
# Convert e x m_out to n x m_out features.
row_expand = row.unsqueeze(-1).expand_as(out)
out = x.new_zeros((n, m_out)).scatter_add_(0, row_expand, out)
# Convert E x M_out to N x M_out features.
row_expanded = row.unsqueeze(-1).expand_as(out)
out = x.new_zeros((N, M_out)).scatter_add_(0, row_expanded, out)
# Normalize out by node degree (if wished).
if norm:
deg = node_degree(row, n, out.dtype, out.device)
out = out / deg.unsqueeze(-1).clamp(min=1)
# Normalize out by node degree (if wished).
if norm:
deg = out.new_zeros(N).scatter_add_(0, row, out.new_ones(E))
out = out / deg.unsqueeze(-1).clamp_(min=1)
# Weight root node separately (if wished).
if root_weight is not None:
out = out + torch.mm(x, root_weight)
# Weight root node separately (if wished).
if root_weight is not None:
out = out + torch.matmul(x, root_weight)
# Add bias (if wished).
if bias is not None:
out = out + bias
# Add bias (if wished).
if bias is not None:
out = out + bias
return out
return out
import torch
def degree(index, num_nodes=None, dtype=None, device=None):
num_nodes = index.max().item() + 1 if num_nodes is None else num_nodes
out = torch.zeros((num_nodes), dtype=dtype, device=device)
return out.scatter_add_(0, index, out.new_ones((index.size(0))))
import torch
import torch_spline_conv.weighting_cpu
if torch.cuda.is_available():
import torch_spline_conv.weighting_cuda
@torch.jit.script
def spline_weighting(x: torch.Tensor, weight: torch.Tensor,
basis: torch.Tensor,
weight_index: torch.Tensor) -> torch.Tensor:
return torch.ops.spline_conv.spline_weighting(x, weight, basis,
weight_index)
def get_func(name, tensor):
if tensor.is_cuda:
return getattr(torch_spline_conv.weighting_cuda, name)
else:
return getattr(torch_spline_conv.weighting_cpu, name)
# class SplineWeighting(torch.autograd.Function):
# @staticmethod
# def forward(ctx, x, weight, basis, weight_index):
# ctx.weight_index = weight_index
# ctx.save_for_backward(x, weight, basis)
# op = get_func('weighting_fw', x)
# out = op(x, weight, basis, weight_index)
# return out
class SplineWeighting(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, basis, weight_index):
ctx.weight_index = weight_index
ctx.save_for_backward(x, weight, basis)
op = get_func('weighting_fw', x)
out = op(x, weight, basis, weight_index)
return out
# @staticmethod
# def backward(ctx, grad_out):
# x, weight, basis = ctx.saved_tensors
# grad_x = grad_weight = grad_basis = None
@staticmethod
def backward(ctx, grad_out):
x, weight, basis = ctx.saved_tensors
grad_x = grad_weight = grad_basis = None
# if ctx.needs_input_grad[0]:
# op = get_func('weighting_bw_x', x)
# grad_x = op(grad_out, weight, basis, ctx.weight_index)
if ctx.needs_input_grad[0]:
op = get_func('weighting_bw_x', x)
grad_x = op(grad_out, weight, basis, ctx.weight_index)
# if ctx.needs_input_grad[1]:
# op = get_func('weighting_bw_w', x)
# grad_weight = op(grad_out, x, basis, ctx.weight_index,
# weight.size(0))
if ctx.needs_input_grad[1]:
op = get_func('weighting_bw_w', x)
grad_weight = op(grad_out, x, basis, ctx.weight_index,
weight.size(0))
# if ctx.needs_input_grad[2]:
# op = get_func('weighting_bw_b', x)
# grad_basis = op(grad_out, x, weight, ctx.weight_index)
if ctx.needs_input_grad[2]:
op = get_func('weighting_bw_b', x)
grad_basis = op(grad_out, x, weight, ctx.weight_index)
return grad_x, grad_weight, grad_basis, None
# return grad_x, grad_weight, grad_basis, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment