Commit 185a0df5 authored by Myle Ott's avatar Myle Ott
Browse files

Fix warning about deprecated `volatile` kwarg for Variables

parent dccf7909
...@@ -11,6 +11,8 @@ import torch ...@@ -11,6 +11,8 @@ import torch
from torch.autograd.variable import Variable from torch.autograd.variable import Variable
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
from .fairseq_criterion import FairseqCriterion from .fairseq_criterion import FairseqCriterion
...@@ -41,7 +43,7 @@ class LabelSmoothedNLLLoss(torch.autograd.Function): ...@@ -41,7 +43,7 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad): def backward(ctx, grad):
return Variable(ctx.grad_input, volatile=True) * grad, None, None, None, None, None return utils.volatile_variable(ctx.grad_input) * grad, None, None, None, None, None
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
......
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
# #
import torch import torch
from torch.autograd import Variable, Function from torch.autograd import Function
from torch.nn.modules.utils import _single from torch.nn.modules.utils import _single
from fairseq import utils
try: try:
from fairseq import temporal_convolution_tbc from fairseq import temporal_convolution_tbc
except ImportError as e: except ImportError as e:
...@@ -93,9 +95,9 @@ class ConvTBCFunction(Function): ...@@ -93,9 +95,9 @@ class ConvTBCFunction(Function):
input, input,
weight) weight)
grad_input = Variable(grad_input, volatile=True) grad_input = utils.volatile_variable(grad_input)
grad_weight = Variable(grad_weight, volatile=True) grad_weight = utils.volatile_variable(grad_weight)
grad_bias = Variable(grad_bias, volatile=True) grad_bias = utils.volatile_variable(grad_bias)
return grad_input, grad_weight, grad_bias, None return grad_input, grad_weight, grad_bias, None
......
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
...@@ -65,7 +68,7 @@ class LinearizedConvolution(ConvTBC): ...@@ -65,7 +68,7 @@ class LinearizedConvolution(ConvTBC):
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone() self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
# append next input # append next input
self.input_buffer[:, -1, :] = input[:, -1, :] self.input_buffer[:, -1, :] = input[:, -1, :]
input = torch.autograd.Variable(self.input_buffer, volatile=True) input = utils.volatile_variable(self.input_buffer)
output = F.linear(input.view(bsz, -1), weight, self.bias) output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1) return output.view(bsz, 1, -1)
......
...@@ -322,7 +322,7 @@ class SequenceGenerator(object): ...@@ -322,7 +322,7 @@ class SequenceGenerator(object):
def _decode(self, tokens, encoder_outs): def _decode(self, tokens, encoder_outs):
# wrap in Variable # wrap in Variable
tokens = Variable(tokens, volatile=True) tokens = utils.volatile_variable(tokens)
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
......
...@@ -176,6 +176,20 @@ def _upgrade_args(args): ...@@ -176,6 +176,20 @@ def _upgrade_args(args):
return args return args
def maybe_no_grad(condition):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
return contextlib.ExitStack()
def volatile_variable(*args, **kwargs):
if hasattr(torch, 'no_grad'):
with torch.no_grad():
return Variable(*args, **kwargs)
return Variable(*args, **kwargs, volatile=True)
def make_variable(sample, volatile=False, cuda_device=None): def make_variable(sample, volatile=False, cuda_device=None):
"""Wrap input tensors in Variable class.""" """Wrap input tensors in Variable class."""
...@@ -183,7 +197,7 @@ def make_variable(sample, volatile=False, cuda_device=None): ...@@ -183,7 +197,7 @@ def make_variable(sample, volatile=False, cuda_device=None):
if torch.is_tensor(maybe_tensor): if torch.is_tensor(maybe_tensor):
if cuda_device is not None and torch.cuda.is_available(): if cuda_device is not None and torch.cuda.is_available():
maybe_tensor = maybe_tensor.cuda(async=True, device=cuda_device) maybe_tensor = maybe_tensor.cuda(async=True, device=cuda_device)
return Variable(maybe_tensor, volatile=volatile) return volatile_variable(maybe_tensor)
elif isinstance(maybe_tensor, dict): elif isinstance(maybe_tensor, dict):
return { return {
key: _make_variable(value) key: _make_variable(value)
...@@ -255,10 +269,3 @@ def strip_pad(tensor, pad): ...@@ -255,10 +269,3 @@ def strip_pad(tensor, pad):
if tensor[-1] == pad: if tensor[-1] == pad:
tensor = rstrip_pad(tensor, pad) tensor = rstrip_pad(tensor, pad)
return tensor return tensor
def maybe_no_grad(condition):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
return contextlib.ExitStack()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment