Commit c9169a69 authored by rusty1s's avatar rusty1s
Browse files

added all boilerplate

parent 06eac75e
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <torch/script.h> #include <torch/script.h>
#include "cpu/basis_cpu.h" #include "cpu/basis_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include "cuda/basis_cuda.h" #include "cuda/basis_cuda.h"
...@@ -63,7 +62,7 @@ public: ...@@ -63,7 +62,7 @@ public:
auto grad_basis = grad_outs[0]; auto grad_basis = grad_outs[0];
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto pseudo = saved[0], kernel_size = saved[1], is_open_spline = saved[2]; auto pseudo = saved[0], kernel_size = saved[1], is_open_spline = saved[2];
auto gree = ctx->saved_data["degree"].toInt(); auto degree = ctx->saved_data["degree"].toInt();
auto grad_pseudo = spline_basis_bw(grad_basis, pseudo, kernel_size, auto grad_pseudo = spline_basis_bw(grad_basis, pseudo, kernel_size,
is_open_spline, degree); is_open_spline, degree);
return {grad_pseudo, Variable(), Variable(), Variable()}; return {grad_pseudo, Variable(), Variable(), Variable()};
...@@ -73,7 +72,8 @@ public: ...@@ -73,7 +72,8 @@ public:
std::tuple<torch::Tensor, torch::Tensor> std::tuple<torch::Tensor, torch::Tensor>
spline_basis(torch::Tensor pseudo, torch::Tensor kernel_size, spline_basis(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) { torch::Tensor is_open_spline, int64_t degree) {
return SplineBasis::apply(pseudo, kernel_size, is_open_spline, degree); auto result = SplineBasis::apply(pseudo, kernel_size, is_open_spline, degree);
return std::make_tuple(result[0], result[1]);
} }
static auto registry = torch::RegisterOperators().op( static auto registry = torch::RegisterOperators().op(
......
#include "basis_cpu.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
return std::make_tuple(pseudo, kernel_size);
}
torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree) {
return grad_basis;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
#pragma once
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#include "weighting_cpu.h"
#include "utils.h"
torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
return x;
}
torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
return grad_out;
}
torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size) {
return grad_out;
}
torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index) {
return grad_out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size);
torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index);
#include "basis_cuda.h"
#include "utils.cuh"
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
return std::make_tuple(pseudo, kernel_size);
}
torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree) {
return grad_basis;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree);
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#include "weighting_cpu.h"
#include "utils.cuh"
torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
return x;
}
torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
return grad_out;
}
torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size) {
return grad_out;
}
torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index) {
return grad_out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size);
torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index);
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <torch/script.h> #include <torch/script.h>
#include "cpu/weighting_cpu.h" #include "cpu/weighting_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include "cuda/weighting_cuda.h" #include "cuda/weighting_cuda.h"
...@@ -114,7 +113,7 @@ public: ...@@ -114,7 +113,7 @@ public:
torch::Tensor spline_weighting(torch::Tensor x, torch::Tensor weight, torch::Tensor spline_weighting(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis, torch::Tensor basis,
torch::Tensor weight_index) { torch::Tensor weight_index) {
return SplineWeighting::apply(x, weight, basis, weight_index); return SplineWeighting::apply(x, weight, basis, weight_index)[0];
} }
static auto registry = torch::RegisterOperators().op( static auto registry = torch::RegisterOperators().op(
......
...@@ -45,7 +45,7 @@ def get_extensions(): ...@@ -45,7 +45,7 @@ def get_extensions():
sources += [path] sources += [path]
extension = Extension( extension = Extension(
'torch_scatter._' + name, 'torch_spline_conv._' + name,
sources, sources,
include_dirs=[extensions_dir], include_dirs=[extensions_dir],
define_macros=define_macros, define_macros=define_macros,
......
...@@ -3,7 +3,7 @@ from itertools import product ...@@ -3,7 +3,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_spline_conv import spline_weighting, spline_basis from torch_spline_conv import spline_basis, spline_weighting
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
......
...@@ -48,7 +48,7 @@ def spline_conv(x: torch.Tensor, edge_index: torch.Tensor, ...@@ -48,7 +48,7 @@ def spline_conv(x: torch.Tensor, edge_index: torch.Tensor,
x = x.unsqueeze(-1) if x.dim() == 1 else x x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
row, col = edge_index row, col = edge_index[0], edge_index[1]
N, E, M_out = x.size(0), row.size(0), weight.size(2) N, E, M_out = x.size(0), row.size(0), weight.size(2)
# Weight each node. # Weight each node.
...@@ -63,7 +63,8 @@ def spline_conv(x: torch.Tensor, edge_index: torch.Tensor, ...@@ -63,7 +63,8 @@ def spline_conv(x: torch.Tensor, edge_index: torch.Tensor,
# Normalize out by node degree (if wished). # Normalize out by node degree (if wished).
if norm: if norm:
deg = out.new_zeros(N).scatter_add_(0, row, out.new_ones(E)) ones = torch.ones(E, dtype=x.dtype, device=x.device)
deg = out.new_zeros(N).scatter_add_(0, row, ones)
out = out / deg.unsqueeze(-1).clamp_(min=1) out = out / deg.unsqueeze(-1).clamp_(min=1)
# Weight root node separately (if wished). # Weight root node separately (if wished).
......
...@@ -5,5 +5,5 @@ import torch ...@@ -5,5 +5,5 @@ import torch
def spline_weighting(x: torch.Tensor, weight: torch.Tensor, def spline_weighting(x: torch.Tensor, weight: torch.Tensor,
basis: torch.Tensor, basis: torch.Tensor,
weight_index: torch.Tensor) -> torch.Tensor: weight_index: torch.Tensor) -> torch.Tensor:
return torch.ops.spline_conv.spline_weighting(x, weight, basis, return torch.ops.torch_spline_conv.spline_weighting(
weight_index) x, weight, basis, weight_index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment