Commit c9169a69 authored by rusty1s's avatar rusty1s
Browse files

added all boilerplate

parent 06eac75e
......@@ -2,7 +2,6 @@
#include <torch/script.h>
#include "cpu/basis_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/basis_cuda.h"
......@@ -63,7 +62,7 @@ public:
auto grad_basis = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto pseudo = saved[0], kernel_size = saved[1], is_open_spline = saved[2];
auto gree = ctx->saved_data["degree"].toInt();
auto degree = ctx->saved_data["degree"].toInt();
auto grad_pseudo = spline_basis_bw(grad_basis, pseudo, kernel_size,
is_open_spline, degree);
return {grad_pseudo, Variable(), Variable(), Variable()};
......@@ -73,7 +72,8 @@ public:
std::tuple<torch::Tensor, torch::Tensor>
spline_basis(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
return SplineBasis::apply(pseudo, kernel_size, is_open_spline, degree);
auto result = SplineBasis::apply(pseudo, kernel_size, is_open_spline, degree);
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators().op(
......
#include "basis_cpu.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
return std::make_tuple(pseudo, kernel_size);
}
torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree) {
return grad_basis;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
#pragma once
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#include "weighting_cpu.h"
#include "utils.h"
torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
return x;
}
torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
return grad_out;
}
torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size) {
return grad_out;
}
torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index) {
return grad_out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size);
torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index);
#include "basis_cuda.h"
#include "utils.cuh"
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
return std::make_tuple(pseudo, kernel_size);
}
torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree) {
return grad_basis;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree);
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#include "weighting_cpu.h"
#include "utils.cuh"
torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
return x;
}
torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
return grad_out;
}
torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size) {
return grad_out;
}
torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index) {
return grad_out;
}
#pragma once
#include <torch/extension.h>
torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size);
torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index);
......@@ -2,7 +2,6 @@
#include <torch/script.h>
#include "cpu/weighting_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/weighting_cuda.h"
......@@ -114,7 +113,7 @@ public:
torch::Tensor spline_weighting(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
return SplineWeighting::apply(x, weight, basis, weight_index);
return SplineWeighting::apply(x, weight, basis, weight_index)[0];
}
static auto registry = torch::RegisterOperators().op(
......
......@@ -45,7 +45,7 @@ def get_extensions():
sources += [path]
extension = Extension(
'torch_scatter._' + name,
'torch_spline_conv._' + name,
sources,
include_dirs=[extensions_dir],
define_macros=define_macros,
......
......@@ -3,7 +3,7 @@ from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
from torch_spline_conv import spline_weighting, spline_basis
from torch_spline_conv import spline_basis, spline_weighting
from .utils import dtypes, devices, tensor
......
......@@ -48,7 +48,7 @@ def spline_conv(x: torch.Tensor, edge_index: torch.Tensor,
x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
row, col = edge_index
row, col = edge_index[0], edge_index[1]
N, E, M_out = x.size(0), row.size(0), weight.size(2)
# Weight each node.
......@@ -63,7 +63,8 @@ def spline_conv(x: torch.Tensor, edge_index: torch.Tensor,
# Normalize out by node degree (if wished).
if norm:
deg = out.new_zeros(N).scatter_add_(0, row, out.new_ones(E))
ones = torch.ones(E, dtype=x.dtype, device=x.device)
deg = out.new_zeros(N).scatter_add_(0, row, ones)
out = out / deg.unsqueeze(-1).clamp_(min=1)
# Weight root node separately (if wished).
......
......@@ -5,5 +5,5 @@ import torch
def spline_weighting(x: torch.Tensor, weight: torch.Tensor,
basis: torch.Tensor,
weight_index: torch.Tensor) -> torch.Tensor:
return torch.ops.spline_conv.spline_weighting(x, weight, basis,
weight_index)
return torch.ops.torch_spline_conv.spline_weighting(
x, weight, basis, weight_index)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment