Commit 7da4e062 authored by Myle Ott's avatar Myle Ott
Browse files

Support deprecation of volatile Variables in latest PyTorch

parent 5637d54e
......@@ -227,6 +227,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.model.train()
self.optimizer.zero_grad()
with utils.maybe_no_grad(eval):
sample_size, logging_output, oom = 0, {}, False
if self._sample is not None:
try:
......
......@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory.
#
import contextlib
import logging
import os
import torch
......@@ -244,3 +245,10 @@ def rstrip_pad(tensor, pad):
if strip > 0:
return tensor[:-strip]
return tensor
def maybe_no_grad(condition):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
return contextlib.ExitStack()
......@@ -35,6 +35,8 @@ def main():
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu
if hasattr(torch, 'set_grad_enabled'):
torch.set_grad_enabled(False)
# Load dataset
if args.replace_unk is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment