Commit d171584b authored by anton's avatar anton
Browse files

add cpu version

add cpu-gpu dispatcher
parent 69abe873
#include <torch/extension.h>
torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma);
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("discounted_cumsum_left", &discounted_cumsum_left, "Discounted Cumulative Sum (Left)");
m.def("discounted_cumsum_right", &discounted_cumsum_right, "Discounted Cumulative Sum (Right)");
}
import time
import torch import torch
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
torch_discounted_cumsum = load( torch_discounted_cumsum_cpu = load(
name='torch_discounted_cumsum', name='torch_discounted_cumsum_cpu',
sources=['discounted_cumsum.cpp', 'discounted_cumsum_kernel.cu'], sources=['discounted_cumsum_cpu.cpp'],
verbose=True, # verbose=True,
) )
torch_discounted_cumsum_cuda = None
# class DiscountedCumSumFunction(torch.autograd.Function): if torch.cuda.is_available():
# @staticmethod torch_discounted_cumsum_cuda = load(
# def forward(ctx, input, weights, bias, old_h, old_cell): name='torch_discounted_cumsum_cuda',
# outputs = torch_discounted_cumsum.forward(input, weights, bias, old_h, old_cell) sources=['discounted_cumsum_cuda.cpp', 'discounted_cumsum_cuda_kernel.cu'],
# new_h, new_cell = outputs[:2] verbose=True,
# variables = outputs[1:] + [weights] )
# ctx.save_for_backward(*variables)
#
# return new_h, new_cell def _discounted_cumsum_left_dispatcher(input, gamma):
# if not torch.is_tensor(input):
# @staticmethod raise ValueError('Input must be a torch.Tensor')
# def backward(ctx, grad_h, grad_cell): if input.is_cuda:
# outputs = torch_discounted_cumsum.backward( if torch_discounted_cumsum_cuda is None:
# grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables) raise EnvironmentError(f'Failed to load native CUDA module')
# d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs return torch_discounted_cumsum_cuda.discounted_cumsum_left_cuda(input.contiguous(), gamma)
# return d_input, d_weights, d_bias, d_old_h, d_old_cell else:
return torch_discounted_cumsum_cpu.discounted_cumsum_left_cpu(input, gamma)
def _discounted_cumsum_right_dispatcher(input, gamma):
if not torch.is_tensor(input):
raise ValueError('Input must be a torch.Tensor')
if input.is_cuda:
if torch_discounted_cumsum_cuda is None:
raise EnvironmentError(f'Failed to load native CUDA module')
return torch_discounted_cumsum_cuda.discounted_cumsum_right_cuda(input.contiguous(), gamma)
else:
return torch_discounted_cumsum_cpu.discounted_cumsum_right_cpu(input, gamma)
class DiscountedCumSumLeftFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, gamma):
output = _discounted_cumsum_left_dispatcher(input, gamma)
ctx.save_for_backward(torch.tensor(gamma))
return output
@staticmethod
def backward(ctx, grad_output):
gamma = ctx.saved_variables[0].item()
grad_input = _discounted_cumsum_right_dispatcher(grad_output, gamma)
return grad_input, None
class DiscountedCumSumRightFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, gamma):
output = _discounted_cumsum_right_dispatcher(input, gamma)
ctx.save_for_backward(torch.tensor(gamma))
return output
@staticmethod
def backward(ctx, grad_output):
gamma = ctx.saved_variables[0].item()
grad_input = _discounted_cumsum_left_dispatcher(grad_output, gamma)
return grad_input, None
def discounted_cumsum_left(input, gamma): def discounted_cumsum_left(input, gamma):
return torch_discounted_cumsum.discounted_cumsum_left(input, gamma) return DiscountedCumSumLeftFunction.apply(input, gamma)
def discounted_cumsum_right(input, gamma): def discounted_cumsum_right(input, gamma):
return torch_discounted_cumsum.discounted_cumsum_right(input, gamma) return DiscountedCumSumRightFunction.apply(input, gamma)
def discounted_cumsum_left_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in range(input.shape[1]):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.append(last_col)
out = torch.cat(out, dim=1)
return out
def discounted_cumsum_right_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in reversed(range(input.shape[1])):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.insert(0, last_col)
out = torch.cat(out, dim=1)
return out
def test_left():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
gamma = 0.99
out_gold_32 = discounted_cumsum_left_gold(x, gamma)
out_gold_64 = discounted_cumsum_left_gold(x.double(), gamma)
out_fn = discounted_cumsum_left(x, gamma)
diff_32 = (out_fn - out_gold_32).abs().max().item()
diff_64 = (out_fn - out_gold_64).abs().max().item()
print('left diff_32', diff_32)
print('left diff_64', diff_64)
def test_right():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
gamma = 0.99
out_gold_32 = discounted_cumsum_right_gold(x, gamma)
out_gold_64 = discounted_cumsum_right_gold(x.double(), gamma)
out_fn = discounted_cumsum_right(x, gamma)
diff_32 = (out_fn - out_gold_32).abs().max().item()
diff_64 = (out_fn - out_gold_64).abs().max().item()
print('right diff_32', diff_32)
print('right diff_64', diff_64)
def test_speed(reps=10000):
torch.manual_seed(0)
x = torch.randn(10, 100000, dtype=torch.float32).cuda()
gamma = 0.99
t1 = time.time()
for _ in range(reps):
discounted_cumsum_right(x, gamma)
t2 = time.time()
print('sec:', t2-t1)
if __name__ == '__main__':
test_left()
test_right()
#test_speed()
#include <torch/extension.h>
template <typename T_accessor, typename scalar_t>
inline
void discounted_sum_update(T_accessor &accessor, int batchsz, scalar_t gamma, int change_pos, int discounted_pos) {
for (int i=0; i<batchsz-3; i+=4) {
accessor[i+0][change_pos] += gamma * accessor[i+0][discounted_pos];
accessor[i+1][change_pos] += gamma * accessor[i+1][discounted_pos];
accessor[i+2][change_pos] += gamma * accessor[i+2][discounted_pos];
accessor[i+3][change_pos] += gamma * accessor[i+3][discounted_pos];
}
for (int i=(batchsz - (batchsz & 3)); i<batchsz; i++) {
accessor[i][change_pos] += gamma * accessor[i][discounted_pos];
}
}
torch::Tensor discounted_cumsum_left_cpu(torch::Tensor x, double gamma) {
TORCH_CHECK(x.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");
auto y = x.clone();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "discounted_cumsum_left_cpu_loop", ([&] {
auto ya = y.accessor<scalar_t, 2>();
for (int j=0; j<y.size(1); j++) {
int j_left = j-1;
if (j_left == 0) {
continue;
}
discounted_sum_update(ya, y.size(0), gamma, j, j_left);
}
}));
return y;
}
torch::Tensor discounted_cumsum_right_cpu(torch::Tensor x, double gamma) {
TORCH_CHECK(x.device().is_cpu(), "Input must be a CPU tensor");
TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");
auto y = x.clone();
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "discounted_cumsum_right_cpu_loop", ([&] {
auto ya = y.accessor<scalar_t, 2>();
for (int j=y.size(1)-1; j>=0; j--) {
int j_right = j+1;
if (j_right == 0) {
continue;
}
discounted_sum_update(ya, y.size(0), gamma, j, j_right);
}
}));
return y;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("discounted_cumsum_left_cpu", &discounted_cumsum_left_cpu, "Discounted Cumulative Sum CPU (Left)");
m.def("discounted_cumsum_right_cpu", &discounted_cumsum_right_cpu, "Discounted Cumulative Sum CPU (Right)");
}
#include <torch/extension.h>
torch::Tensor discounted_cumsum_left_cuda(torch::Tensor x, double gamma);
torch::Tensor discounted_cumsum_right_cuda(torch::Tensor x, double gamma);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("discounted_cumsum_left_cuda", &discounted_cumsum_left_cuda, "Discounted Cumulative Sum CUDA (Left)");
m.def("discounted_cumsum_right_cuda", &discounted_cumsum_right_cuda, "Discounted Cumulative Sum CUDA (Right)");
}
#include <torch/extension.h> #include <torch/extension.h>
template <typename scalar_t>
__device__ __forceinline__
scalar_t discounted_sum_power(scalar_t a, scalar_t b, scalar_t gamma, int power) {
return a + b * pow(gamma, scalar_t(power));
}
enum SumDirection { enum SumDirection {
SUM_DIRECTION_LEFT, SUM_DIRECTION_LEFT,
SUM_DIRECTION_RIGHT, SUM_DIRECTION_RIGHT,
...@@ -46,6 +39,13 @@ void resolve_positions<SUM_DIRECTION_RIGHT>( ...@@ -46,6 +39,13 @@ void resolve_positions<SUM_DIRECTION_RIGHT>(
} }
template <typename scalar_t>
__device__ __forceinline__
scalar_t discounted_sum_power(scalar_t a, scalar_t b, scalar_t gamma, int power) {
return a + b * pow(gamma, scalar_t(power));
}
template <typename scalar_t, SumDirection sum_direction> template <typename scalar_t, SumDirection sum_direction>
__global__ __global__
void discounted_cumsum_kernel_stage( void discounted_cumsum_kernel_stage(
...@@ -97,7 +97,7 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) { ...@@ -97,7 +97,7 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
// Minimum required number of threads, assigns them dynamically to respective positions upon each iteration. // Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling. // Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
TORCH_CHECK(x.type().is_cuda(), "Input must be a CUDA tensor"); TORCH_CHECK(x.device().is_cuda(), "Input must be a CUDA tensor");
TORCH_CHECK(x.is_contiguous(), "Input must be contiguous"); TORCH_CHECK(x.is_contiguous(), "Input must be contiguous");
TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional"); TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]"); TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");
...@@ -114,7 +114,7 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) { ...@@ -114,7 +114,7 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
const dim3 blocks((threads_total_x + threads - 1) / threads, x.size(0)); const dim3 blocks((threads_total_x + threads - 1) / threads, x.size(0));
for (int stage=0; stage<nstages; stage++) { for (int stage=0; stage<nstages; stage++) {
AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_kernel_stage", ([&] { AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "discounted_cumsum_kernel_stage", ([&] {
discounted_cumsum_kernel_stage<scalar_t, sum_direction><<<blocks, threads>>>( discounted_cumsum_kernel_stage<scalar_t, sum_direction><<<blocks, threads>>>(
y.packed_accessor32<scalar_t, 2>(), y.packed_accessor32<scalar_t, 2>(),
scalar_t(gamma), scalar_t(gamma),
...@@ -127,11 +127,11 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) { ...@@ -127,11 +127,11 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
} }
torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma) { torch::Tensor discounted_cumsum_left_cuda(torch::Tensor x, double gamma) {
return discounted_cumsum<SUM_DIRECTION_LEFT>(x, gamma); return discounted_cumsum<SUM_DIRECTION_LEFT>(x, gamma);
} }
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) { torch::Tensor discounted_cumsum_right_cuda(torch::Tensor x, double gamma) {
return discounted_cumsum<SUM_DIRECTION_RIGHT>(x, gamma); return discounted_cumsum<SUM_DIRECTION_RIGHT>(x, gamma);
} }
torch>=1.5 torch>=1.5
\ No newline at end of file ninja
\ No newline at end of file
from setuptools import setup from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
setup( setup(
name='torch_discounted_cumsum', name='torch_discounted_cumsum',
ext_modules=[ ext_modules=[
CUDAExtension('lltm_cuda', [ CppExtension('torch_discounted_cumsum_cpu', [
'discounted_cumsum.cpp', 'discounted_cumsum_cpu.cpp'
'discounted_cumsum_kernel.cu', ]),
CUDAExtension('torch_discounted_cumsum_cuda', [
'discounted_cumsum_cuda.cpp',
'discounted_cumsum_cuda_kernel.cu',
]) ])
], ],
cmdclass={ cmdclass={
......
import time
import torch
from discounted_cumsum import discounted_cumsum_left, discounted_cumsum_right
def discounted_cumsum_left_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in range(input.shape[1]):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.append(last_col)
out = torch.cat(out, dim=1)
return out
def discounted_cumsum_right_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in reversed(range(input.shape[1])):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.insert(0, last_col)
out = torch.cat(out, dim=1)
return out
def test_left():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
gamma = 0.99
out_gold_32 = discounted_cumsum_left_gold(x, gamma)
out_gold_64 = discounted_cumsum_left_gold(x.double(), gamma)
out_fn = discounted_cumsum_left(x, gamma)
diff_32 = (out_fn - out_gold_32).abs().max().item()
diff_64 = (out_fn - out_gold_64).abs().max().item()
print('left diff_32', diff_32)
print('left diff_64', diff_64)
def test_right():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
gamma = 0.99
out_gold_32 = discounted_cumsum_right_gold(x, gamma)
out_gold_64 = discounted_cumsum_right_gold(x.double(), gamma)
out_fn = discounted_cumsum_right(x, gamma)
diff_32 = (out_fn - out_gold_32).abs().max().item()
diff_64 = (out_fn - out_gold_64).abs().max().item()
print('right diff_32', diff_32)
print('right diff_64', diff_64)
def test_grad_left():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
x = torch.nn.Parameter(x)
gamma = 0.99
out_gold = discounted_cumsum_left_gold(x, gamma)
out_gold.sum().backward()
out_gold_grad = x.grad.clone()
del x.grad
out_fn = discounted_cumsum_left(x, gamma)
out_fn.sum().backward()
out_fn_grad = x.grad.clone()
diff_grad = (out_gold_grad - out_fn_grad).abs().max().item()
print('left diff_grad', diff_grad)
def test_grad_right():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
x = torch.nn.Parameter(x)
gamma = 0.99
out_gold = discounted_cumsum_right_gold(x, gamma)
out_gold.sum().backward()
out_gold_grad = x.grad.clone()
del x.grad
out_fn = discounted_cumsum_right(x, gamma)
out_fn.sum().backward()
out_fn_grad = x.grad.clone()
diff_grad = (out_gold_grad - out_fn_grad).abs().max().item()
print('right diff_grad', diff_grad)
def test_speed(reps=10000):
torch.manual_seed(0)
x = torch.randn(10, 100000, dtype=torch.float32).cuda()
gamma = 0.99
t1 = time.time()
for _ in range(reps):
discounted_cumsum_right(x, gamma)
t2 = time.time()
print('sec:', t2-t1)
if __name__ == '__main__':
test_left()
test_right()
test_grad_left()
test_grad_right()
#test_speed()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment