Unverified Commit b7036f49 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Load optimizer state on CPU to avoid CUDA OOM (#22159)

parent ebdb185b
...@@ -2416,7 +2416,6 @@ class Trainer: ...@@ -2416,7 +2416,6 @@ class Trainer:
self.optimizer.load_state_dict(optimizer_state) self.optimizer.load_state_dict(optimizer_state)
self.lr_scheduler.load_state_dict(lr_scheduler_state) self.lr_scheduler.load_state_dict(lr_scheduler_state)
else: else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
# Optimizer checkpoint was saved with smp >= 1.10 # Optimizer checkpoint was saved with smp >= 1.10
...@@ -2436,7 +2435,7 @@ class Trainer: ...@@ -2436,7 +2435,7 @@ class Trainer:
self.model_wrapped.register_post_step_hook(opt_load_hook) self.model_wrapped.register_post_step_hook(opt_load_hook)
else: else:
self.optimizer.load_state_dict( self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
) )
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment