Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
9c56e510
Commit
9c56e510
authored
Jul 17, 2021
by
Daniel Povey
Browse files
Refactor code slightly for more memory efficiency
parent
2523eeeb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
29 deletions
+57
-29
torch_learned_nonlin/learned_nonlin.py
torch_learned_nonlin/learned_nonlin.py
+54
-29
torch_learned_nonlin/learned_nonlin_test.py
torch_learned_nonlin/learned_nonlin_test.py
+3
-0
No files found.
torch_learned_nonlin/learned_nonlin.py
View file @
9c56e510
...
...
@@ -69,19 +69,67 @@ def _learned_nonlin_backward_dispatcher(input: torch.Tensor,
def
_reshape_as_3dim
(
x
:
torch
.
Tensor
,
dim
:
int
):
"""
Returns x reshaped so that dimension 'dim' is the middle of 3 dimensions,
combining dimensions and unsqueezing as needed. For example (writing
the behavior of this function as
input_shape, dim -> output_shape,
it will do:
(3), 0 -> (1, 3, 1)
(2, 5, 9), 1 -> (2, 5, 9)
(2, 5, 9), 2 -> (10, 9, 1)
(3, 4, 5, 6) -> (12, 5, 6)
The idea is to normalize the shape so the channel dimension is the middle
of 3, so the implementation can deal with a fixed layout.
Args:
x: tensor to be reshaped
dim: Dimension of x that is to be the middle of 3 dimensions in the result.
If negative, interpreted as an offset from x.dim.
"""
if
dim
<
0
:
dim
+=
input
.
ndim
orig_shape
=
list
(
x
.
shape
)
# `new_shape` is `orig_shape` but modified so that the channel dim (`dim`)
# is dimension/axis 1. We do this not by transposing, but by combining
# adjacent dims.
a
,
b
=
1
,
1
for
i
in
range
(
0
,
dim
):
a
*=
orig_shape
[
i
]
for
i
in
range
(
dim
+
1
,
len
(
orig_shape
)):
b
*=
orig_shape
[
i
]
new_shape
=
(
a
,
orig_shape
[
dim
],
b
)
return
x
.
reshape
(
new_shape
)
# `reshape` will make a contiguous copy if needed.
class
LearnedNonlinFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
:
torch
.
Tensor
,
params
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
_learned_nonlin_forward_dispatcher
(
input
,
params
)
def
forward
(
ctx
,
input
:
torch
.
Tensor
,
params
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
Tensor
:
if
dim
<
0
:
dim
+=
input
.
ndim
assert
dim
>=
0
and
dim
<
input
.
ndim
assert
params
.
ndim
==
2
and
params
.
shape
[
1
]
%
2
==
1
assert
params
.
shape
[
0
]
==
input
.
shape
[
dim
]
ctx
.
dim
=
dim
ctx
.
save_for_backward
(
input
,
params
)
output
=
_learned_nonlin_forward_dispatcher
(
_reshape_as_3dim
(
input
,
dim
),
params
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
def
backward
(
ctx
,
grad_output
:
torch
.
Tensor
,
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
None
]:
(
input
,
params
)
=
ctx
.
saved_tensors
orig_shape
=
input
.
shape
# We re-do the reshaping in the backward, rather than save the reshaped
# input, so that if this reshaping results in a copy it is not retained
# (this saves memory at the expense of a little extra work in such
# situations).
grad_input
,
grad_params
=
_learned_nonlin_backward_dispatcher
(
input
,
params
,
grad_output
)
return
grad_input
,
grad_params
_reshape_as_3dim
(
input
,
ctx
.
dim
)
,
params
,
grad_output
)
return
grad_input
.
reshape
(
input
.
shape
)
,
grad_params
,
None
def
learned_nonlin
(
input
,
params
,
dim
):
...
...
@@ -142,27 +190,4 @@ def learned_nonlin(input, params, dim):
Return: output, of the same shape as `input`.
"""
if
dim
<
0
:
dim
+=
input
.
ndim
assert
dim
>=
0
and
dim
<
input
.
ndim
assert
params
.
ndim
==
2
and
params
.
shape
[
1
]
%
2
==
1
assert
params
.
shape
[
0
]
==
input
.
shape
[
dim
]
orig_shape
=
list
(
input
.
shape
)
# `new_shape` is `orig_shape` but modified so that the channel dim (`dim`)
# is dimension/axis 1. We do this not by transposing, but by combining
# adjacent dims.
a
,
b
=
1
,
1
for
i
in
range
(
0
,
dim
):
a
*=
orig_shape
[
i
]
for
i
in
range
(
dim
+
1
,
len
(
orig_shape
)):
b
*=
orig_shape
[
i
]
new_shape
=
(
a
,
orig_shape
[
dim
],
b
)
input
=
input
.
reshape
(
new_shape
)
# `reshape` should make input contiguous if needed.
assert
params
.
shape
[
0
]
==
input
.
shape
[
1
]
output
=
torch
.
empty_like
(
input
)
ans
=
LearnedNonlinFunction
.
apply
(
input
,
params
)
return
ans
.
reshape
(
orig_shape
)
return
LearnedNonlinFunction
.
apply
(
x
,
params
,
dim
)
torch_learned_nonlin/learned_nonlin_test.py
View file @
9c56e510
# Caution: this will fail occasionally due to cutoffs not being quite large enough.
# As long as it passes most of the time, it's OK.
import
random
import
torch
from
torch_learned_nonlin
import
learned_nonlin
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment