"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9ee3dd38626624e063a738b220d81ab6df271fdc"
Commit 7da4e062 authored by Myle Ott's avatar Myle Ott
Browse files

Support deprecation of volatile Variables in latest PyTorch

parent 5637d54e
...@@ -227,20 +227,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -227,20 +227,21 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.model.train() self.model.train()
self.optimizer.zero_grad() self.optimizer.zero_grad()
sample_size, logging_output, oom = 0, {}, False with utils.maybe_no_grad(eval):
if self._sample is not None: sample_size, logging_output, oom = 0, {}, False
try: if self._sample is not None:
# calculate loss and sample size try:
self.loss, sample_size, logging_output = self.criterion(self.model, self._sample) # calculate loss and sample size
except RuntimeError as e: self.loss, sample_size, logging_output = self.criterion(self.model, self._sample)
if not eval and 'out of memory' in str(e): except RuntimeError as e:
print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id)) if not eval and 'out of memory' in str(e):
oom = True print('| WARNING: ran out of memory on GPU #{}, skipping batch'.format(device_id))
self.loss = None oom = True
if hasattr(torch.cuda, 'empty_cache'): self.loss = None
torch.cuda.empty_cache() if hasattr(torch.cuda, 'empty_cache'):
else: torch.cuda.empty_cache()
raise e else:
raise e
return sample_size, logging_output, oom return sample_size, logging_output, oom
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
# #
import contextlib
import logging import logging
import os import os
import torch import torch
...@@ -244,3 +245,10 @@ def rstrip_pad(tensor, pad): ...@@ -244,3 +245,10 @@ def rstrip_pad(tensor, pad):
if strip > 0: if strip > 0:
return tensor[:-strip] return tensor[:-strip]
return tensor return tensor
def maybe_no_grad(condition):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
return contextlib.ExitStack()
...@@ -35,6 +35,8 @@ def main(): ...@@ -35,6 +35,8 @@ def main():
print(args) print(args)
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
if hasattr(torch, 'set_grad_enabled'):
torch.set_grad_enabled(False)
# Load dataset # Load dataset
if args.replace_unk is None: if args.replace_unk is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment