Unverified Commit 9d7d0005 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[training] SAVE_STATE_WARNING was removed in pytorch (#8979)

* [training] SAVE_STATE_WARNING was removed in pytorch

FYI `SAVE_STATE_WARNING` has been removed 3 days ago: pytorch/pytorch#46813

Fixes: #8232

@sgugger

* style, but add () to prevent autoformatters from botching it

* switch to try/except

* cleanup
parent 2ae7388e
...@@ -23,7 +23,6 @@ from typing import List, Optional, Union ...@@ -23,7 +23,6 @@ from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler from torch.utils.data.sampler import RandomSampler, Sampler
...@@ -34,10 +33,11 @@ from .utils import logging ...@@ -34,10 +33,11 @@ from .utils import logging
if is_torch_tpu_available(): if is_torch_tpu_available():
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
if version.parse(torch.__version__) <= version.parse("1.4.1"): # this is used to supress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
SAVE_STATE_WARNING = "" try:
else:
from torch.optim.lr_scheduler import SAVE_STATE_WARNING from torch.optim.lr_scheduler import SAVE_STATE_WARNING
except ImportError:
SAVE_STATE_WARNING = ""
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment