Commit ac26fc19 authored by rusty1s's avatar rusty1s
Browse files

prepare tracing

parent d3169766
import torch from typing import Optional
from .basis import SplineBasis import torch
from .weighting import SplineWeighting
from .utils.degree import degree as node_degree from .basis import spline_basis
from .weighting import spline_weighting
class SplineConv(object): @torch.jit.script
def spline_conv(x: torch.Tensor, edge_index: torch.Tensor,
pseudo: torch.Tensor, weight: torch.Tensor,
kernel_size: torch.Tensor, is_open_spline: torch.Tensor,
degree: int = 1, norm: bool = True,
root_weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""Applies the spline-based convolution operator :math:`(f \star g)(i) = r"""Applies the spline-based convolution operator :math:`(f \star g)(i) =
\frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)} \frac{1}{|\mathcal{N}(i)|} \sum_{l=1}^{M_{in}} \sum_{j \in \mathcal{N}(i)}
f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph. f_l(j) \cdot g_l(u(i, j))` over several node features of an input graph.
...@@ -38,37 +44,34 @@ class SplineConv(object): ...@@ -38,37 +44,34 @@ class SplineConv(object):
:rtype: :class:`Tensor` :rtype: :class:`Tensor`
""" """
@staticmethod
def apply(x, edge_index, pseudo, weight, kernel_size, is_open_spline,
degree=1, norm=True, root_weight=None, bias=None):
x = x.unsqueeze(-1) if x.dim() == 1 else x x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
row, col = edge_index
N, E, M_out = x.size(0), row.size(0), weight.size(2)
row, col = edge_index # Weight each node.
n, m_out = x.size(0), weight.size(2) basis, weight_index = spline_basis(pseudo, kernel_size, is_open_spline,
degree)
# Weight each node. out = spline_weighting(x[col], weight, basis, weight_index)
basis, weight_index = SplineBasis.apply(pseudo, kernel_size,
is_open_spline, degree)
weight_index = weight_index.detach()
out = SplineWeighting.apply(x[col], weight, basis, weight_index)
# Convert e x m_out to n x m_out features. # Convert E x M_out to N x M_out features.
row_expand = row.unsqueeze(-1).expand_as(out) row_expanded = row.unsqueeze(-1).expand_as(out)
out = x.new_zeros((n, m_out)).scatter_add_(0, row_expand, out) out = x.new_zeros((N, M_out)).scatter_add_(0, row_expanded, out)
# Normalize out by node degree (if wished). # Normalize out by node degree (if wished).
if norm: if norm:
deg = node_degree(row, n, out.dtype, out.device) deg = out.new_zeros(N).scatter_add_(0, row, out.new_ones(E))
out = out / deg.unsqueeze(-1).clamp(min=1) out = out / deg.unsqueeze(-1).clamp_(min=1)
# Weight root node separately (if wished). # Weight root node separately (if wished).
if root_weight is not None: if root_weight is not None:
out = out + torch.mm(x, root_weight) out = out + torch.matmul(x, root_weight)
# Add bias (if wished). # Add bias (if wished).
if bias is not None: if bias is not None:
out = out + bias out = out + bias
return out return out
import torch
def degree(index, num_nodes=None, dtype=None, device=None):
num_nodes = index.max().item() + 1 if num_nodes is None else num_nodes
out = torch.zeros((num_nodes), dtype=dtype, device=device)
return out.scatter_add_(0, index, out.new_ones((index.size(0))))
import torch import torch
import torch_spline_conv.weighting_cpu
if torch.cuda.is_available():
import torch_spline_conv.weighting_cuda
@torch.jit.script
def spline_weighting(x: torch.Tensor, weight: torch.Tensor,
basis: torch.Tensor,
weight_index: torch.Tensor) -> torch.Tensor:
return torch.ops.spline_conv.spline_weighting(x, weight, basis,
weight_index)
def get_func(name, tensor):
if tensor.is_cuda:
return getattr(torch_spline_conv.weighting_cuda, name)
else:
return getattr(torch_spline_conv.weighting_cpu, name)
# class SplineWeighting(torch.autograd.Function):
# @staticmethod
# def forward(ctx, x, weight, basis, weight_index):
# ctx.weight_index = weight_index
# ctx.save_for_backward(x, weight, basis)
# op = get_func('weighting_fw', x)
# out = op(x, weight, basis, weight_index)
# return out
class SplineWeighting(torch.autograd.Function): # @staticmethod
@staticmethod # def backward(ctx, grad_out):
def forward(ctx, x, weight, basis, weight_index): # x, weight, basis = ctx.saved_tensors
ctx.weight_index = weight_index # grad_x = grad_weight = grad_basis = None
ctx.save_for_backward(x, weight, basis)
op = get_func('weighting_fw', x)
out = op(x, weight, basis, weight_index)
return out
@staticmethod # if ctx.needs_input_grad[0]:
def backward(ctx, grad_out): # op = get_func('weighting_bw_x', x)
x, weight, basis = ctx.saved_tensors # grad_x = op(grad_out, weight, basis, ctx.weight_index)
grad_x = grad_weight = grad_basis = None
if ctx.needs_input_grad[0]: # if ctx.needs_input_grad[1]:
op = get_func('weighting_bw_x', x) # op = get_func('weighting_bw_w', x)
grad_x = op(grad_out, weight, basis, ctx.weight_index) # grad_weight = op(grad_out, x, basis, ctx.weight_index,
# weight.size(0))
if ctx.needs_input_grad[1]: # if ctx.needs_input_grad[2]:
op = get_func('weighting_bw_w', x) # op = get_func('weighting_bw_b', x)
grad_weight = op(grad_out, x, basis, ctx.weight_index, # grad_basis = op(grad_out, x, weight, ctx.weight_index)
weight.size(0))
if ctx.needs_input_grad[2]: # return grad_x, grad_weight, grad_basis, None
op = get_func('weighting_bw_b', x)
grad_basis = op(grad_out, x, weight, ctx.weight_index)
return grad_x, grad_weight, grad_basis, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment