test_seq2seq_examples_multi_gpu.py 4.29 KB
Newer Older
1
2
3
4
5
# as due to their complexity multi-gpu tests could impact other tests, and to aid debug we have those in a separate module.

import os
import sys

6
7
8
9
10
from transformers.testing_utils import (
    TestCasePlus,
    execute_subprocess_async,
    get_gpu_count,
    require_torch_gpu,
11
    require_torch_multi_gpu,
12
13
    slow,
)
14

Stas Bekman's avatar
Stas Bekman committed
15
from .test_seq2seq_examples import CHEAP_ARGS, make_test_data_dir
16
from .utils import load_json
17
18


19
20
21
22
23
class TestSummarizationDistillerMultiGPU(TestCasePlus):
    @classmethod
    def setUpClass(cls):
        return cls

24
25
    @require_torch_multi_gpu
    def test_multi_gpu(self):
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

        updates = dict(
            no_teacher=True,
            freeze_encoder=True,
            gpus=2,
            overwrite_output_dir=True,
            sortish_sampler=True,
        )
        self._test_distiller_cli_fork(updates, check_contents=False)

    def _test_distiller_cli_fork(self, updates, check_contents=True):
        default_updates = dict(
            label_smoothing=0.0,
            early_stopping_patience=-1,
            train_batch_size=1,
            eval_batch_size=2,
            max_epochs=2,
            alpha_mlm=0.2,
            alpha_ce=0.8,
            do_predict=True,
            model_name_or_path="sshleifer/tinier_bart",
            teacher=CHEAP_ARGS["model_name_or_path"],
            val_check_interval=0.5,
        )
        default_updates.update(updates)
        args_d: dict = CHEAP_ARGS.copy()
        tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
        output_dir = self.get_auto_remove_tmp_dir()
        args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)

        def convert(k, v):
            if k in ["tgt_suffix", "server_ip", "server_port", "out", "n_tpu_cores"]:
                return ""
            if v is False or v is None:
                return ""
            if v is True:  # or len(str(v))==0:
                return f"--{k}"
            return f"--{k}={v}"

65
        cli_args = [x for x in (convert(k, v) for k, v in args_d.items()) if len(x)]
66
67
        cmd = [sys.executable, f"{self.test_file_dir}/distillation.py"] + cli_args
        execute_subprocess_async(cmd, env=self.get_env())
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

        contents = os.listdir(output_dir)
        contents = {os.path.basename(p) for p in contents}
        ckpt_files = [p for p in contents if p.endswith("ckpt")]
        assert len(ckpt_files) > 0

        self.assertIn("test_generations.txt", contents)
        self.assertIn("test_results.txt", contents)

        # get the following from the module, (we don't have access to `model` here)
        metrics_save_path = os.path.join(output_dir, "metrics.json")
        val_metric = "rouge2"

        metrics = load_json(metrics_save_path)
        # {'test': [{'test_avg_loss': 10.63731575012207, 'test_avg_rouge1': 0.0, 'test_avg_rouge2': 0.0, 'test_avg_rougeL': 0.0, 'test_avg_gen_time': 0.1822289228439331, 'test_avg_gen_len': 142.0, 'step_count': 1}]}
        print(metrics)
        last_step_stats = metrics["val"][-1]
        self.assertGreaterEqual(last_step_stats["val_avg_gen_time"], 0.01)
        self.assertIsInstance(last_step_stats[f"val_avg_{val_metric}"], float)
        self.assertEqual(len(metrics["test"]), 1)
        desired_n_evals = int(args_d["max_epochs"] * (1 / args_d["val_check_interval"]) / 2 + 1)
        self.assertEqual(len(metrics["val"]), desired_n_evals)
90
91
92
93
94
95
96
97

    @slow
    @require_torch_gpu
    def test_distributed_eval(self):
        output_dir = self.get_auto_remove_tmp_dir()
        args = f"""
            --model_name Helsinki-NLP/opus-mt-en-ro
            --save_dir {output_dir}
98
            --data_dir {self.test_file_dir_str}/test_data/wmt_en_ro
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
            --num_beams 2
            --task translation
        """.split()

        # we want this test to run even if there is only one GPU, but if there are more we use them all
        n_gpu = get_gpu_count()
        distributed_args = f"""
            -m torch.distributed.launch
            --nproc_per_node={n_gpu}
            {self.test_file_dir}/run_distributed_eval.py
        """.split()
        cmd = [sys.executable] + distributed_args + args
        execute_subprocess_async(cmd, env=self.get_env())

        metrics_save_path = os.path.join(output_dir, "test_bleu.json")
        metrics = load_json(metrics_save_path)
        # print(metrics)
        self.assertGreaterEqual(metrics["bleu"], 25)