Commit 69abe873 authored by anton's avatar anton
Browse files

add left cumsum kernel specialization

refactor variable names
parent 1613a3eb
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma);
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma); torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("discounted_cumsum_right", &discounted_cumsum_right, m.def("discounted_cumsum_left", &discounted_cumsum_left, "Discounted Cumulative Sum (Left)");
"Discounted Cumulative Sum Right"); m.def("discounted_cumsum_right", &discounted_cumsum_right, "Discounted Cumulative Sum (Right)");
} }
...@@ -29,10 +29,27 @@ torch_discounted_cumsum = load( ...@@ -29,10 +29,27 @@ torch_discounted_cumsum = load(
# return d_input, d_weights, d_bias, d_old_h, d_old_cell # return d_input, d_weights, d_bias, d_old_h, d_old_cell
def discounted_cumsum_left(input, gamma):
return torch_discounted_cumsum.discounted_cumsum_left(input, gamma)
def discounted_cumsum_right(input, gamma): def discounted_cumsum_right(input, gamma):
return torch_discounted_cumsum.discounted_cumsum_right(input, gamma) return torch_discounted_cumsum.discounted_cumsum_right(input, gamma)
def discounted_cumsum_left_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in range(input.shape[1]):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.append(last_col)
out = torch.cat(out, dim=1)
return out
def discounted_cumsum_right_gold(input, gamma): def discounted_cumsum_right_gold(input, gamma):
assert input.dim() == 2 assert input.dim() == 2
assert 0 <= gamma <= 1 assert 0 <= gamma <= 1
...@@ -46,7 +63,20 @@ def discounted_cumsum_right_gold(input, gamma): ...@@ -46,7 +63,20 @@ def discounted_cumsum_right_gold(input, gamma):
return out return out
def test(): def test_left():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
gamma = 0.99
out_gold_32 = discounted_cumsum_left_gold(x, gamma)
out_gold_64 = discounted_cumsum_left_gold(x.double(), gamma)
out_fn = discounted_cumsum_left(x, gamma)
diff_32 = (out_fn - out_gold_32).abs().max().item()
diff_64 = (out_fn - out_gold_64).abs().max().item()
print('left diff_32', diff_32)
print('left diff_64', diff_64)
def test_right():
torch.manual_seed(0) torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda() x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
gamma = 0.99 gamma = 0.99
...@@ -55,8 +85,8 @@ def test(): ...@@ -55,8 +85,8 @@ def test():
out_fn = discounted_cumsum_right(x, gamma) out_fn = discounted_cumsum_right(x, gamma)
diff_32 = (out_fn - out_gold_32).abs().max().item() diff_32 = (out_fn - out_gold_32).abs().max().item()
diff_64 = (out_fn - out_gold_64).abs().max().item() diff_64 = (out_fn - out_gold_64).abs().max().item()
print('diff_32', diff_32) print('right diff_32', diff_32)
print('diff_64', diff_64) print('right diff_64', diff_64)
def test_speed(reps=10000): def test_speed(reps=10000):
...@@ -71,5 +101,6 @@ def test_speed(reps=10000): ...@@ -71,5 +101,6 @@ def test_speed(reps=10000):
if __name__ == '__main__': if __name__ == '__main__':
test() test_left()
test_speed() test_right()
#test_speed()
...@@ -3,38 +3,50 @@ ...@@ -3,38 +3,50 @@
template <typename scalar_t> template <typename scalar_t>
__device__ __forceinline__ __device__ __forceinline__
scalar_t discounted_sum_pow(scalar_t a, scalar_t b, scalar_t gamma, int power) { scalar_t discounted_sum_power(scalar_t a, scalar_t b, scalar_t gamma, int power) {
return a + b * pow(gamma, scalar_t(power)); return a + b * pow(gamma, scalar_t(power));
} }
enum SumDirection { enum SumDirection {
SUM_RIGHT, SUM_DIRECTION_LEFT,
SUM_LEFT SUM_DIRECTION_RIGHT,
}; };
template <SumDirection d> template <SumDirection sum_direction>
__device__ __forceinline__ __device__ __forceinline__
void resolve_positions( void resolve_positions(
const int &gr_prev_stride, const int &gr_cur_stride, const int &gr_of_thread, const int &thread_in_gr, const int &stride_prev_group, const int &stride_cur_group, const int &group_of_thread, const int &thread_in_group,
int &change_pos, int &discounted_pos, int &discount_power int &change_pos, int &discounted_pos, int &discount_power
); );
template <> template <>
__device__ __forceinline__ __device__ __forceinline__
void resolve_positions<SUM_RIGHT>( void resolve_positions<SUM_DIRECTION_LEFT>(
const int &gr_prev_stride, const int &gr_cur_stride, const int &gr_of_thread, const int &thread_in_gr, const int &stride_prev_group, const int &stride_cur_group, const int &group_of_thread, const int &thread_in_group,
int &change_pos, int &discounted_pos, int &discount_power int &change_pos, int &discounted_pos, int &discount_power
) { ) {
change_pos = gr_of_thread * gr_cur_stride + thread_in_gr; change_pos = group_of_thread * stride_cur_group + thread_in_group + stride_prev_group;
discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride; discounted_pos = group_of_thread * stride_cur_group + stride_prev_group - 1;
discount_power = gr_prev_stride - thread_in_gr; discount_power = thread_in_group + 1;
} }
template <typename scalar_t, SumDirection d> template <>
__device__ __forceinline__
void resolve_positions<SUM_DIRECTION_RIGHT>(
const int &stride_prev_group, const int &stride_cur_group, const int &group_of_thread, const int &thread_in_group,
int &change_pos, int &discounted_pos, int &discount_power
) {
change_pos = group_of_thread * stride_cur_group + thread_in_group;
discounted_pos = group_of_thread * stride_cur_group + stride_prev_group;
discount_power = stride_prev_group - thread_in_group;
}
template <typename scalar_t, SumDirection sum_direction>
__global__ __global__
void discounted_cumsum_kernel_stage( void discounted_cumsum_kernel_stage(
torch::PackedTensorAccessor32<scalar_t, 2> x, torch::PackedTensorAccessor32<scalar_t, 2> x,
...@@ -42,26 +54,22 @@ void discounted_cumsum_kernel_stage( ...@@ -42,26 +54,22 @@ void discounted_cumsum_kernel_stage(
int stage int stage
) { ) {
const int len = x.size(1); const int len = x.size(1);
const int threadidx = blockIdx.x * blockDim.x + threadIdx.x; const int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int threadidy = blockIdx.y * blockDim.y + threadIdx.y; const int thread_idy = blockIdx.y * blockDim.y + threadIdx.y;
if (threadidy >= x.size(0)) { if (thread_idy >= x.size(0)) {
return; return;
} }
int gr_prev_stride = 1 << stage; int stride_prev_group = 1 << stage;
int gr_cur_stride = gr_prev_stride << 1; int stride_cur_group = stride_prev_group << 1;
int gr_of_thread = threadidx >> stage;
int thread_in_gr = threadidx - (gr_of_thread << stage);
//int change_pos = gr_of_thread * gr_cur_stride + thread_in_gr; int group_of_thread = thread_idx >> stage;
//int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride; int thread_in_group = thread_idx - (group_of_thread << stage);
//int discount_power = gr_prev_stride - thread_in_gr;
int change_pos, discounted_pos, discount_power; int change_pos, discounted_pos, discount_power;
resolve_positions<d>( resolve_positions<sum_direction>(
gr_prev_stride, gr_cur_stride, gr_of_thread, thread_in_gr, stride_prev_group, stride_cur_group, group_of_thread, thread_in_group,
change_pos, discounted_pos, discount_power change_pos, discounted_pos, discount_power
); );
...@@ -69,9 +77,9 @@ void discounted_cumsum_kernel_stage( ...@@ -69,9 +77,9 @@ void discounted_cumsum_kernel_stage(
return; return;
} }
x[threadidy][change_pos] = discounted_sum_pow( x[thread_idy][change_pos] = discounted_sum_power(
x[threadidy][change_pos], x[thread_idy][change_pos],
x[threadidy][discounted_pos], x[thread_idy][discounted_pos],
gamma, gamma,
discount_power discount_power
); );
...@@ -84,7 +92,7 @@ int log2ceil(int x) { ...@@ -84,7 +92,7 @@ int log2ceil(int x) {
} }
template <SumDirection d> template <SumDirection sum_direction>
torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) { torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
// Minimum required number of threads, assigns them dynamically to respective positions upon each iteration. // Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling. // Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
...@@ -107,7 +115,7 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) { ...@@ -107,7 +115,7 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
for (int stage=0; stage<nstages; stage++) { for (int stage=0; stage<nstages; stage++) {
AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_kernel_stage", ([&] { AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_kernel_stage", ([&] {
discounted_cumsum_kernel_stage<scalar_t, d><<<blocks, threads>>>( discounted_cumsum_kernel_stage<scalar_t, sum_direction><<<blocks, threads>>>(
y.packed_accessor32<scalar_t, 2>(), y.packed_accessor32<scalar_t, 2>(),
scalar_t(gamma), scalar_t(gamma),
stage stage
...@@ -119,10 +127,11 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) { ...@@ -119,10 +127,11 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
} }
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) { torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma) {
return discounted_cumsum<SUM_RIGHT>(x, gamma); return discounted_cumsum<SUM_DIRECTION_LEFT>(x, gamma);
} }
//torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma) { torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
//} return discounted_cumsum<SUM_DIRECTION_RIGHT>(x, gamma);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment