Commit c905d890 authored by Daniel Povey's avatar Daniel Povey
Browse files

First version.. only forward completed, not compiled.

parent 126d977f
......@@ -9,15 +9,15 @@ with open('requirements.txt') as f:
long_description = """
This package implements an efficient parallel algorithm for the computation of discounted cumulative sums
with differentiable bindings to PyTorch. The `cumsum` operation is frequently seen in data science
domains concerned with time series, including Reinforcement Learning (RL).
This package implements an efficient parallel algorithm for the computation of discounted cumulative sums
with differentiable bindings to PyTorch. The `cumsum` operation is frequently seen in data science
domains concerned with time series, including Reinforcement Learning (RL).
The traditional sequential algorithm performs the computation of the output elements in a loop. For an input of size
`N`, it requires `O(N)` operations and takes `O(N)` time steps to complete.
The traditional sequential algorithm performs the computation of the output elements in a loop. For an input of size
`N`, it requires `O(N)` operations and takes `O(N)` time steps to complete.
The proposed parallel algorithm requires a total of `O(N log N)` operations, but takes only `O(log N)` time, which is a
considerable trade-off in many applications involving large inputs.
The proposed parallel algorithm requires a total of `O(N log N)` operations, but takes only `O(log N)` time, which is a
considerable trade-off in many applications involving large inputs.
Features of the parallel algorithm:
- Speed logarithmic in the input size
......@@ -38,19 +38,19 @@ https://www.github.com/toshas/torch-discounted-cumsum
def configure_extensions():
out = [
CppExtension(
'torch_discounted_cumsum_cpu',
'torch_integrated_conv_cpu',
[
os.path.join('torch_discounted_cumsum', 'discounted_cumsum_cpu.cpp'),
os.path.join('torch_integrated_conv', 'integrated_conv_cpu.cpp'),
],
)
]
try:
out.append(
CUDAExtension(
'torch_discounted_cumsum_cuda',
'torch_integrated_conv_cuda',
[
os.path.join('torch_discounted_cumsum', 'discounted_cumsum_cuda.cpp'),
os.path.join('torch_discounted_cumsum', 'discounted_cumsum_cuda_kernel.cu'),
os.path.join('torch_integrated_conv', 'integrated_conv_cuda.cpp'),
os.path.join('torch_integrated_conv', 'integrated_conv_cuda_kernel.cu'),
],
)
)
......@@ -60,7 +60,7 @@ def configure_extensions():
setup(
name='torch_discounted_cumsum',
name='torch_integrated_conv',
version='1.0.2',
description='Fast discounted cumulative sums in PyTorch',
long_description=long_description,
......
import os
import torch
from torch.utils.cpp_extension import load
VERBOSE = False
def _resolve(name):
return os.path.join(os.path.dirname(os.path.realpath(__file__)), name)
try:
import torch_discounted_cumsum_cpu
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_discounted_cumsum_cpu')
torch_discounted_cumsum_cpu = load(
name='torch_discounted_cumsum_cpu',
sources=[
_resolve('discounted_cumsum_cpu.cpp'),
],
verbose=VERBOSE,
)
try:
import torch_discounted_cumsum_cuda
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_discounted_cumsum_cuda')
torch_discounted_cumsum_cuda = None
if torch.cuda.is_available():
torch_discounted_cumsum_cuda = load(
name='torch_discounted_cumsum_cuda',
sources=[
_resolve('discounted_cumsum_cuda.cpp'),
_resolve('discounted_cumsum_cuda_kernel.cu'),
],
verbose=VERBOSE,
)
def _discounted_cumsum_left_dispatcher(input, gamma):
if not torch.is_tensor(input):
raise ValueError('Input must be a torch.Tensor')
if input.is_cuda:
if torch_discounted_cumsum_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_discounted_cumsum_cuda.discounted_cumsum_left_cuda(input.contiguous(), gamma)
else:
return torch_discounted_cumsum_cpu.discounted_cumsum_left_cpu(input, gamma)
def _discounted_cumsum_right_dispatcher(input, gamma):
if not torch.is_tensor(input):
raise ValueError('Input must be a torch.Tensor')
if input.is_cuda:
if torch_discounted_cumsum_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_discounted_cumsum_cuda.discounted_cumsum_right_cuda(input.contiguous(), gamma)
else:
return torch_discounted_cumsum_cpu.discounted_cumsum_right_cpu(input, gamma)
class DiscountedCumSumLeftFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, gamma):
output = _discounted_cumsum_left_dispatcher(input, gamma)
ctx.save_for_backward(torch.tensor(gamma))
return output
@staticmethod
def backward(ctx, grad_output):
gamma = ctx.saved_tensors[0].item()
grad_input = _discounted_cumsum_right_dispatcher(grad_output, gamma)
return grad_input, None
class DiscountedCumSumRightFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, gamma):
output = _discounted_cumsum_right_dispatcher(input, gamma)
ctx.save_for_backward(torch.tensor(gamma))
return output
@staticmethod
def backward(ctx, grad_output):
gamma = ctx.saved_tensors[0].item()
grad_input = _discounted_cumsum_left_dispatcher(grad_output, gamma)
return grad_input, None
def discounted_cumsum_left(input, gamma):
return DiscountedCumSumLeftFunction.apply(input, gamma)
def discounted_cumsum_right(input, gamma):
return DiscountedCumSumRightFunction.apply(input, gamma)
#include <torch/extension.h>
torch::Tensor discounted_cumsum_left_cuda(torch::Tensor x, double gamma);
torch::Tensor discounted_cumsum_right_cuda(torch::Tensor x, double gamma);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("discounted_cumsum_left_cuda", &discounted_cumsum_left_cuda, "Discounted Cumulative Sums CUDA (Left)");
m.def("discounted_cumsum_right_cuda", &discounted_cumsum_right_cuda, "Discounted Cumulative Sums CUDA (Right)");
}
#include <torch/extension.h>
enum SumDirection {
SUM_DIRECTION_LEFT,
SUM_DIRECTION_RIGHT,
};
template <SumDirection sum_direction>
__device__ __forceinline__
void resolve_positions(
const int &stride_prev_group, const int &stride_cur_group, const int &group_of_thread, const int &thread_in_group,
int &change_pos, int &discounted_pos, int &discount_power
);
template <>
__device__ __forceinline__
void resolve_positions<SUM_DIRECTION_LEFT>(
const int &stride_prev_group, const int &stride_cur_group, const int &group_of_thread, const int &thread_in_group,
int &change_pos, int &discounted_pos, int &discount_power
) {
change_pos = group_of_thread * stride_cur_group + thread_in_group + stride_prev_group;
discounted_pos = group_of_thread * stride_cur_group + stride_prev_group - 1;
discount_power = thread_in_group + 1;
}
template <>
__device__ __forceinline__
void resolve_positions<SUM_DIRECTION_RIGHT>(
const int &stride_prev_group, const int &stride_cur_group, const int &group_of_thread, const int &thread_in_group,
int &change_pos, int &discounted_pos, int &discount_power
) {
change_pos = group_of_thread * stride_cur_group + thread_in_group;
discounted_pos = group_of_thread * stride_cur_group + stride_prev_group;
discount_power = stride_prev_group - thread_in_group;
}
template <typename scalar_t>
__device__ __forceinline__
scalar_t discounted_sum_power(scalar_t a, scalar_t b, scalar_t gamma, int power) {
return a + b * pow(gamma, scalar_t(power));
}
template <typename scalar_t, SumDirection sum_direction>
__global__
void discounted_cumsum_kernel_stage(
torch::PackedTensorAccessor32<scalar_t, 2> x,
const scalar_t gamma,
int stage
) {
const int len = x.size(1);
const int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int thread_idy = blockIdx.y * blockDim.y + threadIdx.y;
if (thread_idy >= x.size(0)) {
return;
}
int stride_prev_group = 1 << stage;
int stride_cur_group = stride_prev_group << 1;
int group_of_thread = thread_idx >> stage;
int thread_in_group = thread_idx - (group_of_thread << stage);
int change_pos, discounted_pos, discount_power;
resolve_positions<sum_direction>(
stride_prev_group, stride_cur_group, group_of_thread, thread_in_group,
change_pos, discounted_pos, discount_power
);
if (change_pos >= len || discounted_pos >= len) {
return;
}
x[thread_idy][change_pos] = discounted_sum_power(
x[thread_idy][change_pos],
x[thread_idy][discounted_pos],
gamma,
discount_power
);
}
inline
int log2ceil(int x) {
return (int)ceil(log2((float)x));
}
template <SumDirection sum_direction>
torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
// Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
TORCH_CHECK(x.device().is_cuda(), "Input must be a CUDA tensor");
TORCH_CHECK(x.is_contiguous(), "Input must be contiguous");
TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");
if (x.size(1) == 0) {
return x;
}
auto y = x.clone();
const int threads = 64;
const int nstages = log2ceil(x.size(1));
const int threads_total_x = 1 << (nstages - 1);
const dim3 blocks((threads_total_x + threads - 1) / threads, x.size(0));
for (int stage=0; stage<nstages; stage++) {
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "discounted_cumsum_kernel_stage", ([&] {
discounted_cumsum_kernel_stage<scalar_t, sum_direction><<<blocks, threads>>>(
y.packed_accessor32<scalar_t, 2>(),
scalar_t(gamma),
stage
);
}));
}
return y;
}
torch::Tensor discounted_cumsum_left_cuda(torch::Tensor x, double gamma) {
return discounted_cumsum<SUM_DIRECTION_LEFT>(x, gamma);
}
torch::Tensor discounted_cumsum_right_cuda(torch::Tensor x, double gamma) {
return discounted_cumsum<SUM_DIRECTION_RIGHT>(x, gamma);
}
import os
import torch
from typing import Tuple
from torch.utils.cpp_extension import load
VERBOSE = False
def _resolve(name):
return os.path.join(os.path.dirname(os.path.realpath(__file__)), name)
try:
import torch_integrated_conv_cpu
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_integrated_conv_cpu')
torch_integrated_conv_cpu = load(
name='torch_integrated_conv_cpu',
sources=[
_resolve('integrated_conv_cpu.cpp'),
],
verbose=VERBOSE,
)
try:
import torch_integrated_conv_cuda
except ImportError:
if VERBOSE:
print('Falling back to JIT compiling torch_integrated_conv_cuda')
torch_integrated_conv_cuda = None
if torch.cuda.is_available():
torch_integrated_conv_cuda = load(
name='torch_integrated_conv_cuda',
sources=[
_resolve('integrated_conv_cuda.cpp'),
_resolve('integrated_conv_cuda_kernel.cu'),
],
verbose=VERBOSE,
)
def _integrated_conv_forward_dispather(input: torch.Tensor,
pos_add: torch.Tensor,
pos_mul: torch.Tensor) -> torch.Tensor:
if input.is_cuda:
if torch_integrated_conv_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_integrated_conv_cuda.integrated_conv_cuda(
input.contiguous(), pos_add.contiguous(), pos_mul.contiguous())
else:
return torch_integrated_conv_cpu.integrated_conv_cpu(
input, pos_add, pos_mul)
def _integrated_conv_backward_dispatcher(input: torch.Tensor,
pos_add: torch.Tensor,
pos_mul: torch.Tensor,
grad_output) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if input.is_cuda:
if torch_integrated_conv_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return tuple(torch_integrated_conv_cuda.integrated_conv_backward_cuda(
input.contiguous(), pos_add.contiguous(), pos_mul.contiguous()))
else:
return tuple(torch_integrated_conv_cpu.integrated_conv_backward_cpu(
input, pos_add, pos_mul))
class IntegratedConvFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, pos_add: torch.Tensor, pos_mul: torch.Tensor) -> torch.Tensor:
output = _integrated_conv_forward_dispatcher(input, pos_add, pos_mul)
ctx.save_for_backward(input, pos_add, pos_mul)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
(input, pos_add, pos_mul) = ctx.saved_tensors
grad_input, grad_pos_add, grad_pos_mul = _integrated_conv_backward_dispatcher(
input, pos_add, pos_mul, grad_output)
return grad_input, grad_pos_add, grad_pos_mul
def integrated_conv(input, pos_add, pos_mul):
"""Integrated convolution.
Args:
input: The input of shape (N, 2*C, W) for 1-d convolution or (N, 2*C, H, W)
for 2-d convolution, where
N is the batch size, C is the number of output channels, and H and W are
the input image's height and width respectively. The input channels are
of two types, "src" and "dest" respectively, meaning whether they relate
to the source or destination image position; all the "src" channels come
first, then the "dest" channels.
pos_add: Positional encoding: the additive part of the convolution kernel.
This is of shape (C, kW) for 1-d
convolution or (C, kH, kW) for 2-d convolution,
where C is the number of channels and kH and kW are the kernel height and
kernel width. Kernel height and width must be odd (we assume zero padding
so the output size is the same as the input size).
pos_mul: Positional encoding: the multiplicative part of the convolution kernel.
This is of shape (C, kW)
for 1-d convolution or (C, kH, kW) for 2-d convolution, where C
is the number of channels and kH and kW are the kernel height and
kernel width.
Return: output, of shape (N, C, W) for 1-d convolution or (N, C, H, W) for
2-d convolution. In the 2-d case the output will be satisfy:
output[n, c, h, w] = \sum_{kh=0}^{kH-1} \sum_{kw=0}^{kW-1}
pos_mul[c, kh, kw] * relu(input[n, c, h, w] + input_padded[n,c,h+kh,w+kw] + pos_add[c, kh, kw])
where input_padded is torch.pad(input, (kW//2, kW//2, kH//2, kH//2)),
meaning zero-padding (this is done implicitly by the implementation).
(Technically this is more closely related to cross-correlation than to
convolution).
"""
if input.ndim == 3:
assert pos_add.ndim == 2 and pos_mul.ndim == 2
# For now we choose to handle only the 2-dimensional case directly. The
# 1-dimensional one is treated as a special case of the 2-dimensional one.
# Actually we could unsqueeze with -2 or -1 here, as the height and width
# behave the same.
return integrated_conv(input.unsqueeze(-2),
pos_add.unsqueeze(-2), pos_mul.unsqueeze(-2)).squeeze(-2)
assert input.ndim == 4 and pos_add.ndim == 3 and pos_mul.ndim == 3
assert input.dim[1] // 2 == pos_add.dim[0] == pos_mul.dim[0]
return IntegratedConvFunction.apply(input, pos_add, pos_mul)
#include <torch/extension.h>
// forward of integrated_conv. """... """ comment of `integrated_conv`
// in integrated_conv.py documents the behavior of this function.
torch::Tensor integrated_conv_cpu(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul) {
TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");
const int N = input.size(0),
C = input.size(1) / 2,
H = input.size(2),
W = input.size(3),
kH = pos_add.size(1),
kW = pos_add.size(2);
TORCH_CHECK(kH % 2 == 1 && kW % 2 == 1);
TORCH_CHECK(input.size(1) % 2 == 0, "Input must have even num-channels");
TORCH_CHECK(pos_add.size(0) == C && pos_mul.size(0) == C &&
pos_mul.size(1) == kH && pos_mul.size(2) == kW,
"Input sizes mismatch.");
TORCH_CHECK(pos_add.device() == input.device() &&
pos_mul.device() == pos_add.device(),
"Input devices mismatch");
dtype scalar_t = input.dtype();
TORCH_CHECK(pos_add.dtype() == scalar_t &&
pos_mul.dtype() == scalar_t,
"Input dtypes mismatch");
torch::Tensor output = torch::empty({N, C, H, W},
torch::TensorOptions().dtype(scalar_t).device(input.device()));
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_cpu_loop", ([&] {
auto input_a = input.accessor<scalar_t, 4>(),
pos_add_a = pos_add.accessor<scalar_t, 3>(),
pos_mul_a = pos_add.accessor<scalar_t, 3>(),
output_a = pos_add.accessor<scalar_t, 4>();
for (int n = 0; n < N; n++) {
for (int c = 0; c < C; c++) {
auto src_input_a = input_a[n][c],
this_pos_add_a = pos_add_a[c],
this_pos_mul_a = pos_mul_a[c],
this_output_a = output_a[n][c];
for (int h = 0; h < H; h++) {
for (int w = 0; w < W; w++) {
scalar_t dest = input_a[n][c + C][h][w],
sum = 0.0;
for (int kh = 0; kh < kH; kh++) {
int src_h = h + kh - kH / 2;
for (int kw = 0; kw < kW; kw++) {
int src_w = h + kh - kH / 2;
scalar_t src = 0.0;
if (static_cast<unsigned int>(src_h) < static_cast<unsigned int>(H) &&
static_cast<unsigned int>(src_w) < static_cast<unsigned int>(W))
src = src_input_a[src_h][src_w];
scalar_t relu = src + dest + this_pos_add_a;
if (relu > 0.0)
sum += relu * this_pos_mul_a;
}
}
output_a[h][w] = sum;
}
}
}
}
}));
return output;
}
// backward of integrated_conv; returns (grad_input, grad_pos_add, grad_pos_mul).
std::vector<torch::Tensor> integrated_conv_backward_cpu(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul,
torch::Tensor grad_output) {
// TODO.
return std::vector<torch::Tensor>();
}
template <typename T_accessor, typename scalar_t>
inline
void discounted_sum_update(T_accessor &accessor, int batchsz, scalar_t gamma, int change_pos, int discounted_pos) {
......@@ -38,7 +118,7 @@ torch::Tensor discounted_cumsum_left_cpu(torch::Tensor x, double gamma) {
torch::Tensor discounted_cumsum_right_cpu(torch::Tensor x, double gamma) {
TORCH_CHECK(x.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(x.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");
......@@ -59,6 +139,6 @@ torch::Tensor discounted_cumsum_right_cpu(torch::Tensor x, double gamma) {
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("discounted_cumsum_left_cpu", &discounted_cumsum_left_cpu, "Discounted Cumulative Sums CPU (Left)");
m.def("discounted_cumsum_right_cpu", &discounted_cumsum_right_cpu, "Discounted Cumulative Sums CPU (Right)");
m.def("integrated_conv_cpu", &integrated_conv_cpu, "Integrated convolution forward function (CPU)");
m.def("integrated_conv_backward_cpu", &integrated_conv_forward_cpu, "Integrated convolution backward function (CPU)");
}
#include <torch/extension.h>
// forward of integrated_conv. """... """ comment of `integrated_conv`
// in integrated_conv.py documents the behavior of this function.
torch::Tensor integrated_conv_cuda(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul);
// backward of integrated_conv; returns (grad_input, grad_pos_add, grad_pos_mul).
std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul,
torch::Tensor grad_output);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("integrated_conv_cuda", &integrated_conv_cuda, "Integrated convolution forward function (CUDA)");
m.def("integrated_conv_backward_cuda", &integrated_conv_forward_cuda, "Integrated convolution backward function (CUDA)");
}
#include <torch/extension.h>
#include <cooperative_groups.h>
template <typename scalar_t, typename group_t>
__device__ int reduce_sum(group_t g, scalar_t *temp, scalar_t val)
{
int lane = g.thread_rank();
// Each iteration halves the number of active threads
// Each thread adds its partial sum[i] to sum[lane+i]
#pragma unroll
for (int i = g.size() / 2; i > 0; i /= 2)
{
temp[lane] = val;
g.sync(); // wait for all threads to store
if (lane < i) val += temp[lane + i];
g.sync(); // wait for all threads to load
}
return val; // note: only thread 0 will return full sum
}
/*
Forward of integrated_conv. Each thread group handles a single channel
(equal to blockIdx.x), and loops over patches of the output.
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
buffer_dim: The number of scalar_t in the shared-memory buffer; this is
shared between the input patch and pieces of pos_add
and pos_mul. It is user's responsibility to ensure that
buffer_dim is large enough for the provided parameters.
Args:
input: input image, shape (N, 2*C, H, W)
pos_add: positional encoding, additive part, shape (C, kH, kW)
pos_add: positional encoding, multiplicative part, shape (C, kH, kW)
Note: kH and kW must both be odd so that it's clear how to pad.
The thread-block should have one dimension (x); blockDim.x should equal
some small power of 2 (threads_per_opixel) times the output-patch size which is
opatchH * opatchW (the output-patch height and width). We expect
threads_per_opixel to be 1, 2, or 4; we use a linear summation to sum up the
different threads' partial sums, and if threads_per_opixel gets larger we'd
need to make this a logarithmic reduction.
The requirements on the grid dimension are:
gridDim.x == num-channels C (required)
gridDim.y <= num-patches per image (recommended)
gridDim.z <= batch-size N (recommended)
When we invoke this kernel, we'll invoke it as:
integrated_conv_forward<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
bytesShared = sizeof(shared_t) * numel, where
numel = 2 * (kH * kW) + max(blockDim.x, (opatchH + kH - 1) * (patchW + kW - 1))
*/
extern __shared__ int extern_buf[];
template <typename scalar_t>
__global__
void integrated_conv_kernel(
torch::PackedTensorAccessor32<scalar_t, 4> input, // N, 2*C, H, W
torch::PackedTensorAccessor32<scalar_t, 3> pos_add, // C, kH, kW
torch::PackedTensorAccessor32<scalar_t, 3> pos_mul, // C, kH, kW
torch::PackedTensorAcessor32<scalar_t, 4> output, // N, C, H, W
int opatchH, // output-patch height,
int opatchW // output-patch width
) {
const int H = input.size(2),
W = input.size(3)
kH = pos_add.size(1),
kW = pos_add.size(2),
npatchH = (H + opatchH - 1) / opatchH, // num patches in vertical dim
npatchW = (W + opatchW - 1) / opatchW, // num patches in horizontal dim
npatch = npatchH * npatchW; // total number of patches per image
// Channel index.
const int c = blockIdx.x;
// We don't need to check the range of `c` because we set gridDim.x to the
// exact number of channels.
const int ipatchH = opatchH + kH - 1,
ipatchW = ipatchW + kW - 1,
ipatch_size = ipatchH * ipatchW,
opatch_size = opatchH * opatchW;
// `extern_buf` is general-purpose shared memory, which we'll divide between
// pos_add, pos_mul and src_img_buf to be shared between the src image size
// (ipatch_size) and the number of threads (blockDim.x)
__shared__ scalar_t buf[buffer_dim];
__shared__ scalar_t
*pos_add_buf = (scalar_t*)extern_buf, // pos_add positional-encoding / kernel parameters,
// indexed [kh*kW + kw] where kh and kw are vertical
// and horizontal positions in the kernel.
*pos_mul_buf = pos_add_buf + (kH * kW), // pos_mul positional-encoding / kernel parameters,
// indexed [kh*kW + kw] where kh and kw are vertical
// and horizontal positions in the kernel.
*src_img_buf = pos_mul_buf + (kH * kW); // version of input image that relates to source position,
// of size [ipatch_size], indexed [h*ipatchW + w]...
// note, the 'h' and 'w' indexes are into the zero-padded input
// image.
threads_per_opixel = blockDim.x / opatch_size;
assert(blockDim.x == opatch_size * threads_per_opixel);
auto tile = cooperative_groups::tiled_partition(g, threads_per_opixel);
// pos_in_patch will be interpreted as h_in_patch * opatchW + w_in_patch.
int pos_in_patch = threadIdx.x / threads_per_opixel;
// Load parts of the kernel parameters pos_add and pos_mul into shared memory,
// in pos_add_buf and pos_mul_buf
for (int i = threadIdx.x; i < kH * kW; i += blockDim.x) {
int kh = i / kW,
kw = i % kW;
pos_add_buf[i] = pos_add[c][kh][kw];
pos_mul_buf[i] = pos_mul[c][kh][kw];
}
// n is the index within the batch. Loop to make sure we cover all images in
// the batch. input.size(0) is the batch size N. All threads in the thread-block
// loop the same number of times.
for (int n = blockIdx.z; n < input.size(0); n += gridDim.z) {
// Loop over the patch within the output image. All threads in the
// thread-block loop the same number of times.
for (int patch_idx = blockIdx.y; patch_idx < npatch; patch_idx += gridDim.y) {
// (patch_h_offset, patch_w_offset) are the (vertical, horizontal) indexes
// of the lowest-numbered pixel in the patch of output that this thread
// block is responsible for.
int patch_h_offset = (patch_idx / npatchW) * opatchH,
patch_w_offset = (patch_idx % npatchW) * opatchW;
// This __syncthreads() is only necessary if we have already looped at
// least once over n or patch_idx: it's in case other threads are still
// using the `src_img_buf` buffer for something else.
__syncthreads();
// Load the 'src' part of the input patch; the size of this is the size of
// the output patch plus a border of sizes kH//2, kW//2 on each side.
for (int i = threadIdx.x; i < ipatch_size; i += blockDim.x) {
int h_in_kernel = i / ipatchW,
w_in_kernel = i % ipatchW;
int src_h = patch_h_offset + h_in_kernel - (kH / 2), // kH / 2 is offset due to padding
src_w = patch_w_offset + w_in_kernel - (kW / 2);
scalar_t src_val = scalar_t(0);
if ((unsigned int)src_h < (unsigned int)H &&
(unsigned int)src_w < (unsigned int)W)
src_val = input[n][c][src_h][src_w];
src_img_buf[i] = src_val;
}
// make sure all threads have written to `src_img_buf`
__syncthreads();
// 'h' and 'w' are the positions within the output image, that this tile
// of size threads_per_opixel is responsible for.
int h = patch_h_offset + pos_in_patch / opatchW,
w = patch_w_offset + pos_in_patch % opatchW;
// The "destination" pixel; this is an input. It gets added to each
// src pixel, prior to the relu, in the loop below.
scalar_t dest_val = scalar_t(0);
if (h < H && w < W) {
// Several threads (within the same tile, which implies the same warp)
// may load the same value here, but I believe the device's memory
// subsystem handles this well enough that we can just ignore the issue
// rather than try to optimize it.
// https://forums.developer.nvidia.com/t/accessing-same-global-memory-address-within-warps/66574
dest_val = input[n][c + C][h][w]; // else 0.
}
// `sum` is the partial sum that this thread computes; we'll sum this over
// the `threads_per_opixel` threads in the tile to get the output pixel
// value.
scalar_t sum = 0.0;
for (int pos_in_kernel = tile.thread_rank();
pos_in_kernel < (kH * kW);
pos_in_kernel += threads_per_opixel) {
int h_in_kernel = pos_in_kernel / kW,
w_in_kernel = pos_in_kernel % kW;
// Note: this is actually more like cross-correlation, as we don't
// have a negative sign on the h and w indexes in the kernel.
// Also note: we already took care of padding and the associated
// offsets of -(kH / 2) and -(kW / 2).
int h_in_src_patch = h_in_patch + h_in_kernel,
w_in_src_patch = w_in_patch + w_in_kernel;
scalar_t src_val = src_img_buf[h_in_src_patch * ipatchW + w_in_src_patch],
pos_add_val = pos_add_buf[pos_in_kernel];
scalar_t relu = (src_val + dest_val + pos_add_val);
if (relu > 0.0)
sum += relu * pos_mul_buf[pos_in_kernel];
}
// Aggregate `sum` over threads, if needed; and write the result to `output`.
if (threads_per_opixel > 1) {
__syncthreads();
src_img_buf[threadIdx.x] = sum;
__syncthreads();
if (tile.thread_rank() == 0 && h < H && w < W) {
// This linear summation should be OK because threads_per_opixel is
// unlikely to be greater than 4.
for (int i = 1; i < threads_per_opixel; i++)
sum += src_img_buf[threadIdx.x + i];
output[n][c][h][w] = sum;
}
} else {
if (h < H && w < W)
output[n][c][h][w] = sum;
}
}
}
}
torch::Tensor integrated_conv_cuda(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul) {
TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(input.device().is_cuda(), "Input must be a CUDA tensor");
const int N = input.size(0),
C = input.size(1) / 2,
H = input.size(2),
W = input.size(3),
kH = pos_add.size(1),
kW = pos_add.size(2);
TORCH_CHECK(kH % 2 == 1 && kW % 2 == 1);
TORCH_CHECK(input.size(1) % 2 == 0, "Input must have even num-channels");
TORCH_CHECK(pos_add.size(0) == C && pos_mul.size(0) == C &&
pos_mul.size(1) == kH && pos_mul.size(2) == kW,
"Input sizes mismatch.");
TORCH_CHECK(pos_add.device() == input.device() &&
pos_mul.device() == pos_add.device(),
"Input devices mismatch");
dtype scalar_t = input.dtype();
TORCH_CHECK(pos_add.dtype() == scalar_t &&
pos_mul.dtype() == scalar_t,
"Input dtypes mismatch");
torch::Tensor output = torch::empty({N, C, H, W},
torch::TensorOptions().dtype(scalar_t).device(input.device()));
// Work out the configuration with which we call the kernel..
int patchH = std::min(H, kH), // output patch height
patchW = std::min(W, kW); // output patch width
// We don't want the height or width of the patch to be less than the kernel
// width, or the padding will make the input-patch size more than twice the
// output-patch size.
// We aim for the output-patch size to be more than 128; this is not something
// very exact, but it roughly corresponds to us wanting to have up to 4 threads
// per output pixel, and the limitation of 512 threads per thread-block which
// we impose so that we can run on architectures with little shared memory.
while (patchW < W && patchH * (patchW + 1) <= 128)
patchW++;
while(patchH < H && (patchH + 1) * patchW <= 128)
patchH++;
// We are assuming that the thread-block size can be as large as 1024; this is
int threads_per_opixel;
if (patchH * patchW * 4 <= 512 && (kH * kW) > 16)
threads_per_opixel = 4;
else if (patchH * patchW * 2 <= 512 && (kH * kW) > 8)
threads_per_opixel = 2;
else
threads_per_opixel = 1;
int input_patchH = patchH + kH - 1,
input_patchW = patchW + kW - 1,
input_patch_size = input_patchH * input_patchW;
int threads_per_block = patchH * patchW * threads_per_opixel;
int buffer_numel = 2 * (kH * kW) + max<int>(threads_per_block,
input_patch_size);
int num_patches_H = (H + patchH - 1) / patchH,
num_patches_W = (W + patchW - 1) / patchW,
num_patches = num_patches_H * num_patches_W;
// gridDim.x == C.
int num_blocks_patch = 1, // gridDim.y. should not be more
num_blocks_batch = 1;
while (C * num_blocks_patch <= 256 &&
num_blocks_patch * 2 <= num_patches)
num_blocks_patch *= 2;
if (C * num_patches <= 512)
num_blocks_patch = num_patches;
while (C * num_blocks_patch * num_blocks_batch <= 512 &&
num_blocks_batch * 2 <= N)
num_blocks_batch *= 2;
if (C * num_blocks_patch * N <= 1024)
num_blocks_batch = N;
assert(num_blocks_patch <= num_patches && num_blocks_batch <= N);
std::cout << "N,C,H,W=" << N << "," << C << "," << H << "," << W
<< "; kW,kH=" << kW << "," << kH
<< "; patchH,patchW=" << patchH << ","
<< patchW << ", num_blocks_patch="
<< num_blocks_patch << ", num_blocks_batch="
<< num_blocks_batch << std::endl;
dim3 gridDim(C, num_blocks_patch, num_blocks_batch);
// blockDim is scalar, just threads_per_block.
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "integrated_conv_kernel", ([&] {
integrated_conv_kernel<scalar_t><<<gridDim, threads_per_block, sizeof(scalar_t) * buffer_numel, at::cuda::getCurrentCUDAStream()>>>(
input.packed_accessor32<scalar_t, 4>(),
pos_add.packed_accessor32<scalar_t, 3>(),
pos_mul.packed_accessor32<scalar_t, 3>(),
output.packed_accessor32<scalar_t, 4>(),
patchH,
patchW);
}));
return output;
}
std::vector<torch::Tensor> integrated_conv_backward_cuda(torch::Tensor input,
torch::Tensor pos_add,
torch::Tensor pos_mul,
torch::Tensor grad_output) {
return std::vector<torch::Tensor>();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment