Commit c80ebba6 authored by Daniel Povey's avatar Daniel Povey
Browse files

A version with apparently-working forward..

parent 2e506591
from .integrated_conv import integrated_conv
import os
import torch
from typing import Tuple
from torch.utils.cpp_extension import load
VERBOSE = False
def _resolve(name):
return os.path.join(os.path.dirname(os.path.realpath(__file__)), name)
try:
import torch_integrated_conv_cpu
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_integrated_conv_cpu')
torch_integrated_conv_cpu = load(
name='torch_integrated_conv_cpu',
sources=[
_resolve('integrated_conv_cpu.cpp'),
],
verbose=VERBOSE,
)
try:
import torch_integrated_conv_cuda
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_integrated_conv_cuda')
torch_integrated_conv_cuda = None
if torch.cuda.is_available():
torch_integrated_conv_cuda = load(
name='torch_integrated_conv_cuda',
sources=[
_resolve('integrated_conv_cuda.cpp'),
_resolve('integrated_conv_cuda_kernel.cu'),
],
verbose=VERBOSE,
)
def _integrated_conv_forward_dispatcher(input: torch.Tensor,
pos_add: torch.Tensor,
pos_mul: torch.Tensor) -> torch.Tensor:
if input.is_cuda:
if torch_integrated_conv_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_integrated_conv_cuda.integrated_conv_cuda(
input.contiguous(), pos_add.contiguous(), pos_mul.contiguous())
else:
return torch_integrated_conv_cpu.integrated_conv_cpu(
input, pos_add, pos_mul)
def _integrated_conv_backward_dispatcher(input: torch.Tensor,
pos_add: torch.Tensor,
pos_mul: torch.Tensor,
grad_output) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if input.is_cuda:
if torch_integrated_conv_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
# Actually it's not a hard requirement that these things be contiguous.
return tuple(torch_integrated_conv_cuda.integrated_conv_backward_cuda(
input.contiguous(), pos_add.contiguous(), pos_mul.contiguous(),
grad_output))
else:
return tuple(torch_integrated_conv_cpu.integrated_conv_backward_cpu(
input, pos_add, pos_mul, grad_output))
class IntegratedConvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, pos_add: torch.Tensor, pos_mul: torch.Tensor) -> torch.Tensor:
output = _integrated_conv_forward_dispatcher(input, pos_add, pos_mul)
ctx.save_for_backward(input, pos_add, pos_mul)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
(input, pos_add, pos_mul) = ctx.saved_tensors
grad_input, grad_pos_add, grad_pos_mul = _integrated_conv_backward_dispatcher(
input, pos_add, pos_mul, grad_output)
return grad_input, grad_pos_add, grad_pos_mul
def integrated_conv(input, pos_add, pos_mul):
"""Integrated convolution.
Args:
input: The input of shape (N, 2*C, W) for 1-d convolution or (N, 2*C, H, W)
for 2-d convolution, where
N is the batch size, C is the number of output channels, and H and W are
the input image's height and width respectively. The input channels are
of two types, "src" and "dest" respectively, meaning whether they relate
to the source or destination image position; all the "src" channels come
first, then the "dest" channels.
pos_add: Positional encoding: the additive part of the convolution kernel.
This is of shape (C, kW) for 1-d
convolution or (C, kH, kW) for 2-d convolution,
where C is the number of channels and kH and kW are the kernel height and
kernel width. Kernel height and width must be odd (we assume zero padding
so the output size is the same as the input size).
pos_mul: Positional encoding: the multiplicative part of the convolution kernel.
This is of shape (C, kW)
for 1-d convolution or (C, kH, kW) for 2-d convolution, where C
is the number of channels and kH and kW are the kernel height and
kernel width.
Return: output, of shape (N, C, W) for 1-d convolution or (N, C, H, W) for
2-d convolution. In the 2-d case the output will be satisfy:
output[n, c, h, w] = \sum_{kh=0}^{kH-1} \sum_{kw=0}^{kW-1}
pos_mul[c, kh, kw] * relu(input[n, c, h, w] + input_padded[n,c,h+kh,w+kw] + pos_add[c, kh, kw])
where input_padded is torch.pad(input, (kW//2, kW//2, kH//2, kH//2)),
meaning zero-padding (this is done implicitly by the implementation).
(Technically this is more closely related to cross-correlation than to
convolution).
"""
if input.ndim == 3:
assert pos_add.ndim == 2 and pos_mul.ndim == 2
# For now we choose to handle only the 2-dimensional case directly. The
# 1-dimensional one is treated as a special case of the 2-dimensional one.
# Actually we could unsqueeze with -2 or -1 here, as the height and width
# behave the same.
return integrated_conv(input.unsqueeze(-2),
pos_add.unsqueeze(-2), pos_mul.unsqueeze(-2)).squeeze(-2)
assert input.ndim == 4 and pos_add.ndim == 3 and pos_mul.ndim == 3
assert input.shape[1] // 2 == pos_add.shape[0] == pos_mul.shape[0]
return IntegratedConvFunction.apply(input, pos_add, pos_mul)
#include <torch/extension.h>
// forward of integrated_conv. """... """ comment of `integrated_conv`
// in integrated_conv.py documents the behavior of this function.
torch::Tensor integrated_conv_cpu(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul) {
TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
const int N = input.size(0),
C = input.size(1) / 2,
H = input.size(2),
W = input.size(3),
kH = pos_add.size(1),
kW = pos_add.size(2);
TORCH_CHECK(kH % 2 == 1 && kW % 2 == 1);
TORCH_CHECK(input.size(1) % 2 == 0, "Input must have even num-channels");
TORCH_CHECK(pos_add.size(0) == C && pos_mul.size(0) == C &&
pos_mul.size(1) == kH && pos_mul.size(2) == kW,
"Input sizes mismatch.");
TORCH_CHECK(pos_add.device() == input.device() &&
pos_mul.device() == pos_add.device(),
"Input devices mismatch");
auto scalar_t = input.scalar_type();
TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
pos_mul.scalar_type() == scalar_t,
"Input dtypes mismatch");
torch::Tensor output = torch::empty({N, C, H, W},
torch::TensorOptions().dtype(scalar_t).device(input.device()));
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_cpu_loop", ([&] {
auto input_a = input.accessor<scalar_t, 4>(),
output_a = output.accessor<scalar_t, 4>();
auto pos_add_a = pos_add.accessor<scalar_t, 3>(),
pos_mul_a = pos_mul.accessor<scalar_t, 3>();
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
auto src_input_a = input_a[n][c],
this_pos_add_a = pos_add_a[c],
this_pos_mul_a = pos_mul_a[c],
this_output_a = output_a[n][c];
for (int h = 0; h < H; h++) {
for (int w = 0; w < W; w++) {
scalar_t dest = input_a[n][c + C][h][w],
sum = 0.0;
for (int kh = 0; kh < kH; kh++) {
int src_h = h + kh - kH / 2;
for (int kw = 0; kw < kW; kw++) {
int src_w = w + kw - kW / 2;
scalar_t src = 0.0;
if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
src = src_input_a[src_h][src_w];
scalar_t relu = src + dest + this_pos_add_a[kh][kw];
if (relu >= 0.0)
sum += relu * this_pos_mul_a[kh][kw];
}
}
this_output_a[h][w] = sum;
}
}
}
}
}));
return output;
}
// backward of integrated_conv; returns (grad_input, grad_pos_add, grad_pos_mul).
std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul,
torch::Tensor grad_output) {
TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
const int N = input.size(0),
C = input.size(1) / 2,
H = input.size(2),
W = input.size(3),
kH = pos_add.size(1),
kW = pos_add.size(2);
TORCH_CHECK(kH % 2 == 1 && kW % 2 == 1);
TORCH_CHECK(input.size(1) % 2 == 0, "Input must have even num-channels");
TORCH_CHECK(pos_add.size(0) == C && pos_mul.size(0) == C &&
pos_mul.size(1) == kH && pos_mul.size(2) == kW,
"Input sizes mismatch.");
TORCH_CHECK(pos_add.device() == input.device() &&
pos_mul.device() == pos_add.device(),
"Input devices mismatch");
auto scalar_t = input.scalar_type();
TORCH_CHECK(pos_add.scalar_type() == scalar_t &&
pos_mul.scalar_type() == scalar_t,
"Input dtypes mismatch");
TORCH_CHECK(grad_output.dim() == 4 && grad_output.size(0) == N
&& grad_output.size(1) == C && grad_output.size(2) == H
&& grad_output.size(3) == W);
torch::Tensor grad_input = torch::zeros({N, 2*C, H, W},
torch::TensorOptions().dtype(scalar_t).device(input.device())),
grad_pos_add = torch::zeros({C, kH, kW},
torch::TensorOptions().dtype(scalar_t).device(input.device())),
grad_pos_mul = torch::zeros({C, kH, kW},
torch::TensorOptions().dtype(scalar_t).device(input.device()));
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_cpu_loop", ([&] {
auto input_a = input.accessor<scalar_t, 4>(),
grad_output_a = grad_output.accessor<scalar_t, 4>(),
grad_input_a = grad_input.accessor<scalar_t, 4>();
auto pos_add_a = pos_add.accessor<scalar_t, 3>(),
pos_mul_a = pos_mul.accessor<scalar_t, 3>(),
grad_pos_add_a = grad_pos_add.accessor<scalar_t, 3>(),
grad_pos_mul_a = grad_pos_mul.accessor<scalar_t, 3>();
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
for (int h = 0; h < H; h++) {
for (int w = 0; w < W; w++) {
scalar_t dest = input_a[n][c + C][h][w],
dest_grad = 0.0, // to be multiplied by this_grad_output later..
this_grad_output = grad_output_a[n][c][h][w];
for (int kh = 0; kh < kH; kh++) {
int src_h = h + kh - kH / 2;
for (int kw = 0; kw < kW; kw++) {
int src_w = w + kw - kW / 2;
scalar_t src = 0.0;
if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
src = input_a[n][c][src_h][src_w];
scalar_t relu = src + dest + pos_add_a[c][kh][kw];
if (relu >= 0.0) {
scalar_t pos_mul_val = pos_mul_a[c][kh][kw];
dest_grad += pos_mul_val; // will later multiply by this_grad_output
grad_pos_add_a[c][kh][kw] += this_grad_output * pos_mul_val;
grad_pos_mul_a[c][kh][kw] += this_grad_output * relu;
if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
grad_input_a[n][c][src_h][src_w] += this_grad_output * pos_mul_val;
}
}
}
grad_input_a[n][c + C][h][w] = dest_grad * this_grad_output;
}
}
}
}
}));
return std::vector<torch::Tensor>({grad_input, grad_pos_add, grad_pos_mul});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("integrated_conv_cpu", &integrated_conv_cpu, "Integrated convolution forward function (CPU)");
m.def("integrated_conv_backward_cpu", &integrated_conv_backward_cpu, "Integrated convolution backward function (CPU)");
}
......@@ -38,19 +38,19 @@ https://www.github.com/toshas/torch-discounted-cumsum
def configure_extensions():
out = [
CppExtension(
'torch_integrated_conv_cpu',
'torch_learned_nonlin_cpu',
[
os.path.join('torch_integrated_conv', 'integrated_conv_cpu.cpp'),
os.path.join('torch_learned_nonlin', 'learned_nonlin_cpu.cpp'),
],
)
]
try:
out.append(
CUDAExtension(
'torch_integrated_conv_cuda',
'torch_learned_nonlin_cuda',
[
os.path.join('torch_integrated_conv', 'integrated_conv_cuda.cpp'),
os.path.join('torch_integrated_conv', 'integrated_conv_cuda_kernel.cu'),
os.path.join('torch_learned_nonlin', 'learned_nonlin_cuda.cpp'),
os.path.join('torch_learned_nonlin', 'learned_nonlin_cuda_kernel.cu'),
],
)
)
......@@ -60,7 +60,7 @@ def configure_extensions():
setup(
name='torch_integrated_conv',
name='torch_learned_nonlin',
version='1.0.2',
description='Fast discounted cumulative sums in PyTorch',
long_description=long_description,
......
from .learned_nonlin import learned_nonlin
import os
import torch
from typing import Tuple
from torch.utils.cpp_extension import load
VERBOSE = False
def _resolve(name):
return os.path.join(os.path.dirname(os.path.realpath(__file__)), name)
try:
import torch_learned_nonlin_cpu
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_learned_nonlin_cpu')
torch_learned_nonlin_cpu = load(
name='torch_learned_nonlin_cpu',
sources=[
_resolve('learned_nonlin_cpu.cpp'),
],
verbose=VERBOSE,
)
try:
import torch_learned_nonlin_cuda
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_learned_nonlin_cuda')
torch_learned_nonlin_cuda = None
if torch.cuda.is_available():
torch_learned_nonlin_cuda = load(
name='torch_learned_nonlin_cuda',
sources=[
_resolve('learned_nonlin_cuda.cpp'),
_resolve('learned_nonlin_cuda_kernel.cu'),
],
verbose=VERBOSE,
)
def _learned_nonlin_forward_dispatcher(input: torch.Tensor,
params: torch.Tensor) -> torch.Tensor:
if input.is_cuda:
if torch_learned_nonlin_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_learned_nonlin_cuda.learned_nonlin_cuda(
input, params.contiguous())
else:
return torch_learned_nonlin_cpu.learned_nonlin_cpu(
input, params)
def _learned_nonlin_backward_dispatcher(input: torch.Tensor,
params: torch.Tensor,
grad_output) -> Tuple[torch.Tensor, torch.Tensor]:
if input.is_cuda:
if torch_learned_nonlin_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return tuple(torch_learned_nonlin_cuda.learned_nonlin_backward_cuda(
input, params,
grad_output))
else:
return tuple(torch_learned_nonlin_cpu.learned_nonlin_backward_cpu(
input, params, grad_output))
class LearnedNonlinFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, params: torch.Tensor) -> torch.Tensor:
output = _learned_nonlin_forward_dispatcher(input, params)
ctx.save_for_backward(input, params)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
(input, params) = ctx.saved_tensors
grad_input, grad_params = _learned_nonlin_backward_dispatcher(
input, params, grad_output)
return grad_input, grad_params
def learned_nonlin(input, params, dim):
"""Learned nonlinearity.
Args:
input: The input, to be transformed pointwise; may be of any shape.
params: The parameters of the learned nonlinearity. Interpreted
as of shape (C, N + 1), where C is the channel and N, which
must be a power of 2 more than 1, is the number of linear regions in the
piecewise linear function. The first element is the log
of the distance between the discontinuities, and the
remaining elements are the derivatives of the function
in the linear pieces. We can explain what this function
is as follows:
Let the row of `params` for a particular channel be
interpreted as (l, d0, d1, d2 ... ). Let K = N/2, and L = exp(l).
Then the discontinuities in the function are at:
L * ( -K+1, -K+2, .., -1, 0, 1, .. K-1 )
and the values d0, d1 .. are interpreted as the slopes of the
function in the intervals, respectively:
[-inf.. L*(-K+1)), [L*-K+1..L*-K+2], ...
and we use these together with the assumption that the
function's value at x=0 is 0, to compute the function's value.
In terms of concrete calculations, we do it as follows:
Firstly, we can get rid of the factor of L by treating the l
parameter as a scale on the input and output, i.e.:
x = input * exp(-l)
... do the calculation y = f(xwithout a scale, interpreting the
discontinuities as being at integer values -K+1, -K+2, ... K+1,
and then:
output = y * = output * exp(l)
The core computation requires computing the y-values at the
discontinuities at -K+1, -K+2 and so on. Each one equals
the sign of the offset (- for negative K) times the sum
of the derivatives 'd' for the regions between the current
points and zero. If we number these as offsets o0, o1 and
so on up to N-2, then the formula is:
for o_n with n < K, o_N = -sum(k = n+1..K-1) d_k
for o_n with n >= k, o_N = sum(K..n-1) d_k
e.g. if K=3 and (d0, d1, d2, d3, d4, d5) = (1, 2, 1, 2, 1, 1), then:
o_0 = -(d1+d2) = -3 # x=-2 maps to y=-3
o_1 = -(d2) = -2 # x=-1 maps to y=-2
o_2 = () = 0 # x=0 maps to y=0
o_3 = (d3) = 2 # x=1 maps to y=2
o_4 = (d3 + d4) = 3 # x=2 maps to y=3
dim: The dimension of `input` that corresponds to the channel. It is
recommended that the channel should not be the fastest-varying
dimension (the one with stride=1), because this will make
the data loads and stores be non-coalesced and the kernels
will be quite slow.
Return: output, of the same shape as `input`.
"""
if dim < 0:
dim += input.ndim
assert dim >= 0 and dim < input.ndim
assert params.ndim == 2 and params.shape[1] % 2 == 1
assert params.shape[0] == input.shape[dim]
orig_shape = list(input.shape)
# `new_shape` is `orig_shape` but modified so that the channel dim (`dim`)
# is dimension/axis 1. We do this not by transposing, but by combining
# adjacent dims.
a, b = 1, 1
for i in range(0, dim):
a *= orig_shape[i]
for i in range(dim + 1, len(orig_shape)):
b *= orig_shape[i]
new_shape = (a, orig_shape[dim], b)
input = input.reshape(new_shape) # `reshape` should make input contiguous if needed.
assert params.shape[0] == input.shape[1]
output = torch.empty_like(input)
ans = LearnedNonlinFunction.apply(input, params)
return ans.reshape(orig_shape)
#include <torch/extension.h>
// forward of learned_nonlin. See """... """ comment of `learned_nonlin` in
// learned_nonlin.py for documentation of the behavior of this function.
torch::Tensor learned_nonlin_cpu(torch::Tensor input,
torch::Tensor params) {
TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional");
TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional.");
TORCH_CHECK(params.size(1) >= 3 &&
((params.size(1) - 1) & (params.size(1) - 2)) == 0,
"params.size(1) has invalid value, must be a power of 2 plus 1.");
TORCH_CHECK(params.size(0) == input.size(1),
"params vs input channels mismatch");
TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(params.device().is_cpu(), "Params must be a CPU tensor");
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1,
K = N / 2;
auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
torch::Tensor y_vals = torch::empty({C, N}, opts),
output = torch::empty({B, C, T}, opts);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_cpu_loop", ([&] {
auto params_a = params.accessor<scalar_t, 2>(),
y_vals_a = y_vals.accessor<scalar_t, 2>();
for (int c = 0; c < C; c++) {
scalar_t sum_negative = 0.0,
sum_positive = 0.0;
for (int i = 0; i < K; i++) {
y_vals_a[c][K + i] = sum_positive;
y_vals_a[c][K - i] = sum_negative;
sum_positive += params_a[c][1 + K + i];
sum_negative -= params_a[c][K - i];
}
// the reference point for the lowest, half-infinite interval (the one
// starting at x=-(K-1) is still x=-(K-1); this value is repeated in y_vals.
y_vals_a[c][0] = y_vals_a[c][1];
}
auto input_a = input.accessor<scalar_t, 3>(),
output_a = output.accessor<scalar_t, 3>();
for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) {
scalar_t l = params_a[c][0],
scale = exp(l),
inv_scale = 1.0 / scale;
for (int t = 0; t < T; t++) {
// `x` is the scaled input x plus an offset so that -K maps to 0.
// Note: the discontinuities in our function are at -(K-1) ... +(K+1),
// so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1)
scalar_t x = input_a[b][c][t] * inv_scale + K,
y;
int min = 0, diff = K;
while (diff > 0) {
int mid = min + diff;
if (x >= mid)
min = mid;
diff = diff >> 1;
}
// OK, at this point, 0 <= min < 2*K.
y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min];
// printf("x = %f, y = %f, min = %d; y = (%f - %d) * %f+ %f\n", x, y, min,
// x, min, params_a[c][min + 1], y_vals_a[c][min - 1]);
output_a[b][c][t] = y * scale;
}
}
}}));
return output;
}
// backward of learned_nonlin. Returns (input_grad, params_grad)
std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
torch::Tensor params,
torch::Tensor output_grad) {
TORCH_CHECK(input.dim() == 3, "input must be 3-dimensional");
TORCH_CHECK(params.dim() == 2, "params must be 2-dimensional.");
TORCH_CHECK(params.size(1) >= 3 &&
((params.size(1) - 1) & (params.size(1) - 2)) == 0,
"params.size(1) has invalid value, must be a power of 2 plus 1.");
TORCH_CHECK(params.size(0) == input.size(1),
"params vs input channels mismatch");
TORCH_CHECK(input.sizes() == output_grad.sizes(),
"Output-grad vs. input sizes mismatch.");
TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(params.device().is_cpu(), "Params must be a CPU tensor");
TORCH_CHECK(output_grad.device().is_cpu(), "Output-grad must be a CPU tensor");
const int B = input.size(0),
C = input.size(1),
T = input.size(2),
N = params.size(1) - 1,
K = N / 2;
auto scalar_t = input.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(input.device());
torch::Tensor y_vals = torch::empty({C, N}, opts),
y_vals_grad = torch::zeros({C, N}, opts),
params_grad = torch::zeros({C, N + 1}, opts),
input_grad = torch::zeros({B, C, T}, opts);
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_backward_cpu_loop", ([&] {
auto params_a = params.accessor<scalar_t, 2>(),
params_grad_a = params.accessor<scalar_t, 2>(),
y_vals_a = y_vals.accessor<scalar_t, 2>(),
y_vals_grad_a = y_vals.accessor<scalar_t, 2>();
for (int c = 0; c < C; c++) {
scalar_t sum_negative = 0.0,
sum_positive = 0.0;
for (int i = 0; i < K; i++) {
y_vals_a[c][K - 1 + i] = sum_positive;
y_vals_a[c][K - 1 - i] = sum_negative;
sum_positive += params_a[c][1 + K + i];
sum_negative -= params_a[c][K - i];
}
}
auto input_a = input.accessor<scalar_t, 3>(),
output_grad_a = output_grad.accessor<scalar_t, 3>(),
input_grad_a = input_grad.accessor<scalar_t, 3>();
for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) {
scalar_t l = params_a[c][0],
scale = exp(l),
inv_scale = 1.0 / scale,
scale_grad = 0.0,
inv_scale_grad = 0.0;
for (int t = 0; t < T; t++) {
// `x` is the scaled input x plus an offset so that -K maps to 0.
// Note: the discontinuities in our function are at -(K-1) ... +(K+1),
// so in a sense -K and +K are not special, but we include those
// extra values as an easy way to handle the semi-infinite regions
// that are < -(K-1) and > (K-1)
scalar_t x = input_a[b][c][t] * inv_scale + K,
output_grad = output_grad_a[b][c][t],
x_grad,
y;
int min = 0, diff = K;
while (diff > 0) {
int mid = min + diff;
if (x >= mid)
min = mid;
diff = diff >> 1;
}
// OK, at this point, 0 <= min < 2*K.
// The "+ 1" is to get (input_a[b][c][t] * inv_scale) - (-(K+1))
if (min == 0) {
y = (x + 1) * params_a[c][1] + y_vals_a[c][0];
// output_a[b][c][t] = y * scale;
scale_grad += y * output_grad;
scalar_t y_grad = scale * output_grad;
x_grad = y_grad * params_a[c][1];
//y_vals_grad_a[c][0] +=
} else {
y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min - 1];
// printf("x = %f, y = %f, min = %d; y = (%f - %d) * %f+ %f\n", x, y, min,
// x, min, params_a[c][min + 1], y_vals_a[c][min - 1]);
}
//output_a[b][c][t] = y * scale;
}
}
}}));
//return output;
//return std::vector<torch::Tensor>({grad_input, grad_pos_add, grad_pos_mul});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("learned_nonlin_cpu", &learned_nonlin_cpu, "Integrated convolution forward function (CPU)");
m.def("learned_nonlin_backward_cpu", &learned_nonlin_backward_cpu, "Integrated convolution backward function (CPU)");
}
#include <torch/extension.h>
// forward of integrated_conv. """... """ comment of `integrated_conv`
// in integrated_conv.py documents the behavior of this function.
torch::Tensor integrated_conv_cuda(torch::Tensor input,
// forward of learned_nonlin. """... """ comment of `learned_nonlin`
// in learned_nonlin.py documents the behavior of this function.
torch::Tensor learned_nonlin_cuda(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul);
// backward of integrated_conv; returns (grad_input, grad_pos_add, grad_pos_mul).
std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
// backward of learned_nonlin; returns (grad_input, grad_pos_add, grad_pos_mul).
std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul,
torch::Tensor grad_output);
......@@ -16,6 +16,6 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("integrated_conv_cuda", &integrated_conv_cuda, "Integrated convolution forward function (CUDA)");
m.def("integrated_conv_backward_cuda", &integrated_conv_backward_cuda, "Integrated convolution backward function (CUDA)");
m.def("learned_nonlin_cuda", &learned_nonlin_cuda, "Integrated convolution forward function (CUDA)");
m.def("learned_nonlin_backward_cuda", &learned_nonlin_backward_cuda, "Integrated convolution backward function (CUDA)");
}
......@@ -40,7 +40,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
}
/*
Forward of integrated_conv. Each thread group handles a single channel (equal
Forward of learned_nonlin. Each thread group handles a single channel (equal
to blockIdx.x), and loops over patches of the output and over the image n
within the batch (different thread groups may be responsible for different
subsets of patches and/or images, see docs of gridDim below).
......@@ -67,7 +67,7 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
gridDim.y <= num-patches per image (recommended)
gridDim.z <= batch-size N (recommended)
When we invoke this kernel, we'll invoke it as:
integrated_conv_forward<<<gridDim, blockDim, bytesShared, stream>>>
learned_nonlin_forward<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * numel, where
numel = 2 * (kH * kW) + max(blockDim.x, (opatchH + kH - 1) * (patchW + kW - 1))
......@@ -76,7 +76,7 @@ extern __shared__ int extern_buf[];
template <typename scalar_t>
__global__
void integrated_conv_kernel(
void learned_nonlin_kernel(
torch::PackedTensorAccessor32<scalar_t, 4> input, // N, 2*C, H, W
torch::PackedTensorAccessor32<scalar_t, 3> pos_add, // C, kH, kW
torch::PackedTensorAccessor32<scalar_t, 3> pos_mul, // C, kH, kW
......@@ -225,7 +225,7 @@ void integrated_conv_kernel(
/*
Backward of integrated_conv. Each thread group handles a single channel (equal
Backward of learned_nonlin. Each thread group handles a single channel (equal
to blockIdx.x), and loops over patches of the output and over the image n
within the batch (different thread groups may be responsible for different
subsets of patches and/or images, see docs of gridDim below).
......@@ -290,7 +290,7 @@ void integrated_conv_kernel(
gridDim.y <= num-patches per image (recommended)
gridDim.z <= batch-size N (recommended)
When we invoke this kernel, we'll invoke it as:
integrated_conv_forward<<<gridDim, blockDim, bytesShared, stream>>>
learned_nonlin_forward<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * numel, where
......@@ -300,7 +300,7 @@ void integrated_conv_kernel(
template <typename scalar_t>
__global__
void integrated_conv_kernel_backward(
void learned_nonlin_kernel_backward(
torch::PackedTensorAccessor32<scalar_t, 4> input, // N, 2*C, H, W
torch::PackedTensorAccessor32<scalar_t, 3> pos_add, // C, kH, kW
torch::PackedTensorAccessor32<scalar_t, 3> pos_mul, // C, kH, kW
......@@ -581,7 +581,7 @@ void integrated_conv_kernel_backward(
torch::Tensor integrated_conv_cuda(torch::Tensor input,
torch::Tensor learned_nonlin_cuda(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul) {
TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
......@@ -665,7 +665,9 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
assert(num_blocks_patch <= num_patches && num_blocks_batch <= N);
#if 0
static int debug_count = 50;
if (debug_count > 0) {
debug_count--;
std::cout << "N,C,H,W=" << N << "," << C << "," << H << "," << W
<< "; kW,kH=" << kW << "," << kH
<< "; patchH,patchW=" << patchH << ","
......@@ -675,12 +677,12 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
<< ", threads_per_opixel=" << threads_per_opixel
<< ", threads_per_block=" << threads_per_block
<< std::endl;
#endif
}
dim3 gridDim(C, num_blocks_patch, num_blocks_batch);
// blockDim is scalar, just threads_per_block.
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_kernel", ([&] {
integrated_conv_kernel<scalar_t><<<gridDim, threads_per_block, sizeof(scalar_t) * buffer_numel, at::cuda::getCurrentCUDAStream()>>>(
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_kernel", ([&] {
learned_nonlin_kernel<scalar_t><<<gridDim, threads_per_block, sizeof(scalar_t) * buffer_numel, at::cuda::getCurrentCUDAStream()>>>(
input.packed_accessor32<scalar_t, 4>(),
pos_add.packed_accessor32<scalar_t, 3>(),
pos_mul.packed_accessor32<scalar_t, 3>(),
......@@ -693,7 +695,7 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul,
torch::Tensor grad_output) {
......@@ -807,7 +809,9 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
assert(patchH * patchW * threads_per_pixel <= threads_per_block);
assert(kH * kW * threads_per_kernel_pos <= threads_per_block);
#if 0
static int debug_count = 50;
if (debug_count > 0) {
debug_count--;
std::cout << "[backward:] N,C,H,W=" << N << "," << C << "," << H << "," << W
<< "; kW,kH=" << kW << "," << kH
<< "; patchH,patchW=" << patchH << ","
......@@ -819,7 +823,7 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
<< ", threads_per_block=" << threads_per_block
<< ", buffer_numel=" << buffer_numel
<< std::endl;
#endif
}
int num_blocks = num_blocks_patch * num_blocks_batch;
......@@ -833,8 +837,8 @@ std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
dim3 gridDim(C, num_blocks_patch, num_blocks_batch);
// blockDim is scalar, just threads_per_block.
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_kernel_backward", ([&] {
integrated_conv_kernel_backward<scalar_t><<<gridDim, threads_per_block,
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "learned_nonlin_kernel_backward", ([&] {
learned_nonlin_kernel_backward<scalar_t><<<gridDim, threads_per_block,
sizeof(scalar_t) * buffer_numel,
at::cuda::getCurrentCUDAStream()>>>(
input.packed_accessor32<scalar_t, 4>(),
......
import random
import torch
from torch_integrated_conv import integrated_conv
from torch_learned_nonlin import learned_nonlin
def test_integrated_conv_zeros():
def test_learned_nonlin_basic():
for dtype in [torch.float32, torch.float64]:
B = 2
C = 4
T = 10
x = -2.0 + 0.4 * torch.arange(10, dtype=dtype)
x = x.reshape(1, 1, 10).repeat(B, C, 1)
K = 4
N = K * 2
params = torch.arange(N + 1, dtype=dtype).unsqueeze(0) + torch.arange(C, dtype=dtype).unsqueeze(1)
print("x = ", x)
print("params = ", params)
print("x.shape = ", x.shape)
y = learned_nonlin(x, params, dim = 1)
print("y = ", y)
def test_learned_nonlin_zeros():
N = 1
C = 2
H = 3
......@@ -24,7 +44,7 @@ def test_integrated_conv_zeros():
pos_mul.requires_grad = True
output_ref = torch.zeros(N, C, H, W, device=device, dtype=dtype)
output = integrated_conv(input, pos_add, pos_mul)
output = learned_nonlin(input, pos_add, pos_mul)
assert torch.allclose(output, output_ref)
output.sum().backward()
......@@ -33,7 +53,7 @@ def test_integrated_conv_zeros():
print("pos_mul_grad=", pos_mul.grad)
def test_integrated_conv_compare():
def test_learned_nonlin_compare():
N = 1
C = 2
H = 3
......@@ -58,8 +78,8 @@ def test_integrated_conv_compare():
for x in [ pos_add, pos_mul, pos_add_cuda, pos_mul_cuda, input, input_cuda ]:
x.requires_grad = True
output = integrated_conv(input, pos_add, pos_mul)
output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda)
output = learned_nonlin(input, pos_add, pos_mul)
output_cuda = learned_nonlin(input_cuda, pos_add_cuda, pos_mul_cuda)
print("output = ", output)
print("output_cuda = ", output_cuda)
......@@ -89,7 +109,7 @@ def test_integrated_conv_compare():
def test_integrated_conv_rand_compare():
def test_learned_nonlin_rand_compare():
for _ in range(30):
N = random.randint(1, 256)
C = random.randint(1, 64)
......@@ -127,8 +147,8 @@ def test_integrated_conv_rand_compare():
pos_add_cuda = pos_add.to(device)
pos_mul_cuda = pos_mul.to(device)
output = integrated_conv(input, pos_add, pos_mul)
output_cuda = integrated_conv(input_cuda, pos_add_cuda, pos_mul_cuda)
output = learned_nonlin(input, pos_add, pos_mul)
output_cuda = learned_nonlin(input_cuda, pos_add_cuda, pos_mul_cuda)
diff = (output - output_cuda.to(torch.device('cpu'))).abs().sum()
sum_abs = output.abs().sum()
......@@ -141,7 +161,7 @@ def test_integrated_conv_rand_compare():
def test_integrated_conv_rand_grad():
def test_learned_nonlin_rand_grad():
for _ in range(30):
N = random.randint(1, 256)
C = random.randint(1, 64)
......@@ -179,7 +199,7 @@ def test_integrated_conv_rand_grad():
pos_add.requires_grad = True
pos_mul.requires_grad = True
output = integrated_conv(input, pos_add, pos_mul)
output = learned_nonlin(input, pos_add, pos_mul)
output_grad = torch.randn(N, C, H, W, dtype=dtype, device=device)
output.backward(gradient=output_grad)
......@@ -187,24 +207,26 @@ def test_integrated_conv_rand_grad():
delta = 1.0e-05
pos_delta = delta * torch.randn(C, kH, kW, dtype=dtype, device=device)
pred_change = (pos_delta * pos_add.grad).sum().to('cpu').item()
change = (output_grad * (integrated_conv(input, pos_add + pos_delta, pos_mul) - output )).sum().to('cpu').item()
change = (output_grad * (learned_nonlin(input, pos_add + pos_delta, pos_mul) - output )).sum().to('cpu').item()
print(f"For pos_add: pred_change={pred_change}, change={change}")
#assert abs(pred_change - change) < 1.0e-04
pred_change = (pos_delta * pos_mul.grad).sum().to('cpu').item()
change = (output_grad * (integrated_conv(input, pos_add, pos_mul + pos_delta) - output )).sum().to('cpu').item()
change = (output_grad * (learned_nonlin(input, pos_add, pos_mul + pos_delta) - output )).sum().to('cpu').item()
print(f"For pos_mul: pred_change={pred_change}, change={change}")
#assert abs(pred_change - change) / abs(change) < 1.0e-04
input_delta = delta * torch.randn(N, 2*C, H, W, dtype=dtype, device=device)
pred_change = (input_delta * input.grad).sum().to('cpu').item()
change = (output_grad * (integrated_conv(input + input_delta, pos_add, pos_mul) - output )).sum().to('cpu').item()
change = (output_grad * (learned_nonlin(input + input_delta, pos_add, pos_mul) - output )).sum().to('cpu').item()
print(f"For input: pred_change={pred_change}, change={change}")
#assert abs(pred_change - change) / abs(change) < 1.0e-04
if __name__ == "__main__":
test_integrated_conv_rand_grad()
test_integrated_conv_zeros()
test_integrated_conv_compare()
test_integrated_conv_rand_compare()
test_learned_nonlin_basic()
if False:
test_learned_nonlin_rand_grad()
test_learned_nonlin_zeros()
test_learned_nonlin_compare()
test_learned_nonlin_rand_compare()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment