Unverified Commit 0d909f6b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fairscale FSDP fix model save (#10596)

* Hotfix fairscale FSDP

* Evaluation works

* Save on process zero
parent ac17f711
...@@ -66,7 +66,7 @@ def require_apex(test_case): ...@@ -66,7 +66,7 @@ def require_apex(test_case):
class TestTrainerExt(TestCasePlus): class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True): def run_seq2seq_quick(self, distributed=False, extra_args_str=None, predict_with_generate=True):
output_dir = self.run_trainer( output_dir = self.run_trainer(
eval_steps=1, eval_steps=1,
max_len=12, max_len=12,
...@@ -83,9 +83,9 @@ class TestTrainerExt(TestCasePlus): ...@@ -83,9 +83,9 @@ class TestTrainerExt(TestCasePlus):
if predict_with_generate: if predict_with_generate:
assert "eval_bleu" in first_step_stats assert "eval_bleu" in first_step_stats
last_step_stats = eval_metrics[-1] last_step_stats = eval_metrics[-1]
assert isinstance(last_step_stats["eval_bleu"], float) assert isinstance(last_step_stats["eval_bleu"], float)
assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`" assert not math.isnan(float(last_step_stats["eval_loss"])), "eval_loss must not be `nan`"
@require_torch_non_multi_gpu @require_torch_non_multi_gpu
def test_run_seq2seq_no_dist(self): def test_run_seq2seq_no_dist(self):
...@@ -116,14 +116,12 @@ class TestTrainerExt(TestCasePlus): ...@@ -116,14 +116,12 @@ class TestTrainerExt(TestCasePlus):
# test --sharded_ddp zero_dp_2 w/o --fp16 # test --sharded_ddp zero_dp_2 w/o --fp16
@require_torch_multi_gpu @require_torch_multi_gpu
@require_fairscale @require_fairscale
@unittest.skip("XXX: Fixme: hanging")
def test_run_seq2seq_fully_sharded_ddp(self): def test_run_seq2seq_fully_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False) self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero_dp_2", predict_with_generate=False)
# test --sharded_ddp zero_dp_2 w/ --fp16 # test --sharded_ddp zero_dp_2 w/ --fp16
@require_torch_multi_gpu @require_torch_multi_gpu
@require_fairscale @require_fairscale
@unittest.skip("XXX: Fixme: hanging")
def test_run_seq2seq_fully_sharded_ddp_fp16(self): def test_run_seq2seq_fully_sharded_ddp_fp16(self):
self.run_seq2seq_quick( self.run_seq2seq_quick(
distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False distributed=True, extra_args_str="--sharded_ddp zero_dp_2 --fp16", predict_with_generate=False
...@@ -206,8 +204,8 @@ class TestTrainerExt(TestCasePlus): ...@@ -206,8 +204,8 @@ class TestTrainerExt(TestCasePlus):
--warmup_steps 8 --warmup_steps 8
--evaluation_strategy steps --evaluation_strategy steps
--logging_steps 0 --logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)} --eval_steps {str(eval_steps)}
--save_steps {str(eval_steps)}
--group_by_length --group_by_length
--label_smoothing_factor 0.1 --label_smoothing_factor 0.1
--adafactor --adafactor
......
...@@ -1497,11 +1497,14 @@ class Trainer: ...@@ -1497,11 +1497,14 @@ class Trainer:
""" """
if is_torch_tpu_available(): if is_torch_tpu_available():
self._save_tpu(output_dir) self._save_tpu(output_dir)
else: elif (
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
):
state_dict = self.model.state_dict()
if self.is_world_process_zero(): if self.is_world_process_zero():
self._save(output_dir) self._save(output_dir, state_dict=state_dict)
if self.args.local_rank != -1: elif self.is_world_process_zero():
dist.barrier() self._save(output_dir)
def _save_tpu(self, output_dir: Optional[str] = None): def _save_tpu(self, output_dir: Optional[str] = None):
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
...@@ -1531,7 +1534,7 @@ class Trainer: ...@@ -1531,7 +1534,7 @@ class Trainer:
if self.tokenizer is not None and self.is_world_process_zero(): if self.tokenizer is not None and self.is_world_process_zero():
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
def _save(self, output_dir: Optional[str] = None): def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that. # If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
...@@ -1540,13 +1543,16 @@ class Trainer: ...@@ -1540,13 +1543,16 @@ class Trainer:
# They can then be reloaded using `from_pretrained()` # They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel): if not isinstance(self.model, PreTrainedModel):
if isinstance(unwrap_model(self.model), PreTrainedModel): if isinstance(unwrap_model(self.model), PreTrainedModel):
unwrap_model(self.model).save_pretrained(output_dir, state_dict=self.model.state_dict()) if state_dict is None:
state_dict = self.model.state_dict()
unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
else: else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict() if state_dict is None:
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else: else:
self.model.save_pretrained(output_dir) self.model.save_pretrained(output_dir, state_dict=state_dict)
if self.tokenizer is not None: if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) self.tokenizer.save_pretrained(output_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment