Commit 4c605c1e authored by anton's avatar anton
Browse files

removed the slower coalesced flavor

parent caff36e7
#include <torch/extension.h> #include <torch/extension.h>
torch::Tensor discounted_cumsum_right_minthreads(torch::Tensor x, double gamma);
torch::Tensor discounted_cumsum_right_coalesced(torch::Tensor x, double gamma); torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("discounted_cumsum_right_minthreads", &discounted_cumsum_right_minthreads, m.def("discounted_cumsum_right", &discounted_cumsum_right,
"Discounted Cumulative Sum Right Minimum Threads"); "Discounted Cumulative Sum Right");
m.def("discounted_cumsum_right_coalesced", &discounted_cumsum_right_coalesced,
"Discounted Cumulative Sum Right Coalesced Writes");
} }
...@@ -29,12 +29,8 @@ torch_discounted_cumsum = load( ...@@ -29,12 +29,8 @@ torch_discounted_cumsum = load(
# return d_input, d_weights, d_bias, d_old_h, d_old_cell # return d_input, d_weights, d_bias, d_old_h, d_old_cell
def discounted_cumsum_right_minthreads(input, gamma): def discounted_cumsum_right(input, gamma):
return torch_discounted_cumsum.discounted_cumsum_right_minthreads(input, gamma) return torch_discounted_cumsum.discounted_cumsum_right(input, gamma)
def discounted_cumsum_right_coalesced(input, gamma):
return torch_discounted_cumsum.discounted_cumsum_right_coalesced(input, gamma)
def discounted_cumsum_right_gold(input, gamma): def discounted_cumsum_right_gold(input, gamma):
...@@ -50,33 +46,30 @@ def discounted_cumsum_right_gold(input, gamma): ...@@ -50,33 +46,30 @@ def discounted_cumsum_right_gold(input, gamma):
return out return out
def test_fn(fn): def test():
torch.manual_seed(0) torch.manual_seed(0)
x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda() x = torch.full((10, 10000), fill_value=1.0, dtype=torch.float32).cuda()
gamma = 0.99 gamma = 0.99
out_gold_32 = discounted_cumsum_right_gold(x, gamma) out_gold_32 = discounted_cumsum_right_gold(x, gamma)
out_gold_64 = discounted_cumsum_right_gold(x.double(), gamma) out_gold_64 = discounted_cumsum_right_gold(x.double(), gamma)
out_fn = fn(x, gamma) out_fn = discounted_cumsum_right(x, gamma)
diff_32 = (out_fn - out_gold_32).abs().max().item() diff_32 = (out_fn - out_gold_32).abs().max().item()
diff_64 = (out_fn - out_gold_64).abs().max().item() diff_64 = (out_fn - out_gold_64).abs().max().item()
print(fn.__name__)
print('diff_32', diff_32) print('diff_32', diff_32)
print('diff_64', diff_64) print('diff_64', diff_64)
def test_speed(fn, reps=10000): def test_speed(reps=10000):
torch.manual_seed(0) torch.manual_seed(0)
x = torch.randn(10, 100000, dtype=torch.float32).cuda() x = torch.randn(10, 100000, dtype=torch.float32).cuda()
gamma = 0.99 gamma = 0.99
t1 = time.time() t1 = time.time()
for _ in range(reps): for _ in range(reps):
fn(x, gamma) discounted_cumsum_right(x, gamma)
t2 = time.time() t2 = time.time()
print(fn.__name__, t2-t1) print('sec:', t2-t1)
if __name__ == '__main__': if __name__ == '__main__':
test_fn(discounted_cumsum_right_minthreads) test()
test_fn(discounted_cumsum_right_coalesced) test_speed()
test_speed(discounted_cumsum_right_minthreads)
test_speed(discounted_cumsum_right_coalesced)
...@@ -14,14 +14,11 @@ int log2ceil(int x) { ...@@ -14,14 +14,11 @@ int log2ceil(int x) {
template <typename scalar_t> template <typename scalar_t>
__global__ void discounted_cumsum_right_kernel_minthreads_stage( __global__ void discounted_cumsum_right_kernel_stage(
torch::PackedTensorAccessor32<scalar_t, 2> x, torch::PackedTensorAccessor32<scalar_t, 2> x,
const scalar_t gamma, const scalar_t gamma,
int stage int stage
) { ) {
// Pros: Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Cons: Uncoalesced writes.
const int len = x.size(1); const int len = x.size(1);
const int threadidx = blockIdx.x * blockDim.x + threadIdx.x; const int threadidx = blockIdx.x * blockDim.x + threadIdx.x;
const int threadidy = blockIdx.y * blockDim.y + threadIdx.y; const int threadidy = blockIdx.y * blockDim.y + threadIdx.y;
...@@ -53,49 +50,9 @@ __global__ void discounted_cumsum_right_kernel_minthreads_stage( ...@@ -53,49 +50,9 @@ __global__ void discounted_cumsum_right_kernel_minthreads_stage(
} }
template <typename scalar_t> torch::Tensor discounted_cumsum_right(torch::Tensor x, double gamma) {
__global__ void discounted_cumsum_right_kernel_coalesced_stage( // Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
torch::PackedTensorAccessor32<scalar_t, 2> x, // Results in uncoalesced writes, which is still faster than coalesced writes with half threads idling.
const scalar_t gamma,
int stage
) {
// Pros: Coalesced writes.
// Cons: Threads allocated statically per each element. Half of threads idles upon each iteration.
const int len = x.size(1);
const int threadidx = blockIdx.x * blockDim.x + threadIdx.x;
const int threadidy = blockIdx.y * blockDim.y + threadIdx.y;
if (threadidx >= len || threadidy >= x.size(0)) {
return;
}
int gr_prev_stride = 1 << stage;
int gr_cur_stride = gr_prev_stride << 1;
int gr_of_thread = threadidx >> (stage + 1);
int thread_in_gr = threadidx - (gr_of_thread << (stage + 1));
int change_pos = threadidx;
int discounted_pos = gr_of_thread * gr_cur_stride + gr_prev_stride;
int discount_power = gr_prev_stride - thread_in_gr;
if (thread_in_gr >= gr_prev_stride || discounted_pos >= len) {
return;
}
x[threadidy][change_pos] = discounted_sum_pow(
x[threadidy][change_pos],
x[threadidy][discounted_pos],
gamma,
discount_power
);
}
torch::Tensor discounted_cumsum_right_minthreads(torch::Tensor x, double gamma) {
// Pros: Minimum required number of threads, assigns them dynamically to respective positions upon each iteration.
// Cons: Uncoalesced writes.
TORCH_CHECK(x.type().is_cuda(), "Input must be a CUDA tensor"); TORCH_CHECK(x.type().is_cuda(), "Input must be a CUDA tensor");
TORCH_CHECK(x.is_contiguous(), "Input must be contiguous"); TORCH_CHECK(x.is_contiguous(), "Input must be contiguous");
...@@ -108,47 +65,14 @@ torch::Tensor discounted_cumsum_right_minthreads(torch::Tensor x, double gamma) ...@@ -108,47 +65,14 @@ torch::Tensor discounted_cumsum_right_minthreads(torch::Tensor x, double gamma)
auto y = x.clone(); auto y = x.clone();
const int threads = 32; const int threads = 64;
const int nstages = log2ceil(x.size(1)); const int nstages = log2ceil(x.size(1));
const int threads_total_x = 1 << (nstages - 1); const int threads_total_x = 1 << (nstages - 1);
const dim3 blocks((threads_total_x + threads - 1) / threads, x.size(0)); const dim3 blocks((threads_total_x + threads - 1) / threads, x.size(0));
for (int stage=0; stage<nstages; stage++) { for (int stage=0; stage<nstages; stage++) {
AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_right_kernel_minthreads_stage", ([&] { AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_right_kernel_stage", ([&] {
discounted_cumsum_right_kernel_minthreads_stage<scalar_t><<<blocks, threads>>>( discounted_cumsum_right_kernel_stage<scalar_t><<<blocks, threads>>>(
y.packed_accessor32<scalar_t, 2>(),
scalar_t(gamma),
stage
);
}));
}
return y;
}
torch::Tensor discounted_cumsum_right_coalesced(torch::Tensor x, double gamma) {
// Pros: Coalesced writes.
// Cons: Threads allocated statically per each element. Half of threads idles upon each iteration.
TORCH_CHECK(x.type().is_cuda(), "Input must be a CUDA tensor");
TORCH_CHECK(x.is_contiguous(), "Input must be contiguous");
TORCH_CHECK(x.dim() == 2, "Input must be 2-dimensional");
TORCH_CHECK(0.0 <= gamma && gamma <= 1.0, "Gamma must be in the range [0,1]");
if (x.size(1) == 0) {
return x;
}
auto y = x.clone();
const int threads = 32;
const int nstages = log2ceil(x.size(1));
const dim3 blocks((x.size(1) + threads - 1) / threads, x.size(0));
for (int stage=0; stage<nstages; stage++) {
AT_DISPATCH_FLOATING_TYPES(x.type(), "discounted_cumsum_right_kernel_coalesced_stage", ([&] {
discounted_cumsum_right_kernel_coalesced_stage<scalar_t><<<blocks, threads>>>(
y.packed_accessor32<scalar_t, 2>(), y.packed_accessor32<scalar_t, 2>(),
scalar_t(gamma), scalar_t(gamma),
stage stage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment