Unverified Commit 855b95ce authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Safe import of LRScheduler (#29919)



* Safe import of LRScheduler

* Update src/transformers/trainer_pt_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/trainer_pt_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix up

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent c9d2e855
......@@ -34,13 +34,18 @@ import numpy as np
import torch
import torch.distributed as dist
from torch import nn
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Dataset, IterableDataset, RandomSampler, Sampler
from torch.utils.data.distributed import DistributedSampler
from .integrations.deepspeed import is_deepspeed_zero3_enabled
from .tokenization_utils_base import BatchEncoding
from .utils import is_sagemaker_mp_enabled, is_torch_xla_available, is_training_run_on_sagemaker, logging
from .utils import (
is_sagemaker_mp_enabled,
is_torch_available,
is_torch_xla_available,
is_training_run_on_sagemaker,
logging,
)
if is_training_run_on_sagemaker():
......@@ -49,6 +54,15 @@ if is_training_run_on_sagemaker():
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
if is_torch_available():
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
if is_torch_greater_or_equal_than_2_0:
from torch.optim.lr_scheduler import LRScheduler
else:
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
# this is used to suppress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
try:
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment