Unverified Commit 7b150fab authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Feature] Optimize the PyTorch CUDA implementation for Criss Cross Attention (#1143)



* optimize criss cross attention

* optimize criss cross attention

* optimize criss cross attention

* fix lint

* fix ci, remove useless variable

* better ca_forward_kernel
Co-authored-by: default avatarwondervictor <victorchanchina@gmail.com>
parent 6fe37225
...@@ -14,25 +14,17 @@ __global__ void ca_forward_kernel(const T *t, const T *f, T *weight, int num, ...@@ -14,25 +14,17 @@ __global__ void ca_forward_kernel(const T *t, const T *f, T *weight, int num,
int y = blockIdx.y * blockDim.y + threadIdx.y; int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width; int sp = height * width;
int len = height + width - 1; int len = height + width - 1;
int z = blockIdx.z; int z = blockIdx.z % len;
int batch = blockIdx.z / len;
if (x < width && y < height && z < height + width - 1) {
for (int batch = 0; batch < num; ++batch) { if (x < width && y < height) {
for (int plane = 0; plane < chn; ++plane) { T *weight_ptr = weight + (batch * len + z) * sp + y * width + x;
T _t = t[(batch * chn + plane) * sp + y * width + x]; const int t_offset = y * width + x;
const int j = (z - width < y) ? z - width : z - width + 1;
if (z < width) { const int f_offset = z < width ? y * width + z : j * width + x;
int i = z; for (int plane = 0; plane < chn; ++plane) {
T _f = f[(batch * chn + plane) * sp + y * width + i]; const int tf_base = (batch * chn + plane) * sp;
weight[(batch * len + i) * sp + y * width + x] += _t * _f; *weight_ptr += t[tf_base + t_offset] * f[tf_base + f_offset];
} else {
int i = z - width;
int j = i < y ? i : i + 1;
T _f = f[(batch * chn + plane) * sp + j * width + x];
weight[(batch * len + width + i) * sp + y * width + x] += _t * _f;
}
}
} }
} }
} }
...@@ -44,23 +36,22 @@ __global__ void ca_backward_kernel_t(const T *dw, const T *t, const T *f, T *dt, ...@@ -44,23 +36,22 @@ __global__ void ca_backward_kernel_t(const T *dw, const T *t, const T *f, T *dt,
int y = blockIdx.y * blockDim.y + threadIdx.y; int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width; int sp = height * width;
int len = height + width - 1; int len = height + width - 1;
int plane = blockIdx.z; int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;
if (x < width && y < height && plane < chn) {
for (int batch = 0; batch < num; ++batch) { if (x < width && y < height) {
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
T _dw = dw[(batch * len + i) * sp + y * width + x]; T _dw = dw[(batch * len + i) * sp + y * width + x];
T _f = f[(batch * chn + plane) * sp + y * width + i]; T _f = f[(batch * chn + plane) * sp + y * width + i];
dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f; dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
} }
for (int i = 0; i < height; ++i) { for (int i = 0; i < height; ++i) {
if (i == y) continue; if (i == y) continue;
int j = i < y ? i : i - 1; int j = i < y ? i : i - 1;
T _dw = dw[(batch * len + width + j) * sp + y * width + x]; T _dw = dw[(batch * len + width + j) * sp + y * width + x];
T _f = f[(batch * chn + plane) * sp + i * width + x]; T _f = f[(batch * chn + plane) * sp + i * width + x];
dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f; dt[(batch * chn + plane) * sp + y * width + x] += _dw * _f;
}
} }
} }
} }
...@@ -72,23 +63,22 @@ __global__ void ca_backward_kernel_f(const T *dw, const T *t, const T *f, T *df, ...@@ -72,23 +63,22 @@ __global__ void ca_backward_kernel_f(const T *dw, const T *t, const T *f, T *df,
int y = blockIdx.y * blockDim.y + threadIdx.y; int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width; int sp = height * width;
int len = height + width - 1; int len = height + width - 1;
int plane = blockIdx.z; int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;
if (x < width && y < height && plane < chn) {
for (int batch = 0; batch < num; ++batch) { if (x < width && y < height) {
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
T _dw = dw[(batch * len + x) * sp + y * width + i]; T _dw = dw[(batch * len + x) * sp + y * width + i];
T _t = t[(batch * chn + plane) * sp + y * width + i]; T _t = t[(batch * chn + plane) * sp + y * width + i];
df[(batch * chn + plane) * sp + y * width + x] += _dw * _t; df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
} }
for (int i = 0; i < height; ++i) { for (int i = 0; i < height; ++i) {
if (i == y) continue; if (i == y) continue;
int j = i > y ? y : y - 1; int j = i > y ? y : y - 1;
T _dw = dw[(batch * len + width + j) * sp + i * width + x]; T _dw = dw[(batch * len + width + j) * sp + i * width + x];
T _t = t[(batch * chn + plane) * sp + i * width + x]; T _t = t[(batch * chn + plane) * sp + i * width + x];
df[(batch * chn + plane) * sp + y * width + x] += _dw * _t; df[(batch * chn + plane) * sp + y * width + x] += _dw * _t;
}
} }
} }
} }
...@@ -100,24 +90,22 @@ __global__ void ca_map_forward_kernel(const T *weight, const T *g, T *out, ...@@ -100,24 +90,22 @@ __global__ void ca_map_forward_kernel(const T *weight, const T *g, T *out,
int y = blockIdx.y * blockDim.y + threadIdx.y; int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width; int sp = height * width;
int len = height + width - 1; int len = height + width - 1;
int plane = blockIdx.z; int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;
if (x < width && y < height && plane < chn) { if (x < width && y < height) {
for (int batch = 0; batch < num; ++batch) { for (int i = 0; i < width; ++i) {
for (int i = 0; i < width; ++i) { T _g = g[(batch * chn + plane) * sp + y * width + i];
T _g = g[(batch * chn + plane) * sp + y * width + i]; T _w = weight[(batch * len + i) * sp + y * width + x];
T _w = weight[(batch * len + i) * sp + y * width + x]; out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
out[(batch * chn + plane) * sp + y * width + x] += _g * _w; }
} for (int i = 0; i < height; ++i) {
for (int i = 0; i < height; ++i) { if (i == y) continue;
if (i == y) continue;
int j = i < y ? i : i - 1;
int j = i < y ? i : i - 1;
T _g = g[(batch * chn + plane) * sp + i * width + x];
T _g = g[(batch * chn + plane) * sp + i * width + x]; T _w = weight[(batch * len + width + j) * sp + y * width + x];
T _w = weight[(batch * len + width + j) * sp + y * width + x]; out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
out[(batch * chn + plane) * sp + y * width + x] += _g * _w;
}
} }
} }
} }
...@@ -130,25 +118,23 @@ __global__ void ca_map_backward_kernel_w(const T *dout, const T *weight, ...@@ -130,25 +118,23 @@ __global__ void ca_map_backward_kernel_w(const T *dout, const T *weight,
int y = blockIdx.y * blockDim.y + threadIdx.y; int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width; int sp = height * width;
int len = height + width - 1; int len = height + width - 1;
int z = blockIdx.z;
int z = blockIdx.z % len;
if (x < width && y < height && z < height + width - 1) { int batch = blockIdx.z / len;
for (int batch = 0; batch < num; ++batch) {
for (int plane = 0; plane < chn; ++plane) { if (x < width && y < height) {
T _dout = dout[(batch * chn + plane) * sp + y * width + x]; int widx = (batch * len + z) * sp + y * width + x;
int dout_idx = batch * chn * sp + y * width + x;
if (z < width) { int gidx = batch * chn * sp;
int i = z; if (z < width) {
T _g = g[(batch * chn + plane) * sp + y * width + i]; gidx += y * width + z;
dw[(batch * len + i) * sp + y * width + x] += _dout * _g; } else {
} else { int j = z - width;
int i = z - width; j = j < y ? j : j + 1;
int j = i < y ? i : i + 1; gidx += j * width + x;
}
T _g = g[(batch * chn + plane) * sp + j * width + x]; for (int plane = 0; plane < chn; plane++) {
dw[(batch * len + width + i) * sp + y * width + x] += _dout * _g; dw[widx] += dout[dout_idx + plane * sp] * g[gidx + plane * sp];
}
}
} }
} }
} }
...@@ -161,25 +147,21 @@ __global__ void ca_map_backward_kernel_g(const T *dout, const T *weight, ...@@ -161,25 +147,21 @@ __global__ void ca_map_backward_kernel_g(const T *dout, const T *weight,
int y = blockIdx.y * blockDim.y + threadIdx.y; int y = blockIdx.y * blockDim.y + threadIdx.y;
int sp = height * width; int sp = height * width;
int len = height + width - 1; int len = height + width - 1;
int plane = blockIdx.z; int plane = blockIdx.z % chn;
int batch = blockIdx.z / chn;
if (x < width && y < height && plane < chn) { int index = (batch * chn + plane) * sp + y * width + x;
for (int batch = 0; batch < num; ++batch) {
for (int i = 0; i < width; ++i) { if (x < width && y < height) {
T _dout = dout[(batch * chn + plane) * sp + y * width + i]; for (int i = 0; i < width; ++i) {
T _w = weight[(batch * len + x) * sp + y * width + i]; dg[index] += dout[(batch * chn + plane) * sp + y * width + i] *
dg[(batch * chn + plane) * sp + y * width + x] += _dout * _w; weight[(batch * len + x) * sp + y * width + i];
} }
for (int i = 0; i < height; ++i) { for (int i = 0; i < height; ++i) {
if (i == y) continue; if (i == y) continue;
int j = i > y ? y : y - 1; int j = i > y ? y : y - 1;
dg[index] += dout[(batch * chn + plane) * sp + i * width + x] *
T _dout = dout[(batch * chn + plane) * sp + i * width + x]; weight[(batch * len + width + j) * sp + i * width + x];
T _w = weight[(batch * len + width + j) * sp + i * width + x];
dg[(batch * chn + plane) * sp + y * width + x] += _dout * _w;
}
} }
} }
} }
#endif // CC_ATTENTION_CUDA_KERNEL_CUH #endif // CC_ATTENTION_CUDA_KERNEL_CUH
...@@ -24,8 +24,8 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f, ...@@ -24,8 +24,8 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f,
dim3 threads(32, 32); dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x; int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y; int d2 = (h + threads.y - 1) / threads.y;
int d3 = h + w; int d3 = h + w - 1;
dim3 blocks(d1, d2, d3); dim3 blocks(d1, d2, d3 * n);
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] { AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] {
ca_forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>( ca_forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
...@@ -53,7 +53,7 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t, ...@@ -53,7 +53,7 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
dim3 threads(32, 32); dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x; int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y; int d2 = (h + threads.y - 1) / threads.y;
int d3 = c; int d3 = c * n;
dim3 blocks(d1, d2, d3); dim3 blocks(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] { AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] {
...@@ -90,7 +90,7 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g, ...@@ -90,7 +90,7 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
dim3 threads(32, 32); dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x; int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y; int d2 = (h + threads.y - 1) / threads.y;
int d3 = c; int d3 = c * n;
dim3 blocks(d1, d2, d3); dim3 blocks(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] { AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] {
...@@ -119,8 +119,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight, ...@@ -119,8 +119,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
dim3 threads(32, 32); dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x; int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y; int d2 = (h + threads.y - 1) / threads.y;
int d3 = h + w; int d3 = h + w - 1;
dim3 blocks(d1, d2, d3); dim3 blocks(d1, d2, d3 * n);
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
weight.scalar_type(), "ca_map_backward_kernel_w", [&] { weight.scalar_type(), "ca_map_backward_kernel_w", [&] {
...@@ -130,7 +130,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight, ...@@ -130,7 +130,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
g.contiguous().data_ptr<scalar_t>(), g.contiguous().data_ptr<scalar_t>(),
dw.contiguous().data_ptr<scalar_t>(), n, c, h, w); dw.contiguous().data_ptr<scalar_t>(), n, c, h, w);
}); });
d3 = c * n;
blocks = dim3(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] { AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] {
ca_map_backward_kernel_g<scalar_t><<<blocks, threads, 0, stream>>>( ca_map_backward_kernel_g<scalar_t><<<blocks, threads, 0, stream>>>(
dout.contiguous().data_ptr<scalar_t>(), dout.contiguous().data_ptr<scalar_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment