Commit 907ca927 authored by Myle Ott's avatar Myle Ott
Browse files

Better support for torch.no_grad (since volatile is deprecated)

parent 0b84ab19
...@@ -69,6 +69,7 @@ class LinearizedConvolution(ConvTBC): ...@@ -69,6 +69,7 @@ class LinearizedConvolution(ConvTBC):
# append next input # append next input
self.input_buffer[:, -1, :] = input[:, -1, :] self.input_buffer[:, -1, :] = input[:, -1, :]
input = utils.volatile_variable(self.input_buffer) input = utils.volatile_variable(self.input_buffer)
with utils.maybe_no_grad():
output = F.linear(input.view(bsz, -1), weight, self.bias) output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1) return output.view(bsz, 1, -1)
......
...@@ -71,6 +71,7 @@ class SequenceGenerator(object): ...@@ -71,6 +71,7 @@ class SequenceGenerator(object):
srclen = input['src_tokens'].size(1) srclen = input['src_tokens'].size(1)
if timer is not None: if timer is not None:
timer.start() timer.start()
with utils.maybe_no_grad():
hypos = self.generate(input['src_tokens'], beam_size=beam_size, hypos = self.generate(input['src_tokens'], beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b)) maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None: if timer is not None:
...@@ -327,6 +328,7 @@ class SequenceGenerator(object): ...@@ -327,6 +328,7 @@ class SequenceGenerator(object):
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
with utils.maybe_no_grad():
decoder_out, attn = model.decoder(tokens, encoder_out) decoder_out, attn = model.decoder(tokens, encoder_out)
probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data
if avg_probs is None: if avg_probs is None:
......
...@@ -176,7 +176,7 @@ def _upgrade_args(args): ...@@ -176,7 +176,7 @@ def _upgrade_args(args):
return args return args
def maybe_no_grad(condition): def maybe_no_grad(condition=True):
if hasattr(torch, 'no_grad') and condition: if hasattr(torch, 'no_grad') and condition:
return torch.no_grad() return torch.no_grad()
# no-op context manager # no-op context manager
...@@ -185,8 +185,9 @@ def maybe_no_grad(condition): ...@@ -185,8 +185,9 @@ def maybe_no_grad(condition):
def volatile_variable(*args, **kwargs): def volatile_variable(*args, **kwargs):
if hasattr(torch, 'no_grad'): if hasattr(torch, 'no_grad'):
with torch.no_grad(): # volatile has been deprecated, use the no_grad context manager instead
return Variable(*args, **kwargs) return Variable(*args, **kwargs)
else:
return Variable(*args, **kwargs, volatile=True) return Variable(*args, **kwargs, volatile=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment