Commit 907ca927 authored by Myle Ott's avatar Myle Ott
Browse files

Better support for torch.no_grad (since volatile is deprecated)

parent 0b84ab19
...@@ -69,7 +69,8 @@ class LinearizedConvolution(ConvTBC): ...@@ -69,7 +69,8 @@ class LinearizedConvolution(ConvTBC):
# append next input # append next input
self.input_buffer[:, -1, :] = input[:, -1, :] self.input_buffer[:, -1, :] = input[:, -1, :]
input = utils.volatile_variable(self.input_buffer) input = utils.volatile_variable(self.input_buffer)
output = F.linear(input.view(bsz, -1), weight, self.bias) with utils.maybe_no_grad():
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1) return output.view(bsz, 1, -1)
def clear_incremental_state(self): def clear_incremental_state(self):
......
...@@ -71,8 +71,9 @@ class SequenceGenerator(object): ...@@ -71,8 +71,9 @@ class SequenceGenerator(object):
srclen = input['src_tokens'].size(1) srclen = input['src_tokens'].size(1)
if timer is not None: if timer is not None:
timer.start() timer.start()
hypos = self.generate(input['src_tokens'], beam_size=beam_size, with utils.maybe_no_grad():
maxlen=int(maxlen_a*srclen + maxlen_b)) hypos = self.generate(input['src_tokens'], beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None: if timer is not None:
timer.stop(s['ntokens']) timer.stop(s['ntokens'])
for i, id in enumerate(s['id'].data): for i, id in enumerate(s['id'].data):
...@@ -327,7 +328,8 @@ class SequenceGenerator(object): ...@@ -327,7 +328,8 @@ class SequenceGenerator(object):
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, encoder_out) with utils.maybe_no_grad():
decoder_out, attn = model.decoder(tokens, encoder_out)
probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data
if avg_probs is None: if avg_probs is None:
avg_probs = probs avg_probs = probs
......
...@@ -176,7 +176,7 @@ def _upgrade_args(args): ...@@ -176,7 +176,7 @@ def _upgrade_args(args):
return args return args
def maybe_no_grad(condition): def maybe_no_grad(condition=True):
if hasattr(torch, 'no_grad') and condition: if hasattr(torch, 'no_grad') and condition:
return torch.no_grad() return torch.no_grad()
# no-op context manager # no-op context manager
...@@ -185,9 +185,10 @@ def maybe_no_grad(condition): ...@@ -185,9 +185,10 @@ def maybe_no_grad(condition):
def volatile_variable(*args, **kwargs): def volatile_variable(*args, **kwargs):
if hasattr(torch, 'no_grad'): if hasattr(torch, 'no_grad'):
with torch.no_grad(): # volatile has been deprecated, use the no_grad context manager instead
return Variable(*args, **kwargs) return Variable(*args, **kwargs)
return Variable(*args, **kwargs, volatile=True) else:
return Variable(*args, **kwargs, volatile=True)
def make_variable(sample, volatile=False, cuda_device=None): def make_variable(sample, volatile=False, cuda_device=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment