"vscode:/vscode.git/clone" did not exist on "a59da36fe7ade3a2ced3555074ee9cee07f5d3f3"
Commit 7a86525b authored by rusty1s's avatar rusty1s
Browse files

new cpu extension api

parent 736354b1
...@@ -47,7 +47,7 @@ If you are running into any installation problems, please create an [issue](http ...@@ -47,7 +47,7 @@ If you are running into any installation problems, please create an [issue](http
```python ```python
from torch_spline_conv import SplineConv from torch_spline_conv import SplineConv
out = SplineConv.apply(src, out = SplineConv.apply(x,
edge_index, edge_index,
pseudo, pseudo,
weight, weight,
...@@ -73,7 +73,7 @@ The kernel function is defined over the weighted B-spline tensor product basis, ...@@ -73,7 +73,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
### Parameters ### Parameters
* **src** *(Tensor)* - Input node features of shape `(number_of_nodes x in_channels)`. * **x** *(Tensor)* - Input node features of shape `(number_of_nodes x in_channels)`.
* **edge_index** *(LongTensor)* - Graph edges, given by source and target indices, of shape `(2 x number_of_edges)`. * **edge_index** *(LongTensor)* - Graph edges, given by source and target indices, of shape `(2 x number_of_edges)`.
* **pseudo** *(Tensor)* - Edge attributes, ie. pseudo coordinates, of shape `(number_of_edges x number_of_edge_attributes)` in the fixed interval [0, 1]. * **pseudo** *(Tensor)* - Edge attributes, ie. pseudo coordinates, of shape `(number_of_edges x number_of_edge_attributes)` in the fixed interval [0, 1].
* **weight** *(Tensor)* - Trainable weight parameters of shape `(kernel_size x in_channels x out_channels)`. * **weight** *(Tensor)* - Trainable weight parameters of shape `(kernel_size x in_channels x out_channels)`.
...@@ -94,7 +94,7 @@ The kernel function is defined over the weighted B-spline tensor product basis, ...@@ -94,7 +94,7 @@ The kernel function is defined over the weighted B-spline tensor product basis,
import torch import torch
from torch_spline_conv import SplineConv from torch_spline_conv import SplineConv
src = torch.rand((4, 2), dtype=torch.float) # 4 nodes with 2 features each x = torch.rand((4, 2), dtype=torch.float) # 4 nodes with 2 features each
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) # 6 edges edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) # 6 edges
pseudo = torch.rand((6, 2), dtype=torch.float) # two-dimensional edge attributes pseudo = torch.rand((6, 2), dtype=torch.float) # two-dimensional edge attributes
weight = torch.rand((25, 2, 4), dtype=torch.float) # 25 parameters for in_channels x out_channels weight = torch.rand((25, 2, 4), dtype=torch.float) # 25 parameters for in_channels x out_channels
...@@ -105,7 +105,7 @@ norm = True # Normalize output by node degree. ...@@ -105,7 +105,7 @@ norm = True # Normalize output by node degree.
root_weight = torch.rand((2, 4), dtype=torch.float) # separately weight root nodes root_weight = torch.rand((2, 4), dtype=torch.float) # separately weight root nodes
bias = None # do not apply an additional bias bias = None # do not apply an additional bias
out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size, out = SplineConv.apply(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, degree, norm, root_weight, bias) is_open_spline, degree, norm, root_weight, bias)
print(out.size()) print(out.size())
......
| Library | Meaning |
|---------|------------------------------------|
| TH | **T**orc**H** |
| THC | **T**orc**H** **C**uda |
| THCC | **T**orc**H** **C**uda **C**onnect |
#include <TH/TH.h>
#define TH_TENSOR_BASIS_FORWARD(M, basis, weightIndex, pseudo, kernelSize, isOpenSpline, CODE) { \
real *basisData = THTensor_(data)(basis); \
int64_t *weightIndexData = THLongTensor_data(weightIndex); \
real *pseudoData = THTensor_(data)(pseudo); \
int64_t *kernelSizeData = THLongTensor_data(kernelSize); \
uint8_t *isOpenSplineData = THByteTensor_data(isOpenSpline); \
\
ptrdiff_t e, s, d; \
int64_t k, kMod, wi, wiOffset; \
real b, v; \
for (e = 0; e < THTensor_(size)(pseudo, 0); e++) { \
for (s = 0; s < THTensor_(size)(basis, 1); s++) { \
k = s; b = 1; wi = 0; wiOffset = 1; \
for (d = 0; d < THTensor_(size)(pseudo, 1); d++) { \
kMod = k % (M + 1); \
k /= M + 1; \
\
v = pseudoData[e * pseudo->stride[0] + d * pseudo->stride[1]]; \
v *= kernelSizeData[d] - M * isOpenSplineData[d]; \
\
wi += (((int64_t) v + kMod) % kernelSizeData[d]) * wiOffset; \
wiOffset *= kernelSizeData[d]; \
\
v -= floor(v); \
v = CODE; \
b *= v; \
} \
basisData[e * basis->stride[0] + s * basis->stride[1]] = b; \
weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]] = wi; \
} \
} \
}
#define TH_TENSOR_BASIS_BACKWARD(M, self, gradBasis, pseudo, kernelSize, isOpenSpline, CODE, \
GRAD_CODE) { \
real *selfData = THTensor_(data)(self); \
real *gradBasisData = THTensor_(data)(gradBasis); \
real *pseudoData = THTensor_(data)(pseudo); \
int64_t *kernelSizeData = THLongTensor_data(kernelSize); \
uint8_t *isOpenSplineData = THByteTensor_data(isOpenSpline); \
\
ptrdiff_t e, d, s, dIt, dOther; \
int64_t kMod; \
real g, v, tmp; \
\
for (e = 0; e < THTensor_(size)(pseudo, 0); e++) { \
for (d = 0; d < THTensor_(size)(pseudo, 1); d++) { \
g = 0; \
for (s = 0; s < THTensor_(size)(gradBasis, 1); s++) { \
kMod = (s / (ptrdiff_t) pow(M + 1, d)) % (M + 1); \
v = pseudoData[e * pseudo->stride[0] + d * pseudo->stride[1]]; \
v *= kernelSizeData[d] - M * isOpenSplineData[d]; \
v -= floor(v); \
v = GRAD_CODE; \
tmp = v; \
\
for (dIt = 1; dIt < THTensor_(size)(pseudo, 1); dIt++) { \
dOther = dIt - (d >= dIt); \
kMod = (s / (ptrdiff_t) pow(M + 1, dOther)) % (M + 1); \
v = pseudoData[e * pseudo->stride[0] + dOther * pseudo->stride[1]]; \
v *= kernelSizeData[dOther] - M * isOpenSplineData[dOther]; \
v -= floor(v); \
v = CODE; \
tmp *= v; \
} \
\
g += tmp * gradBasisData[e * gradBasis->stride[0] + s * gradBasis->stride[1]]; \
} \
g *= kernelSizeData[d] - M * isOpenSplineData[d]; \
selfData[e * self->stride[0] + d * self->stride[1]] = g; \
} \
} \
}
#include "generic/THBasis.c"
#include "THGenerateFloatTypes.h"
void THFloatTensor_linearBasisForward( THFloatTensor *basis, THLongTensor *weightIndex, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_linearBasisForward(THDoubleTensor *basis, THLongTensor *weightIndex, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_quadraticBasisForward( THFloatTensor *basis, THLongTensor *weightIndex, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_quadraticBasisForward(THDoubleTensor *basis, THLongTensor *weightIndex, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_cubicBasisForward( THFloatTensor *basis, THLongTensor *weightIndex, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_cubicBasisForward(THDoubleTensor *basis, THLongTensor *weightIndex, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_linearBasisBackward( THFloatTensor *self, THFloatTensor *gradBasis, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_linearBasisBackward(THDoubleTensor *self, THDoubleTensor *gradBasis, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_quadraticBasisBackward( THFloatTensor *self, THFloatTensor *gradBasis, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_quadraticBasisBackward(THDoubleTensor *self, THDoubleTensor *gradBasis, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THFloatTensor_cubicBasisBackward( THFloatTensor *self, THFloatTensor *gradBasis, THFloatTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
void THDoubleTensor_cubicBasisBackward(THDoubleTensor *self, THDoubleTensor *gradBasis, THDoubleTensor *pseudo, THLongTensor *kernelSize, THByteTensor *isOpenSpline);
#include <TH/TH.h>
#include "generic/THWeighting.c"
#include "THGenerateFloatTypes.h"
void THFloatTensor_weightingForward( THFloatTensor *self, THFloatTensor *src, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_weightingForward(THDoubleTensor *self, THDoubleTensor *src, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weightIndex);
void THFloatTensor_weightingBackwardSrc( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *weight, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_weightingBackwardSrc(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *weight, THDoubleTensor *basis, THLongTensor *weightIndex);
void THFloatTensor_weightingBackwardWeight( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *src, THFloatTensor *basis, THLongTensor *weightIndex);
void THDoubleTensor_weightingBackwardWeight(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *src, THDoubleTensor *basis, THLongTensor *weightIndex);
void THFloatTensor_weightingBackwardBasis( THFloatTensor *self, THFloatTensor *gradOutput, THFloatTensor *src, THFloatTensor *weight, THLongTensor *weightIndex);
void THDoubleTensor_weightingBackwardBasis(THDoubleTensor *self, THDoubleTensor *gradOutput, THDoubleTensor *src, THDoubleTensor *weight, THLongTensor *weightIndex);
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/THBasis.c"
#else
inline real THTensor_(linear)(real v, int64_t kMod) {
return 1 - v - kMod + 2 * v * kMod;
}
inline real THTensor_(gradLinear)(real v, int64_t kMod) {
return 2 * kMod - 1;
}
inline real THTensor_(quadratic)(real v, int64_t kMod) {
if (kMod == 0) return 0.5 * v * v - v + 0.5;
else if (kMod == 1) return -v * v + v + 0.5;
else return 0.5 * v * v;
}
inline real THTensor_(gradQuadratic)(real v, int64_t kMod) {
if (kMod == 0) return v - 1;
else if (kMod == 1) return -2 * v + 1;
else return v;
}
inline real THTensor_(cubic)(real v, int64_t kMod) {
if (kMod == 0) { v = (1 - v); return v * v * v / 6.0; }
else if (kMod == 1) return (3 * v * v * v - 6 * v * v + 4) / 6;
else if (kMod == 2) return (-3 * v * v * v + 3 * v * v + 3 * v + 1) / 6;
else return v * v * v / 6;
}
inline real THTensor_(gradCubic)(real v, int64_t kMod) {
if (kMod == 0) return (-v * v + 2 * v - 1) / 2;
else if (kMod == 1) return (3 * v * v - 4 * v) / 2;
else if (kMod == 2) return (-3 * v * v + 2 * v + 1) / 2;
else return v * v / 2;
}
void THTensor_(linearBasisForward)(THTensor *basis, THLongTensor *weightIndex, THTensor *pseudo,
THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
TH_TENSOR_BASIS_FORWARD(1, basis, weightIndex, pseudo, kernelSize, isOpenSpline,
THTensor_(linear)(v, kMod))
}
void THTensor_(quadraticBasisForward)(THTensor *basis, THLongTensor *weightIndex, THTensor *pseudo,
THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
TH_TENSOR_BASIS_FORWARD(2, basis, weightIndex, pseudo, kernelSize, isOpenSpline,
THTensor_(quadratic)(v, kMod))
}
void THTensor_(cubicBasisForward)(THTensor *basis, THLongTensor *weightIndex, THTensor *pseudo,
THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
TH_TENSOR_BASIS_FORWARD(3, basis, weightIndex, pseudo, kernelSize, isOpenSpline,
THTensor_(cubic)(v, kMod))
}
void THTensor_(linearBasisBackward)(THTensor *self, THTensor *gradBasis, THTensor *pseudo,
THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
TH_TENSOR_BASIS_BACKWARD(1, self, gradBasis, pseudo, kernelSize, isOpenSpline,
THTensor_(linear)(v, kMod), THTensor_(gradLinear)(v, kMod))
}
void THTensor_(quadraticBasisBackward)(THTensor *self, THTensor *gradBasis, THTensor *pseudo,
THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
TH_TENSOR_BASIS_BACKWARD(2, self, gradBasis, pseudo, kernelSize, isOpenSpline,
THTensor_(quadratic)(v, kMod), THTensor_(gradQuadratic)(v, kMod))
}
void THTensor_(cubicBasisBackward)(THTensor *self, THTensor *gradBasis, THTensor *pseudo,
THLongTensor *kernelSize, THByteTensor *isOpenSpline) {
TH_TENSOR_BASIS_BACKWARD(3, self, gradBasis, pseudo, kernelSize, isOpenSpline,
THTensor_(cubic)(v, kMod), THTensor_(gradCubic)(v, kMod))
}
#endif // TH_GENERIC_FILE
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/THWeighting.c"
#else
void THTensor_(weightingForward)(THTensor *self, THTensor *src, THTensor *weight, THTensor *basis,
THLongTensor *weightIndex) {
real *selfData = THTensor_(data)(self);
real *srcData = THTensor_(data)(src);
real *weightData = THTensor_(data)(weight);
real *basisData = THTensor_(data)(basis);
int64_t *weightIndexData = THLongTensor_data(weightIndex);
ptrdiff_t e, mOut, s, mIn;
real v, b, tmp;
int64_t wi;
for (e = 0; e < THTensor_(size)(src, 0); e++) {
for (mOut = 0; mOut < THTensor_(size)(self, 1); mOut++) {
v = 0;
for (s = 0; s < THTensor_(size)(basis, 1); s++) {
b = basisData[e * basis->stride[0] + s * basis->stride[1]];
wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
for (mIn = 0; mIn < THTensor_(size)(src, 1); mIn++) {
tmp = weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]];
tmp *= b * srcData[e * src->stride[0] + mIn * src->stride[1]];
v += tmp;
}
}
selfData[e * self->stride[0] + mOut * self->stride[1]] = v;
}
}
}
void THTensor_(weightingBackwardSrc)(THTensor *self, THTensor *gradOutput, THTensor *weight,
THTensor *basis, THLongTensor *weightIndex) {
THTensor_(fill)(self, 0);
real *selfData = THTensor_(data)(self);
real *gradOutputData = THTensor_(data)(gradOutput);
real *weightData = THTensor_(data)(weight);
real *basisData = THTensor_(data)(basis);
int64_t *weightIndexData = THLongTensor_data(weightIndex);
ptrdiff_t e, mOut, s, mIn;
real g, b, v;
int64_t wi;
for (e = 0; e < THTensor_(size)(self, 0); e++) {
for (mOut = 0; mOut < THTensor_(size)(gradOutput, 1); mOut++) {
g = gradOutputData[e * gradOutput->stride[0] + mOut * gradOutput->stride[1]];
for (s = 0; s < THTensor_(size)(basis, 1); s++) {
b = basisData[e * basis->stride[0] + s * basis->stride[1]];
wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
for (mIn = 0; mIn < THTensor_(size)(self, 1); mIn++) {
v = weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]];
selfData[e * self->stride[0] + mIn * self->stride[1]] += g * b * v;
}
}
}
}
}
void THTensor_(weightingBackwardWeight)(THTensor *self, THTensor *gradOutput, THTensor *src,
THTensor *basis, THLongTensor *weightIndex) {
THTensor_(fill)(self, 0);
real *selfData = THTensor_(data)(self);
real *gradOutputData = THTensor_(data)(gradOutput);
real *srcData = THTensor_(data)(src);
real *basisData = THTensor_(data)(basis);
int64_t *weightIndexData = THLongTensor_data(weightIndex);
ptrdiff_t e, mOut, s, mIn;
real g, b, v;
int64_t wi;
for (e = 0; e < THTensor_(size)(src, 0); e++) {
for (mOut = 0; mOut < THTensor_(size)(gradOutput, 1); mOut++) {
g = gradOutputData[e * gradOutput->stride[0] + mOut * gradOutput->stride[1]];
for (s = 0; s < THTensor_(size)(basis, 1); s++) {
b = basisData[e * basis->stride[0] + s * basis->stride[1]];
wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
for (mIn = 0; mIn < THTensor_(size)(src, 1); mIn++) {
v = b * g * srcData[e * src->stride[0] + mIn * src->stride[1]];
selfData[wi * self->stride[0] + mIn * self->stride[1] + mOut * self->stride[2]] += v;
}
}
}
}
}
void THTensor_(weightingBackwardBasis)(THTensor *self, THTensor *gradOutput, THTensor *src,
THTensor *weight, THLongTensor *weightIndex) {
THTensor_(fill)(self, 0);
real *selfData = THTensor_(data)(self);
real *gradOutputData = THTensor_(data)(gradOutput);
real *srcData = THTensor_(data)(src);
real *weightData = THTensor_(data)(weight);
int64_t *weightIndexData = THLongTensor_data(weightIndex);
ptrdiff_t e, mOut, s, mIn;
real g, v, tmp;
int64_t wi;
for (e = 0; e < THTensor_(size)(src, 0); e++) {
for (mOut = 0; mOut < THTensor_(size)(gradOutput, 1); mOut++) {
g = gradOutputData[e * gradOutput->stride[0] + mOut * gradOutput->stride[1]];
for (s = 0; s < THLongTensor_size(weightIndex, 1); s++) {
v = 0;
wi = weightIndexData[e * weightIndex->stride[0] + s * weightIndex->stride[1]];
for (mIn = 0; mIn < THTensor_(size)(src, 1); mIn++) {
tmp = weightData[wi * weight->stride[0] + mIn * weight->stride[1] + mOut * weight->stride[2]];
tmp *= srcData[e * src->stride[0] + mIn * src->stride[1]];
v += tmp;
}
selfData[e * self->stride[0] + s * self->stride[1]] += g * v;
}
}
}
}
#endif // TH_GENERIC_FILE
#include <torch/torch.h>
template <typename scalar> inline scalar linear(scalar v, int64_t k_mod) {
return 1 - v - k_mod + 2 * v * k_mod;
}
template <typename scalar> inline scalar quadratic(scalar v, int64_t k_mod) {
if (k_mod == 0)
return 0.5 * v * v - v + 0.5;
else if (k_mod == 1)
return -v * v + v + 0.5;
else
return 0.5 * v * v;
}
template <typename scalar> inline scalar cubic(scalar v, int64_t k_mod) {
if (k_mod == 0) {
return (1 - v) * (1 - v) * (1 - v) / 6.0;
} else if (k_mod == 1)
return (3 * v * v * v - 6 * v * v + 4) / 6;
else if (k_mod == 2)
return (-3 * v * v * v + 3 * v * v + 3 * v + 1) / 6;
else
return v * v * v / 6;
}
template <typename scalar> inline scalar grad_linear(scalar v, int64_t k_mod) {
return 2 * k_mod - 1;
}
template <typename scalar>
inline scalar grad_quadratic(scalar v, int64_t k_mod) {
if (k_mod == 0)
return v - 1;
else if (k_mod == 1)
return -2 * v + 1;
else
return v;
}
template <typename scalar> inline scalar grad_cubic(scalar v, int64_t k_mod) {
if (k_mod == 0)
return (-v * v + 2 * v - 1) / 2;
else if (k_mod == 1)
return (3 * v * v - 4 * v) / 2;
else if (k_mod == 2)
return (-3 * v * v + 2 * v + 1) / 2;
else
return v * v / 2;
}
#define BASIS_FORWARD(M, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, FUNC) \
[&]() -> std::tuple<at::Tensor, at::Tensor> { \
auto E = PSEUDO.size(0), D = PSEUDO.size(1); \
auto S = (int64_t)(pow(M + 1, KERNEL_SIZE.size(0)) + 0.5); \
auto basis = at::empty({E, S}, PSEUDO.type()); \
auto weight_index = at::empty({E, S}, KERNEL_SIZE.type()); \
\
AT_DISPATCH_FLOATING_TYPES(PSEUDO.type(), "basis_forward_##M", [&] { \
auto pseudo_data = PSEUDO.data<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.data<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.data<uint8_t>(); \
auto basis_data = basis.data<scalar_t>(); \
auto weight_index_data = weight_index.data<int64_t>(); \
\
int64_t k, wi, wi_offset; \
scalar_t b; \
\
for (ptrdiff_t e = 0; e < E; e++) { \
for (ptrdiff_t s = 0; s < S; s++) { \
k = s; \
wi = 0; \
wi_offset = 1; \
b = 1; \
for (ptrdiff_t d = 0; d < D; d++) { \
auto k_mod = k % (M + 1); \
k /= M + 1; \
\
auto v = pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)]; \
v *= kernel_size_data[d] - M * is_open_spline_data[d]; \
\
wi += (((int64_t)v + k_mod) % kernel_size_data[d]) * wi_offset; \
wi_offset *= kernel_size_data[d]; \
\
v -= floor(v); \
v = FUNC<scalar_t>(v, k_mod); \
b *= v; \
} \
basis_data[e * S + s] = b; \
weight_index_data[e * S + s] = wi; \
} \
} \
}); \
return std::make_tuple(basis, weight_index); \
}()
#define BASIS_BACKWARD(M, GRAD_BASIS, PSEUDO, KERNEL_SIZE, IS_OPEN_SPLINE, \
FUNC, GRAD_FUNC) \
[&]() -> at::Tensor { \
auto E = PSEUDO.size(0), D = PSEUDO.size(1); \
auto S = GRAD_BASIS.size(1); \
auto grad_pseudo = at::empty({E, D}, PSEUDO.type()); \
\
AT_DISPATCH_FLOATING_TYPES(PSEUDO.type(), "basis_backward_##M", [&] { \
auto grad_basis_data = GRAD_BASIS.data<scalar_t>(); \
auto pseudo_data = PSEUDO.data<scalar_t>(); \
auto kernel_size_data = KERNEL_SIZE.data<int64_t>(); \
auto is_open_spline_data = IS_OPEN_SPLINE.data<uint8_t>(); \
auto grad_pseudo_data = grad_pseudo.data<scalar_t>(); \
\
scalar_t g, tmp; \
\
for (ptrdiff_t e = 0; e < E; e++) { \
for (ptrdiff_t d = 0; d < D; d++) { \
g = 0; \
for (ptrdiff_t s = 0; s < S; s++) { \
auto k_mod = (s / (int64_t)(pow(M + 1, d) + 0.5)) % (M + 1); \
auto v = pseudo_data[e * pseudo.stride(0) + d * pseudo.stride(1)]; \
v *= kernel_size_data[d] - M * is_open_spline_data[d]; \
v -= floor(v); \
v = GRAD_FUNC<scalar_t>(v, k_mod); \
tmp = v; \
\
for (ptrdiff_t d_it = 1; d_it < D; d_it++) { \
auto d_other = d_it - (d >= d_it); \
k_mod = (s / (int64_t)(pow(M + 1, d_other) + 0.5)) % (M + 1); \
v = pseudo_data[e * pseudo.stride(0) + \
d_other * pseudo.stride(1)]; \
v *= kernel_size_data[d_other] - \
M * is_open_spline_data[d_other]; \
v -= floor(v); \
v = FUNC<scalar_t>(v, k_mod); \
tmp *= v; \
} \
g += tmp * grad_basis_data[e * grad_basis.stride(0) + \
s * grad_basis.stride(1)]; \
} \
g *= kernel_size_data[d] - M * is_open_spline_data[d]; \
grad_pseudo_data[e * D + d] = g; \
} \
} \
}); \
return grad_pseudo; \
}()
std::tuple<at::Tensor, at::Tensor> linear_fw(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline) {
return BASIS_FORWARD(1, pseudo, kernel_size, is_open_spline, linear);
}
std::tuple<at::Tensor, at::Tensor> quadratic_fw(at::Tensor pseudo,
at::Tensor kernel_size,
at::Tensor is_open_spline) {
return BASIS_FORWARD(2, pseudo, kernel_size, is_open_spline, quadratic);
}
std::tuple<at::Tensor, at::Tensor>
cubic_fw(at::Tensor pseudo, at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_FORWARD(3, pseudo, kernel_size, is_open_spline, cubic);
}
at::Tensor linear_bw(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_BACKWARD(1, grad_basis, pseudo, kernel_size, is_open_spline,
linear, grad_linear);
}
at::Tensor quadratic_bw(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_BACKWARD(2, grad_basis, pseudo, kernel_size, is_open_spline,
quadratic, grad_quadratic);
}
at::Tensor cubic_bw(at::Tensor grad_basis, at::Tensor pseudo,
at::Tensor kernel_size, at::Tensor is_open_spline) {
return BASIS_BACKWARD(3, grad_basis, pseudo, kernel_size, is_open_spline,
cubic, grad_cubic);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_fw", &linear_fw, "Linear Basis Forward (CPU)");
m.def("quadratic_fw", &quadratic_fw, "Quadratic Basis Forward (CPU)");
m.def("cubic_fw", &cubic_fw, "Cubic Basis Forward (CPU)");
m.def("linear_bw", &linear_bw, "Linear Basis Backward (CPU)");
m.def("quadratic_bw", &quadratic_bw, "Quadratic Basis Backward (CPU)");
m.def("cubic_bw", &cubic_bw, "Cubic Basis Backward (CPU)");
}
#include <torch/torch.h>
at::Tensor weighting_fw(at::Tensor x, at::Tensor weight, at::Tensor basis,
at::Tensor weight_index) {
auto E = x.size(0), M_in = x.size(1), M_out = weight.size(2);
auto S = basis.size(1);
auto out = at::empty({E, M_out}, x.type());
AT_DISPATCH_FLOATING_TYPES(out.type(), "weighting_fw", [&] {
auto x_data = x.data<scalar_t>();
auto weight_data = weight.data<scalar_t>();
auto basis_data = basis.data<scalar_t>();
auto weight_index_data = weight_index.data<int64_t>();
auto out_data = out.data<scalar_t>();
scalar_t v;
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
v = 0;
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
auto tmp =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
tmp *= b * x_data[e * x.stride(0) + m_in * x.stride(1)];
v += tmp;
}
}
out_data[e * M_out + m_out] = v;
}
}
});
return out;
}
at::Tensor weighting_bw_x(at::Tensor grad_out, at::Tensor weight,
at::Tensor basis, at::Tensor weight_index) {
auto E = grad_out.size(0), M_in = weight.size(1), M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_x = at::zeros({E, M_in}, grad_out.type());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_x", [&] {
auto grad_out_data = grad_out.data<scalar_t>();
auto weight_data = weight.data<scalar_t>();
auto basis_data = basis.data<scalar_t>();
auto weight_index_data = weight_index.data<int64_t>();
auto grad_x_data = grad_x.data<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
auto w =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
grad_x_data[e * M_in + m_in] += g * b * w;
}
}
}
}
});
return grad_x;
}
at::Tensor weighting_bw_w(at::Tensor grad_out, at::Tensor x, at::Tensor basis,
at::Tensor weight_index, int64_t K) {
auto E = grad_out.size(0), M_in = x.size(1), M_out = grad_out.size(1);
auto S = basis.size(1);
auto grad_weight = at::zeros({K, M_in, M_out}, grad_out.type());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_w", [&] {
auto grad_out_data = grad_out.data<scalar_t>();
auto x_data = x.data<scalar_t>();
auto basis_data = basis.data<scalar_t>();
auto weight_index_data = weight_index.data<int64_t>();
auto grad_weight_data = grad_weight.data<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (ptrdiff_t s = 0; s < S; s++) {
auto b = basis_data[e * S + s];
auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
auto v = g * b * x_data[e * x.stride(0) + m_in * x.stride(1)];
grad_weight_data[wi * M_in * M_out + m_in * M_out + m_out] += v;
}
}
}
}
});
return grad_weight;
}
at::Tensor weighting_bw_b(at::Tensor grad_out, at::Tensor x, at::Tensor weight,
at::Tensor weight_index) {
auto E = grad_out.size(0), M_in = x.size(1), M_out = grad_out.size(1);
auto S = weight_index.size(1);
auto grad_basis = at::zeros({E, S}, grad_out.type());
AT_DISPATCH_FLOATING_TYPES(grad_out.type(), "weighting_bw_b", [&] {
auto grad_out_data = grad_out.data<scalar_t>();
auto x_data = x.data<scalar_t>();
auto weight_data = weight.data<scalar_t>();
auto weight_index_data = weight_index.data<int64_t>();
auto grad_basis_data = grad_basis.data<scalar_t>();
for (ptrdiff_t e = 0; e < E; e++) {
for (ptrdiff_t m_out = 0; m_out < M_out; m_out++) {
auto g =
grad_out_data[e * grad_out.stride(0) + m_out * grad_out.stride(1)];
for (ptrdiff_t s = 0; s < S; s++) {
scalar_t b = 0;
auto wi = weight_index_data[e * S + s];
for (ptrdiff_t m_in = 0; m_in < M_in; m_in++) {
auto w =
weight_data[wi * weight.stride(0) + m_in * weight.stride(1) +
m_out * weight.stride(2)];
w *= x_data[e * x.stride(0) + m_in * x.stride(1)];
b += w;
}
grad_basis_data[e * S + s] += g * b;
}
}
}
});
return grad_basis;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("weighting_fw", &weighting_fw, "Weighting Forward (CPU)");
m.def("weighting_bw_x", &weighting_bw_x, "Weighting Backward X (CPU)");
m.def("weighting_bw_w", &weighting_bw_w, "Weighting Backward W (CPU)");
m.def("weighting_bw_b", &weighting_bw_b, "Weighting Backward B (CPU)");
}
from os import path as osp
from setuptools import setup, find_packages from setuptools import setup, find_packages
import torch
from torch.utils.cpp_extension import CppExtension, CUDAExtension
ext_modules = [
CppExtension('basis_cpu', ['cpu/basis.cpp']),
CppExtension('weighting_cpu', ['cpu/weighting.cpp']),
]
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
__version__ = '1.0.4' __version__ = '1.0.4'
url = 'https://github.com/rusty1s/pytorch_spline_conv' url = 'https://github.com/rusty1s/pytorch_spline_conv'
install_requires = ['cffi'] install_requires = []
setup_requires = ['pytest-runner', 'cffi'] setup_requires = ['pytest-runner']
tests_require = ['pytest', 'pytest-cov'] tests_require = ['pytest', 'pytest-cov']
setup( setup(
...@@ -25,7 +31,7 @@ setup( ...@@ -25,7 +31,7 @@ setup(
install_requires=install_requires, install_requires=install_requires,
setup_requires=setup_requires, setup_requires=setup_requires,
tests_require=tests_require, tests_require=tests_require,
packages=find_packages(exclude=['build']), ext_modules=ext_modules,
ext_package='', cmdclass=cmdclass,
cffi_modules=[osp.join(osp.dirname(__file__), 'build.py:ffi')], packages=find_packages(),
) )
...@@ -6,6 +6,8 @@ from torch_spline_conv.basis import SplineBasis ...@@ -6,6 +6,8 @@ from torch_spline_conv.basis import SplineBasis
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
devices = [torch.device('cpu')]
tests = [{ tests = [{
'pseudo': [[0], [0.0625], [0.25], [0.75], [0.9375], [1]], 'pseudo': [[0], [0.0625], [0.25], [0.75], [0.9375], [1]],
'kernel_size': [5], 'kernel_size': [5],
...@@ -32,7 +34,9 @@ def test_spline_basis_forward(test, dtype, device): ...@@ -32,7 +34,9 @@ def test_spline_basis_forward(test, dtype, device):
pseudo = tensor(test['pseudo'], dtype, device) pseudo = tensor(test['pseudo'], dtype, device)
kernel_size = tensor(test['kernel_size'], torch.long, device) kernel_size = tensor(test['kernel_size'], torch.long, device)
is_open_spline = tensor(test['is_open_spline'], torch.uint8, device) is_open_spline = tensor(test['is_open_spline'], torch.uint8, device)
degree = 1
basis, weight_idx = SplineBasis.apply(pseudo, kernel_size, is_open_spline) op = SplineBasis.apply
basis, weight_index = op(pseudo, kernel_size, is_open_spline, degree)
assert basis.tolist() == test['basis'] assert basis.tolist() == test['basis']
assert weight_idx.tolist() == test['weight_index'] assert weight_index.tolist() == test['weight_index']
...@@ -4,12 +4,14 @@ import pytest ...@@ -4,12 +4,14 @@ import pytest
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_spline_conv import SplineConv from torch_spline_conv import SplineConv
from torch_spline_conv.utils.ffi import implemented_degrees as degrees from torch_spline_conv.basis import implemented_degrees as degrees
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
devices = [torch.device('cpu')]
tests = [{ tests = [{
'src': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]], 'x': [[9, 10], [1, 2], [3, 4], [5, 6], [7, 8]],
'edge_index': [[0, 0, 0, 0], [1, 2, 3, 4]], 'edge_index': [[0, 0, 0, 0], [1, 2, 3, 4]],
'pseudo': [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]], 'pseudo': [[0.25, 0.125], [0.25, 0.375], [0.75, 0.625], [0.75, 0.875]],
'weight': [ 'weight': [
...@@ -42,7 +44,7 @@ tests = [{ ...@@ -42,7 +44,7 @@ tests = [{
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spline_conv_forward(test, dtype, device): def test_spline_conv_forward(test, dtype, device):
src = tensor(test['src'], dtype, device) x = tensor(test['x'], dtype, device)
edge_index = tensor(test['edge_index'], torch.long, device) edge_index = tensor(test['edge_index'], torch.long, device)
pseudo = tensor(test['pseudo'], dtype, device) pseudo = tensor(test['pseudo'], dtype, device)
weight = tensor(test['weight'], dtype, device) weight = tensor(test['weight'], dtype, device)
...@@ -51,15 +53,15 @@ def test_spline_conv_forward(test, dtype, device): ...@@ -51,15 +53,15 @@ def test_spline_conv_forward(test, dtype, device):
root_weight = tensor(test['root_weight'], dtype, device) root_weight = tensor(test['root_weight'], dtype, device)
bias = tensor(test['bias'], dtype, device) bias = tensor(test['bias'], dtype, device)
out = SplineConv.apply(src, edge_index, pseudo, weight, kernel_size, out = SplineConv.apply(x, edge_index, pseudo, weight, kernel_size,
is_open_spline, 1, True, root_weight, bias) is_open_spline, 1, True, root_weight, bias)
assert out.tolist() == test['expected'] assert out.tolist() == test['expected']
@pytest.mark.parametrize('degree,device', product(degrees.keys(), devices)) @pytest.mark.parametrize('degree,device', product(degrees.keys(), devices))
def test_spline_basis_backward(degree, device): def test_spline_basis_backward(degree, device):
src = torch.rand((3, 2), dtype=torch.double, device=device) x = torch.rand((3, 2), dtype=torch.double, device=device)
src.requires_grad_() x.requires_grad_()
edge_index = tensor([[0, 1, 1, 2], [1, 0, 2, 1]], torch.long, device) edge_index = tensor([[0, 1, 1, 2], [1, 0, 2, 1]], torch.long, device)
pseudo = torch.rand((4, 3), dtype=torch.double, device=device) pseudo = torch.rand((4, 3), dtype=torch.double, device=device)
pseudo.requires_grad_() pseudo.requires_grad_()
...@@ -72,6 +74,6 @@ def test_spline_basis_backward(degree, device): ...@@ -72,6 +74,6 @@ def test_spline_basis_backward(degree, device):
bias = torch.rand((4), dtype=torch.double, device=device) bias = torch.rand((4), dtype=torch.double, device=device)
bias.requires_grad_() bias.requires_grad_()
data = (src, edge_index, pseudo, weight, kernel_size, is_open_spline, data = (x, edge_index, pseudo, weight, kernel_size, is_open_spline, degree,
degree, True, root_weight, bias) True, root_weight, bias)
assert gradcheck(SplineConv.apply, data, eps=1e-6, atol=1e-4) is True assert gradcheck(SplineConv.apply, data, eps=1e-6, atol=1e-4) is True
...@@ -8,8 +8,10 @@ from torch_spline_conv.basis import SplineBasis ...@@ -8,8 +8,10 @@ from torch_spline_conv.basis import SplineBasis
from .utils import dtypes, devices, tensor from .utils import dtypes, devices, tensor
devices = [torch.device('cpu')]
tests = [{ tests = [{
'src': [[1, 2], [3, 4]], 'x': [[1, 2], [3, 4]],
'weight': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]], 'weight': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]],
'basis': [[0.5, 0, 0.5, 0], [0, 0, 0.5, 0.5]], 'basis': [[0.5, 0, 0.5, 0], [0, 0, 0.5, 0.5]],
'weight_index': [[0, 1, 2, 3], [0, 1, 2, 3]], 'weight_index': [[0, 1, 2, 3], [0, 1, 2, 3]],
...@@ -22,28 +24,30 @@ tests = [{ ...@@ -22,28 +24,30 @@ tests = [{
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_spline_weighting_forward(test, dtype, device): def test_spline_weighting_forward(test, dtype, device):
src = tensor(test['src'], dtype, device) x = tensor(test['x'], dtype, device)
weight = tensor(test['weight'], dtype, device) weight = tensor(test['weight'], dtype, device)
basis = tensor(test['basis'], dtype, device) basis = tensor(test['basis'], dtype, device)
weight_index = tensor(test['weight_index'], torch.long, device) weight_index = tensor(test['weight_index'], torch.long, device)
out = SplineWeighting.apply(src, weight, basis, weight_index) out = SplineWeighting.apply(x, weight, basis, weight_index)
assert out.tolist() == test['expected'] assert out.tolist() == test['expected']
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_spline_basis_backward(device): def test_spline_basis_backward(device):
pseudo = torch.rand((4, 2), dtype=torch.double, device=device) pseudo = torch.rand((4, 2), dtype=torch.double, device=device)
pseudo.requires_grad_()
kernel_size = tensor([5, 5], torch.long, device) kernel_size = tensor([5, 5], torch.long, device)
is_open_spline = tensor([1, 1], torch.uint8, device) is_open_spline = tensor([1, 1], torch.uint8, device)
degree = 1
basis, weight_idx = SplineBasis.apply(pseudo, kernel_size, is_open_spline) op = SplineBasis.apply
basis, weight_index = op(pseudo, kernel_size, is_open_spline, degree)
basis.requires_grad_()
src = torch.rand((4, 2), dtype=torch.double, device=device) x = torch.rand((4, 2), dtype=torch.double, device=device)
src.requires_grad_() x.requires_grad_()
weight = torch.rand((25, 2, 4), dtype=torch.double, device=device) weight = torch.rand((25, 2, 4), dtype=torch.double, device=device)
weight.requires_grad_() weight.requires_grad_()
data = (src, weight, basis, weight_idx) data = (x, weight, basis, weight_index)
assert gradcheck(SplineWeighting.apply, data, eps=1e-6, atol=1e-4) is True assert gradcheck(SplineWeighting.apply, data, eps=1e-6, atol=1e-4) is True
import torch import torch
from torch.autograd import Function import basis_cpu
from .utils.ffi import fw_basis, bw_basis implemented_degrees = {1: 'linear', 2: 'quadratic', 3: 'cubic'}
def fw(degree, pseudo, kernel_size, is_open_spline): def get_func(name, tensor):
num_edges, S = pseudo.size(0), (degree + 1)**kernel_size.size(0) # module = basis_cuda if tensor.is_cuda else basis_cpu
basis = pseudo.new_empty((num_edges, S)) module = basis_cpu
weight_index = kernel_size.new_empty((num_edges, S)) return getattr(module, name)
fw_basis(degree, basis, weight_index, pseudo, kernel_size, is_open_spline)
def fw(pseudo, kernel_size, is_open_spline, degree):
op = get_func('{}_fw'.format(implemented_degrees[degree]), pseudo)
basis, weight_index = op(pseudo, kernel_size, is_open_spline)
return basis, weight_index return basis, weight_index
def bw(degree, grad_basis, pseudo, kernel_size, is_open_spline): def bw(grad_basis, pseudo, kernel_size, is_open_spline, degree):
self = torch.empty_like(pseudo) op = get_func('{}_bw'.format(implemented_degrees[degree]), pseudo)
bw_basis(degree, self, grad_basis, pseudo, kernel_size, is_open_spline) grad_pseudo = op(grad_basis, pseudo, kernel_size, is_open_spline)
return self return grad_pseudo
class SplineBasis(Function): class SplineBasis(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, pseudo, kernel_size, is_open_spline, degree=1): def forward(ctx, pseudo, kernel_size, is_open_spline, degree):
ctx.save_for_backward(pseudo) ctx.save_for_backward(pseudo)
ctx.kernel_size = kernel_size ctx.kernel_size = kernel_size
ctx.is_open_spline = is_open_spline ctx.is_open_spline = is_open_spline
ctx.degree = degree ctx.degree = degree
return fw(degree, pseudo, kernel_size, is_open_spline) return fw(pseudo, kernel_size, is_open_spline, degree)
@staticmethod @staticmethod
def backward(ctx, grad_basis, grad_weight_index): def backward(ctx, grad_basis, grad_weight_index):
pseudo, = ctx.saved_tensors pseudo, = ctx.saved_tensors
grad_pseudo = None grad_pseudo = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_pseudo = bw(ctx.degree, grad_basis, pseudo, ctx.kernel_size, grad_pseudo = bw(grad_basis, pseudo, ctx.kernel_size,
ctx.is_open_spline) ctx.is_open_spline, ctx.degree)
return grad_pseudo, None, None, None return grad_pseudo, None, None, None
...@@ -14,7 +14,7 @@ class SplineConv(object): ...@@ -14,7 +14,7 @@ class SplineConv(object):
tensor product basis for a single input feature map :math:`l`. tensor product basis for a single input feature map :math:`l`.
Args: Args:
src (:class:`Tensor`): Input node features of shape x (:class:`Tensor`): Input node features of shape
(number_of_nodes x in_channels). (number_of_nodes x in_channels).
edge_index (:class:`LongTensor`): Graph edges, given by source and edge_index (:class:`LongTensor`): Graph edges, given by source and
target indices, of shape (2 x number_of_edges) in the fixed target indices, of shape (2 x number_of_edges) in the fixed
...@@ -40,7 +40,7 @@ class SplineConv(object): ...@@ -40,7 +40,7 @@ class SplineConv(object):
""" """
@staticmethod @staticmethod
def apply(src, def apply(x,
edge_index, edge_index,
pseudo, pseudo,
weight, weight,
...@@ -51,19 +51,19 @@ class SplineConv(object): ...@@ -51,19 +51,19 @@ class SplineConv(object):
root_weight=None, root_weight=None,
bias=None): bias=None):
src = src.unsqueeze(-1) if src.dim() == 1 else src x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
row, col = edge_index row, col = edge_index
n, m_out = src.size(0), weight.size(2) n, m_out = x.size(0), weight.size(2)
# Weight each node. # Weight each node.
data = SplineBasis.apply(pseudo, kernel_size, is_open_spline, degree) data = SplineBasis.apply(pseudo, kernel_size, is_open_spline, degree)
out = SplineWeighting.apply(src[col], weight, *data) out = SplineWeighting.apply(x[col], weight, *data)
# Convert e x m_out to n x m_out features. # Convert e x m_out to n x m_out features.
row_expand = row.unsqueeze(-1).expand_as(out) row_expand = row.unsqueeze(-1).expand_as(out)
out = src.new_zeros((n, m_out)).scatter_add_(0, row_expand, out) out = x.new_zeros((n, m_out)).scatter_add_(0, row_expand, out)
# Normalize out by node degree (if wished). # Normalize out by node degree (if wished).
if norm: if norm:
...@@ -72,7 +72,7 @@ class SplineConv(object): ...@@ -72,7 +72,7 @@ class SplineConv(object):
# Weight root node separately (if wished). # Weight root node separately (if wished).
if root_weight is not None: if root_weight is not None:
out = out + torch.mm(src, root_weight) out = out + torch.mm(x, root_weight)
# Add bias (if wished). # Add bias (if wished).
if bias is not None: if bias is not None:
......
import torch
import spline_conv_cpu
if torch.cuda.is_available():
import spline_conv_cuda
def get_func(name, tensor):
module = spline_conv_cuda if tensor.is_cuda else spline_conv_cpu
return getattr(module, name)
from torch.autograd import Function import torch
import weighting_cpu
from .utils.ffi import fw_weighting, bw_weighting_src
from .utils.ffi import bw_weighting_weight, bw_weighting_basis
def get_func(name, tensor):
# module = weighting_cuda if tensor.is_cuda else weighting_cpu
module = weighting_cpu
return getattr(module, name)
def fw(src, weight, basis, weight_index):
out = src.new_empty((src.size(0), weight.size(2)))
fw_weighting(out, src, weight, basis, weight_index)
return out
class SplineWeighting(torch.autograd.Function):
def bw_src(grad_out, weight, basis, weight_index):
grad_src = grad_out.new_empty((grad_out.size(0), weight.size(1)))
bw_weighting_src(grad_src, grad_out, weight, basis, weight_index)
return grad_src
def bw_weight(grad_out, src, basis, weight_index, K):
grad_weight = src.new_empty((K, src.size(1), grad_out.size(1)))
bw_weighting_weight(grad_weight, grad_out, src, basis, weight_index)
return grad_weight
def bw_basis(grad_out, src, weight, weight_index):
grad_basis = src.new_empty(weight_index.size())
bw_weighting_basis(grad_basis, grad_out, src, weight, weight_index)
return grad_basis
class SplineWeighting(Function):
@staticmethod @staticmethod
def forward(ctx, src, weight, basis, weight_index): def forward(ctx, x, weight, basis, weight_index):
ctx.save_for_backward(src, weight, basis, weight_index) ctx.weight_index = weight_index
return fw(src, weight, basis, weight_index) ctx.save_for_backward(x, weight, basis)
op = get_func('weighting_fw', x)
out = op(x, weight, basis, weight_index)
return out
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
grad_src = grad_weight = grad_basis = None x, weight, basis = ctx.saved_tensors
src, weight, basis, weight_index = ctx.saved_tensors grad_x = grad_weight = grad_basis = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_src = bw_src(grad_out, weight, basis, weight_index) op = get_func('weighting_bw_x', x)
grad_x = op(grad_out, weight, basis, ctx.weight_index)
if ctx.needs_input_grad[1]: if ctx.needs_input_grad[1]:
K = weight.size(0) op = get_func('weighting_bw_w', x)
grad_weight = bw_weight(grad_out, src, basis, weight_index, K) grad_weight = op(grad_out, x, basis, ctx.weight_index,
weight.size(0))
if ctx.needs_input_grad[2]: if ctx.needs_input_grad[2]:
grad_basis = bw_basis(grad_out, src, weight, weight_index) op = get_func('weighting_bw_b', x)
grad_basis = op(grad_out, x, weight, ctx.weight_index)
return grad_src, grad_weight, grad_basis, None return grad_x, grad_weight, grad_basis, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment