Unverified Commit b7036f49 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Load optimizer state on CPU to avoid CUDA OOM (#22159)

parent ebdb185b
......@@ -2416,7 +2416,6 @@ class Trainer:
self.optimizer.load_state_dict(optimizer_state)
self.lr_scheduler.load_state_dict(lr_scheduler_state)
else:
map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device
if is_sagemaker_mp_enabled():
if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
# Optimizer checkpoint was saved with smp >= 1.10
......@@ -2436,7 +2435,7 @@ class Trainer:
self.model_wrapped.register_post_step_hook(opt_load_hook)
else:
self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
)
with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment