test_trainer_distributed.py 3.97 KB
Newer Older
1
2
3
4
import sys
from typing import Dict

from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available
5
from transformers.testing_utils import TestCasePlus, execute_subprocess_async, require_torch_multi_gpu
6
from transformers.utils import logging
7
8


9
logger = logging.get_logger(__name__)
10
11
12
13
14
15
16


if is_torch_available():
    import torch
    from torch import nn
    from torch.utils.data.dataset import Dataset

17
    from transformers import Trainer
18
19
20
21
22
23
24
25
26
27
28

    class DummyDataset(Dataset):
        def __init__(self, length: int = 101):
            self.length = length

        def __len__(self):
            return self.length

        def __getitem__(self, i) -> int:
            return i

29
30
    class DummyDataCollator:
        def __call__(self, features):
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
            return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}

    class DummyModel(nn.Module):
        def __init__(self):
            super().__init__()
            # Add some (unused) params otherwise DDP will complain.
            self.fc = nn.Linear(120, 80)

        def forward(self, input_ids, labels=None):
            if labels is not None:
                return torch.tensor(0.0, device=input_ids.device), input_ids
            else:
                return input_ids


46
class TestTrainerDistributed(TestCasePlus):
47
    @require_torch_multi_gpu
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    def test_trainer(self):

        distributed_args = f"""
            -m torch.distributed.launch
            --nproc_per_node={torch.cuda.device_count()}
            {self.test_file_dir}/test_trainer_distributed.py
        """.split()
        output_dir = self.get_auto_remove_tmp_dir()
        args = f"--output_dir {output_dir}".split()
        cmd = [sys.executable] + distributed_args + args
        execute_subprocess_async(cmd, env=self.get_env())
        # successful return here == success - any errors would have caused an error in the sub-call


62
if __name__ == "__main__":
63
64
65
66
    # The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
    #
    # PYTHONPATH="src" python -m torch.distributed.launch --nproc_per_node 2 --output_dir output_dir ./tests/test_trainer_distributed.py

67
    parser = HfArgumentParser((TrainingArguments,))
Sylvain Gugger's avatar
Sylvain Gugger committed
68
    training_args = parser.parse_args_into_dataclasses()[0]
69
70
71
72
73
74
75
76
77

    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        training_args.local_rank != -1,
    )

78
79
    # Essentially, what we want to verify in the distributed case is that we get all samples back,
    # in the right order. (this is crucial for prediction for instance)
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    for dataset_length in [101, 40, 7]:
        dataset = DummyDataset(dataset_length)

        def compute_metrics(p: EvalPrediction) -> Dict:
            sequential = list(range(len(dataset)))
            success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential
            return {"success": success}

        trainer = Trainer(
            model=DummyModel(),
            args=training_args,
            data_collator=DummyDataCollator(),
            eval_dataset=dataset,
            compute_metrics=compute_metrics,
        )
        metrics = trainer.evaluate()
        logger.info(metrics)
        if metrics["eval_success"] is not True:
            logger.error(metrics)
            exit(1)

        p = trainer.predict(dataset)
        logger.info(p.metrics)
        if p.metrics["eval_success"] is not True:
            logger.error(p.metrics)
            exit(1)

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        trainer.args.eval_accumulation_steps = 2

        metrics = trainer.evaluate()
        logger.info(metrics)
        if metrics["eval_success"] is not True:
            logger.error(metrics)
            exit(1)

        p = trainer.predict(dataset)
        logger.info(p.metrics)
        if p.metrics["eval_success"] is not True:
            logger.error(p.metrics)
            exit(1)

        trainer.args.eval_accumulation_steps = None