Unverified Commit cd7b1988 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Merge pull request #14 from rusty1s/tracing

prepare tracing
parents d3169766 32224979
...@@ -3,5 +3,6 @@ source=torch_spline_conv ...@@ -3,5 +3,6 @@ source=torch_spline_conv
[report] [report]
exclude_lines = exclude_lines =
pragma: no cover pragma: no cover
cuda torch.jit.script
backward raise
except
__pycache__/ __pycache__/
_ext/
build/ build/
dist/ dist/
.cache/ .cache/
......
language: shell
os:
- linux
- osx
- windows
env:
global:
- CUDA_HOME=/usr/local/cuda
jobs:
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cpu
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cpu
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu101
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cpu
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
- TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu101
jobs: jobs:
include: exclude: # Exclude *all* macOS CUDA jobs and Windows CUDA 9.2/10.0 jobs.
- os: linux - os: osx
language: python env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92
python: 3.7 - os: osx
addons: env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
apt: - os: osx
sources: env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
- ubuntu-toolchain-r-test - os: osx
packages: env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
- gcc-5 - os: osx
- g++-5 env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
env: - os: osx
- CC=gcc-5 env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu101
- CXX=g++-5 - os: osx
- os: osx env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
language: sh - os: osx
before_cache: env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
- brew cleanup - os: osx
cache: env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu101
directories: - os: windows
- $HOME/Library/Caches/Homebrew env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu92
- /usr/local/Homebrew - os: windows
addons: env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu100
homebrew: - os: windows
packages: python3 env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu92
before_install: - os: windows
- python3 -m pip install --upgrade virtualenv env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.7 IDX=cu100
- virtualenv -p python3 --system-site-packages "$HOME/venv" - os: windows
- source "$HOME/venv/bin/activate" env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu92
env: - os: windows
- CC=clang env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.6 IDX=cu100
- CXX=clang++ - os: windows
env: TORCH_VERSION=1.4.0 PYTHON_VERSION=3.8 IDX=cu101
install: install:
- pip install numpy - source script/cuda.sh
- pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html - source script/conda.sh
- pip install pycodestyle - conda create --yes -n test python="${PYTHON_VERSION}"
- pip install flake8 - source activate test
- pip install codecov - conda install pytorch=${TORCH_VERSION} ${TOOLKIT} -c pytorch --yes
- source script/torch.sh
- pip install flake8 codecov
- python setup.py install
script: script:
- python -c "import torch; print(torch.__version__)"
- pycodestyle .
- flake8 . - flake8 .
- python setup.py install
- python setup.py test - python setup.py test
after_success: after_success:
- python setup.py bdist_wheel --dist-dir=dist/torch-${TORCH_VERSION}
- python script/rename_wheel.py ${IDX}
- codecov - codecov
deploy:
provider: s3
region: eu-central-1
edge: true
access_key_id: ${S3_ACCESS_KEY}
secret_access_key: ${S3_SECRET_ACCESS_KEY}
bucket: pytorch-geometric.com
local_dir: dist/torch-${TORCH_VERSION}
upload_dir: whl/torch-${TORCH_VERSION}
acl: public_read
on:
repo: rusty1s/pytorch_spline_conv
tags: true
notifications: notifications:
email: false email: false
Copyright (c) 2019 Matthias Fey <matthias.fey@tu-dortmund.de> Copyright (c) 2020 Matthias Fey <matthias.fey@tu-dortmund.de>
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
......
include README.md
include LICENSE include LICENSE
recursive-include cpu *
recursive-include cuda * recursive-exclude test *
recursive-include csrc *
...@@ -21,11 +21,30 @@ The operator works on all floating point data types and is implemented both for ...@@ -21,11 +21,30 @@ The operator works on all floating point data types and is implemented both for
## Installation ## Installation
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*: ### Binaries
We provide pip wheels for all major OS/PyTorch/CUDA combinations, see [here](https://pytorch-geometric.com/whl).
To install the binaries for PyTorch 1.4.0, simply run
```
pip install torch-spline-conv==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.4.0.html
```
where `${CUDA}` should be replaced by either `cpu`, `cu92`, `cu100` or `cu101` depending on your PyTorch installation.
| | `cpu` | `cu92` | `cu100` | `cu101` |
|-------------|-------|--------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ❌ | ❌ | ✅ |
| **macOS** | ✅ | | | |
### From source
Ensure that at least PyTorch 1.4.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
``` ```
$ python -c "import torch; print(torch.__version__)" $ python -c "import torch; print(torch.__version__)"
>>> 1.1.0 >>> 1.4.0
$ echo $PATH $ echo $PATH
>>> /usr/local/cuda/bin:... >>> /usr/local/cuda/bin:...
...@@ -40,24 +59,28 @@ Then run: ...@@ -40,24 +59,28 @@ Then run:
pip install torch-spline-conv pip install torch-spline-conv
``` ```
If you are running into any installation problems, please create an [issue](https://github.com/rusty1s/pytorch_spline_conv/issues). When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail.
Be sure to import `torch` first before using this package to resolve symbols the dynamic linker must see. In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*:
```
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"
```
## Usage ## Usage
```python ```python
from torch_spline_conv import SplineConv from torch_spline_conv import spline_conv
out = SplineConv.apply(x, out = spline_conv(x,
edge_index, edge_index,
pseudo, pseudo,
weight, weight,
kernel_size, kernel_size,
is_open_spline, is_open_spline,
degree=1, degree=1,
norm=True, norm=True,
root_weight=None, root_weight=None,
bias=None) bias=None)
``` ```
Applies the spline-based convolution operator Applies the spline-based convolution operator
...@@ -93,7 +116,7 @@ The kernel function is defined over the weighted B-spline tensor product basis, ...@@ -93,7 +116,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
```python ```python
import torch import torch
from torch_spline_conv import SplineConv from torch_spline_conv import spline_conv
x = torch.rand((4, 2), dtype=torch.float) # 4 nodes with 2 features each x = torch.rand((4, 2), dtype=torch.float) # 4 nodes with 2 features each
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) # 6 edges edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) # 6 edges
...@@ -106,8 +129,8 @@ norm = True # Normalize output by node degree. ...@@ -106,8 +129,8 @@ norm = True # Normalize output by node degree.
root_weight = torch.rand((2, 4), dtype=torch.float) # separately weight root nodes root_weight = torch.rand((2, 4), dtype=torch.float) # separately weight root nodes
bias = None # do not apply an additional bias bias = None # do not apply an additional bias
out = SplineConv.apply(x, edge_index, pseudo, weight, kernel_size, out = spline_conv(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, norm, root_weight, bias) is_open_spline, degree, norm, root_weight, bias)
print(out.size()) print(out.size())
torch.Size([4, 4]) # 4 nodes with 4 features each torch.Size([4, 4]) # 4 nodes with 4 features each
......
#include <torch/extension.h>
#include "compat.h"
template <typename scalar_t> inline scalar_t linear(scalar_t v, int64_t k_mod) {
return 1 - v - k_mod + 2 * v * k_mod;
}
template <typename scalar_t>
inline scalar_t quadratic(scalar_t v, int64_t k_mod) {
if (k_mod == 0)
return 0.5 * v * v - v + 0.5;
else if (k_mod == 1)
return -v * v + v + 0.5;
else
return 0.5 * v * v;
}
template <typename scalar_t> inline scalar_t cubic(scalar_t v, int64_t k_mod) {
if (k_mod == 0)
return (1 - v) * (1 - v) * (1 - v) / 6.0;
else if (k_mod == 1)
return (3 * v * v * v - 6 * v * v + 4) / 6;
else if (k_mod == 2)
return (-3 * v * v * v + 3 * v * v + 3 * v + 1) / 6;
else
return v * v * v / 6;
}
#define BASIS_FORWARD(M, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, FUNC) \
[&]() -> std::tuple<at::Tensor, at::Tensor> { \
auto E = PSEUDO.size(0), D = PSEUDO.size(1); \
auto S = (int64_t)(pow(M + 1, KERNEL_SIZE.size(0)) + 0.5); \
auto basis = at::empty({E, S}, PSEUDO.options()); \
auto weight_index = at::empty({E, S}, KERNEL_SIZE.options()); \
\
AT_DISPATCH_FLOATING_TYPES( \
PSEUDO.scalar_type(), "basis_forward_##M", [&] { \
auto pseudo_data = PSEUDO.DATA_PTR<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.DATA_PTR<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.DATA_PTR<uint8_t>(); \
auto basis_data = basis.DATA_PTR<scalar_t>(); \
auto weight_index_data = weight_index.DATA_PTR<int64_t>(); \
\
int64_t k, wi, wi_offset; \
scalar_t b; \
\
for (ptrdiff_t e = 0; e < E; e++) { \
for (ptrdiff_t s = 0; s < S; s++) { \
k = s; \
wi = 0; \
wi_offset = 1; \
b = 1; \
for (ptrdiff_t d = 0; d < D; d++) { \
auto k_mod = k % (M + 1); \
k /= M + 1; \
\
auto v = \
pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)]; \
v *= kernel_size_data[d] - M * is_open_spline_data[d]; \
\
wi += \
(((int64_t)v + k_mod) % kernel_size_data[d]) * wi_offset; \
wi_offset *= kernel_size_data[d]; \
\
v -= floor(v); \
v = FUNC<scalar_t>(v, k_mod); \
b *= v; \
} \
basis_data[e * S + s] = b; \
weight_index_data[e * S + s] = wi; \
} \
} \
}); \
return std::make_tuple(basis, weight_index); \
}()
std::tuple<at::Tensor, at::Tensor> linear_fw(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline) {
return BASIS_FORWARD(1, pseudo, kernel_size, is_open_spline, linear);
}
std::tuple<at::Tensor, at::Tensor> quadratic_fw(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline) {
return BASIS_FORWARD(2, pseudo, kernel_size, is_open_spline, quadratic);
}
std::tuple<at::Tensor, at::Tensor>
cubic_fw(at::Tensor pseudo, at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_FORWARD(3, pseudo, kernel_size, is_open_spline, cubic);
}
template <typename scalar_t>
inline scalar_t grad_linear(scalar_t v, int64_t k_mod) {
return 2 * k_mod - 1;
}
template <typename scalar_t>
inline scalar_t grad_quadratic(scalar_t v, int64_t k_mod) {
if (k_mod == 0)
return v - 1;
else if (k_mod == 1)
return -2 * v + 1;
else
return v;
}
template <typename scalar_t>
inline scalar_t grad_cubic(scalar_t v, int64_t k_mod) {
if (k_mod == 0)
return (-v * v + 2 * v - 1) / 2;
else if (k_mod == 1)
return (3 * v * v - 4 * v) / 2;
else if (k_mod == 2)
return (-3 * v * v + 2 * v + 1) / 2;
else
return v * v / 2;
}
#define BASIS_BACKWARD(M, GRAD_BASIS, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, \
FUNC, GRAD_FUNC) \
[&]() -> at::Tensor { \
auto E = PSEUDO.size(0), D = PSEUDO.size(1); \
auto S = GRAD_BASIS.size(1); \
auto grad_pseudo = at::empty({E, D}, PSEUDO.options()); \
\
AT_DISPATCH_FLOATING_TYPES( \
PSEUDO.scalar_type(), "basis_backward_##M", [&] { \
auto grad_basis_data = GRAD_BASIS.DATA_PTR<scalar_t>(); \
auto pseudo_data = PSEUDO.DATA_PTR<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.DATA_PTR<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.DATA_PTR<uint8_t>(); \
auto grad_pseudo_data = grad_pseudo.DATA_PTR<scalar_t>(); \
\
scalar_t g, tmp; \
\
for (ptrdiff_t e = 0; e < E; e++) { \
for (ptrdiff_t d = 0; d < D; d++) { \
g = 0; \
for (ptrdiff_t s = 0; s < S; s++) { \
auto k_mod = (s / (int64_t)(pow(M + 1, d) + 0.5)) % (M + 1); \
auto v = \
pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)]; \
v *= kernel_size_data[d] - M * is_open_spline_data[d]; \
v -= floor(v); \
v = GRAD_FUNC<scalar_t>(v, k_mod); \
tmp = v; \
\
for (ptrdiff_t d_it = 1; d_it < D; d_it++) { \
auto d_new = d_it - (d >= d_it); \
k_mod = (s / (int64_t)(pow(M + 1, d_new) + 0.5)) % (M + 1); \
v = pseudo_data[e * pseudo.stride(0) + \
d_new * pseudo.stride(1)]; \
v *= kernel_size_data[d_new] - \
M * is_open_spline_data[d_new]; \
v -= floor(v); \
v = FUNC<scalar_t>(v, k_mod); \
tmp *= v; \
} \
g += tmp * grad_basis_data[e * grad_basis.stride(0) + \
s * grad_basis.stride(1)]; \
} \
g *= kernel_size_data[d] - M * is_open_spline_data[d]; \
grad_pseudo_data[e * D + d] = g; \
} \
} \
}); \
return grad_pseudo; \
}()
at::Tensor linear_bw(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_BACKWARD(1, grad_basis, pseudo, kernel_size, is_open_spline,
linear, grad_linear);
}
at::Tensor quadratic_bw(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_BACKWARD(2, grad_basis, pseudo, kernel_size, is_open_spline,
quadratic, grad_quadratic);
}
at::Tensor cubic_bw(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_BACKWARD(3, grad_basis, pseudo, kernel_size, is_open_spline,
cubic, grad_cubic);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_fw", &linear_fw, "Linear Basis Forward (CPU)");
m.def("quadratic_fw", &quadratic_fw, "Quadratic Basis Forward (CPU)");
m.def("cubic_fw", &cubic_fw, "Cubic Basis Forward (CPU)");
m.def("linear_bw", &linear_bw, "Linear Basis Backward (CPU)");
m.def("quadratic_bw", &quadratic_bw, "Quadratic Basis Backward (CPU)");
m.def("cubic_bw", &cubic_bw, "Cubic Basis Backward (CPU)");
}
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include <Python.h>
#include <torch/script.h>
#include "cpu/basis_cpu.h"
#ifdef WITH_CUDA
#include "cuda/basis_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__basis(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
if (pseudo.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_basis_fw_cuda(pseudo, kernel_size, is_open_spline, degree);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_basis_fw_cpu(pseudo, kernel_size, is_open_spline, degree);
}
}
torch::Tensor spline_basis_bw(torch::Tensor grad_basis, torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
if (grad_basis.device().is_cuda()) {
#ifdef WITH_CUDA
return spline_basis_bw_cuda(grad_basis, pseudo, kernel_size, is_open_spline,
degree);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return spline_basis_bw_cpu(grad_basis, pseudo, kernel_size, is_open_spline,
degree);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SplineBasis : public torch::autograd::Function<SplineBasis> {
public:
static variable_list forward(AutogradContext *ctx, Variable pseudo,
Variable kernel_size, Variable is_open_spline,
int64_t degree) {
ctx->saved_data["degree"] = degree;
auto result = spline_basis_fw(pseudo, kernel_size, is_open_spline, degree);
auto basis = std::get<0>(result), weight_index = std::get<1>(result);
ctx->save_for_backward({pseudo, kernel_size, is_open_spline});
ctx->mark_non_differentiable({weight_index});
return {basis, weight_index};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_basis = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto pseudo = saved[0], kernel_size = saved[1], is_open_spline = saved[2];
auto degree = ctx->saved_data["degree"].toInt();
auto grad_pseudo = spline_basis_bw(grad_basis, pseudo, kernel_size,
is_open_spline, degree);
return {grad_pseudo, Variable(), Variable(), Variable()};
}
};
std::tuple<torch::Tensor, torch::Tensor>
spline_basis(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
pseudo = pseudo.contiguous();
auto result = SplineBasis::apply(pseudo, kernel_size, is_open_spline, degree);
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators().op(
"torch_spline_conv::spline_basis", &spline_basis);
#include "basis_cpu.h"
#include "utils.h"
template <typename scalar_t, int64_t degree> struct Basis {
static inline scalar_t forward(scalar_t v, int64_t k_mod) {
if (degree == 1) {
return 1. - v - k_mod + 2. * v * k_mod;
} else if (degree == 2) {
if (k_mod == 0)
return 0.5 * v * v - v + 0.5;
else if (k_mod == 1)
return -v * v + v + 0.5;
else
return 0.5 * v * v;
} else if (degree == 3) {
if (k_mod == 0)
return (1. - v) * (1. - v) * (1. - v) / 6.;
else if (k_mod == 1)
return (3. * v * v * v - 6. * v * v + 4.) / 6.;
else if (k_mod == 2)
return (-3. * v * v * v + 3. * v * v + 3. * v + 1.) / 6.;
else
return v * v * v / 6.;
} else {
return (scalar_t)-1.;
}
}
static inline scalar_t backward(scalar_t v, int64_t k_mod) {
if (degree == 1) {
return 2 * k_mod - 1;
} else if (degree == 2) {
if (k_mod == 0)
return v - 1.;
else if (k_mod == 1)
return -2. * v + 1.;
else
return v;
} else if (degree == 3) {
if (k_mod == 0)
return (-v * v + 2. * v - 1.) / 2.;
else if (k_mod == 1)
return (3. * v * v - 4. * v) / 2.;
else if (k_mod == 2)
return (-3. * v * v + 2. * v + 1.) / 2.;
else
return v * v / 2.;
} else {
return (scalar_t)-1.;
}
}
};
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
CHECK_CPU(pseudo);
CHECK_CPU(kernel_size);
CHECK_CPU(is_open_spline);
CHECK_INPUT(kernel_size.dim() == 1);
CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
CHECK_INPUT(is_open_spline.dim());
CHECK_INPUT(pseudo.size(1) == is_open_spline.numel());
auto E = pseudo.size(0);
auto D = pseudo.size(1);
auto S = (int64_t)(pow(degree + 1, D) + 0.5);
auto basis = at::empty({E, S}, pseudo.options());
auto weight_index = at::empty({E, S}, kernel_size.options());
auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] {
auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] {
int64_t k, wi, wi_offset;
scalar_t b;
for (int64_t e = 0; e < E; e++) {
for (int64_t s = 0; s < S; s++) {
k = s, wi = 0, wi_offset = 1, b = (scalar_t)1.;
for (int64_t d = 0; d < D; d++) {
int64_t k_mod = k % (DEGREE + 1);
k /= DEGREE + 1;
auto v = pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)];
v *= kernel_size_data[d] - DEGREE * is_open_spline_data[d];
wi += (((int64_t)v + k_mod) % kernel_size_data[d]) * wi_offset;
wi_offset *= kernel_size_data[d];
v -= floor(v);
v = Basis<scalar_t, DEGREE>::forward(v, k_mod);
b *= v;
}
basis_data[e * S + s] = b;
weight_index_data[e * S + s] = wi;
}
}
});
});
return std::make_tuple(basis, weight_index);
}
torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree) {
CHECK_CPU(grad_basis);
CHECK_CPU(pseudo);
CHECK_CPU(kernel_size);
CHECK_CPU(is_open_spline);
CHECK_INPUT(grad_basis.size(0) == pseudo.size(0));
CHECK_INPUT(kernel_size.dim() == 1);
CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
CHECK_INPUT(is_open_spline.dim());
CHECK_INPUT(pseudo.size(1) == is_open_spline.numel());
auto E = pseudo.size(0);
auto D = pseudo.size(1);
auto S = grad_basis.size(1);
auto grad_pseudo = at::empty({E, D}, pseudo.options());
auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] {
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] {
scalar_t g, tmp;
for (int64_t e = 0; e < E; e++) {
for (int64_t d = 0; d < D; d++) {
g = (scalar_t)0.;
for (int64_t s = 0; s < S; s++) {
int64_t k_mod =
(s / (int64_t)(pow(DEGREE + 1, d) + 0.5)) % (DEGREE + 1);
auto v = pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)];
v *= kernel_size_data[d] - DEGREE * is_open_spline_data[d];
v -= floor(v);
v = Basis<scalar_t, DEGREE>::backward(v, k_mod);
tmp = v;
for (int64_t d_it = 1; d_it < D; d_it++) {
int64_t d_new = d_it - (d >= d_it);
k_mod =
(s / (int64_t)(pow(DEGREE + 1, d_new) + 0.5)) % (DEGREE + 1);
v = pseudo_data[e * pseudo.stride(0) + d_new * pseudo.stride(1)];
v *=
kernel_size_data[d_new] - DEGREE * is_open_spline_data[d_new];
v -= floor(v);
v = Basis<scalar_t, DEGREE>::forward(v, k_mod);
tmp *= v;
}
g += tmp * grad_basis_data[e * grad_basis.stride(0) +
s * grad_basis.stride(1)];
}
g *= kernel_size_data[d] - DEGREE * is_open_spline_data[d];
grad_pseudo_data[e * D + d] = g;
}
}
});
});
return grad_pseudo;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cpu(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
torch::Tensor spline_basis_bw_cpu(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
#pragma once
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define AT_DISPATCH_DEGREE_TYPES(degree, ...) \
[&] { \
switch (degree) { \
case 1: { \
static constexpr int64_t DEGREE = 1; \
return __VA_ARGS__(); \
} \
case 2: { \
static constexpr int64_t DEGREE = 2; \
return __VA_ARGS__(); \
} \
case 3: { \
static constexpr int64_t DEGREE = 3; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Basis degree not implemented"); \
} \
}()
#include <torch/extension.h> #include "weighting_cpu.h"
#include "compat.h" #include "utils.h"
at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis, torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
at::Tensor weight_index) { torch::Tensor basis,
auto E = x.size(0), M_in = x.size(1), M_out = weight.size(2); torch::Tensor weight_index) {
CHECK_CPU(x);
CHECK_CPU(weight);
CHECK_CPU(basis);
CHECK_CPU(weight_index);
CHECK_INPUT(x.size(1) == weight.size(1));
auto E = x.size(0);
auto M_in = x.size(1);
auto M_out = weight.size(2);
auto S = basis.size(1); auto S = basis.size(1);
auto out = at::empty({E, M_out}, x.options()); auto out = at::empty({E, M_out}, x.options());
AT_DISPATCH_FLOATING_TYPES(out.scalar_type(), "weighting_fw", [&] { auto weight_index_data = weight_index.data_ptr<int64_t>();
auto x_data = x.DATA_PTR<scalar_t>();
auto weight_data = weight.DATA_PTR<scalar_t>(); AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] {
auto basis_data = basis.DATA_PTR<scalar_t>(); auto x_data = x.data_ptr<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>(); auto weight_data = weight.data_ptr<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t v; scalar_t v;
for (ptrdiff_t e = 0; e < E; e++) { for (int64_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) { for (int64_t m_out = 0; m_out < M_out; m_out++) {
v = 0; v = 0;
for (ptrdiff_t s = 0; s < S; s++) { for (int64_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s]; auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s]; auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) { for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto tmp = auto tmp =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) + weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)]; m_out * weight.stride(2)];
...@@ -39,27 +51,40 @@ at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis, ...@@ -39,27 +51,40 @@ at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
return out; return out;
} }
at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight, torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
at::Tensor basis, at::Tensor weight_index) { torch::Tensor weight,
auto E = grad_out.size(0), M_in = weight.size(1), M_out = grad_out.size(1); torch::Tensor basis,
torch::Tensor weight_index) {
CHECK_CPU(grad_out);
CHECK_CPU(weight);
CHECK_CPU(basis);
CHECK_CPU(weight_index);
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = weight.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1); auto S = basis.size(1);
auto grad_x = at::zeros({E, M_in}, grad_out.options()); auto grad_x = at::zeros({E, M_in}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] { AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.DATA_PTR<scalar_t>(); auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto weight_data = weight.DATA_PTR<scalar_t>(); auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.DATA_PTR<scalar_t>(); auto basis_data = basis.data_ptr<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>(); auto grad_x_data = grad_x.data_ptr<scalar_t>();
auto grad_x_data = grad_x.DATA_PTR<scalar_t>();
for (int64_t e = 0; e < E; e++) {
for (ptrdiff_t e = 0; e < E; e++) { for (int64_t m_out = 0; m_out < M_out; m_out++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
auto g = auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)]; grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (ptrdiff_t s = 0; s < S; s++) { for (int64_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s]; auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s]; auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) { for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto w = auto w =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) + weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)]; m_out * weight.stride(2)];
...@@ -73,27 +98,39 @@ at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight, ...@@ -73,27 +98,39 @@ at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
return grad_x; return grad_x;
} }
at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis, torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
at::Tensor weight_index, int64_t K) { torch::Tensor x,
auto E = grad_out.size(0), M_in = x.size(1), M_out = grad_out.size(1); torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size) {
CHECK_CPU(grad_out);
CHECK_CPU(x);
CHECK_CPU(basis);
CHECK_CPU(weight_index);
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1); auto S = basis.size(1);
auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_w", [&] { auto grad_weight = at::zeros({kernel_size, M_in, M_out}, grad_out.options());
auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto x_data = x.DATA_PTR<scalar_t>(); auto weight_index_data = weight_index.data_ptr<int64_t>();
auto basis_data = basis.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_weight_data = grad_weight.DATA_PTR<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) { auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
auto g = auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)]; grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (ptrdiff_t s = 0; s < S; s++) { for (int64_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s]; auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s]; auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) { for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto v = g * b * x_data[e * x.stride(0) + m_in * x.stride(1)]; auto v = g * b * x_data[e * x.stride(0) + m_in * x.stride(1)];
grad_weight_data[wi * M_in * M_out + m_in * M_out + m_out] += v; grad_weight_data[wi * M_in * M_out + m_in * M_out + m_out] += v;
} }
...@@ -105,27 +142,41 @@ at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis, ...@@ -105,27 +142,41 @@ at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
return grad_weight; return grad_weight;
} }
at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight, torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
at::Tensor weight_index) { torch::Tensor x,
auto E = grad_out.size(0), M_in = x.size(1), M_out = grad_out.size(1); torch::Tensor weight,
torch::Tensor weight_index) {
CHECK_CPU(grad_out);
CHECK_CPU(x);
CHECK_CPU(weight);
CHECK_CPU(weight_index);
CHECK_INPUT(x.size(1) == weight.size(1));
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = weight_index.size(1); auto S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.options()); auto grad_basis = at::zeros({E, S}, grad_out.options());
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_b", [&] { auto weight_index_data = weight_index.data_ptr<int64_t>();
auto grad_out_data = grad_out.DATA_PTR<scalar_t>();
auto x_data = x.DATA_PTR<scalar_t>();
auto weight_data = weight.DATA_PTR<scalar_t>();
auto weight_index_data = weight_index.DATA_PTR<int64_t>();
auto grad_basis_data = grad_basis.DATA_PTR<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) { auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
for (int64_t e = 0; e < E; e++) {
for (int64_t m_out = 0; m_out < M_out; m_out++) {
auto g = auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)]; grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (ptrdiff_t s = 0; s < S; s++) { for (int64_t s = 0; s < S; s++) {
scalar_t b = 0; scalar_t b = 0;
auto wi = weight_index_data[e * S + s]; auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) { for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto w = auto w =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) + weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)]; m_out * weight.stride(2)];
...@@ -140,10 +191,3 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight, ...@@ -140,10 +191,3 @@ at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
return grad_basis; return grad_basis;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("weighting_fw", &weighting_fw, "Weighting Forward (CPU)");
m.def("weighting_bw_x", &weighting_bw_x, "Weighting Backward X (CPU)");
m.def("weighting_bw_w", &weighting_bw_w, "Weighting Backward Weight (CPU)");
m.def("weighting_bw_b", &weighting_bw_b, "Weighting Backward Basis (CPU)");
}
#pragma once
#include <torch/extension.h>
torch::Tensor spline_weighting_fw_cpu(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_x_cpu(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_weight_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size);
torch::Tensor spline_weighting_bw_basis_cpu(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index);
#pragma once
static inline __device__ void atomAdd(float *address, float val) {
atomicAdd(address, val);
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
static inline __device__ void atomAdd(double *address, double val) {
unsigned long long int *address_as_ull = (unsigned long long int *)address;
unsigned long long int old = *address_as_ull;
unsigned long long int assumed;
do {
assumed = old;
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
}
#else
static inline __device__ void atomAdd(double *address, double val) {
atomicAdd(address, val);
}
#endif
#include "basis_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t, int64_t degree> struct Basis {
static inline __device__ scalar_t forward(scalar_t v, int64_t k_mod) {
if (degree == 1) {
return 1. - v - k_mod + 2. * v * k_mod;
} else if (degree == 2) {
if (k_mod == 0)
return 0.5 * v * v - v + 0.5;
else if (k_mod == 1)
return -v * v + v + 0.5;
else
return 0.5 * v * v;
} else if (degree == 3) {
if (k_mod == 0)
return (1. - v) * (1. - v) * (1. - v) / 6.;
else if (k_mod == 1)
return (3. * v * v * v - 6. * v * v + 4.) / 6.;
else if (k_mod == 2)
return (-3. * v * v * v + 3. * v * v + 3. * v + 1.) / 6.;
else
return v * v * v / 6.;
} else {
return (scalar_t)-1.;
}
}
static inline __device__ scalar_t backward(scalar_t v, int64_t k_mod) {
if (degree == 1) {
return 2 * k_mod - 1;
} else if (degree == 2) {
if (k_mod == 0)
return v - 1.;
else if (k_mod == 1)
return -2. * v + 1.;
else
return v;
} else if (degree == 3) {
if (k_mod == 0)
return (-v * v + 2. * v - 1.) / 2.;
else if (k_mod == 1)
return (3. * v * v - 4. * v) / 2.;
else if (k_mod == 2)
return (-3. * v * v + 2. * v + 1.) / 2.;
else
return v * v / 2.;
} else {
return (scalar_t)-1.;
}
}
};
template <typename scalar_t, int64_t degree>
__global__ void
spline_basis_fw_kernel(const scalar_t *pseudo, const int64_t *kernel_size,
const uint8_t *is_open_spline, scalar_t *basis,
int64_t *weight_index, int64_t E, int64_t D, int64_t S,
int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / S;
const int64_t s = thread_idx % S;
if (thread_idx < numel) {
int64_t k = s, wi = 0, wi_offset = 1;
scalar_t b = (scalar_t)1.;
for (int64_t d = 0; d < D; d++) {
const int64_t k_mod = k % (degree + 1);
k /= degree + 1;
scalar_t v = pseudo[e * D + d];
v *= kernel_size[d] - degree * is_open_spline[d];
wi += (((int64_t)v + k_mod) % kernel_size[d]) * wi_offset;
wi_offset *= kernel_size[d];
v -= floor(v);
v = Basis<scalar_t, degree>::forward(v, k_mod);
b *= v;
}
basis[thread_idx] = b;
weight_index[thread_idx] = wi;
}
}
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree) {
CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline);
cudaSetDevice(pseudo.get_device());
CHECK_INPUT(kernel_size.dim() == 1);
CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
CHECK_INPUT(is_open_spline.dim());
CHECK_INPUT(pseudo.size(1) == is_open_spline.numel());
auto E = pseudo.size(0);
auto D = pseudo.size(1);
auto S = (int64_t)(powf(degree + 1, D) + 0.5);
auto basis = at::empty({E, S}, pseudo.options());
auto weight_index = at::empty({E, S}, kernel_size.options());
auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_fw", [&] {
auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] {
spline_basis_fw_kernel<scalar_t, DEGREE>
<<<BLOCKS(basis.numel()), THREADS, 0, stream>>>(
pseudo_data, kernel_size_data, is_open_spline_data, basis_data,
weight_index_data, E, D, S, basis.numel());
});
});
return std::make_tuple(basis, weight_index);
}
template <typename scalar_t, int64_t degree>
__global__ void
spline_basis_bw_kernel(const scalar_t *grad_basis, const scalar_t *pseudo,
const int64_t *kernel_size,
const uint8_t *is_open_spline, scalar_t *grad_pseudo,
int64_t E, int64_t D, int64_t S, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / D;
const int64_t d = thread_idx % D;
if (thread_idx < numel) {
scalar_t g = (scalar_t)0., tmp;
for (ptrdiff_t s = 0; s < S; s++) {
int64_t k_mod = (s / (int64_t)(powf(degree + 1, d) + 0.5)) % (degree + 1);
scalar_t v = pseudo[e * D + d];
v *= kernel_size[d] - degree * is_open_spline[d];
v -= floor(v);
v = Basis<scalar_t, degree>::backward(v, k_mod);
tmp = v;
for (int64_t d_it = 1; d_it < D; d_it++) {
const int64_t d_new = d_it - (d >= d_it);
k_mod = (s / (int64_t)(powf(degree + 1, d_new) + 0.5)) % (degree + 1);
v = pseudo[e * D + d_new];
v *= kernel_size[d_new] - degree * is_open_spline[d_new];
v -= floor(v);
v = Basis<scalar_t, degree>::forward(v, k_mod);
tmp *= v;
}
g += tmp * grad_basis[e * S + s];
}
g *= kernel_size[d] - degree * is_open_spline[d];
grad_pseudo[thread_idx] = g;
}
}
torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree) {
CHECK_CUDA(grad_basis);
CHECK_CUDA(pseudo);
CHECK_CUDA(kernel_size);
CHECK_CUDA(is_open_spline);
cudaSetDevice(grad_basis.get_device());
CHECK_INPUT(grad_basis.size(0) == pseudo.size(0));
CHECK_INPUT(kernel_size.dim() == 1);
CHECK_INPUT(pseudo.size(1) == kernel_size.numel());
CHECK_INPUT(is_open_spline.dim());
CHECK_INPUT(pseudo.size(1) == is_open_spline.numel());
auto E = pseudo.size(0);
auto D = pseudo.size(1);
auto S = grad_basis.size(1);
auto grad_pseudo = at::empty({E, D}, pseudo.options());
auto kernel_size_data = kernel_size.data_ptr<int64_t>();
auto is_open_spline_data = is_open_spline.data_ptr<uint8_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(pseudo.scalar_type(), "basis_bw", [&] {
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
auto pseudo_data = pseudo.data_ptr<scalar_t>();
auto grad_pseudo_data = grad_pseudo.data_ptr<scalar_t>();
AT_DISPATCH_DEGREE_TYPES(degree, [&] {
spline_basis_bw_kernel<scalar_t, DEGREE>
<<<BLOCKS(grad_pseudo.numel()), THREADS, 0, stream>>>(
grad_basis_data, pseudo_data, kernel_size_data,
is_open_spline_data, grad_pseudo_data, E, D, S,
grad_pseudo.numel());
});
});
return grad_pseudo;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::Tensor>
spline_basis_fw_cuda(torch::Tensor pseudo, torch::Tensor kernel_size,
torch::Tensor is_open_spline, int64_t degree);
torch::Tensor spline_basis_bw_cuda(torch::Tensor grad_basis,
torch::Tensor pseudo,
torch::Tensor kernel_size,
torch::Tensor is_open_spline,
int64_t degree);
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define AT_DISPATCH_DEGREE_TYPES(degree, ...) \
[&] { \
switch (degree) { \
case 1: { \
const int64_t DEGREE = 1; \
return __VA_ARGS__(); \
} \
case 2: { \
const int64_t DEGREE = 2; \
return __VA_ARGS__(); \
} \
case 3: { \
const int64_t DEGREE = 3; \
return __VA_ARGS__(); \
} \
default: \
AT_ERROR("Basis degree not implemented"); \
} \
}()
#include "weighting_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "atomics.cuh"
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t>
__global__ void
spline_weighting_fw_kernel(const scalar_t *x, const scalar_t *weight,
const scalar_t *basis, const int64_t *weight_index,
scalar_t *out, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / M_out;
const int64_t m_out = thread_idx % M_out;
if (thread_idx < numel) {
scalar_t v = (scalar_t)0.;
for (ptrdiff_t s = 0; s < S; s++) {
const scalar_t b = basis[e * S + s];
const int64_t wi = weight_index[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
scalar_t tmp = weight[wi * M_in * M_out + m_in * M_out + m_out];
tmp *= b * x[e * M_in + m_in];
v += tmp;
}
}
out[thread_idx] = v;
}
}
torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
CHECK_CUDA(x);
CHECK_CUDA(weight);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
cudaSetDevice(x.get_device());
CHECK_INPUT(x.size(1) == weight.size(1));
auto E = x.size(0);
auto M_in = x.size(1);
auto M_out = weight.size(2);
auto S = basis.size(1);
auto out = at::empty({E, M_out}, x.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_fw", [&] {
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
spline_weighting_fw_kernel<scalar_t>
<<<BLOCKS(out.numel()), THREADS, 0, stream>>>(
x_data, weight_data, basis_data, weight_index_data, out_data, E,
M_in, M_out, S, out.numel());
});
return out;
}
template <typename scalar_t>
__global__ void
spline_weighting_bw_x_kernel(const scalar_t *grad_out, const scalar_t *weight,
const scalar_t *basis, const int64_t *weight_index,
scalar_t *grad_x, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / M_in;
const int64_t m_in = thread_idx % M_in;
if (thread_idx < numel) {
scalar_t v = (scalar_t)0.;
for (int64_t s = 0; s < S; s++) {
const scalar_t b = basis[e * S + s];
const int64_t wi = weight_index[e * S + s];
for (int64_t m_out = 0; m_out < M_out; m_out++) {
scalar_t tmp = weight[wi * M_out * M_in + m_out * M_in + m_in];
tmp *= b * grad_out[e * M_out + m_out];
v += tmp;
}
}
grad_x[thread_idx] = v;
}
}
torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index) {
CHECK_CUDA(grad_out);
CHECK_CUDA(weight);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
cudaSetDevice(grad_out.get_device());
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = weight.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_x = at::zeros({E, M_in}, grad_out.options());
weight = weight.transpose(1, 2).contiguous(); // Contiguous memory-access.
auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(grad_out.scalar_type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto grad_x_data = grad_x.data_ptr<scalar_t>();
spline_weighting_bw_x_kernel<scalar_t>
<<<BLOCKS(grad_x.numel()), THREADS, 0, stream>>>(
grad_out_data, weight_data, basis_data, weight_index_data,
grad_x_data, E, M_in, M_out, S, grad_x.numel());
});
return grad_x;
}
template <typename scalar_t>
__global__ void spline_weighting_bw_weight_kernel(
const scalar_t *grad_out, const scalar_t *x, const scalar_t *basis,
const int64_t *weight_index, scalar_t *grad_weight, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) {
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / M_out;
const int64_t m_out = thread_idx % M_out;
if (thread_idx < numel) {
auto g = grad_out[e * M_out + m_out];
for (int64_t s = 0; s < S; s++) {
const scalar_t b = basis[e * S + s];
const int64_t wi = weight_index[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
auto v = g * b * x[e * M_in + m_in];
atomAdd(&grad_weight[wi * M_in * M_out + m_in * M_out + m_out], v);
}
}
}
}
torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size) {
CHECK_CUDA(grad_out);
CHECK_CUDA(x);
CHECK_CUDA(basis);
CHECK_CUDA(weight_index);
cudaSetDevice(grad_out.get_device());
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_weight = at::zeros({kernel_size, M_in, M_out}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_weight", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto basis_data = basis.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
spline_weighting_bw_weight_kernel<scalar_t>
<<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
grad_out_data, x_data, basis_data, weight_index_data,
grad_weight_data, E, M_in, M_out, S, grad_out.numel());
});
return grad_weight;
}
template <typename scalar_t>
__global__ void spline_weighting_bw_basis_kernel(
const scalar_t *grad_out, const scalar_t *x, const scalar_t *weight,
const int64_t *weight_index, scalar_t *grad_basis, int64_t E, int64_t M_in,
int64_t M_out, int64_t S, int64_t numel) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t e = thread_idx / M_out;
const int64_t m_out = thread_idx % M_out;
if (thread_idx < numel) {
const scalar_t g = grad_out[e * M_out + m_out];
for (int64_t s = 0; s < S; s++) {
scalar_t v = (scalar_t)0.;
const int64_t wi = weight_index[e * S + s];
for (int64_t m_in = 0; m_in < M_in; m_in++) {
const scalar_t w = weight[wi * M_in * M_out + m_in * M_out + m_out];
v += g * w * x[e * M_in + m_in];
}
atomAdd(&grad_basis[e * S + s], v);
}
}
}
torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index) {
CHECK_CUDA(grad_out);
CHECK_CUDA(x);
CHECK_CUDA(weight);
CHECK_CUDA(weight_index);
cudaSetDevice(grad_out.get_device());
CHECK_INPUT(x.size(1) == weight.size(1));
CHECK_INPUT(grad_out.size(1) == weight.size(2));
auto E = grad_out.size(0);
auto M_in = x.size(1);
auto M_out = grad_out.size(1);
auto S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.options());
auto weight_index_data = weight_index.data_ptr<int64_t>();
auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "weighting_bw_basis", [&] {
auto grad_out_data = grad_out.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto grad_basis_data = grad_basis.data_ptr<scalar_t>();
spline_weighting_bw_basis_kernel<scalar_t>
<<<BLOCKS(grad_out.numel()), THREADS, 0, stream>>>(
grad_out_data, x_data, weight_data, weight_index_data,
grad_basis_data, E, M_in, M_out, S, grad_out.numel());
});
return grad_basis;
}
#pragma once
#include <torch/extension.h>
torch::Tensor spline_weighting_fw_cuda(torch::Tensor x, torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_x_cuda(torch::Tensor grad_out,
torch::Tensor weight,
torch::Tensor basis,
torch::Tensor weight_index);
torch::Tensor spline_weighting_bw_weight_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor basis,
torch::Tensor weight_index,
int64_t kernel_size);
torch::Tensor spline_weighting_bw_basis_cuda(torch::Tensor grad_out,
torch::Tensor x,
torch::Tensor weight,
torch::Tensor weight_index);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment