Commit f622968c authored by Jan Eric Lenssen's avatar Jan Eric Lenssen
Browse files

started backprop to adj

parent c16689ca
from .spline_conv import spline_conv
from .spline_conv_bp2adj import spline_conv_bp2adj
__all__ = ['spline_conv']
__all__ = ['spline_conv', 'spline_conv_bp2adj']
import torch
from torch.autograd import Function
from ....utils.cuda import (cuda_num_threads, Stream, Dtype, load_kernel,
kernel_loop, get_blocks)
_edgewise_spline_weighting_forward_kernel = kernel_loop + '''
extern "C"
__global__ void edgewise_spline_weighting_forward_kernel(
const ${Dtype}* input, const ${Dtype}* weight, ${Dtype}* output,
const ${Dtype}* amount, const long* index) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
const int e_idx = idx / ${M_out};
const int m_out_idx = idx % ${M_out};
${Dtype} result = 0.0;
${Dtype} w;
${Dtype} f;
int k;
${Dtype} b;
long c;
long w_idx;
for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
k = e_idx * ${k_max} + k_idx;
b = amount[k];
c = index[k];
for (int m_in_idx = 0; m_in_idx < ${M_in}; m_in_idx++) {
w_idx = c * ${M_out} * ${M_in} +
m_in_idx * ${M_out} +
m_out_idx;
w = weight[w_idx];
f = input[e_idx * ${M_in} + m_in_idx];
result += b * w * f;
}
}
output[idx] = result;
}
}
'''
_edgewise_spline_weighting_backward_kernel = kernel_loop + '''
extern "C"
__global__ void edgewise_spline_weighting_backward_kernel(
const ${Dtype}* grad_output, ${Dtype}* grad_input, ${Dtype}* grad_weight,
${Dtype}* grad_b, const ${Dtype}* input, const ${Dtype}* weight,
const ${Dtype}* amount, const long* index) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
const int e_idx = idx / ${M_out};
const int m_out_idx = idx % ${M_out};
${Dtype} w;
${Dtype} g;
${Dtype} f;
${Dtype} w_grad;
${Dtype} b_grad;
int k;
${Dtype} b;
long c;
long w_idx;
for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
k = e_idx * ${k_max} + k_idx;
b = amount[k];
c = index[k];
for (int m_in_idx = 0; m_in_idx < ${M_in}; m_in_idx++) {
w_idx = c * ${M_out} * ${M_in} +
m_in_idx * ${M_out} +
m_out_idx;
w = weight[w_idx];
// Calculate input gradient.
g = grad_output[e_idx * ${M_out} + m_out_idx];
atomicAdd(&(grad_input[e_idx * ${M_in} + m_in_idx]), b * w * g);
// This is inefficient: `reduce_sum` shouldn't be done like this.
// Looping over `M_out` would be better to avoid the `atomicAdd`.
// Calculate weight gradient.
f = input[e_idx * ${M_in} + m_in_idx];
w_grad = f * b * grad_output[e_idx * ${M_out} + m_out_idx];
atomicAdd(&(grad_weight[w_idx]), w_grad);
// Not so efficient either, but not avoidable.
//Calculate gradient of B
b_grad = f * w * g;
atomicAdd(&(grad_b[k]), b_grad);
}
}
}
}
'''
_spline_kernel = kernel_loop + '''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, ${Dtype}* amount, long* index,
const long* kernel_size, const long* is_open_spline) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
const int e_idx = idx / ${k_max};
int k_idx = idx % ${k_max};
int K = ${K};
int k_idx_mod;
int bot;
int top;
${Dtype} value;
${Dtype} frac;
${Dtype} a = 1.0;
long i = 0;
for (int d_idx = 0; d_idx < ${dim}; d_idx++) {
K /= kernel_size[d_idx];
k_idx_mod = k_idx % 2;
k_idx >>= 1;
value = input[e_idx * ${dim} + d_idx] *
(kernel_size[d_idx] - is_open_spline[d_idx]);
frac = value - floor(value);
a *= (1 - k_idx_mod) * frac + k_idx_mod * (1 - frac);
bot = int(floor(value));
top = (bot + 1) % kernel_size[d_idx];
bot %= kernel_size[d_idx];
i += (k_idx_mod * bot + (1 - k_idx_mod) * top) * K;
}
amount[idx] = a;
index[idx] = i;
}
}
'''
_spline_kernel_backwards = kernel_loop + '''
extern "C"
__global__ void spline_kernel(
const ${Dtype}* input, ${Dtype}* amount, long* index,
const long* kernel_size, const long* is_open_spline) {
CUDA_KERNEL_LOOP(idx, ${num_threads}) {
const int e_idx = idx / ${k_max};
int k_idx = idx % ${k_max};
int K = ${K};
int k_idx_mod;
int bot;
int top;
${Dtype} value;
${Dtype} frac;
${Dtype} a = 1.0;
long i = 0;
for (int d_idx = 0; d_idx < ${dim}; d_idx++) {
K /= kernel_size[d_idx];
k_idx_mod = k_idx % 2;
k_idx >>= 1;
value = input[e_idx * ${dim} + d_idx] *
(kernel_size[d_idx] - is_open_spline[d_idx]);
frac = value - floor(value);
a *= (1 - k_idx_mod) * frac + k_idx_mod * (1 - frac);
bot = int(floor(value));
top = (bot + 1) % kernel_size[d_idx];
bot %= kernel_size[d_idx];
i += (k_idx_mod * bot + (1 - k_idx_mod) * top) * K;
}
amount[idx] = a;
index[idx] = i;
}
}
'''
class EdgewiseSplineWeightingGPU(Function):
def __init__(self, kernel_size, is_open_spline, K):
super(EdgewiseSplineWeightingGPU, self).__init__()
assert kernel_size.is_cuda and is_open_spline.is_cuda
self.kernel_size = kernel_size
self.is_open_spline = is_open_spline
self.K = K
def forward(self, input, values, weight):
assert input.is_cuda and weight.is_cuda
self.save_for_backward(input, weight)
values = values.unsqueeze(1) if len(values.size()) < 2 else values
num_edges, dim = values.size()
k_max = 2 ** dim
amount = values.new(num_edges, k_max)
index = values.new(num_edges, k_max).long()
num_threads = amount.numel()
with torch.cuda.device_of(input):
f = load_kernel(
'spline_kernel',
_spline_kernel,
Dtype=Dtype(input),
num_threads=num_threads,
k_max=k_max,
dim=dim,
K=self.K)
f(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
values.data_ptr(),
amount.data_ptr(),
index.data_ptr(),
self.kernel_size.data_ptr(),
self.is_open_spline.data_ptr()
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
self.amount = amount
self.index = index
_, M_in, M_out = weight.size()
k_max = self.amount.size(1)
output = input.new(input.size(0), M_out)
num_threads = output.numel()
with torch.cuda.device_of(input):
f = load_kernel(
'edgewise_spline_weighting_forward_kernel',
_edgewise_spline_weighting_forward_kernel,
Dtype=Dtype(input),
num_threads=num_threads,
M_in=M_in,
M_out=M_out,
k_max=k_max)
f(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
input.data_ptr(),
weight.data_ptr(),
output.data_ptr(),
self.amount.data_ptr(),
self.index.data_ptr()
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
return output
def backward(self, grad_output):
input, weight = self.saved_tensors
K, M_in, M_out = weight.size()
E, k_max = self.amount.size()
grad_input = grad_output.new(input.size(0), M_in).fill_(0)
grad_weight = grad_output.new(K, M_in, M_out).fill_(0)
grad_b = grad_output.new(E, k_max).fill_(0)
num_threads = grad_output.numel()
with torch.cuda.device_of(grad_output):
f = load_kernel(
'edgewise_spline_weighting_backward_kernel',
_edgewise_spline_weighting_backward_kernel,
Dtype=Dtype(input),
num_threads=num_threads,
M_in=M_in,
M_out=M_out,
k_max=k_max,
K=K)
f(block=(cuda_num_threads, 1, 1),
grid=(get_blocks(num_threads), 1, 1),
args=[
grad_output.data_ptr(),
grad_input.data_ptr(),
grad_weight.data_ptr(),
grad_b.data_ptr(),
input.data_ptr(),
weight.data_ptr(),
self.amount.data_ptr(),
self.index.data_ptr()
],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
grad_b = grad_output.new(E, k_max).fill_(0)
return grad_input, grad_weight, grad_u
import torch
from torch.autograd import Variable
from .spline import spline
from .edgewise_spline_weighting_bp2adj_gpu import EdgewiseSplineWeightingGPU
def spline_conv_bp2adj(
adj, # Tensor
input, # Variable
weight, # Variable
kernel_size,
is_open_spline,
K,
degree=1,
bias=None):
values = adj._values()
row, col = adj._indices()
# Get features for every end vertex with shape [|E| x M_in].
output = input[col]
# Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
output = EdgewiseSplineWeightingGPU(kernel_size, is_open_spline, K)(output, weight[:-1], values)
# Convolution via `scatter_add`. Converts [|E| x M_out] feature matrix to
# [n x M_out] feature matrix.
zero = output.data.new(adj.size(1), output.size(1)).fill_(0.0)
zero = Variable(zero) if not torch.is_tensor(output) else zero
r = row.view(-1, 1).expand(row.size(0), output.size(1))
output = zero.scatter_add_(0, Variable(r), output)
# Weighten root node features by multiplying with root weight.
output += torch.mm(input, weight[-1])
# Normalize output by degree.
ones = values.new(values.size(0)).fill_(1)
zero = values.new(output.size(0)).fill_(0)
degree = zero.scatter_add_(0, row, ones)
degree = torch.clamp(degree, min=1)
output = output / Variable(degree.view(-1, 1))
if bias is not None:
output += bias
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment