Unverified Commit 7563d5a3 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Catch PyTorch warning when saving/loading scheduler (#7401)

parent 1749ca31
......@@ -59,6 +59,8 @@ from .utils import logging
_use_native_amp = False
_use_apex = False
PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
from .file_utils import is_apex_available
......@@ -99,6 +101,14 @@ if is_ray_available():
logger = logging.get_logger(__name__)
def reissue_pt_warnings(caught_warnings):
# Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
if len(caught_warnings) > 1:
for w in caught_warnings:
if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
warnings.warn(w.message, w.category)
@contextmanager
def torch_distributed_zero_first(local_rank: int):
"""
......@@ -643,7 +653,9 @@ class Trainer:
self.optimizer.load_state_dict(
torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
)
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
model = self.model
if self.args.fp16 and _use_apex:
......@@ -821,10 +833,14 @@ class Trainer:
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
elif self.is_world_process_zero():
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
reissue_pt_warnings(caught_warnings)
epoch_pbar.update(1)
if self.args.max_steps > 0 and self.global_step >= self.args.max_steps:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment