Commit c4f47fe0 authored by anton's avatar anton
Browse files

implement package tests: precision, speedup, validity

fix cpu version boundary conditions
parent 5089052d
...@@ -11,7 +11,7 @@ with open('requirements.txt') as f: ...@@ -11,7 +11,7 @@ with open('requirements.txt') as f:
setup( setup(
name='torch_discounted_cumsum', name='torch_discounted_cumsum',
version=0.1, version=0.1,
description='Fast differentiable discounted cumulative sum', description='Fast discounted cumulative sum in PyTorch',
install_requires=requirements, install_requires=requirements,
python_requires='>=3.6', python_requires='>=3.6',
packages=find_packages(), packages=find_packages(),
......
import time
import torch
from torch_discounted_cumsum import discounted_cumsum_left, discounted_cumsum_right
def discounted_cumsum_left_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in range(input.shape[1]):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.append(last_col)
out = torch.cat(out, dim=1)
return out
def discounted_cumsum_right_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in reversed(range(input.shape[1])):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.insert(0, last_col)
out = torch.cat(out, dim=1)
return out
def test_left():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
gamma = 0.99
out_gold_32 = discounted_cumsum_left_gold(x, gamma)
out_gold_64 = discounted_cumsum_left_gold(x.double(), gamma)
out_fn = discounted_cumsum_left(x, gamma)
diff_32 = (out_fn - out_gold_32).abs().max().item()
diff_64 = (out_fn - out_gold_64).abs().max().item()
print('left diff_32', diff_32)
print('left diff_64', diff_64)
def test_right():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
gamma = 0.99
out_gold_32 = discounted_cumsum_right_gold(x, gamma)
out_gold_64 = discounted_cumsum_right_gold(x.double(), gamma)
out_fn = discounted_cumsum_right(x, gamma)
diff_32 = (out_fn - out_gold_32).abs().max().item()
diff_64 = (out_fn - out_gold_64).abs().max().item()
print('right diff_32', diff_32)
print('right diff_64', diff_64)
def test_grad_left():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
x = torch.nn.Parameter(x)
gamma = 0.99
out_gold = discounted_cumsum_left_gold(x, gamma)
out_gold.sum().backward()
out_gold_grad = x.grad.clone()
del x.grad
out_fn = discounted_cumsum_left(x, gamma)
out_fn.sum().backward()
out_fn_grad = x.grad.clone()
diff_grad = (out_gold_grad - out_fn_grad).abs().max().item()
print('left diff_grad', diff_grad)
def test_grad_right():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
x = torch.nn.Parameter(x)
gamma = 0.99
out_gold = discounted_cumsum_right_gold(x, gamma)
out_gold.sum().backward()
out_gold_grad = x.grad.clone()
del x.grad
out_fn = discounted_cumsum_right(x, gamma)
out_fn.sum().backward()
out_fn_grad = x.grad.clone()
diff_grad = (out_gold_grad - out_fn_grad).abs().max().item()
print('right diff_grad', diff_grad)
def test_speed(reps=10000):
torch.manual_seed(0)
x = torch.randn(10, 100000, dtype=torch.float32).cuda()
gamma = 0.99
t1 = time.time()
for _ in range(reps):
discounted_cumsum_right(x, gamma)
t2 = time.time()
print('sec:', t2-t1)
if __name__ == '__main__':
test_left()
test_right()
test_grad_left()
test_grad_right()
#test_speed()
import os
import random
import time
import unittest
import torch
from tqdm import tqdm
from torch_discounted_cumsum import discounted_cumsum_left, discounted_cumsum_right
def get_grad(param, out):
out.sum().backward()
grad = param.grad.clone()
del param.grad
return grad
def discounted_cumsum_left_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in range(input.shape[1]):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.append(last_col)
out = torch.cat(out, dim=1)
return out
def discounted_cumsum_right_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in reversed(range(input.shape[1])):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.insert(0, last_col)
out = torch.cat(out, dim=1)
return out
def discounted_cumsum_lib(x, gamma, dir):
return {
'left': discounted_cumsum_left,
'right': discounted_cumsum_right,
}[dir](x, gamma)
def discounted_cumsum_gold(x, gamma, dir):
return {
'left': discounted_cumsum_left_gold,
'right': discounted_cumsum_right_gold,
}[dir](x, gamma)
def compute_linf(batchsz, veclen, dir, gamma=0.99, dtype=torch.float32, cuda=False, data='randn', tol=1e-3, seed=2020):
torch.manual_seed(seed)
if data == 'randn':
x = torch.randn((batchsz, veclen), dtype=dtype)
elif data == 'ones':
x = torch.ones((batchsz, veclen), dtype=dtype)
else:
raise ValueError('Invalid data generation identifier')
if cuda:
x = x.cuda()
x = torch.nn.Parameter(x)
out_gold = discounted_cumsum_gold(x, gamma, dir)
grad_gold = get_grad(x, out_gold)
out_lib = discounted_cumsum_lib(x, gamma, dir)
grad_lib = get_grad(x, out_lib)
out_linf = (out_lib - out_gold).abs().max().item()
grad_linf = (grad_lib - grad_gold).abs().max().item()
if out_linf >= tol or grad_linf >= tol:
print(f'x={x}\nout_gold={out_gold}\nout_lib={out_lib}\ngrad_gold={grad_gold}\ngrad_lib={grad_lib}\n')
return out_linf, grad_linf
class TestDiscountedCumSum(unittest.TestCase):
def test_validity(self):
print('Testing validity...')
is_cuda = os.environ.get('CUDA_VISIBLE_DEVICES', '') != ''
for cuda in (True, False):
if cuda and not is_cuda:
print('Skipping validity CUDA tests')
continue
rng = random.Random(2020)
with tqdm(total=2*2*2*17) as pbar:
for data in ('ones', 'randn'):
for dtype in (torch.float32, torch.float64):
for i in range(2):
batchsz = 8 ** i
for j in range(17):
veclen = max(1, 2 ** j + rng.randint(-1, 1))
gamma = rng.random()
seed = rng.randint(0, 2 ** 16)
dir = rng.choice(['left', 'right'])
tol = 2e-3
out_linf, grad_linf = compute_linf(
batchsz, veclen, dir, gamma, dtype, cuda, data, tol, seed
)
msg = f'Validity test failed with batchsz={batchsz}, veclen={veclen}, dir={dir}, ' \
f'gamma={gamma}, dtype={dtype}, cuda={cuda}, data={data}, seed={seed}, ' \
f'out_linf={out_linf}, grad_linf={grad_linf}'
self.assertLess(out_linf, tol, msg)
self.assertLess(grad_linf, tol, msg)
pbar.update(1)
def test_precision(self):
print('Testing precision...')
is_cuda = os.environ.get('CUDA_VISIBLE_DEVICES', '') != ''
if not is_cuda:
print('Skipping precision tests')
return
batchsz = 1
veclen = 10000
gamma = 0.99
dir = 'right'
for data in ('ones', 'randn'):
if data == 'ones':
precision_factor = 2.0
else:
precision_factor = 1.1
torch.manual_seed(2020)
if data == 'randn':
x_32 = torch.randn((batchsz, veclen), dtype=torch.float32)
elif data == 'ones':
x_32 = torch.ones((batchsz, veclen), dtype=torch.float32)
else:
raise ValueError('Invalid data generation identifier')
x_32 = x_32.cuda()
x_64 = x_32.double()
gold_64 = discounted_cumsum_gold(x_64, gamma, dir)
gold_32 = discounted_cumsum_gold(x_32, gamma, dir).double()
lib_32 = discounted_cumsum_lib(x_32, gamma, dir).double()
err_32_gold = (gold_32 - gold_64).abs().max().item()
err_32_lib = (lib_32 - gold_64).abs().max().item()
msg = f'Precision improvement test failed with data={data}, ' \
f'err_32_gold={err_32_gold}, err_32_lib={err_32_lib}'
self.assertLess(precision_factor * err_32_lib, err_32_gold, msg)
print(f'data={data}\nerr_32_gold={err_32_gold:10.8f}\nerr_32_lib ={err_32_lib:10.8f}')
def test_speed(self):
print('Testing speed...')
is_cuda = os.environ.get('CUDA_VISIBLE_DEVICES', '') != ''
NUM_RUNS = 30
NUM_RUNS_GOLD = 6
if not is_cuda:
print('Skipping speed tests')
return
gamma = 0.99
x_32 = torch.randn((1, 100000), dtype=torch.float32)
x_32 += torch.ones_like(x_32)
x_32_gpu = x_32.cuda()
timer = time.clock_gettime(time.CLOCK_MONOTONIC)
for _ in tqdm(range(NUM_RUNS_GOLD), desc='gold', leave=True):
discounted_cumsum_right_gold(x_32, gamma)
dur_gold = time.clock_gettime(time.CLOCK_MONOTONIC) - timer
dur_gold = dur_gold * NUM_RUNS / NUM_RUNS_GOLD
timer = time.clock_gettime(time.CLOCK_MONOTONIC)
for _ in tqdm(range(NUM_RUNS), desc='lib_cpu', leave=True):
discounted_cumsum_right(x_32, gamma)
dur_lib_cpu = time.clock_gettime(time.CLOCK_MONOTONIC) - timer
timer = time.clock_gettime(time.CLOCK_MONOTONIC)
for _ in tqdm(range(NUM_RUNS), desc='lib_cuda', leave=True):
discounted_cumsum_right(x_32_gpu, gamma)
dur_lib_cuda = time.clock_gettime(time.CLOCK_MONOTONIC) - timer
print(f'dur_gold: {dur_gold:7.4f} sec')
print(f'dur_lib_cpu: {dur_lib_cpu:7.4f} sec')
print(f'dur_lib_cuda: {dur_lib_cuda:7.4f} sec')
print(f'speedup gold -> lib_cpu: {dur_gold / dur_lib_cpu:5.2f}')
print(f'speedup gold -> lib_cuda: {dur_gold / dur_lib_cuda:5.2f}')
print(f'speedup lib_cpu -> lib_cuda: {dur_lib_cpu / dur_lib_cuda:5.2f}')
if __name__ == '__main__':
unittest.main()
...@@ -72,7 +72,7 @@ class DiscountedCumSumLeftFunction(torch.autograd.Function): ...@@ -72,7 +72,7 @@ class DiscountedCumSumLeftFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
gamma = ctx.saved_variables[0].item() gamma = ctx.saved_tensors[0].item()
grad_input = _discounted_cumsum_right_dispatcher(grad_output, gamma) grad_input = _discounted_cumsum_right_dispatcher(grad_output, gamma)
return grad_input, None return grad_input, None
...@@ -86,7 +86,7 @@ class DiscountedCumSumRightFunction(torch.autograd.Function): ...@@ -86,7 +86,7 @@ class DiscountedCumSumRightFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
gamma = ctx.saved_variables[0].item() gamma = ctx.saved_tensors[0].item()
grad_input = _discounted_cumsum_left_dispatcher(grad_output, gamma) grad_input = _discounted_cumsum_left_dispatcher(grad_output, gamma)
return grad_input, None return grad_input, None
......
...@@ -26,7 +26,7 @@ torch::Tensor discounted_cumsum_left_cpu(torch::Tensor x, double gamma) { ...@@ -26,7 +26,7 @@ torch::Tensor discounted_cumsum_left_cpu(torch::Tensor x, double gamma) {
auto ya = y.accessor<scalar_t, 2>(); auto ya = y.accessor<scalar_t, 2>();
for (int j=0; j<y.size(1); j++) { for (int j=0; j<y.size(1); j++) {
int j_left = j-1; int j_left = j-1;
if (j_left == 0) { if (j_left == -1) {
continue; continue;
} }
discounted_sum_update(ya, y.size(0), gamma, j, j_left); discounted_sum_update(ya, y.size(0), gamma, j, j_left);
...@@ -47,7 +47,7 @@ torch::Tensor discounted_cumsum_right_cpu(torch::Tensor x, double gamma) { ...@@ -47,7 +47,7 @@ torch::Tensor discounted_cumsum_right_cpu(torch::Tensor x, double gamma) {
auto ya = y.accessor<scalar_t, 2>(); auto ya = y.accessor<scalar_t, 2>();
for (int j=y.size(1)-1; j>=0; j--) { for (int j=y.size(1)-1; j>=0; j--) {
int j_right = j+1; int j_right = j+1;
if (j_right == 0) { if (j_right == y.size(1)) {
continue; continue;
} }
discounted_sum_update(ya, y.size(0), gamma, j, j_right); discounted_sum_update(ya, y.size(0), gamma, j, j_right);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment