Commit caff36e7 authored by anton's avatar anton
Browse files

first working discounted cumsum right cuda kernel and pytorch bindings

parents
.idea
venv*
deploy*
#include <torch/extension.h>
torch::Tensor discounted_cumsum_right_minthreads(torch::Tensor x, double gamma);
torch::Tensor discounted_cumsum_right_coalesced(torch::Tensor x, double gamma);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("discounted_cumsum_right_minthreads", &discounted_cumsum_right_minthreads,
"Discounted Cumulative Sum Right Minimum Threads");
m.def("discounted_cumsum_right_coalesced", &discounted_cumsum_right_coalesced,
"Discounted Cumulative Sum Right Coalesced Writes");
}
import time
import torch
from torch.utils.cpp_extension import load
torch_discounted_cumsum = load(
name='torch_discounted_cumsum',
sources=['discounted_cumsum.cpp', 'discounted_cumsum_kernel.cu'],
verbose=True,
)
# class DiscountedCumSumFunction(torch.autograd.Function):
# @staticmethod
# def forward(ctx, input, weights, bias, old_h, old_cell):
# outputs = torch_discounted_cumsum.forward(input, weights, bias, old_h, old_cell)
# new_h, new_cell = outputs[:2]
# variables = outputs[1:] + [weights]
# ctx.save_for_backward(*variables)
#
# return new_h, new_cell
#
# @staticmethod
# def backward(ctx, grad_h, grad_cell):
# outputs = torch_discounted_cumsum.backward(
# grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables)
# d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs
# return d_input, d_weights, d_bias, d_old_h, d_old_cell
def discounted_cumsum_right_minthreads(input, gamma):
return torch_discounted_cumsum.discounted_cumsum_right_minthreads(input, gamma)
def discounted_cumsum_right_coalesced(input, gamma):
return torch_discounted_cumsum.discounted_cumsum_right_coalesced(input, gamma)
def discounted_cumsum_right_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in reversed(range(input.shape[1])):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.insert(0, last_col)
out = torch.cat(out, dim=1)
return out
def test_fn(fn):
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
gamma = 0.99
out_gold_32 = discounted_cumsum_right_gold(x, gamma)
out_gold_64 = discounted_cumsum_right_gold(x.double(), gamma)
out_fn = fn(x, gamma)
diff_32 = (out_fn - out_gold_32).abs().max().item()
diff_64 = (out_fn - out_gold_64).abs().max().item()
print(fn.__name__)
print('diff_32', diff_32)
print('diff_64', diff_64)
def test_speed(fn, reps=10000):
torch.manual_seed(0)
x = torch.randn(10, 100000, dtype=torch.float32).cuda()
gamma = 0.99
t1 = time.time()
for _ in range(reps):
fn(x, gamma)
t2 = time.time()
print(fn.__name__, t2-t1)
if __name__ == '__main__':
test_fn(discounted_cumsum_right_minthreads)
test_fn(discounted_cumsum_right_coalesced)
test_speed(discounted_cumsum_right_minthreads)
test_speed(discounted_cumsum_right_coalesced)
#include <torch/extension.h>
template <typename scalar_t>
__device__ __forceinline__ scalar_t discounted_sum_pow(scalar_t a, scalar_t b, scalar_t gamma, int power) {
return a + b * pow(gamma, scalar_t(power));
}
__inline__
int log2ceil(int x) {
return (int)ceil(log2((float)x));
}
template <typename scalar_t>
__global__ void discounted_cumsum_right_kernel_minthreads_stage(
torch::PackedTensorAccessor32<scalar_t, 2> x,
const scalar_t gamma,
int stage
) {
// Pros: Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Cons: Uncoalesced writes.
const int len = x.size(1);
const int threadidx = blockIdx.x * blockDim.x + threadIdx.x;
const int threadidy = blockIdx.y * blockDim.y + threadIdx.y;
if (threadidy >= x.size(0)) {
return;
}
int gr_prev_stride = 1 << stage;
int gr_cur_stride = gr_prev_stride << 1;
int gr_of_thread = threadidx >> stage;
int thread_in_gr = threadidx - (gr_of_thread << stage);
int change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
int discount_power = gr_prev_stride - thread_in_gr;
if (change_pos >= len || discounted_pos >= len) {
return;
}
x[threadidy][change_pos] = discounted_sum_pow(
x[threadidy][change_pos],
x[threadidy][discounted_pos],
gamma,
discount_power
);
}
template <typename scalar_t>
__global__ void discounted_cumsum_right_kernel_coalesced_stage(
torch::PackedTensorAccessor32<scalar_t, 2> x,
const scalar_t gamma,
int stage
) {
// Pros: Coalesced writes.
// Cons: Threads allocated statically per each element. Half of threads idles upon each iteration.
const int len = x.size(1);
const int threadidx = blockIdx.x * blockDim.x + threadIdx.x;
const int threadidy = blockIdx.y * blockDim.y + threadIdx.y;
if (threadidx >= len || threadidy >= x.size(0)) {
return;
}
int gr_prev_stride = 1 << stage;
int gr_cur_stride = gr_prev_stride << 1;
int gr_of_thread = threadidx >> (stage + 1);
int thread_in_gr = threadidx - (gr_of_thread << (stage + 1));
int change_pos = threadidx;
int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
int discount_power = gr_prev_stride - thread_in_gr;
if (thread_in_gr >= gr_prev_stride || discounted_pos >= len) {
return;
}
x[threadidy][change_pos] = discounted_sum_pow(
x[threadidy][change_pos],
x[threadidy][discounted_pos],
gamma,
discount_power
);
}
torch::Tensor discounted_cumsum_right_minthreads(torch::Tensor x, double gamma) {
// Pros: Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Cons: Uncoalesced writes.
TORCH_CHECK(x.type().is_cuda(), "Input must be a CUDA tensor");
TORCH_CHECK(x.is_contiguous(), "Input must be contiguous");
TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");
if (x.size(1) == 0) {
return x;
}
auto y = x.clone();
const int threads = 32;
const int nstages = log2ceil(x.size(1));
const int threads_total_x = 1 << (nstages - 1);
const dim3 blocks((threads_total_x + threads - 1) / threads, x.size(0));
for (int stage=0; stage<nstages; stage++) {
AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_right_kernel_minthreads_stage", ([&] {
discounted_cumsum_right_kernel_minthreads_stage<scalar_t><<<blocks, threads>>>(
y.packed_accessor32<scalar_t, 2>(),
scalar_t(gamma),
stage
);
}));
}
return y;
}
torch::Tensor discounted_cumsum_right_coalesced(torch::Tensor x, double gamma) {
// Pros: Coalesced writes.
// Cons: Threads allocated statically per each element. Half of threads idles upon each iteration.
TORCH_CHECK(x.type().is_cuda(), "Input must be a CUDA tensor");
TORCH_CHECK(x.is_contiguous(), "Input must be contiguous");
TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");
if (x.size(1) == 0) {
return x;
}
auto y = x.clone();
const int threads = 32;
const int nstages = log2ceil(x.size(1));
const dim3 blocks((x.size(1) + threads - 1) / threads, x.size(0));
for (int stage=0; stage<nstages; stage++) {
AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_right_kernel_coalesced_stage", ([&] {
discounted_cumsum_right_kernel_coalesced_stage<scalar_t><<<blocks, threads>>>(
y.packed_accessor32<scalar_t, 2>(),
scalar_t(gamma),
stage
);
}));
}
return y;
}
torch>=1.5
\ No newline at end of file
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='torch_discounted_cumsum',
ext_modules=[
CUDAExtension('lltm_cuda', [
'discounted_cumsum.cpp',
'discounted_cumsum_kernel.cu',
])
],
cmdclass={
'build_ext': BuildExtension
})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment