"vscode:/vscode.git/clone" did not exist on "3990952f76df3d553dfb54a9600eafc44899e2bf"
Commit 185a0df5 authored by Myle Ott's avatar Myle Ott
Browse files

Fix warning about deprecated `volatile` kwarg for Variables

parent dccf7909
......@@ -11,6 +11,8 @@ import torch
from torch.autograd.variable import Variable
import torch.nn.functional as F
from fairseq import utils
from .fairseq_criterion import FairseqCriterion
......@@ -41,7 +43,7 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod
def backward(ctx, grad):
return Variable(ctx.grad_input, volatile=True) * grad, None, None, None, None, None
return utils.volatile_variable(ctx.grad_input) * grad, None, None, None, None, None
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
......
......@@ -7,9 +7,11 @@
#
import torch
from torch.autograd import Variable, Function
from torch.autograd import Function
from torch.nn.modules.utils import _single
from fairseq import utils
try:
from fairseq import temporal_convolution_tbc
except ImportError as e:
......@@ -93,9 +95,9 @@ class ConvTBCFunction(Function):
input,
weight)
grad_input = Variable(grad_input, volatile=True)
grad_weight = Variable(grad_weight, volatile=True)
grad_bias = Variable(grad_bias, volatile=True)
grad_input = utils.volatile_variable(grad_input)
grad_weight = utils.volatile_variable(grad_weight)
grad_bias = utils.volatile_variable(grad_bias)
return grad_input, grad_weight, grad_bias, None
......
......@@ -8,6 +8,9 @@
import torch
import torch.nn.functional as F
from fairseq import utils
from .conv_tbc import ConvTBC
......@@ -65,7 +68,7 @@ class LinearizedConvolution(ConvTBC):
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
# append next input
self.input_buffer[:, -1, :] = input[:, -1, :]
input = torch.autograd.Variable(self.input_buffer, volatile=True)
input = utils.volatile_variable(self.input_buffer)
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1)
......
......@@ -322,7 +322,7 @@ class SequenceGenerator(object):
def _decode(self, tokens, encoder_outs):
# wrap in Variable
tokens = Variable(tokens, volatile=True)
tokens = utils.volatile_variable(tokens)
avg_probs = None
avg_attn = None
......
......@@ -176,6 +176,20 @@ def _upgrade_args(args):
return args
def maybe_no_grad(condition):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
return contextlib.ExitStack()
def volatile_variable(*args, **kwargs):
if hasattr(torch, 'no_grad'):
with torch.no_grad():
return Variable(*args, **kwargs)
return Variable(*args, **kwargs, volatile=True)
def make_variable(sample, volatile=False, cuda_device=None):
"""Wrap input tensors in Variable class."""
......@@ -183,7 +197,7 @@ def make_variable(sample, volatile=False, cuda_device=None):
if torch.is_tensor(maybe_tensor):
if cuda_device is not None and torch.cuda.is_available():
maybe_tensor = maybe_tensor.cuda(async=True, device=cuda_device)
return Variable(maybe_tensor, volatile=volatile)
return volatile_variable(maybe_tensor)
elif isinstance(maybe_tensor, dict):
return {
key: _make_variable(value)
......@@ -255,10 +269,3 @@ def strip_pad(tensor, pad):
if tensor[-1] == pad:
tensor = rstrip_pad(tensor, pad)
return tensor
def maybe_no_grad(condition):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
return contextlib.ExitStack()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment