Commit 12d3b03d authored by Daniel Povey's avatar Daniel Povey
Browse files

Fix..

parent 6fbcae3e
from .discounted_cumsum import discounted_cumsum_left, discounted_cumsum_right
from .integrated_conv import integrated_conv
......@@ -126,5 +126,5 @@ def integrated_conv(input, pos_add, pos_mul):
return integrated_conv(input.unsqueeze(-2),
pos_add.unsqueeze(-2), pos_mul.unsqueeze(-2)).squeeze(-2)
assert input.ndim == 4 and pos_add.ndim == 3 and pos_mul.ndim == 3
assert input.dim[1] // 2 == pos_add.dim[0] == pos_mul.dim[0]
assert input.shape[1] // 2 == pos_add.shape[0] == pos_mul.shape[0]
return IntegratedConvFunction.apply(input, pos_add, pos_mul)
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h> // for getCurrentCUDAStream()
#include <cooperative_groups.h>
template <typename scalar_t, typename group_t>
__device__ int reduce_sum(group_t g, scalar_t *temp, scalar_t val)
{
int lane = g.thread_rank();
// Each iteration halves the number of active threads
// Each thread adds its partial sum[i] to sum[lane+i]
#pragma unroll
for (int i = g.size() / 2; i > 0; i /= 2)
{
temp[lane] = val;
g.sync(); // wait for all threads to store
if (lane < i) val += temp[lane + i];
g.sync(); // wait for all threads to load
}
return val; // note: only thread 0 will return full sum
}
/*
Forward of integrated_conv. Each thread group handles a single channel
(equal to blockIdx.x), and loops over patches of the output.
......@@ -66,12 +46,12 @@ void integrated_conv_kernel(
torch::PackedTensorAccessor32<scalar_t, 4> input, // N, 2*C, H, W
torch::PackedTensorAccessor32<scalar_t, 3> pos_add, // C, kH, kW
torch::PackedTensorAccessor32<scalar_t, 3> pos_mul, // C, kH, kW
torch::PackedTensorAcessor32<scalar_t, 4> output, // N, C, H, W
torch::PackedTensorAccessor32<scalar_t, 4> output, // N, C, H, W
int opatchH, // output-patch height,
int opatchW // output-patch width
) {
const int H = input.size(2),
W = input.size(3)
W = input.size(3),
kH = pos_add.size(1),
kW = pos_add.size(2),
npatchH = (H + opatchH - 1) / opatchH, // num patches in vertical dim
......@@ -84,16 +64,17 @@ void integrated_conv_kernel(
// exact number of channels.
const int ipatchH = opatchH + kH - 1,
ipatchW = ipatchW + kW - 1,
ipatchW = opatchW + kW - 1,
ipatch_size = ipatchH * ipatchW,
opatch_size = opatchH * opatchW;
// `extern_buf` is general-purpose shared memory, which we'll divide between
// pos_add, pos_mul and src_img_buf to be shared between the src image size
// (ipatch_size) and the number of threads (blockDim.x)
__shared__ scalar_t buf[buffer_dim];
__shared__ scalar_t
// these are pointers to __shared__ memory; the compiler should
// be able to figure this out.
scalar_t
*pos_add_buf = (scalar_t*)extern_buf, // pos_add positional-encoding / kernel parameters,
// indexed [kh*kW + kw] where kh and kw are vertical
// and horizontal positions in the kernel.
......@@ -106,10 +87,12 @@ void integrated_conv_kernel(
// image.
threads_per_opixel = blockDim.x / opatch_size;
int threads_per_opixel = blockDim.x / opatch_size;
assert(blockDim.x == opatch_size * threads_per_opixel);
auto tile = cooperative_groups::tiled_partition(g, threads_per_opixel);
auto tile = cooperative_groups::tiled_partition(
cooperative_groups::this_thread_block(),
threads_per_opixel);
// pos_in_patch will be interpreted as h_in_patch * opatchW + w_in_patch.
int pos_in_patch = threadIdx.x / threads_per_opixel;
......@@ -173,6 +156,7 @@ void integrated_conv_kernel(
// subsystem handles this well enough that we can just ignore the issue
// rather than try to optimize it.
// https://forums.developer.nvidia.com/t/accessing-same-global-memory-address-within-warps/66574
int C = input.size(1) / 2;
dest_val = input[n][c + C][h][w]; // else 0.
}
......@@ -190,8 +174,8 @@ void integrated_conv_kernel(
// have a negative sign on the h and w indexes in the kernel.
// Also note: we already took care of padding and the associated
// offsets of -(kH / 2) and -(kW / 2).
int h_in_src_patch = h_in_patch + h_in_kernel,
w_in_src_patch = w_in_patch + w_in_kernel;
int h_in_src_patch = (pos_in_patch / opatchW) + h_in_kernel,
w_in_src_patch = (pos_in_patch % opatchW) + w_in_kernel;
scalar_t src_val = src_img_buf[h_in_src_patch * ipatchW + w_in_src_patch],
pos_add_val = pos_add_buf[pos_in_kernel];
scalar_t relu = (src_val + dest_val + pos_add_val);
......@@ -281,7 +265,7 @@ torch::Tensor integrated_conv_cuda(torch::Tensor input,
int threads_per_block = patchH * patchW * threads_per_opixel;
int buffer_numel = 2 * (kH * kW) + max<int>(threads_per_block,
int buffer_numel = 2 * (kH * kW) + std::max<int>(threads_per_block,
input_patch_size);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment