Commit 69abe873 authored by anton's avatar anton
Browse files

add left cumsum kernel specialization

refactor variable names
parent 1613a3eb
#include <torch/extension.h>
torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma);
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("discounted_cumsum_right", &discounted_cumsum_right,
"Discounted Cumulative Sum Right");
m.def("discounted_cumsum_left", &discounted_cumsum_left, "Discounted Cumulative Sum (Left)");
m.def("discounted_cumsum_right", &discounted_cumsum_right, "Discounted Cumulative Sum (Right)");
}
......@@ -29,10 +29,27 @@ torch_discounted_cumsum = load(
# return d_input, d_weights, d_bias, d_old_h, d_old_cell
def discounted_cumsum_left(input, gamma):
return torch_discounted_cumsum.discounted_cumsum_left(input, gamma)
def discounted_cumsum_right(input, gamma):
return torch_discounted_cumsum.discounted_cumsum_right(input, gamma)
def discounted_cumsum_left_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
out = []
last_col = torch.zeros((input.shape[0], 1), dtype=input.dtype, device=input.device)
for i in range(input.shape[1]):
cur_col = input[:, i].unsqueeze(-1)
last_col = cur_col + gamma * last_col
out.append(last_col)
out = torch.cat(out, dim=1)
return out
def discounted_cumsum_right_gold(input, gamma):
assert input.dim() == 2
assert 0 <= gamma <= 1
......@@ -46,7 +63,20 @@ def discounted_cumsum_right_gold(input, gamma):
return out
def test():
def test_left():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
gamma = 0.99
out_gold_32 = discounted_cumsum_left_gold(x, gamma)
out_gold_64 = discounted_cumsum_left_gold(x.double(), gamma)
out_fn = discounted_cumsum_left(x, gamma)
diff_32 = (out_fn - out_gold_32).abs().max().item()
diff_64 = (out_fn - out_gold_64).abs().max().item()
print('left diff_32', diff_32)
print('left diff_64', diff_64)
def test_right():
torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
gamma = 0.99
......@@ -55,8 +85,8 @@ def test():
out_fn = discounted_cumsum_right(x, gamma)
diff_32 = (out_fn - out_gold_32).abs().max().item()
diff_64 = (out_fn - out_gold_64).abs().max().item()
print('diff_32', diff_32)
print('diff_64', diff_64)
print('right diff_32', diff_32)
print('right diff_64', diff_64)
def test_speed(reps=10000):
......@@ -71,5 +101,6 @@ def test_speed(reps=10000):
if __name__ == '__main__':
test()
test_speed()
test_left()
test_right()
#test_speed()
......@@ -3,65 +3,73 @@
template <typename scalar_t>
__device__ __forceinline__
scalar_t discounted_sum_pow(scalar_t a, scalar_t b, scalar_t gamma, int power) {
scalar_t discounted_sum_power(scalar_t a, scalar_t b, scalar_t gamma, int power) {
return a + b * pow(gamma, scalar_t(power));
}
enum SumDirection {
SUM_RIGHT,
SUM_LEFT
SUM_DIRECTION_LEFT,
SUM_DIRECTION_RIGHT,
};
template <SumDirection d>
template <SumDirection sum_direction>
__device__ __forceinline__
void resolve_positions(
const int &gr_prev_stride, const int &gr_cur_stride, const int &gr_of_thread, const int &thread_in_gr,
int &change_pos, int &discounted_pos, int &discount_power
const int &stride_prev_group, const int &stride_cur_group, const int &group_of_thread, const int &thread_in_group,
int &change_pos, int &discounted_pos, int &discount_power
);
template <>
__device__ __forceinline__
void resolve_positions<SUM_RIGHT>(
const int &gr_prev_stride, const int &gr_cur_stride, const int &gr_of_thread, const int &thread_in_gr,
int &change_pos, int &discounted_pos, int &discount_power
void resolve_positions<SUM_DIRECTION_LEFT>(
const int &stride_prev_group, const int &stride_cur_group, const int &group_of_thread, const int &thread_in_group,
int &change_pos, int &discounted_pos, int &discount_power
) {
change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
discount_power = gr_prev_stride - thread_in_gr;
change_pos = group_of_thread * stride_cur_group + thread_in_group + stride_prev_group;
discounted_pos = group_of_thread * stride_cur_group + stride_prev_group - 1;
discount_power = thread_in_group + 1;
}
template <typename scalar_t, SumDirection d>
template <>
__device__ __forceinline__
void resolve_positions<SUM_DIRECTION_RIGHT>(
const int &stride_prev_group, const int &stride_cur_group, const int &group_of_thread, const int &thread_in_group,
int &change_pos, int &discounted_pos, int &discount_power
) {
change_pos = group_of_thread * stride_cur_group + thread_in_group;
discounted_pos = group_of_thread * stride_cur_group + stride_prev_group;
discount_power = stride_prev_group - thread_in_group;
}
template <typename scalar_t, SumDirection sum_direction>
__global__
void discounted_cumsum_kernel_stage(
torch::PackedTensorAccessor32<scalar_t, 2> x,
const scalar_t gamma,
int stage
torch::PackedTensorAccessor32<scalar_t, 2> x,
const scalar_t gamma,
int stage
) {
const int len = x.size(1);
const int threadidx = blockIdx.x * blockDim.x + threadIdx.x;
const int threadidy = blockIdx.y * blockDim.y + threadIdx.y;
const int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int thread_idy = blockIdx.y * blockDim.y + threadIdx.y;
if (threadidy >= x.size(0)) {
if (thread_idy >= x.size(0)) {
return;
}
int gr_prev_stride = 1 << stage;
int gr_cur_stride = gr_prev_stride << 1;
int gr_of_thread = threadidx >> stage;
int thread_in_gr = threadidx - (gr_of_thread << stage);
int stride_prev_group = 1 << stage;
int stride_cur_group = stride_prev_group << 1;
//int change_pos = gr_of_thread * gr_cur_stride + thread_in_gr;
//int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
//int discount_power = gr_prev_stride - thread_in_gr;
int group_of_thread = thread_idx >> stage;
int thread_in_group = thread_idx - (group_of_thread << stage);
int change_pos, discounted_pos, discount_power;
resolve_positions<d>(
gr_prev_stride, gr_cur_stride, gr_of_thread, thread_in_gr,
resolve_positions<sum_direction>(
stride_prev_group, stride_cur_group, group_of_thread, thread_in_group,
change_pos, discounted_pos, discount_power
);
......@@ -69,9 +77,9 @@ void discounted_cumsum_kernel_stage(
return;
}
x[threadidy][change_pos] = discounted_sum_pow(
x[threadidy][change_pos],
x[threadidy][discounted_pos],
x[thread_idy][change_pos] = discounted_sum_power(
x[thread_idy][change_pos],
x[thread_idy][discounted_pos],
gamma,
discount_power
);
......@@ -84,7 +92,7 @@ int log2ceil(int x) {
}
template <SumDirection d>
template <SumDirection sum_direction>
torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
// Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
......@@ -107,7 +115,7 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
for (int stage=0; stage<nstages; stage++) {
AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_kernel_stage", ([&] {
discounted_cumsum_kernel_stage<scalar_t, d><<<blocks, threads>>>(
discounted_cumsum_kernel_stage<scalar_t, sum_direction><<<blocks, threads>>>(
y.packed_accessor32<scalar_t, 2>(),
scalar_t(gamma),
stage
......@@ -119,10 +127,11 @@ torch::Tensor discounted_cumsum(torch::Tensor x, double gamma) {
}
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
return discounted_cumsum<SUM_RIGHT>(x, gamma);
torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma) {
return discounted_cumsum<SUM_DIRECTION_LEFT>(x, gamma);
}
//torch::Tensor discounted_cumsum_left(torch::Tensor x, double gamma) {
//}
torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
return discounted_cumsum<SUM_DIRECTION_RIGHT>(x, gamma);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment