run_glue.py 9.23 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Lysandre's avatar
Lysandre committed
16
""" Finetuning the library models for sequence classification on GLUE."""
thomwolf's avatar
thomwolf committed
17
18


Julien Chaumond's avatar
Julien Chaumond committed
19
import dataclasses
thomwolf's avatar
thomwolf committed
20
21
import logging
import os
22
import sys
23
from dataclasses import dataclass, field
24
from typing import Callable, Dict, Optional
thomwolf's avatar
thomwolf committed
25
26
27

import numpy as np

28
29
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset
from transformers import GlueDataTrainingArguments as DataTrainingArguments
30
from transformers import (
31
    HfArgumentParser,
Julien Chaumond's avatar
Julien Chaumond committed
32
    Trainer,
33
    TrainingArguments,
Julien Chaumond's avatar
Julien Chaumond committed
34
35
36
37
    glue_compute_metrics,
    glue_output_modes,
    glue_tasks_num_labels,
    set_seed,
38
)
Aymeric Augustin's avatar
Aymeric Augustin committed
39

thomwolf's avatar
thomwolf committed
40
41
42

logger = logging.getLogger(__name__)

thomwolf's avatar
thomwolf committed
43

44
45
46
47
48
49
50
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
Julien Chaumond's avatar
Julien Chaumond committed
51
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
52
    )
53
54
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
55
    )
56
57
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
58
    )
59
    cache_dir: Optional[str] = field(
Julien Chaumond's avatar
Julien Chaumond committed
60
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
61
62
63
    )


64
def main():
Julien Chaumond's avatar
Julien Chaumond committed
65
66
67
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
68

69
70
71
72
73
74
75
76
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
thomwolf's avatar
thomwolf committed
77

78
    if (
Julien Chaumond's avatar
Julien Chaumond committed
79
80
81
82
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
83
84
    ):
        raise ValueError(
Julien Chaumond's avatar
Julien Chaumond committed
85
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
86
        )
thomwolf's avatar
thomwolf committed
87

thomwolf's avatar
thomwolf committed
88
    # Setup logging
89
90
91
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
Julien Chaumond's avatar
Julien Chaumond committed
92
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
93
94
95
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
Julien Chaumond's avatar
Julien Chaumond committed
96
97
98
99
100
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
101
    )
Julien Chaumond's avatar
Julien Chaumond committed
102
    logger.info("Training/evaluation parameters %s", training_args)
thomwolf's avatar
thomwolf committed
103

thomwolf's avatar
thomwolf committed
104
    # Set seed
Julien Chaumond's avatar
Julien Chaumond committed
105
    set_seed(training_args.seed)
thomwolf's avatar
thomwolf committed
106

Julien Chaumond's avatar
Julien Chaumond committed
107
108
109
110
111
    try:
        num_labels = glue_tasks_num_labels[data_args.task_name]
        output_mode = glue_output_modes[data_args.task_name]
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))
thomwolf's avatar
thomwolf committed
112
113

    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
114
115
116
117
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
thomwolf's avatar
thomwolf committed
118

119
    config = AutoConfig.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
120
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
121
        num_labels=num_labels,
Julien Chaumond's avatar
Julien Chaumond committed
122
123
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
124
    )
125
    tokenizer = AutoTokenizer.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
126
127
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
128
    )
129
    model = AutoModelForSequenceClassification.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
130
131
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
132
        config=config,
Julien Chaumond's avatar
Julien Chaumond committed
133
        cache_dir=model_args.cache_dir,
134
    )
thomwolf's avatar
thomwolf committed
135

Julien Chaumond's avatar
Julien Chaumond committed
136
    # Get datasets
137
138
139
140
141
142
143
144
145
146
147
148
149
    train_dataset = (
        GlueDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None
    )
    eval_dataset = (
        GlueDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
        if training_args.do_eval
        else None
    )
    test_dataset = (
        GlueDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
        if training_args.do_predict
        else None
    )
thomwolf's avatar
thomwolf committed
150

151
152
153
154
155
156
157
158
159
    def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
        def compute_metrics_fn(p: EvalPrediction):
            if output_mode == "classification":
                preds = np.argmax(p.predictions, axis=1)
            elif output_mode == "regression":
                preds = np.squeeze(p.predictions)
            return glue_compute_metrics(task_name, preds, p.label_ids)

        return compute_metrics_fn
thomwolf's avatar
thomwolf committed
160

Julien Chaumond's avatar
Julien Chaumond committed
161
162
163
164
165
166
    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
167
        compute_metrics=build_compute_metrics_fn(data_args.task_name),
Julien Chaumond's avatar
Julien Chaumond committed
168
    )
thomwolf's avatar
thomwolf committed
169

thomwolf's avatar
thomwolf committed
170
    # Training
Julien Chaumond's avatar
Julien Chaumond committed
171
172
173
174
175
    if training_args.do_train:
        trainer.train(
            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
        )
        trainer.save_model()
176
177
178
179
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)
thomwolf's avatar
thomwolf committed
180

thomwolf's avatar
thomwolf committed
181
    # Evaluation
182
    eval_results = {}
183
    if training_args.do_eval:
Julien Chaumond's avatar
Julien Chaumond committed
184
185
186
187
188
189
        logger.info("*** Evaluate ***")

        # Loop to handle MNLI double evaluation (matched, mis-matched)
        eval_datasets = [eval_dataset]
        if data_args.task_name == "mnli":
            mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
190
191
192
            eval_datasets.append(
                GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir)
            )
Julien Chaumond's avatar
Julien Chaumond committed
193
194

        for eval_dataset in eval_datasets:
195
            trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name)
196
            eval_result = trainer.evaluate(eval_dataset=eval_dataset)
Julien Chaumond's avatar
Julien Chaumond committed
197
198
199

            output_eval_file = os.path.join(
                training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
200
            )
201
202
203
            if trainer.is_world_master():
                with open(output_eval_file, "w") as writer:
                    logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
204
                    for key, value in eval_result.items():
205
206
                        logger.info("  %s = %s", key, value)
                        writer.write("%s = %s\n" % (key, value))
207

208
            eval_results.update(eval_result)
thomwolf's avatar
thomwolf committed
209

210
211
212
213
214
    if training_args.do_predict:
        logging.info("*** Test ***")
        test_datasets = [test_dataset]
        if data_args.task_name == "mnli":
            mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
215
216
217
            test_datasets.append(
                GlueDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir)
            )
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

        for test_dataset in test_datasets:
            predictions = trainer.predict(test_dataset=test_dataset).predictions
            if output_mode == "classification":
                predictions = np.argmax(predictions, axis=1)

            output_test_file = os.path.join(
                training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt"
            )
            if trainer.is_world_master():
                with open(output_test_file, "w") as writer:
                    logger.info("***** Test results {} *****".format(test_dataset.args.task_name))
                    writer.write("index\tprediction\n")
                    for index, item in enumerate(predictions):
                        if output_mode == "regression":
                            writer.write("%d\t%3.3f\n" % (index, item))
                        else:
                            item = test_dataset.get_labels()[item]
                            writer.write("%d\t%s\n" % (index, item))
    return eval_results
thomwolf's avatar
thomwolf committed
238
239


Lysandre Debut's avatar
Lysandre Debut committed
240
241
242
243
244
def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


thomwolf's avatar
thomwolf committed
245
246
if __name__ == "__main__":
    main()