"git@developer.sourcefind.cn:change/sglang.git" did not exist on "e3046ea3a8189aa897a24428da94af67a10a0ee1"
Commit 7da4e062 authored by Myle Ott's avatar Myle Ott
Browse files

Support deprecation of volatile Variables in latest PyTorch

parent 5637d54e
...@@ -227,6 +227,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop): ...@@ -227,6 +227,7 @@ class MultiprocessingTrainer(MultiprocessingEventLoop):
self.model.train() self.model.train()
self.optimizer.zero_grad() self.optimizer.zero_grad()
with utils.maybe_no_grad(eval):
sample_size, logging_output, oom = 0, {}, False sample_size, logging_output, oom = 0, {}, False
if self._sample is not None: if self._sample is not None:
try: try:
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
# #
import contextlib
import logging import logging
import os import os
import torch import torch
...@@ -244,3 +245,10 @@ def rstrip_pad(tensor, pad): ...@@ -244,3 +245,10 @@ def rstrip_pad(tensor, pad):
if strip > 0: if strip > 0:
return tensor[:-strip] return tensor[:-strip]
return tensor return tensor
def maybe_no_grad(condition):
if hasattr(torch, 'no_grad') and condition:
return torch.no_grad()
# no-op context manager
return contextlib.ExitStack()
...@@ -35,6 +35,8 @@ def main(): ...@@ -35,6 +35,8 @@ def main():
print(args) print(args)
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
if hasattr(torch, 'set_grad_enabled'):
torch.set_grad_enabled(False)
# Load dataset # Load dataset
if args.replace_unk is None: if args.replace_unk is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment