Unverified Commit da005253 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Proper map location for optimizer load (#22273)

* Proper map location for optimizer load

* What happened to my code?
parent 786092a3
...@@ -2433,8 +2433,12 @@ class Trainer: ...@@ -2433,8 +2433,12 @@ class Trainer:
self.model_wrapped.register_post_step_hook(opt_load_hook) self.model_wrapped.register_post_step_hook(opt_load_hook)
else: else:
# We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
# In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
# likely to get OOM on CPU (since we load num_gpu times the optimizer state
map_location = self.args.device if self.args.world_size > 1 else "cpu"
self.optimizer.load_state_dict( self.optimizer.load_state_dict(
torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
) )
with warnings.catch_warnings(record=True) as caught_warnings: with warnings.catch_warnings(record=True) as caught_warnings:
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment