Commit 9c56e510 authored by Daniel Povey's avatar Daniel Povey
Browse files

Refactor code slightly for more memory efficiency

parent 2523eeeb
......@@ -69,19 +69,67 @@ def _learned_nonlin_backward_dispatcher(input: torch.Tensor,
def _reshape_as_3dim(x: torch.Tensor, dim: int):
"""
Returns x reshaped so that dimension 'dim' is the middle of 3 dimensions,
combining dimensions and unsqueezing as needed. For example (writing
the behavior of this function as
input_shape, dim -> output_shape,
it will do:
(3), 0 -> (1, 3, 1)
(2, 5, 9), 1 -> (2, 5, 9)
(2, 5, 9), 2 -> (10, 9, 1)
(3, 4, 5, 6) -> (12, 5, 6)
The idea is to normalize the shape so the channel dimension is the middle
of 3, so the implementation can deal with a fixed layout.
Args:
x: tensor to be reshaped
dim: Dimension of x that is to be the middle of 3 dimensions in the result.
If negative, interpreted as an offset from x.dim.
"""
if dim < 0:
dim += input.ndim
orig_shape = list(x.shape)
# `new_shape` is `orig_shape` but modified so that the channel dim (`dim`)
# is dimension/axis 1. We do this not by transposing, but by combining
# adjacent dims.
a, b = 1, 1
for i in range(0, dim):
a *= orig_shape[i]
for i in range(dim + 1, len(orig_shape)):
b *= orig_shape[i]
new_shape = (a, orig_shape[dim], b)
return x.reshape(new_shape) # `reshape` will make a contiguous copy if needed.
class LearnedNonlinFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor, params: torch.Tensor) -> torch.Tensor:
output = _learned_nonlin_forward_dispatcher(input, params)
def forward(ctx, input: torch.Tensor, params: torch.Tensor, dim: int) -> torch.Tensor:
if dim < 0:
dim += input.ndim
assert dim >= 0 and dim < input.ndim
assert params.ndim == 2 and params.shape[1] % 2 == 1
assert params.shape[0] == input.shape[dim]
ctx.dim = dim
ctx.save_for_backward(input, params)
output = _learned_nonlin_forward_dispatcher(_reshape_as_3dim(input, dim),
params)
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def backward(ctx, grad_output: torch.Tensor, None) -> Tuple[torch.Tensor, torch.Tensor, None]:
(input, params) = ctx.saved_tensors
orig_shape = input.shape
# We re-do the reshaping in the backward, rather than save the reshaped
# input, so that if this reshaping results in a copy it is not retained
# (this saves memory at the expense of a little extra work in such
# situations).
grad_input, grad_params = _learned_nonlin_backward_dispatcher(
input, params, grad_output)
return grad_input, grad_params
_reshape_as_3dim(input, ctx.dim), params, grad_output)
return grad_input.reshape(input.shape), grad_params, None
def learned_nonlin(input, params, dim):
......@@ -142,27 +190,4 @@ def learned_nonlin(input, params, dim):
Return: output, of the same shape as `input`.
"""
if dim < 0:
dim += input.ndim
assert dim >= 0 and dim < input.ndim
assert params.ndim == 2 and params.shape[1] % 2 == 1
assert params.shape[0] == input.shape[dim]
orig_shape = list(input.shape)
# `new_shape` is `orig_shape` but modified so that the channel dim (`dim`)
# is dimension/axis 1. We do this not by transposing, but by combining
# adjacent dims.
a, b = 1, 1
for i in range(0, dim):
a *= orig_shape[i]
for i in range(dim + 1, len(orig_shape)):
b *= orig_shape[i]
new_shape = (a, orig_shape[dim], b)
input = input.reshape(new_shape) # `reshape` should make input contiguous if needed.
assert params.shape[0] == input.shape[1]
output = torch.empty_like(input)
ans = LearnedNonlinFunction.apply(input, params)
return ans.reshape(orig_shape)
return LearnedNonlinFunction.apply(x, params, dim)
# Caution: this will fail occasionally due to cutoffs not being quite large enough.
# As long as it passes most of the time, it's OK.
import random
import torch
from torch_learned_nonlin import learned_nonlin
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment