"benchmark/vscode:/vscode.git/clone" did not exist on "1b1701f1f7ea59ffca374cfbd8cd53ed5fd39df8"
Commit 7761cb1d authored by Jan Eric Lenssen's avatar Jan Eric Lenssen
Browse files

several bugs in pos gradient fixed

parent 73322b61
from .spline_conv import spline_conv from .spline_conv import spline_conv
from .spline_conv_bp2adj import spline_conv_bp2adj
__all__ = ['spline_conv', 'spline_conv_bp2adj'] __all__ = ['spline_conv']
...@@ -22,11 +22,12 @@ def spline_conv( ...@@ -22,11 +22,12 @@ def spline_conv(
if input.dim() == 1: if input.dim() == 1:
input = input.unsqueeze(1) input = input.unsqueeze(1)
values = adj._values() values = adj['values']
row, col = adj._indices() row, col = adj['indices']
# Get features for every end vertex with shape [|E| x M_in]. # Get features for every end vertex with shape [|E| x M_in].
output = input[col] output = input[col]
# Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out]. # Convert to [|E| x M_in] feature matrix and calculate [|E| x M_out].
if output.is_cuda: if output.is_cuda:
output = SplineConvGPU(kernel_size, is_open_spline, K, degree, output = SplineConvGPU(kernel_size, is_open_spline, K, degree,
...@@ -39,7 +40,7 @@ def spline_conv( ...@@ -39,7 +40,7 @@ def spline_conv(
# Convolution via `scatter_add`. Converts [|E| x M_out] feature matrix to # Convolution via `scatter_add`. Converts [|E| x M_out] feature matrix to
# [n x M_out] feature matrix. # [n x M_out] feature matrix.
zero = output.data.new(adj.size(1), output.size(1)).fill_(0.0) zero = output.data.new(adj['size'][1], output.size(1)).fill_(0.0)
zero = Variable(zero) if not torch.is_tensor(output) else zero zero = Variable(zero) if not torch.is_tensor(output) else zero
r = row.view(-1, 1).expand(row.size(0), output.size(1)) r = row.view(-1, 1).expand(row.size(0), output.size(1))
output = zero.scatter_add_(0, Variable(r), output) output = zero.scatter_add_(0, Variable(r), output)
...@@ -48,8 +49,8 @@ def spline_conv( ...@@ -48,8 +49,8 @@ def spline_conv(
output += torch.mm(input, weight[-1]) output += torch.mm(input, weight[-1])
# Normalize output by degree. # Normalize output by degree.
ones = values.new(values.size(0)).fill_(1) ones = values.data.new(values.size(0)).fill_(1)
zero = values.new(output.size(0)).fill_(0) zero = values.data.new(output.size(0)).fill_(0)
degree = zero.scatter_add_(0, row, ones) degree = zero.scatter_add_(0, row, ones)
degree = torch.clamp(degree, min=1) degree = torch.clamp(degree, min=1)
output = output / Variable(degree.view(-1, 1)) output = output / Variable(degree.view(-1, 1))
......
import torch import torch
from torch.autograd import Function from torch.autograd import Function, Variable
from ....utils.cuda import (cuda_num_threads, Stream, load_kernel, kernel_loop, from ....utils.cuda import (cuda_num_threads, Stream, load_kernel, kernel_loop,
get_blocks) get_blocks)
...@@ -120,6 +120,7 @@ const ${Dtype}* amount, const long* index, int num_threads) { ...@@ -120,6 +120,7 @@ const ${Dtype}* amount, const long* index, int num_threads) {
k = e_idx * ${k_max} + k_idx; k = e_idx * ${k_max} + k_idx;
b = amount[k]; b = amount[k];
c = index[k]; c = index[k];
${Dtype} adj_g = 0.0;
for (int m_in_idx = 0; m_in_idx < ${M_in}; m_in_idx++) { for (int m_in_idx = 0; m_in_idx < ${M_in}; m_in_idx++) {
w_idx = c * ${M_out} * ${M_in} + w_idx = c * ${M_out} * ${M_in} +
...@@ -212,8 +213,7 @@ const long* kernel_size, const long* is_open_spline, int num_threads) { ...@@ -212,8 +213,7 @@ const long* kernel_size, const long* is_open_spline, int num_threads) {
k_idx >>= 1; k_idx >>= 1;
value = input[e_idx * ${dim} + d_idx]; value = input[e_idx * ${dim} + d_idx];
if(value==1.0)
value -= 0.000001;
value *= kernel_size[d_idx] - is_open_spline[d_idx]; value *= kernel_size[d_idx] - is_open_spline[d_idx];
frac = value - floor(value); frac = value - floor(value);
...@@ -340,16 +340,12 @@ int num_threads) { ...@@ -340,16 +340,12 @@ int num_threads) {
const int e_idx = idx / ${dim}; const int e_idx = idx / ${dim};
int d_idx = idx % ${dim}; int d_idx = idx % ${dim};
int K = ${K};
int k_idx_mod; int k_idx_mod;
int bot;
int top;
${Dtype} value; ${Dtype} value;
${Dtype} frac; ${Dtype} frac;
${Dtype} grad_out = 0.0; ${Dtype} grad_out = 0.0;
long i = 0;
int quotient = (int)pow(2.0,(float)d_idx); int quotient = (int)pow(2.0,(double)d_idx);
for (int k_idx = 0; k_idx < ${k_max}; k_idx++) { for (int k_idx = 0; k_idx < ${k_max}; k_idx++) {
...@@ -359,8 +355,8 @@ int num_threads) { ...@@ -359,8 +355,8 @@ int num_threads) {
value *= kernel_size[d_idx] - is_open_spline[d_idx]; value *= kernel_size[d_idx] - is_open_spline[d_idx];
frac = value - floor(value); frac = value - floor(value);
residual = (1 - k_idx_mod) * (frac - 1) + k_idx_mod * frac; ${Dtype} residual = (1 - k_idx_mod) * (frac - 1) + k_idx_mod * frac;
int a_idx = e_idx*${k_max} + k_idx int a_idx = e_idx*${k_max} + k_idx;
grad_out += grad_amount[a_idx]*amount[a_idx]/residual; grad_out += grad_amount[a_idx]*amount[a_idx]/residual;
} }
...@@ -471,9 +467,12 @@ class SplineConvGPU(Function): ...@@ -471,9 +467,12 @@ class SplineConvGPU(Function):
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
if self.bp_to_adj: if self.bp_to_adj:
self.save_for_backward(input, weight, adj_values, amount, index) self.save_for_backward(input, weight, adj_values)
else: else:
self.save_for_backward(input, weight, amount, index) self.save_for_backward(input, weight)
self.amount = amount
self.index = index
return output return output
...@@ -483,7 +482,9 @@ class SplineConvGPU(Function): ...@@ -483,7 +482,9 @@ class SplineConvGPU(Function):
num_threads = grad_output.numel() num_threads = grad_output.numel()
if self.bp_to_adj: if self.bp_to_adj:
input, weight, adj_values, amount, index = self.saved_tensors input, weight, adj_values = self.saved_tensors
amount = self.amount
index = self.index
grad_amount = grad_output.new(amount.size(0), grad_amount = grad_output.new(amount.size(0),
amount.size(1)).fill_(0) amount.size(1)).fill_(0)
with torch.cuda.device_of(grad_output): with torch.cuda.device_of(grad_output):
...@@ -502,9 +503,9 @@ class SplineConvGPU(Function): ...@@ -502,9 +503,9 @@ class SplineConvGPU(Function):
], ],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
num_threads = grad_amount.numel()
grad_adj = grad_amount.new(grad_amount.size(0), grad_adj = grad_amount.new(grad_amount.size(0),
self.kernel_size.size(0)).fill_(0) self.kernel_size.size(0)).fill_(0)
num_threads = grad_adj.numel()
with torch.cuda.device_of(grad_amount): with torch.cuda.device_of(grad_amount):
self.f_basis_bw( self.f_basis_bw(
...@@ -516,14 +517,17 @@ class SplineConvGPU(Function): ...@@ -516,14 +517,17 @@ class SplineConvGPU(Function):
amount.data_ptr(), amount.data_ptr(),
grad_adj.data_ptr(), grad_adj.data_ptr(),
self.kernel_size.data_ptr(), self.kernel_size.data_ptr(),
self.is_open_spline.data_ptr(), num_threads self.is_open_spline.data_ptr(),
num_threads
], ],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
return grad_input, grad_weight, grad_adj return grad_input, grad_weight, None
else: else:
input, weight, amount, index = self.saved_tensors input, weight = self.saved_tensors
amount = self.amount
index = self.index
with torch.cuda.device_of(grad_output): with torch.cuda.device_of(grad_output):
self.f_weighting_bw( self.f_weighting_bw(
block=(cuda_num_threads, 1, 1), block=(cuda_num_threads, 1, 1),
...@@ -539,4 +543,4 @@ class SplineConvGPU(Function): ...@@ -539,4 +543,4 @@ class SplineConvGPU(Function):
], ],
stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
return grad_input, grad_weight return grad_input, grad_weight, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment