Unverified Commit c60e0e1e authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

deepspeed + grad acumm (#9622)

parent 6d3b688b
...@@ -112,6 +112,11 @@ class TestFinetuneTrainer(TestCasePlus): ...@@ -112,6 +112,11 @@ class TestFinetuneTrainer(TestCasePlus):
def test_finetune_trainer_deepspeed(self): def test_finetune_trainer_deepspeed(self):
self.finetune_trainer_quick(deepspeed=True) self.finetune_trainer_quick(deepspeed=True)
@require_torch_multi_gpu
@require_deepspeed
def test_finetune_trainer_deepspeed_grad_acum(self):
self.finetune_trainer_quick(deepspeed=True, extra_args_str="--gradient_accumulation_steps 2")
@slow @slow
def test_finetune_trainer_slow(self): def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere # There is a missing call to __init__process_group somewhere
......
...@@ -931,7 +931,9 @@ class Trainer: ...@@ -931,7 +931,9 @@ class Trainer:
) )
# Optimizer step # Optimizer step
if is_torch_tpu_available(): if self.deepspeed:
self.deepspeed.step()
elif is_torch_tpu_available():
xm.optimizer_step(self.optimizer) xm.optimizer_step(self.optimizer)
elif self.use_amp: elif self.use_amp:
self.scaler.step(self.optimizer) self.scaler.step(self.optimizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment