"gallery/transforms/plot_datapoints.py" did not exist on "59b27ed64cf126357d60e8f2944d204f83075e2e"
Unverified Commit 94818ad1 authored by pc's avatar pc Committed by GitHub
Browse files

update ca_forward_kernel (#1144)

parent 7b150fab
...@@ -24,8 +24,8 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f, ...@@ -24,8 +24,8 @@ void CAForwardCUDAKernelLauncher(const Tensor t, const Tensor f,
dim3 threads(32, 32); dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x; int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y; int d2 = (h + threads.y - 1) / threads.y;
int d3 = h + w; int d3 = h + w - 1;
dim3 blocks(d1, d2, d3); dim3 blocks(d1, d2, d3 * n);
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] { AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_forward", [&] {
ca_forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>( ca_forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
...@@ -53,7 +53,7 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t, ...@@ -53,7 +53,7 @@ void CABackwardCUDAKernelLauncher(const Tensor dw, const Tensor t,
dim3 threads(32, 32); dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x; int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y; int d2 = (h + threads.y - 1) / threads.y;
int d3 = c; int d3 = c * n;
dim3 blocks(d1, d2, d3); dim3 blocks(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] { AT_DISPATCH_FLOATING_TYPES(t.scalar_type(), "ca_backward_kernel_t", [&] {
...@@ -90,7 +90,7 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g, ...@@ -90,7 +90,7 @@ void CAMapForwardCUDAKernelLauncher(const Tensor weight, const Tensor g,
dim3 threads(32, 32); dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x; int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y; int d2 = (h + threads.y - 1) / threads.y;
int d3 = c; int d3 = c * n;
dim3 blocks(d1, d2, d3); dim3 blocks(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] { AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_forward", [&] {
...@@ -119,8 +119,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight, ...@@ -119,8 +119,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
dim3 threads(32, 32); dim3 threads(32, 32);
int d1 = (w + threads.x - 1) / threads.x; int d1 = (w + threads.x - 1) / threads.x;
int d2 = (h + threads.y - 1) / threads.y; int d2 = (h + threads.y - 1) / threads.y;
int d3 = h + w; int d3 = h + w - 1;
dim3 blocks(d1, d2, d3); dim3 blocks(d1, d2, d3 * n);
AT_DISPATCH_FLOATING_TYPES( AT_DISPATCH_FLOATING_TYPES(
weight.scalar_type(), "ca_map_backward_kernel_w", [&] { weight.scalar_type(), "ca_map_backward_kernel_w", [&] {
...@@ -130,7 +130,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight, ...@@ -130,7 +130,8 @@ void CAMapBackwardCUDAKernelLauncher(const Tensor dout, const Tensor weight,
g.contiguous().data_ptr<scalar_t>(), g.contiguous().data_ptr<scalar_t>(),
dw.contiguous().data_ptr<scalar_t>(), n, c, h, w); dw.contiguous().data_ptr<scalar_t>(), n, c, h, w);
}); });
d3 = c * n;
blocks = dim3(d1, d2, d3);
AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] { AT_DISPATCH_FLOATING_TYPES(g.scalar_type(), "ca_map_backward_kernel_g", [&] {
ca_map_backward_kernel_g<scalar_t><<<blocks, threads, 0, stream>>>( ca_map_backward_kernel_g<scalar_t><<<blocks, threads, 0, stream>>>(
dout.contiguous().data_ptr<scalar_t>(), dout.contiguous().data_ptr<scalar_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment