Unverified Commit 7f34d757 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[s2s trainer] fix DP mode (#8823)

* fix DP case on multi-gpu

* make executable

* test all 3 modes

* use the correct check for distributed

* dp doesn't need a special case

* restore original name

* cleanup
parent d8fc26e9
#!/usr/bin/env python
import logging import logging
import os import os
import sys import sys
......
...@@ -122,7 +122,8 @@ class Seq2SeqTrainer(Trainer): ...@@ -122,7 +122,8 @@ class Seq2SeqTrainer(Trainer):
else: else:
if self.args.sortish_sampler: if self.args.sortish_sampler:
self.train_dataset.make_sortish_sampler( self.train_dataset.make_sortish_sampler(
self.args.per_device_train_batch_size, distributed=self.args.n_gpu > 1 self.args.per_device_train_batch_size,
distributed=(self.args.local_rank != -1),
) )
return ( return (
......
...@@ -4,7 +4,14 @@ from unittest.mock import patch ...@@ -4,7 +4,14 @@ from unittest.mock import patch
from transformers import BertTokenizer, EncoderDecoderModel from transformers import BertTokenizer, EncoderDecoderModel
from transformers.file_utils import is_datasets_available from transformers.file_utils import is_datasets_available
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, get_gpu_count, slow from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_gpu_count,
require_torch_multi_gpu,
require_torch_non_multi_gpu,
slow,
)
from transformers.trainer_callback import TrainerState from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
...@@ -18,17 +25,32 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" ...@@ -18,17 +25,32 @@ MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
class TestFinetuneTrainer(TestCasePlus): class TestFinetuneTrainer(TestCasePlus):
def test_finetune_trainer(self): def finetune_trainer_quick(self, distributed=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1) output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()] eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0] first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats assert "eval_bleu" in first_step_stats
@require_torch_non_multi_gpu
def test_finetune_trainer_no_dist(self):
self.finetune_trainer_quick()
# the following 2 tests verify that the trainer can handle distributed and non-distributed with n_gpu > 1
@require_torch_multi_gpu
def test_finetune_trainer_dp(self):
self.finetune_trainer_quick(distributed=False)
@require_torch_multi_gpu
def test_finetune_trainer_ddp(self):
self.finetune_trainer_quick(distributed=True)
@slow @slow
def test_finetune_trainer_slow(self): def test_finetune_trainer_slow(self):
# There is a missing call to __init__process_group somewhere # There is a missing call to __init__process_group somewhere
output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10) output_dir = self.run_trainer(
eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=10, distributed=False
)
# Check metrics # Check metrics
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
...@@ -158,7 +180,9 @@ class TestFinetuneTrainer(TestCasePlus): ...@@ -158,7 +180,9 @@ class TestFinetuneTrainer(TestCasePlus):
# start training # start training
trainer.train() trainer.train()
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int): def run_trainer(
self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int, distributed: bool = False
):
data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro" data_dir = self.examples_dir / "seq2seq/test_data/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir() output_dir = self.get_auto_remove_tmp_dir()
args = f""" args = f"""
...@@ -193,8 +217,8 @@ class TestFinetuneTrainer(TestCasePlus): ...@@ -193,8 +217,8 @@ class TestFinetuneTrainer(TestCasePlus):
""".split() """.split()
# --eval_beams 2 # --eval_beams 2
n_gpu = get_gpu_count() if distributed:
if n_gpu > 1: n_gpu = get_gpu_count()
distributed_args = f""" distributed_args = f"""
-m torch.distributed.launch -m torch.distributed.launch
--nproc_per_node={n_gpu} --nproc_per_node={n_gpu}
...@@ -203,7 +227,6 @@ class TestFinetuneTrainer(TestCasePlus): ...@@ -203,7 +227,6 @@ class TestFinetuneTrainer(TestCasePlus):
cmd = [sys.executable] + distributed_args + args cmd = [sys.executable] + distributed_args + args
execute_subprocess_async(cmd, env=self.get_env()) execute_subprocess_async(cmd, env=self.get_env())
else: else:
# 0 or 1 gpu
testargs = ["finetune_trainer.py"] + args testargs = ["finetune_trainer.py"] + args
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
main() main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment