Unverified Commit cd7b1988 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #14 from rusty1s/tracing

prepare tracing
parents d3169766 32224979
......@@ -3,5 +3,6 @@ source=torch_spline_conv
[report]
exclude_lines =
pragma: no cover
cuda
backward
torch.jit.script
raise
except
__pycache__/
_ext/
build/
dist/
.cache/
......
language: shell
os:
- linux
- osx
- windows
env:
global:
- CUDA_HOME=/usr/local/cuda
jobs:
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cpu
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cpu
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu101
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cpu
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu101
jobs:
include:
- os: linux
language: python
python: 3.7
addons:
apt:
sources:
- ubuntu-toolchain-r-test
packages:
- gcc-5
- g++-5
env:
- CC=gcc-5
- CXX=g++-5
- os: osx
language: sh
before_cache:
- brew cleanup
cache:
directories:
- $HOME/Library/Caches/Homebrew
- /usr/local/Homebrew
addons:
homebrew:
packages: python3
before_install:
- python3 -m pip install --upgrade virtualenv
- virtualenv -p python3 --system-site-packages "$HOME/venv"
- source "$HOME/venv/bin/activate"
env:
- CC=clang
- CXX=clang++
exclude: # Exclude *all* macOS CUDA jobs and Windows CUDA 9.2/10.0 jobs.
- os: osx
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92
- os: osx
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
- os: osx
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
- os: osx
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
- os: osx
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
- os: osx
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu101
- os: osx
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
- os: osx
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
- os: osx
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu101
- os: windows
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92
- os: windows
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
- os: windows
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
- os: windows
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
- os: windows
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
- os: windows
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
- os: windows
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
install:
- pip install numpy
- pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
- pip install pycodestyle
- pip install flake8
- pip install codecov
- source script/cuda.sh
- source script/conda.sh
- conda create --yes -n test python="${PYTHON_VERSION}"
- source activate test
- conda install pytorch=${TORCH_VERSION} ${TOOLKIT} -c pytorch --yes
- source script/torch.sh
- pip install flake8 codecov
- python setup.py install
script:
- python -c "import torch; print(torch.__version__)"
- pycodestyle .
- flake8 .
- python setup.py install
- python setup.py test
after_success:
- python setup.py bdist_wheel --dist-dir=dist/torch-${TORCH_VERSION}
- python script/rename_wheel.py ${IDX}
- codecov
deploy:
provider: s3
region: eu-central-1
edge: true
access_key_id: ${S3_ACCESS_KEY}
secret_access_key: ${S3_SECRET_ACCESS_KEY}
bucket: pytorch-geometric.com
local_dir: dist/torch-${TORCH_VERSION}
upload_dir: whl/torch-${TORCH_VERSION}
acl: public_read
on:
repo: rusty1s/pytorch_spline_conv
tags: true
notifications:
email: false
Copyright (c) 2019 Matthias Fey <matthias.fey@tu-dortmund.de>
Copyright (c) 2020 Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
......
include README.md
include LICENSE
recursive-include cpu *
recursive-include cuda *
recursive-exclude test *
recursive-include csrc *
......@@ -21,11 +21,30 @@ The operator works on all floating point data types and is implemented both for
## Installation
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
### Binaries
We provide pip wheels for all major OS/PyTorch/CUDA combinations, see [here](https://pytorch-geometric.com/whl).
To install the binaries for PyTorch 1.4.0, simply run
```
pip install torch-spline-conv==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.4.0.html
```
where `${CUDA}` should be replaced by either `cpu`, `cu92`, `cu100` or `cu101` depending on your PyTorch installation.
| | `cpu` | `cu92` | `cu100` | `cu101` |
|-------------|-------|--------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ❌ | ❌ | ✅ |
| **macOS** | ✅ | | | |
### From source
Ensure that at least PyTorch 1.4.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
```
$ python -c "import torch; print(torch.__version__)"
>>> 1.1.0
>>> 1.4.0
$ echo $PATH
>>> /usr/local/cuda/bin:...
......@@ -40,24 +59,28 @@ Then run:
pip install torch-spline-conv
```
If you are running into any installation problems, please create an [issue](https://github.com/rusty1s/pytorch_spline_conv/issues).
Be sure to import `torch` first before using this package to resolve symbols the dynamic linker must see.
When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail.
In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
```
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"
```
## Usage
```python
from torch_spline_conv import SplineConv
out = SplineConv.apply(x,
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
degree=1,
norm=True,
root_weight=None,
bias=None)
from torch_spline_conv import spline_conv
out = spline_conv(x,
edge_index,
pseudo,
weight,
kernel_size,
is_open_spline,
degree=1,
norm=True,
root_weight=None,
bias=None)
```
Applies the spline-based convolution operator
......@@ -93,7 +116,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
```python
import torch
from torch_spline_conv import SplineConv
from torch_spline_conv import spline_conv
x = torch.rand((4, 2), dtype=torch.float) # 4 nodes with 2 features each
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) # 6 edges
......@@ -106,8 +129,8 @@ norm = True # Normalize output by node degree.
root_weight = torch.rand((2, 4), dtype=torch.float) # separately weight root nodes
bias = None # do not apply an additional bias
out = SplineConv.apply(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, norm, root_weight, bias)
out = spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, norm, root_weight, bias)
print(out.size())
torch.Size([4, 4]) # 4 nodes with 4 features each
......
#include <torch/extension.h>
#include "compat.h"
template <typename scalar_t> inline scalar_t linear(scalar_t v, int64_t k_mod) {
return 1 - v - k_mod + 2 * v * k_mod;
}
template <typename scalar_t>
inline scalar_t quadratic(scalar_t v, int64_t k_mod) {
if (k_mod == 0)
return 0.5 * v * v - v + 0.5;
else if (k_mod == 1)
return -v * v + v + 0.5;
else
return 0.5 * v * v;
}
template <typename scalar_t> inline scalar_t cubic(scalar_t v, int64_t k_mod) {
if (k_mod == 0)
return (1 - v) * (1 - v) * (1 - v) / 6.0;
else if (k_mod == 1)
return (3 * v * v * v - 6 * v * v + 4) / 6;
else if (k_mod == 2)
return (-3 * v * v * v + 3 * v * v + 3 * v + 1) / 6;
else
return v * v * v / 6;
}
#define BASIS_FORWARD(M, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, FUNC) \
[&]() -> std::tuple<at::Tensor, at::Tensor> { \
auto E = PSEUDO.size(0), D = PSEUDO.size(1); \
auto S = (int64_t)(pow(M + 1, KERNEL_SIZE.size(0)) + 0.5); \
auto basis = at::empty({E, S}, PSEUDO.options()); \
auto weight_index = at::empty({E, S}, KERNEL_SIZE.options()); \
\
AT_DISPATCH_FLOATING_TYPES( \
PSEUDO.scalar_type(), "basis_forward_##M", [&] { \
auto pseudo_data = PSEUDO.DATA_PTR<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.DATA_PTR<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.DATA_PTR<uint8_t>(); \
auto basis_data = basis.DATA_PTR<scalar_t>(); \
auto weight_index_data = weight_index.DATA_PTR<int64_t>(); \
\
int64_t k, wi, wi_offset; \
scalar_t b; \
\
for (ptrdiff_t e = 0; e < E; e++) { \
for (ptrdiff_t s = 0; s < S; s++) { \
k = s; \
wi = 0; \
wi_offset = 1; \
b = 1; \
for (ptrdiff_t d = 0; d < D; d++) { \
auto k_mod = k % (M + 1); \
k /= M + 1; \
\
auto v = \
pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)]; \
v *= kernel_size_data[d] - M * is_open_spline_data[d]; \
\
wi += \
(((int64_t)v + k_mod) % kernel_size_data[d]) * wi_offset; \
wi_offset *= kernel_size_data[d]; \
\
v -= floor(v); \
v = FUNC<scalar_t>(v, k_mod); \
b *= v; \
} \
basis_data[e * S + s] = b; \
weight_index_data[e * S + s] = wi; \
} \
} \
}); \
return std::make_tuple(basis, weight_index); \
}()
std::tuple<at::Tensor, at::Tensor> linear_fw(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline) {
return BASIS_FORWARD(1, pseudo, kernel_size, is_open_spline, linear);
}
std::tuple<at::Tensor, at::Tensor> quadratic_fw(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline) {
return BASIS_FORWARD(2, pseudo, kernel_size, is_open_spline, quadratic);
}
std::tuple<at::Tensor, at::Tensor>
cubic_fw(at::Tensor pseudo, at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_FORWARD(3, pseudo, kernel_size, is_open_spline, cubic);
}
template <typename scalar_t>
inline scalar_t grad_linear(scalar_t v, int64_t k_mod) {
return 2 * k_mod - 1;
}
template <typename scalar_t>
inline scalar_t grad_quadratic(scalar_t v, int64_t k_mod) {
if (k_mod == 0)
return v - 1;
else if (k_mod == 1)
return -2 * v + 1;
else
return v;
}
template <typename scalar_t>
inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) {
if (k_mod == 0)
return (-v * v + 2 * v - 1) / 2;
else if (k_mod == 1)
return (3 * v * v - 4 * v) / 2;
else if (k_mod == 2)
return (-3 * v * v + 2 * v + 1) / 2;
else
return v * v / 2;
}
#define BASIS_BACKWARD(M, GRAD_BASIS, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, \
FUNC, GRAD_FUNC) \
[&]() -> at::Tensor { \
auto E = PSEUDO.size(0), D = PSEUDO.size(1); \
auto S = GRAD_BASIS.size(1); \
auto grad_pseudo = at::empty({E, D}, PSEUDO.options()); \
\
AT_DISPATCH_FLOATING_TYPES( \
PSEUDO.scalar_type(), "basis_backward_##M", [&] { \
auto grad_basis_data = GRAD_BASIS.DATA_PTR<scalar_t>(); \
auto pseudo_data = PSEUDO.DATA_PTR<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.DATA_PTR<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.DATA_PTR<uint8_t>(); \
auto grad_pseudo_data = grad_pseudo.DATA_PTR<scalar_t>(); \
\
scalar_t g, tmp; \
\
for (ptrdiff_t e = 0; e < E; e++) { \
for (ptrdiff_t d = 0; d < D; d++) { \
g = 0; \
for (ptrdiff_t s = 0; s < S; s++) { \
auto k_mod = (s / (int64_t)(pow(M + 1, d) + 0.5)) % (M + 1); \
auto v = \
pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)]; \
v *= kernel_size_data[d] - M * is_open_spline_data[d]; \
v -= floor(v); \
v = GRAD_FUNC<scalar_t>(v, k_mod); \
tmp = v; \
\
for (ptrdiff_t d_it = 1; d_it < D; d_it++) { \
auto d_new = d_it - (d >= d_it); \
k_mod = (s / (int64_t)(pow(M + 1, d_new) + 0.5)) % (M + 1); \
v = pseudo_data[e * pseudo.stride(0) + \
d_new * pseudo.stride(1)]; \
v *= kernel_size_data[d_new] - \
M * is_open_spline_data[d_new]; \
v -= floor(v); \
v = FUNC<scalar_t>(v, k_mod); \
tmp *= v; \
} \
g += tmp * grad_basis_data[e * grad_basis.stride(0) + \
s * grad_basis.stride(1)]; \
} \
g *= kernel_size_data[d] - M * is_open_spline_data[d]; \
grad_pseudo_data[e * D + d] = g; \
} \
} \
}); \
return grad_pseudo; \
}()
at::Tensor linear_bw(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_BACKWARD(1, grad_basis, pseudo, kernel_size, is_open_spline,
linear, grad_linear);
}
at::Tensor quadratic_bw(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_BACKWARD(2, grad_basis, pseudo, kernel_size, is_open_spline,
quadratic, grad_quadratic);
}
at::Tensor cubic_bw(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_BACKWARD(3, grad_basis, pseudo, kernel_size, is_open_spline,
cubic, grad_cubic);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_fw", &linear_fw, "Linear Basis Forward (CPU)");
m.def("quadratic_fw", &quadratic_fw, "Quadratic Basis Forward (CPU)");
m.def("cubic_fw", &cubic_fw, "Cubic Basis Forward (CPU)");
m.def("linear_bw", &linear_bw, "Linear Basis Backward (CPU)");
m.def("quadratic_bw", &quadratic_bw, "Quadratic Basis Backward (CPU)");
m.def("cubic_bw", &cubic_bw, "Cubic Basis Backward (CPU)");
}
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <Python.h>
#include <torch/script.h>
#include "cpu/basis_cpu.h"
#ifdef WITH_CUDA
#include "cuda/basis_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__basis(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
if (pseudo.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_basis_fw_cuda(pseudo, kernel_size, is_open_spline, degree);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_basis_fw_cpu(pseudo, kernel_size, is_open_spline, degree);
}
}
torch::Tensor spline_basis_bw(torch::Tensor grad_basis, torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
if (grad_basis.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_basis_bw_cuda(grad_basis, pseudo, kernel_size, is_open_spline,
degree);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_basis_bw_cpu(grad_basis, pseudo, kernel_size, is_open_spline,
degree);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SplineBasis : public torch::autograd::Function<SplineBasis> {
public:
static variable_list forward(AutogradContext *ctx, Variable pseudo,
Variable kernel_size, Variable is_open_spline,
int64_t degree) {
ctx->saved_data["degree"] = degree;
auto result = spline_basis_fw(pseudo, kernel_size, is_open_spline, degree);
auto basis = std::get<0>(result), weight_index = std::get<1>(result);
ctx->save_for_backward({pseudo, kernel_size, is_open_spline});
ctx->mark_non_differentiable({weight_index});
return {basis, weight_index};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_basis = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto pseudo = saved[0], kernel_size = saved[1], is_open_spline = saved[2];
auto degree = ctx->saved_data["degree"].toInt();
auto grad_pseudo = spline_basis_bw(grad_basis, pseudo, kernel_size,
is_open_spline, degree);
return {grad_pseudo, Variable(), Variable(), Variable()};
}
};
std::tuple<torch::Tensor, torch::Tensor>
spline_basis(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
pseudo = pseudo.contiguous();
auto result = SplineBasis::apply(pseudo, kernel_size, is_open_spline, degree);
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators().op(
"torch_spline_conv::spline_basis", &spline_basis);
#include "basis_cpu.h"
#include "utils.h"
template <typename scalar_t, int64_t degree> struct Basis {
static inline scalar_t forward(scalar_t v, int64_t k_mod) {
if (degree == 1) {
return 1. - v - k_mod + 2. * v * k_mod;
} else if (degree == 2) {
if (k_mod == 0)
return 0.5 * v * v - v + 0.5;
else if (k_mod == 1)
return -v * v + v + 0.5;
else
return 0.5 * v * v;
} else if (degree == 3) {
if (k_mod == 0)
return (1. - v) * (1. - v) * (1. - v) / 6.;
else if (k_mod == 1)
return (3. * v * v * v - 6. * v * v + 4.) / 6.;
else if (k_mod == 2)
return (-3. * v * v * v + 3. * v * v + 3. * v + 1.) / 6.;
else
return v * v * v / 6.;
} else {
return (scalar_t)-1.;
}
}
static inline scalar_t backward(scalar_t v, int64_t k_mod) {
if (degree == 1) {
return 2 * k_mod - 1;
} else if (degree == 2) {
if (k_mod == 0)
return v - 1.;
else if (k_mod == 1)
return -2. * v + 1.;
else
return v;
} else if (degree == 3) {
if (k_mod == 0)
return (-v * v + 2. * v - 1.) / 2.;
else if (k_mod == 1)
return (3. * v * v - 4. * v) / 2.;
else if (k_mod == 2)
return (-3. * v * v + 2. * v + 1.) / 2.;
else
return v * v / 2.;
} else {
return (scalar_t)-1.;
}
}
};
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
CHECK_CPU(pseudo);
CHECK_CPU(kernel_size);
CHECK_CPU(is_open_spline);
CHECK_INPUT(kernel_size.dim() == 1);
CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
CHECK_INPUT(is_open_spline.dim());
CHECK_INPUT(pseudo.size(1) == is_open_spline.numel());
auto E = pseudo.size(0);
auto D = pseudo.size(1);
auto S = (int64_t)(pow(degree + 1, D) + 0.5);
auto basis = at::empty({E, S}, pseudo.options());
auto weight_index = at::empty({E, S}, kernel_size.options());
auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] {
auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] {
int64_t k, wi, wi_offset;
scalar_t b;
for (int64_t e = 0; e < E; e++) {
for (int64_t s = 0; s < S; s++) {
k = s, wi = 0, wi_offset = 1, b = (scalar_t)1.;
for (int64_t d = 0; d < D; d++) {
int64_t k_mod = k % (DEGREE + 1);
k /= DEGREE + 1;
auto v = pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)];
v *= kernel_size_data[d] - DEGREE * is_open_spline_data[d];
wi += (((int64_t)v + k_mod) % kernel_size_data[d]) * wi_offset;
wi_offset *= kernel_size_data[d];
v -= floor(v);
v = Basis<scalar_t, DEGREE>::forward(v, k_mod);
b *= v;
}
basis_data[e * S + s] = b;
weight_index_data[e * S + s] = wi;
}
}
});
});
return std::make_tuple(basis, weight_index);
}
torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree) {
CHECK_CPU(grad_basis);
CHECK_CPU(pseudo);
CHECK_CPU(kernel_size);
CHECK_CPU(is_open_spline);
CHECK_INPUT(grad_basis.size(0) == pseudo.size(0));
CHECK_INPUT(kernel_size.dim() == 1);
CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
CHECK_INPUT(is_open_spline.dim());
CHECK_INPUT(pseudo.size(1) == is_open_spline.numel());
auto E = pseudo.size(0);
auto D = pseudo.size(1);
auto S = grad_basis.size(1);
auto grad_pseudo = at::empty({E, D}, pseudo.options());
auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] {
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] {
scalar_t g, tmp;
for (int64_t e = 0; e < E; e++) {
for (int64_t d = 0; d < D; d++) {
g = (scalar_t)0.;
for (int64_t s = 0; s < S; s++) {
int64_t k_mod =
(s / (int64_t)(pow(DEGREE + 1, d) + 0.5)) % (DEGREE + 1);
auto v = pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)];
v *= kernel_size_data[d] - DEGREE * is_open_spline_data[d];
v -= floor(v);
v = Basis<scalar_t, DEGREE>::backward(v, k_mod);
tmp = v;
for (int64_t d_it = 1; d_it < D; d_it++) {
int64_t d_new = d_it - (d >= d_it);
k_mod =
(s / (int64_t)(pow(DEGREE + 1, d_new) + 0.5)) % (DEGREE + 1);
v = pseudo_data[e * pseudo.stride(0) + d_new * pseudo.stride(1)];
v *=
kernel_size_data[d_new] - DEGREE * is_open_spline_data[d_new];
v -= floor(v);
v = Basis<scalar_t, DEGREE>::forward(v, k_mod);
tmp *= v;
}
g += tmp * grad_basis_data[e * grad_basis.stride(0) +
s * grad_basis.stride(1)];
}
g *= kernel_size_data[d] - DEGREE * is_open_spline_data[d];
grad_pseudo_data[e * D + d] = g;
}
}
});
});
return grad_pseudo;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
#pragma once
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define AT_DISPATCH_DEGREE_TYPES(degree, ...) \
[&] { \
switch (degree) { \
case 1: { \
static constexpr int64_t DEGREE = 1; \
return __VA_ARGS__(); \
} \
case 2: { \
static constexpr int64_t DEGREE = 2; \
return __VA_ARGS__(); \
} \
case 3: { \
static constexpr int64_t DEGREE = 3; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Basis degree not implemented"); \
} \
}()
#include <torch/extension.h>
#include "weighting_cpu.h"
#include "compat.h"
#include "utils.h"
at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) {
auto E = x.size(0), M_in = x.size(1), M_out = weight.size(2);
torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
CHECK_CPU(x);
CHECK_CPU(weight);
CHECK_CPU(basis);
CHECK_CPU(weight_index);
CHECK_INPUT(x.size(1) == weight.size(1));
auto E = x.size(0);
auto M_in = x.size(1);
auto M_out = weight.size(2);
auto S = basis.size(1);
auto out = at::empty({E, M_out}, x.options());
AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "weighting_fw", [&] {
auto x_data = x.DATA_PTR<scalar_t>();
auto weight_data = weight.DATA_PTR<scalar_t>();
auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto out_data = out.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] {
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t v;
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
v = 0;
for (ptrdiff_t s = 0; s < S; s++) {
for (int64_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto tmp =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
......@@ -39,27 +51,40 @@ at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
return out;
}
at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index) {
auto E = grad_out.size(0), M_in = weight.size(1), M_out = grad_out.size(1);
torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
CHECK_CPU(grad_out);
CHECK_CPU(weight);
CHECK_CPU(basis);
CHECK_CPU(weight_index);
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = weight.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_x = at::zeros({E, M_in}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto weight_data = weight.DATA_PTR<scalar_t>();
auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_x_data = grad_x.DATA_PTR<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto grad_x_data = grad_x.data_ptr<scalar_t>();
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (ptrdiff_t s = 0; s < S; s++) {
for (int64_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto w =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
......@@ -73,27 +98,39 @@ at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
return grad_x;
}
at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
at::Tensor weight_index, int64_t K) {
auto E = grad_out.size(0), M_in = x.size(1), M_out = grad_out.size(1);
torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size) {
CHECK_CPU(grad_out);
CHECK_CPU(x);
CHECK_CPU(basis);
CHECK_CPU(weight_index);
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_w", [&] {
auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto x_data = x.DATA_PTR<scalar_t>();
auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_weight_data = grad_weight.DATA_PTR<scalar_t>();
auto grad_weight = at::zeros({kernel_size, M_in, M_out}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (ptrdiff_t s = 0; s < S; s++) {
for (int64_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto v = g * b * x_data[e * x.stride(0) + m_in * x.stride(1)];
grad_weight_data[wi * M_in * M_out + m_in * M_out + m_out] += v;
}
......@@ -105,27 +142,41 @@ at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
return grad_weight;
}
at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
at::Tensor weight_index) {
auto E = grad_out.size(0), M_in = x.size(1), M_out = grad_out.size(1);
torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index) {
CHECK_CPU(grad_out);
CHECK_CPU(x);
CHECK_CPU(weight);
CHECK_CPU(weight_index);
CHECK_INPUT(x.size(1) == weight.size(1));
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_b", [&] {
auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto x_data = x.DATA_PTR<scalar_t>();
auto weight_data = weight.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_basis_data = grad_basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.data_ptr<int64_t>();
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (ptrdiff_t s = 0; s < S; s++) {
for (int64_t s = 0; s < S; s++) {
scalar_t b = 0;
auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto w =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
......@@ -140,10 +191,3 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
return grad_basis;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("weighting_fw", &weighting_fw, "Weighting Forward (CPU)");
m.def("weighting_bw_x", &weighting_bw_x, "Weighting Backward X (CPU)");
m.def("weighting_bw_w", &weighting_bw_w, "Weighting Backward Weight (CPU)");
m.def("weighting_bw_b", &weighting_bw_b, "Weighting Backward Basis (CPU)");
}
#pragma once
#include <torch/extension.h>
torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size);
torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index);
#pragma once
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static inline __device__ void atomAdd(double *address, double val) {
unsigned long long int *address_as_ull = (unsigned long long int *)address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
}
#else
static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val);
}
#endif
#include "basis_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t, int64_t degree> struct Basis {
static inline __device__ scalar_t forward(scalar_t v, int64_t k_mod) {
if (degree == 1) {
return 1. - v - k_mod + 2. * v * k_mod;
} else if (degree == 2) {
if (k_mod == 0)
return 0.5 * v * v - v + 0.5;
else if (k_mod == 1)
return -v * v + v + 0.5;
else
return 0.5 * v * v;
} else if (degree == 3) {
if (k_mod == 0)
return (1. - v) * (1. - v) * (1. - v) / 6.;
else if (k_mod == 1)
return (3. * v * v * v - 6. * v * v + 4.) / 6.;
else if (k_mod == 2)
return (-3. * v * v * v + 3. * v * v + 3. * v + 1.) / 6.;
else
return v * v * v / 6.;
} else {
return (scalar_t)-1.;
}
}
static inline __device__ scalar_t backward(scalar_t v, int64_t k_mod) {
if (degree == 1) {
return 2 * k_mod - 1;
} else if (degree == 2) {
if (k_mod == 0)
return v - 1.;
else if (k_mod == 1)
return -2. * v + 1.;
else
return v;
} else if (degree == 3) {
if (k_mod == 0)
return (-v * v + 2. * v - 1.) / 2.;
else if (k_mod == 1)
return (3. * v * v - 4. * v) / 2.;
else if (k_mod == 2)
return (-3. * v * v + 2. * v + 1.) / 2.;
else
return v * v / 2.;
} else {
return (scalar_t)-1.;
}
}
};
template <typename scalar_t, int64_t degree>
__global__ void
spline_basis_fw_kernel(const scalar_t *pseudo, const int64_t *kernel_size,
const uint8_t *is_open_spline, scalar_t *basis,
int64_t *weight_index, int64_t E, int64_t D, int64_t S,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / S;
const int64_t s = thread_idx % S;
if (thread_idx < numel) {
int64_t k = s, wi = 0, wi_offset = 1;
scalar_t b = (scalar_t)1.;
for (int64_t d = 0; d < D; d++) {
const int64_t k_mod = k % (degree + 1);
k /= degree + 1;
scalar_t v = pseudo[e * D + d];
v *= kernel_size[d] - degree * is_open_spline[d];
wi += (((int64_t)v + k_mod) % kernel_size[d]) * wi_offset;
wi_offset *= kernel_size[d];
v -= floor(v);
v = Basis<scalar_t, degree>::forward(v, k_mod);
b *= v;
}
basis[thread_idx] = b;
weight_index[thread_idx] = wi;
}
}
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline);
cudaSetDevice(pseudo.get_device());
CHECK_INPUT(kernel_size.dim() == 1);
CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
CHECK_INPUT(is_open_spline.dim());
CHECK_INPUT(pseudo.size(1) == is_open_spline.numel());
auto E = pseudo.size(0);
auto D = pseudo.size(1);
auto S = (int64_t)(powf(degree + 1, D) + 0.5);
auto basis = at::empty({E, S}, pseudo.options());
auto weight_index = at::empty({E, S}, kernel_size.options());
auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] {
auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] {
spline_basis_fw_kernel<scalar_t, DEGREE>
<<<BLOCKS(basis.numel()), THREADS, 0, stream>>>(
pseudo_data, kernel_size_data, is_open_spline_data, basis_data,
weight_index_data, E, D, S, basis.numel());
});
});
return std::make_tuple(basis, weight_index);
}
template <typename scalar_t, int64_t degree>
__global__ void
spline_basis_bw_kernel(const scalar_t *grad_basis, const scalar_t *pseudo,
const int64_t *kernel_size,
const uint8_t *is_open_spline, scalar_t *grad_pseudo,
int64_t E, int64_t D, int64_t S, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / D;
const int64_t d = thread_idx % D;
if (thread_idx < numel) {
scalar_t g = (scalar_t)0., tmp;
for (ptrdiff_t s = 0; s < S; s++) {
int64_t k_mod = (s / (int64_t)(powf(degree + 1, d) + 0.5)) % (degree + 1);
scalar_t v = pseudo[e * D + d];
v *= kernel_size[d] - degree * is_open_spline[d];
v -= floor(v);
v = Basis<scalar_t, degree>::backward(v, k_mod);
tmp = v;
for (int64_t d_it = 1; d_it < D; d_it++) {
const int64_t d_new = d_it - (d >= d_it);
k_mod = (s / (int64_t)(powf(degree + 1, d_new) + 0.5)) % (degree + 1);
v = pseudo[e * D + d_new];
v *= kernel_size[d_new] - degree * is_open_spline[d_new];
v -= floor(v);
v = Basis<scalar_t, degree>::forward(v, k_mod);
tmp *= v;
}
g += tmp * grad_basis[e * S + s];
}
g *= kernel_size[d] - degree * is_open_spline[d];
grad_pseudo[thread_idx] = g;
}
}
torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree) {
CHECK_CUDA(grad_basis);
CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline);
cudaSetDevice(grad_basis.get_device());
CHECK_INPUT(grad_basis.size(0) == pseudo.size(0));
CHECK_INPUT(kernel_size.dim() == 1);
CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
CHECK_INPUT(is_open_spline.dim());
CHECK_INPUT(pseudo.size(1) == is_open_spline.numel());
auto E = pseudo.size(0);
auto D = pseudo.size(1);
auto S = grad_basis.size(1);
auto grad_pseudo = at::empty({E, D}, pseudo.options());
auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] {
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] {
spline_basis_bw_kernel<scalar_t, DEGREE>
<<<BLOCKS(grad_pseudo.numel()), THREADS, 0, stream>>>(
grad_basis_data, pseudo_data, kernel_size_data,
is_open_spline_data, grad_pseudo_data, E, D, S,
grad_pseudo.numel());
});
});
return grad_pseudo;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree);
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define AT_DISPATCH_DEGREE_TYPES(degree, ...) \
[&] { \
switch (degree) { \
case 1: { \
const int64_t DEGREE = 1; \
return __VA_ARGS__(); \
} \
case 2: { \
const int64_t DEGREE = 2; \
return __VA_ARGS__(); \
} \
case 3: { \
const int64_t DEGREE = 3; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Basis degree not implemented"); \
} \
}()
#include "weighting_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "atomics.cuh"
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void
spline_weighting_fw_kernel(const scalar_t *x, const scalar_t *weight,
const scalar_t *basis, const int64_t *weight_index,
scalar_t *out, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / M_out;
const int64_t m_out = thread_idx % M_out;
if (thread_idx < numel) {
scalar_t v = (scalar_t)0.;
for (ptrdiff_t s = 0; s < S; s++) {
const scalar_t b = basis[e * S + s];
const int64_t wi = weight_index[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
scalar_t tmp = weight[wi * M_in * M_out + m_in * M_out + m_out];
tmp *= b * x[e * M_in + m_in];
v += tmp;
}
}
out[thread_idx] = v;
}
}
torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
CHECK_CUDA(x);
CHECK_CUDA(weight);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
cudaSetDevice(x.get_device());
CHECK_INPUT(x.size(1) == weight.size(1));
auto E = x.size(0);
auto M_in = x.size(1);
auto M_out = weight.size(2);
auto S = basis.size(1);
auto out = at::empty({E, M_out}, x.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] {
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
spline_weighting_fw_kernel<scalar_t>
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
x_data, weight_data, basis_data, weight_index_data, out_data, E,
M_in, M_out, S, out.numel());
});
return out;
}
template <typename scalar_t>
__global__ void
spline_weighting_bw_x_kernel(const scalar_t *grad_out, const scalar_t *weight,
const scalar_t *basis, const int64_t *weight_index,
scalar_t *grad_x, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / M_in;
const int64_t m_in = thread_idx % M_in;
if (thread_idx < numel) {
scalar_t v = (scalar_t)0.;
for (int64_t s = 0; s < S; s++) {
const scalar_t b = basis[e * S + s];
const int64_t wi = weight_index[e * S + s];
for (int64_t m_out = 0; m_out < M_out; m_out++) {
scalar_t tmp = weight[wi * M_out * M_in + m_out * M_in + m_in];
tmp *= b * grad_out[e * M_out + m_out];
v += tmp;
}
}
grad_x[thread_idx] = v;
}
}
torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
CHECK_CUDA(grad_out);
CHECK_CUDA(weight);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
cudaSetDevice(grad_out.get_device());
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = weight.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_x = at::zeros({E, M_in}, grad_out.options());
weight = weight.transpose(1, 2).contiguous(); // Contiguous memory-access.
auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto grad_x_data = grad_x.data_ptr<scalar_t>();
spline_weighting_bw_x_kernel<scalar_t>
<<<BLOCKS(grad_x.numel()), THREADS, 0, stream>>>(
grad_out_data, weight_data, basis_data, weight_index_data,
grad_x_data, E, M_in, M_out, S, grad_x.numel());
});
return grad_x;
}
template <typename scalar_t>
__global__ void spline_weighting_bw_weight_kernel(
const scalar_t *grad_out, const scalar_t *x, const scalar_t *basis,
const int64_t *weight_index, scalar_t *grad_weight, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / M_out;
const int64_t m_out = thread_idx % M_out;
if (thread_idx < numel) {
auto g = grad_out[e * M_out + m_out];
for (int64_t s = 0; s < S; s++) {
const scalar_t b = basis[e * S + s];
const int64_t wi = weight_index[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto v = g * b * x[e * M_in + m_in];
atomAdd(&grad_weight[wi * M_in * M_out + m_in * M_out + m_out], v);
}
}
}
}
torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size) {
CHECK_CUDA(grad_out);
CHECK_CUDA(x);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
cudaSetDevice(grad_out.get_device());
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_weight = at::zeros({kernel_size, M_in, M_out}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
spline_weighting_bw_weight_kernel<scalar_t>
<<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
grad_out_data, x_data, basis_data, weight_index_data,
grad_weight_data, E, M_in, M_out, S, grad_out.numel());
});
return grad_weight;
}
template <typename scalar_t>
__global__ void spline_weighting_bw_basis_kernel(
const scalar_t *grad_out, const scalar_t *x, const scalar_t *weight,
const int64_t *weight_index, scalar_t *grad_basis, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / M_out;
const int64_t m_out = thread_idx % M_out;
if (thread_idx < numel) {
const scalar_t g = grad_out[e * M_out + m_out];
for (int64_t s = 0; s < S; s++) {
scalar_t v = (scalar_t)0.;
const int64_t wi = weight_index[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
const scalar_t w = weight[wi * M_in * M_out + m_in * M_out + m_out];
v += g * w * x[e * M_in + m_in];
}
atomAdd(&grad_basis[e * S + s], v);
}
}
}
torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index) {
CHECK_CUDA(grad_out);
CHECK_CUDA(x);
CHECK_CUDA(weight);
CHECK_CUDA(weight_index);
cudaSetDevice(grad_out.get_device());
CHECK_INPUT(x.size(1) == weight.size(1));
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
spline_weighting_bw_basis_kernel<scalar_t>
<<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
grad_out_data, x_data, weight_data, weight_index_data,
grad_basis_data, E, M_in, M_out, S, grad_out.numel());
});
return grad_basis;
}
#pragma once
#include <torch/extension.h>
torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size);
torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment