Unverified Commit f2b91835 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

fix bugs with trainer (#24134)



* fix the deepspeed test failures

* apex fix

* FSDP save ckpt fix

* Update src/transformers/trainer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent be10092e
......@@ -1749,7 +1749,16 @@ class Trainer:
# prepare using `accelerator` prepare
if use_accelerator_prepare:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
if hasattr(self.lr_scheduler, "step"):
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
if self.is_fsdp_enabled:
self.model = model
......@@ -2841,6 +2850,7 @@ class Trainer:
or self.is_fsdp_enabled
):
if self.is_fsdp_enabled:
os.makedirs(output_dir, exist_ok=True)
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
else:
state_dict = self.model.state_dict()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment