Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
8f096574
Commit
8f096574
authored
Jul 10, 2021
by
Daniel Povey
Browse files
Optimization of CPU code; start drafting forward code for CUDA
parent
d6081b04
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
131 additions
and
288 deletions
+131
-288
torch_learned_nonlin/learned_nonlin_cpu.cpp
torch_learned_nonlin/learned_nonlin_cpu.cpp
+31
-30
torch_learned_nonlin/learned_nonlin_cuda.cpp
torch_learned_nonlin/learned_nonlin_cuda.cpp
+7
-8
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
+93
-250
No files found.
torch_learned_nonlin/learned_nonlin_cpu.cpp
View file @
8f096574
...
...
@@ -34,10 +34,12 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
y_vals_a
=
y_vals
.
accessor
<
scalar_t
,
2
>
();
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
sum_negative
=
0.0
,
sum_positive
=
0.0
;
sum_positive
=
0.0
,
scale
=
exp
(
params_a
[
c
][
0
]);
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
;
y_vals_a
[
c
][
K
-
i
]
=
sum_negative
;
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
*
scale
;
y_vals_a
[
c
][
K
-
i
]
=
sum_negative
*
scale
;
sum_positive
+=
params_a
[
c
][
1
+
K
+
i
];
sum_negative
-=
params_a
[
c
][
K
-
i
];
}
...
...
@@ -51,9 +53,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
l
=
params_a
[
c
][
0
],
scale
=
exp
(
l
),
inv_scale
=
1.0
/
scale
;
scalar_t
inv_scale
=
exp
(
-
params_a
[
c
][
0
]);
for
(
int
t
=
0
;
t
<
T
;
t
++
)
{
// `x` is the scaled input x plus an offset so that -K maps to 0.
// Note: the discontinuities in our function are at -(K-1) ... +(K+1),
...
...
@@ -72,7 +72,7 @@ torch::Tensor learned_nonlin_cpu(torch::Tensor input,
scalar_t
y
=
(
x
-
(
scalar_t
)
min
)
*
params_a
[
c
][
min
+
1
]
+
y_vals_a
[
c
][
min
];
// printf("x = %f, y = %f, min = %d; y = (%f - %d) * %f+ %f\n", x, y, min,
// x, min, params_a[c][min + 1], y_vals_a[c][min - 1]);
output_a
[
b
][
c
][
t
]
=
y
*
scale
;
output_a
[
b
][
c
][
t
]
=
y
;
}
}
}}));
...
...
@@ -120,12 +120,13 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
y_vals_grad_a
=
y_vals_grad
.
accessor
<
scalar_t
,
2
>
();
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
sum_negative
=
0.0
,
sum_positive
=
0.0
;
sum_positive
=
0.0
,
scale
=
exp
(
params_a
[
c
][
0
]);
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
y_vals_a
[
c
][
K
+
i
]
=
sum_positive
;
y_vals_a
[
c
][
K
-
i
]
=
sum_negative
;
sum_positive
+=
params_a
[
c
][
1
+
K
+
i
];
sum_negative
-=
params_a
[
c
][
K
-
i
];
sum_positive
+=
params_a
[
c
][
1
+
K
+
i
]
*
scale
;
sum_negative
-=
params_a
[
c
][
K
-
i
]
*
scale
;
}
// the reference point for the lowest, half-infinite interval (the one
// starting at x=-(K-1) is still x=-(K-1); this value is repeated in y_vals.
...
...
@@ -138,10 +139,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
for
(
int
b
=
0
;
b
<
B
;
b
++
)
{
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
scalar_t
l
=
params_a
[
c
][
0
],
scale
=
exp
(
l
),
inv_scale
=
1.0
/
scale
,
scale_grad
=
0.0
,
scalar_t
inv_scale
=
exp
(
-
params_a
[
c
][
0
]),
inv_scale_grad
=
0.0
;
for
(
int
t
=
0
;
t
<
T
;
t
++
)
{
// `x` is the scaled input x plus an offset so that -K maps to 0.
...
...
@@ -151,7 +149,7 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
// that are < -(K-1) and > (K-1)
scalar_t
input
=
input_a
[
b
][
c
][
t
],
x
=
input
*
inv_scale
+
K
,
output
_grad
=
output_grad_a
[
b
][
c
][
t
];
y
_grad
=
output_grad_a
[
b
][
c
][
t
];
int
min
=
0
,
diff
=
K
;
while
(
diff
>
0
)
{
int
mid
=
min
+
diff
;
...
...
@@ -160,10 +158,6 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
diff
=
diff
>>
1
;
}
// OK, at this point, 0 <= min < 2*K.
scalar_t
y
=
(
x
-
(
scalar_t
)
min
)
*
params_a
[
c
][
min
+
1
]
+
y_vals_a
[
c
][
min
];
// backprop for: output_a[b][c][t] = y * scale;
scale_grad
+=
y
*
output_grad
;
scalar_t
y_grad
=
scale
*
output_grad
;
// backprop for:
// scalar_t y = (x - (scalar_t)min) * params_a[c][min + 1] + y_vals_a[c][min];
scalar_t
x_grad
=
y_grad
*
params_a
[
c
][
min
+
1
];
...
...
@@ -173,29 +167,36 @@ std::vector<torch::Tensor> learned_nonlin_backward_cpu(torch::Tensor input,
inv_scale_grad
+=
x_grad
*
input
;
input_grad_a
[
b
][
c
][
t
]
=
x_grad
*
inv_scale
;
}
// Do the backprop to l as if we had done:
// scale = exp(l); inv_scale = exp(-l);
scalar_t
l_grad
=
scale
*
scale_grad
-
inv_scale
*
inv_scale_grad
;
params_grad_a
[
c
][
0
]
+=
l_grad
;
// Do the backprop for: inv_scale = exp(-params_a[c][0])
params_grad_a
[
c
][
0
]
-=
inv_scale
*
inv_scale_grad
;
}
}
// Now do the backprop for the loop above where we set y_vals_a.
for
(
int
c
=
0
;
c
<
C
;
c
++
)
{
// backprop for: y_vals_a[c][0] = y_vals_a[c][1];
y_vals_grad_a
[
c
][
1
]
+=
y_vals_grad_a
[
c
][
0
];
scalar_t
sum_negative_grad
=
0.0
,
scalar_t
scale
=
exp
(
params_a
[
c
][
0
]),
inv_scale
=
1.0
/
scale
,
scale_grad
=
0.0
,
sum_negative_grad
=
0.0
,
sum_positive_grad
=
0.0
;
for
(
int
i
=
K
-
1
;
i
>=
0
;
i
--
)
{
// backprop for: sum_negative -= params_a[c][K - i];
params_grad_a
[
c
][
K
-
i
]
-=
sum_negative_grad
;
// backprop for: sum_positive += params_a[c][1 + K + i];
// backprop for: sum_positive += params_a[c][1 + K + i]
* scale
;
params_grad_a
[
c
][
1
+
K
+
i
]
+=
sum_positive_grad
;
// backprop for: y_vals_a[c][K - i] = sum_negative;
sum_negative_grad
+=
y_vals_grad_a
[
c
][
K
-
i
];
// backprop for: y_vals_a[c][K + i] = sum_positive;
sum_positive_grad
+=
y_vals_grad_a
[
c
][
K
+
i
];
// backprop for: y_vals_a[c][K - i] = sum_negative * scale;
sum_negative_grad
+=
y_vals_grad_a
[
c
][
K
-
i
]
*
scale
;
// The next code line is equivalent to:
// scale_grad += y_vals_grad_a[c][K - i] * sum_negative, substituting:
// sum_negative == y_vals_a[c][K - i] / scale
scale_grad
+=
y_vals_grad_a
[
c
][
K
-
i
]
*
y_vals_a
[
c
][
K
-
i
]
*
inv_scale
;
// backprop for: y_vals_a[c][K + i] = sum_positive * scale;
sum_positive_grad
+=
y_vals_grad_a
[
c
][
K
+
i
]
*
scale
;
scale_grad
+=
y_vals_grad_a
[
c
][
K
+
i
]
*
y_vals_a
[
c
][
K
+
i
]
*
inv_scale
;
}
// Backprop for: scale = exp(params_a[c][0]),
params_grad_a
[
c
][
0
]
+=
scale
*
scale_grad
;
}
}));
return
std
::
vector
<
torch
::
Tensor
>
({
input_grad
,
params_grad
});
...
...
torch_learned_nonlin/learned_nonlin_cuda.cpp
View file @
8f096574
...
...
@@ -4,18 +4,17 @@
// forward of learned_nonlin. """... """ comment of `learned_nonlin`
// in learned_nonlin.py documents the behavior of this function.
torch
::
Tensor
learned_nonlin_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
);
torch
::
Tensor
params
);
// backward of learned_nonlin; returns (grad_input, grad_pos_add, grad_pos_mul).
// backward of learned_nonlin; returns (grad_input, grad_params).
std
::
vector
<
torch
::
Tensor
>
learned_nonlin_backward_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
,
torch
::
Tensor
grad_output
);
torch
::
Tensor
params
,
torch
::
Tensor
grad_output
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"learned_nonlin_cuda"
,
&
learned_nonlin_cuda
,
"
Integrated convolution
forward function (CUDA)"
);
m
.
def
(
"learned_nonlin_backward_cuda"
,
&
learned_nonlin_backward_cuda
,
"
Integrated convolution
backward function (CUDA)"
);
m
.
def
(
"learned_nonlin_cuda"
,
&
learned_nonlin_cuda
,
"
Learned nonlinearity
forward function (CUDA)"
);
m
.
def
(
"learned_nonlin_backward_cuda"
,
&
learned_nonlin_backward_cuda
,
"
Learned nonlinearity
backward function (CUDA)"
);
}
torch_learned_nonlin/learned_nonlin_cuda_kernel.cu
View file @
8f096574
...
...
@@ -3,6 +3,8 @@
#include <cooperative_groups.h>
#define THREADS_PER_BLOCK 256
/*
...
...
@@ -39,33 +41,36 @@ __forceinline__ __device__ scalar_t tiled_warp_reduce_sum(int threads_per_tile,
// return the full sums of their tiles.
}
/*
Forward of learned_nonlin. Each thread group handles a single channel (equal
to blockIdx.x), and loops over patches of the output and over the image n
within the batch (different thread groups may be responsible for different
subsets of patches and/or images, see docs of gridDim below).
to blockIdx.x); the gridDim is (C, nb) where 1 <= nb <= B (nb relates to the batch).
Template args:
scalar_t: the floating-point type, e.g. float, double, maybe half.
Args:
input: input image, shape (N, 2*C, H, W)
pos_add: positional encoding, additive part, shape (C, kH, kW)
pos_mul: positional encoding, multiplicative part, shape (C, kH, kW)
output: output image, shape (N, 2*C, H, W)
Note: kH and kW must both be odd so that it's clear how to zero-pad.
The thread-block should have one dimension (x); blockDim.x should equal
some small power of 2 (threads_per_opixel) times the output-patch size which is
opatchH * opatchW (the output-patch height and width). We expect
threads_per_opixel to be 1, 2, or 4; we use a linear summation to sum up the
different threads' partial sums, and if threads_per_opixel gets larger we'd
need to make this a logarithmic reduction.
input: input image, shape (B, C, T) where B is batch size, C is
the number of channels and T is the time axis. (For more-than-1d
convolution setups, T would really be more than 1 axis, reshaped).
params: of shape (C, N+1) where N is the number of linear regions in the
piecewise linear function; params[c][0] is l which is
a log scale parameter that dictates how far apart
the discontinuities in the piecewise linear function are,
and params[c][n+1] for 0 <= n < N are the derivatives
of the linear parts of the piecewise linear function.
The discontinuities of the function are at:
exp(l) * [ -(N/2 - 1), -(N/2 - 2), ... (N/2 - 1) ]
This kernel is allocated with `extern_buf` containing enough memory
to store 2*N values of type scalar_t.
The blockDim must equal (THREADS_PER_BLOCK, 1, 1)
The requirements on the grid dimension are:
gridDim.x == num-channels C (required)
gridDim.y <=
num-patches per image (recommended)
gridDim.z
<
=
batch-size N (recommended)
1 <=
gridDim.y <=
B, where B is the number of blocks
gridDim.z
=
=
1
When we invoke this kernel, we'll invoke it as:
learned_nonlin_forward<<<gridDim, blockDim, bytesShared, stream>>>
where bytesShared is the number of bytes needed in `extern_buf`:
...
...
@@ -77,150 +82,46 @@ extern __shared__ int extern_buf[];
template
<
typename
scalar_t
>
__global__
void
learned_nonlin_kernel
(
torch
::
PackedTensorAccessor32
<
scalar_t
,
4
>
input
,
// N, 2*C, H, W
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_add
,
// C, kH, kW
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
pos_mul
,
// C, kH, kW
torch
::
PackedTensorAccessor32
<
scalar_t
,
4
>
output
,
// N, C, H, W
int
opatchH
,
// output-patch height,
int
opatchW
// output-patch width
)
{
const
int
H
=
input
.
size
(
2
),
W
=
input
.
size
(
3
),
kH
=
pos_add
.
size
(
1
),
kW
=
pos_add
.
size
(
2
),
npatchH
=
(
H
+
opatchH
-
1
)
/
opatchH
,
// num patches in vertical dim
npatchW
=
(
W
+
opatchW
-
1
)
/
opatchW
,
// num patches in horizontal dim
npatch
=
npatchH
*
npatchW
;
// total number of patches per image
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
input
,
// B, C, T, i.e. batch, channels, time
torch
::
PackedTensorAccessor32
<
scalar_t
,
2
>
params
,
// C, N + 1
torch
::
PackedTensorAccessor32
<
scalar_t
,
3
>
output
)
{
// B, C, T
// Channel index.
const
int
c
=
blockIdx
.
x
;
// We don't need to check the range of `c` because we set gridDim.x to the
// exact number of channels.
const
int
ipatchH
=
opatchH
+
kH
-
1
,
ipatchW
=
opatchW
+
kW
-
1
,
ipatch_size
=
ipatchH
*
ipatchW
,
opatch_size
=
opatchH
*
opatchW
;
// `extern_buf` is general-purpose shared memory, which we'll divide between
// pos_add, pos_mul and src_img_buf to be shared between the src image size
// (ipatch_size) and the number of threads (blockDim.x)
// these are pointers to __shared__ memory; the compiler should
// be able to figure this out.
scalar_t
*
pos_add_buf
=
(
scalar_t
*
)
extern_buf
,
// pos_add positional-encoding / kernel parameters,
// indexed [kh*kW + kw] where kh and kw are vertical
// and horizontal positions in the kernel.
*
pos_mul_buf
=
pos_add_buf
+
(
kH
*
kW
),
// pos_mul positional-encoding / kernel parameters,
// indexed [kh*kW + kw] where kh and kw are vertical
// and horizontal positions in the kernel.
*
src_img_buf
=
pos_mul_buf
+
(
kH
*
kW
);
// version of input image that relates to source position,
// of size [ipatch_size], indexed [h*ipatchW + w]...
// note, the 'h' and 'w' indexes are into the zero-padded input
// image.
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
,
K
=
N
/
2
;
// Note: N and K are powers fo 2, with K >= 1.
int
threads_per_opixel
=
block
Dim
.
x
/
opatch_size
;
assert
(
blockDim
.
x
==
opatch_size
*
threads_per_opixel
);
const
int
c
=
block
Idx
.
x
,
/
/ c is channel index
b
=
blockIdx
.
y
;
// b is batch index; we'll iterate over b
// pos_in_patch will be interpreted as h_in_patch * opatchW + w_in_patch.
int
pos_in_patch
=
threadIdx
.
x
/
threads_per_opixel
;
// Load parts of the kernel parameters pos_add and pos_mul into shared memory,
// in pos_add_buf and pos_mul_buf
for
(
int
i
=
threadIdx
.
x
;
i
<
kH
*
kW
;
i
+=
blockDim
.
x
)
{
int
kh
=
i
/
kW
,
kw
=
i
%
kW
;
pos_add_buf
[
i
]
=
pos_add
[
c
][
kh
][
kw
];
pos_mul_buf
[
i
]
=
pos_mul
[
c
][
kh
][
kw
];
scalar_t
*
y_vals_buf
=
(
scalar_t
*
)
extern_buf
,
// [N]
*
params_buf
=
(
scalar_t
*
)
y_vals_buf
+
N
;
// [N]
// Load parameters
for
(
int
n
=
threadIdx
.
x
;
n
<
N
+
1
;
n
+=
THREADS_PER_BLOCK
)
{
params_buf
[
n
-
1
]
=
params
[
c
][
n
];
}
if
(
threadIdx
.
x
==
0
)
{
scalar_t
scale
=
exp
(
params_buf
[
-
1
]),
inv_scale
=
1.0
/
scale
;
params_buf
[
-
1
]
=
scale
;
params_buf
[
-
2
]
=
inv_scale
;
}
else
if
(
threadIdx
.
x
&
~
96
==
0
)
{
// threadIdx.x == 32 or 64. These, and 0, are in separate warps so we can
// allow them to do separate jobs. This code takes linear time in K which
// is not at all ideal and could be improved if K is largish, but it shouldn't
// dominate the total time taken if we are processing a lot of data;
// and anyway, we doubt that K will be need to be more than 4 or 8 or so,
// so the potential savings are quite small.
// n is the index within the batch. Loop to make sure we cover all images in
// the batch. input.size(0) is the batch size N. All threads in the thread-block
// loop the same number of times.
for
(
int
n
=
blockIdx
.
z
;
n
<
input
.
size
(
0
);
n
+=
gridDim
.
z
)
{
// Loop over the patch within the output image. All threads in the
// thread-block loop the same number of times.
for
(
int
patch_idx
=
blockIdx
.
y
;
patch_idx
<
npatch
;
patch_idx
+=
gridDim
.
y
)
{
// (patch_h_offset, patch_w_offset) are the (vertical, horizontal) indexes
// of the lowest-numbered pixel in the patch of output that this thread
// block is responsible for.
int
patch_h_offset
=
(
patch_idx
/
npatchW
)
*
opatchH
,
patch_w_offset
=
(
patch_idx
%
npatchW
)
*
opatchW
;
// This __syncthreads() is only necessary if we have already looped at
// least once over n or patch_idx: it's in case other threads are still
// using the `src_img_buf` buffer for something else.
__syncthreads
();
// Load the 'src' part of the input patch; the size of this is the size of
// the output patch plus a border of sizes kH//2, kW//2 on each side.
for
(
int
i
=
threadIdx
.
x
;
i
<
ipatch_size
;
i
+=
blockDim
.
x
)
{
int
h_in_kernel
=
i
/
ipatchW
,
w_in_kernel
=
i
%
ipatchW
;
int
src_h
=
patch_h_offset
+
h_in_kernel
-
(
kH
/
2
),
// kH / 2 is offset due to padding
src_w
=
patch_w_offset
+
w_in_kernel
-
(
kW
/
2
);
scalar_t
src_val
=
scalar_t
(
0
);
if
((
unsigned
int
)
src_h
<
(
unsigned
int
)
H
&&
// h >= 0 && h < H
(
unsigned
int
)
src_w
<
(
unsigned
int
)
W
)
// w >= 0 && w < W
src_val
=
input
[
n
][
c
][
src_h
][
src_w
];
src_img_buf
[
i
]
=
src_val
;
}
// make sure all threads have written to `src_img_buf`
__syncthreads
();
// 'h' and 'w' are the positions within the output image, that this tile
// of size threads_per_opixel is responsible for.
int
h
=
patch_h_offset
+
pos_in_patch
/
opatchW
,
w
=
patch_w_offset
+
pos_in_patch
%
opatchW
;
// The "destination" pixel; this is an input. It gets added to each
// src pixel, prior to the relu, in the loop below.
scalar_t
dest_val
=
scalar_t
(
0
);
if
(
h
<
H
&&
w
<
W
)
{
// Several threads (within the same tile, which implies the same warp)
// may load the same value here, but I believe the device's memory
// subsystem handles this well enough that we can just ignore the issue
// rather than try to optimize it.
// https://forums.developer.nvidia.com/t/accessing-same-global-memory-address-within-warps/66574
int
C
=
input
.
size
(
1
)
/
2
;
dest_val
=
input
[
n
][
c
+
C
][
h
][
w
];
// else 0.
}
// `sum` is the partial sum that this thread computes; we'll sum this over
// the `threads_per_opixel` threads in the tile to get the output pixel
// value.
scalar_t
sum
=
0.0
;
for
(
int
pos_in_kernel
=
threadIdx
.
x
%
threads_per_opixel
;
pos_in_kernel
<
(
kH
*
kW
);
pos_in_kernel
+=
threads_per_opixel
)
{
int
h_in_kernel
=
pos_in_kernel
/
kW
,
w_in_kernel
=
pos_in_kernel
%
kW
;
// Note: this is actually more like cross-correlation, as we don't
// have a negative sign on the h and w indexes in the kernel.
// Also note: we already took care of padding and the associated
// offsets of -(kH / 2) and -(kW / 2).
int
h_in_src_patch
=
(
pos_in_patch
/
opatchW
)
+
h_in_kernel
,
w_in_src_patch
=
(
pos_in_patch
%
opatchW
)
+
w_in_kernel
;
scalar_t
src_val
=
src_img_buf
[
h_in_src_patch
*
ipatchW
+
w_in_src_patch
],
pos_add_val
=
pos_add_buf
[
pos_in_kernel
];
scalar_t
relu
=
(
src_val
+
dest_val
+
pos_add_val
);
if
(
relu
>
0.0
)
sum
+=
relu
*
pos_mul_buf
[
pos_in_kernel
];
}
// Sync threads because src_img_buf is also used above.
__syncthreads
();
// Aggregate `sum` over threads
sum
=
tiled_warp_reduce_sum
(
threads_per_opixel
,
src_img_buf
,
sum
);
if
(
threadIdx
.
x
%
threads_per_opixel
==
0
&&
h
<
H
&&
w
<
W
)
{
output
[
n
][
c
][
h
][
w
]
=
sum
;
}
}
}
__syncthreads
();
scalar_t
scale
=
params_buf
[
-
1
],
inv_scale
=
params_buf
[
-
2
];
}
...
...
@@ -578,117 +479,58 @@ void learned_nonlin_kernel_backward(
torch
::
Tensor
learned_nonlin_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
params
)
{
TORCH_CHECK
(
input
.
dim
()
==
3
,
"input must be 3-dimensional"
);
TORCH_CHECK
(
params
.
dim
()
==
2
,
"params must be 2-dimensional."
);
TORCH_CHECK
(
params
.
size
(
1
)
>=
3
&&
((
params
.
size
(
1
)
-
1
)
&
(
params
.
size
(
1
)
-
2
))
==
0
,
"params.size(1) has invalid value, must be a power of 2 plus 1."
);
TORCH_CHECK
(
params
.
size
(
0
)
==
input
.
size
(
1
),
"params vs input channels mismatch"
);
torch
::
Tensor
learned_nonlin_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
)
{
TORCH_CHECK
(
input
.
dim
()
==
4
,
"input must be 4-dimensional"
);
TORCH_CHECK
(
pos_add
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
pos_mul
.
dim
()
==
3
,
"pos_add must be 3-dimensional."
);
TORCH_CHECK
(
input
.
device
().
is_cuda
(),
"Input must be a CUDA tensor"
);
const
int
N
=
input
.
size
(
0
),
C
=
input
.
size
(
1
)
/
2
,
H
=
input
.
size
(
2
),
W
=
input
.
size
(
3
),
kH
=
pos_add
.
size
(
1
),
kW
=
pos_add
.
size
(
2
);
TORCH_CHECK
(
kH
%
2
==
1
&&
kW
%
2
==
1
);
TORCH_CHECK
(
input
.
size
(
1
)
%
2
==
0
,
"Input must have even num-channels"
);
TORCH_CHECK
(
pos_add
.
size
(
0
)
==
C
&&
pos_mul
.
size
(
0
)
==
C
&&
pos_mul
.
size
(
1
)
==
kH
&&
pos_mul
.
size
(
2
)
==
kW
,
"Input sizes mismatch."
);
TORCH_CHECK
(
pos_add
.
device
()
==
input
.
device
()
&&
pos_mul
.
device
()
==
pos_add
.
device
(),
"Input devices mismatch"
);
auto
scalar_t
=
input
.
scalar_type
();
TORCH_CHECK
(
pos_add
.
scalar_type
()
==
scalar_t
&&
pos_mul
.
scalar_type
()
==
scalar_t
,
"Input dtypes mismatch"
);
TORCH_CHECK
(
params
.
device
().
is_cuda
(),
"Params must be a CUDA tensor"
);
torch
::
Tensor
output
=
torch
::
empty
({
N
,
C
,
H
,
W
},
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
()));
const
int
B
=
input
.
size
(
0
),
C
=
input
.
size
(
1
),
T
=
input
.
size
(
2
),
N
=
params
.
size
(
1
)
-
1
;
// Work out the configuration to call the kernel with..
int
patchH
=
std
::
min
(
H
,
kH
),
// output patch height
patchW
=
std
::
min
(
W
,
kW
);
// output patch width
// We don't want the height or width of the patch to be less than the kernel
// width, or the padding will make the input-patch size more than twice the
// output-patch size.
// We aim for the output-patch size to be more than 128; this is not something
// very exact, but it roughly corresponds to us wanting to have up to 4 threads
// per output pixel, and the limitation of 512 threads per thread-block which
// we impose so that we can run on architectures with little shared memory.
while
(
patchW
<
W
&&
patchH
*
(
patchW
+
1
)
<=
128
)
patchW
++
;
while
(
patchH
<
H
&&
(
patchH
+
1
)
*
patchW
<=
128
)
patchH
++
;
auto
scalar_t
=
input
.
scalar_type
();
auto
opts
=
torch
::
TensorOptions
().
dtype
(
scalar_t
).
device
(
input
.
device
());
// We are assuming that the thread-block size can be as large as 512; this
// works even on old CUDA architectures.
int
threads_per_opixel
;
if
(
patchH
*
patchW
*
4
<=
512
&&
(
kH
*
kW
)
>
16
)
threads_per_opixel
=
4
;
else
if
(
patchH
*
patchW
*
2
<=
512
&&
(
kH
*
kW
)
>
8
)
threads_per_opixel
=
2
;
else
threads_per_opixel
=
1
;
torch
::
Tensor
output
=
torch
::
empty
({
B
,
C
,
T
},
opts
);
int
input_patchH
=
patchH
+
kH
-
1
,
input_patchW
=
patchW
+
kW
-
1
,
input_patch_size
=
input_patchH
*
input_patchW
;
if
(
C
*
B
*
T
==
0
)
return
output
;
int
threads_per_block
=
patchH
*
patchW
*
threads_per_opixel
;
// The number of thread blocks is at least C (the number of channels), but
// if the number of channels is small we may split further on the batch.
int
buffer_numel
=
2
*
(
kH
*
kW
)
+
std
::
max
<
int
>
(
threads_per_block
,
input_patch_size
);
int
batches_per_block
=
1
;
if
(
C
*
batches_per_block
<
128
)
{
// Aim for at least 128 thread blocks..
batches_per_block
=
128
/
C
;
if
(
batches_per_block
>
B
)
batches_per_block
=
B
;
}
int
num_patches_H
=
(
H
+
patchH
-
1
)
/
patchH
,
num_patches_W
=
(
W
+
patchW
-
1
)
/
patchW
,
num_patches
=
num_patches_H
*
num_patches_W
;
int
shared_mem_numel
=
2
*
N
,
num_blocks_batch
=
(
B
+
batches_per_block
-
1
)
/
B
;
// gridDim.x == C.
int
num_blocks_patch
=
1
,
// gridDim.y.
num_blocks_batch
=
1
;
// gridDim.z
while
(
C
*
num_blocks_patch
<=
256
&&
num_blocks_patch
*
2
<=
num_patches
)
num_blocks_patch
*=
2
;
if
(
C
*
num_patches
<=
512
)
num_blocks_patch
=
num_patches
;
while
(
C
*
num_blocks_patch
*
num_blocks_batch
<=
512
&&
num_blocks_batch
*
2
<=
N
)
num_blocks_batch
*=
2
;
if
(
C
*
num_blocks_patch
*
N
<=
1024
)
num_blocks_batch
=
N
;
assert
(
num_blocks_patch
<=
num_patches
&&
num_blocks_batch
<=
N
);
dim3
gridDim
(
C
,
num_blocks_batch
,
1
);
static
int
debug_count
=
50
;
if
(
debug_count
>
0
)
{
debug_count
--
;
std
::
cout
<<
"N,C,H,W="
<<
N
<<
","
<<
C
<<
","
<<
H
<<
","
<<
W
<<
"; kW,kH="
<<
kW
<<
","
<<
kH
<<
"; patchH,patchW="
<<
patchH
<<
","
<<
patchW
<<
", num_blocks_patch="
<<
num_blocks_patch
<<
", num_blocks_batch="
<<
num_blocks_batch
<<
", threads_per_opixel="
<<
threads_per_opixel
<<
", threads_per_block="
<<
threads_per_block
<<
std
::
endl
;
}
dim3
gridDim
(
C
,
num_blocks_patch
,
num_blocks_batch
);
// blockDim is scalar, just threads_per_block.
// blockDim is scalar, just THREADS_PER_BLOCK.
AT_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"learned_nonlin_kernel"
,
([
&
]
{
learned_nonlin_kernel
<
scalar_t
><<<
gridDim
,
threads_per_block
,
sizeof
(
scalar_t
)
*
buffer_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
input
.
packed_accessor32
<
scalar_t
,
4
>
(),
pos_add
.
packed_accessor32
<
scalar_t
,
3
>
(),
pos_mul
.
packed_accessor32
<
scalar_t
,
3
>
(),
output
.
packed_accessor32
<
scalar_t
,
4
>
(),
patchH
,
patchW
);
learned_nonlin_kernel
<
scalar_t
><<<
gridDim
,
THREADS_PER_BLOCK
,
sizeof
(
scalar_t
)
*
shared_mem_numel
,
at
::
cuda
::
getCurrentCUDAStream
()
>>>
(
input
.
packed_accessor32
<
scalar_t
,
3
>
(),
params
.
packed_accessor32
<
scalar_t
,
2
>
(),
output
.
packed_accessor32
<
scalar_t
,
3
>
());
}));
return
output
;
}
...
...
@@ -696,9 +538,10 @@ torch::Tensor learned_nonlin_cuda(torch::Tensor input,
std
::
vector
<
torch
::
Tensor
>
learned_nonlin_backward_cuda
(
torch
::
Tensor
input
,
torch
::
Tensor
pos_add
,
torch
::
Tensor
pos_mul
,
torch
::
Tensor
params
,
torch
::
Tensor
grad_output
)
{
/*
TORCH_CHECK(input.dim() == 4, "input must be 4-dimensional");
TORCH_CHECK(pos_add.dim() == 3, "pos_add must be 3-dimensional.");
TORCH_CHECK(pos_mul.dim() == 3, "pos_add must be 3-dimensional.");
...
...
@@ -856,5 +699,5 @@ std::vector<torch::Tensor> learned_nonlin_backward_cuda(torch::Tensor input,
grad_pos_add = at::sum(grad_pos_add, {0});
grad_pos_mul = at::sum(grad_pos_mul, {0});
return
std
::
vector
<
torch
::
Tensor
>
({
grad_input
,
grad_pos_add
,
grad_pos_mul
});
return std::vector<torch::Tensor>({grad_input, grad_pos_add, grad_pos_mul});
*/
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment