run_ner.py 15.3 KB
Newer Older
1
# coding=utf-8
2
# Copyright 2020 The HuggingFace Team All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
16
17
18
19
"""
Fine-tuning the library models for token classification.
"""
# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as comments.

20
21
import logging
import os
22
import sys
Julien Chaumond's avatar
Julien Chaumond committed
23
from dataclasses import dataclass, field
24
from typing import Optional
25
26

import numpy as np
27
from datasets import load_dataset
28
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
Aymeric Augustin's avatar
Aymeric Augustin committed
29

30
import transformers
Aymeric Augustin's avatar
Aymeric Augustin committed
31
from transformers import (
32
33
34
    AutoConfig,
    AutoModelForTokenClassification,
    AutoTokenizer,
35
    DataCollatorForTokenClassification,
Julien Chaumond's avatar
Julien Chaumond committed
36
37
38
39
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
Aymeric Augustin's avatar
Aymeric Augustin committed
40
)
41
from transformers.trainer_utils import is_main_process
Aymeric Augustin's avatar
Aymeric Augustin committed
42
43


44
45
46
logger = logging.getLogger(__name__)


Julien Chaumond's avatar
Julien Chaumond committed
47
48
49
50
51
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
52

Julien Chaumond's avatar
Julien Chaumond committed
53
54
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
55
    )
Julien Chaumond's avatar
Julien Chaumond committed
56
57
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
58
    )
Julien Chaumond's avatar
Julien Chaumond committed
59
60
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
61
    )
Julien Chaumond's avatar
Julien Chaumond committed
62
63
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
64
    )
65
66


Julien Chaumond's avatar
Julien Chaumond committed
67
68
69
70
71
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
72

73
74
75
76
77
78
    task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."})
    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
79
    )
80
81
82
83
    train_file: Optional[str] = field(
        default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
    )
    validation_file: Optional[str] = field(
84
        default=None,
85
        metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
86
    )
87
88
89
    test_file: Optional[str] = field(
        default=None,
        metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
90
    )
Julien Chaumond's avatar
Julien Chaumond committed
91
92
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
93
    )
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )
    pad_to_max_length: bool = field(
        default=False,
        metadata={
            "help": "Whether to pad all samples to model maximum sentence length. "
            "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
            "efficient on GPU but very bad for TPU."
        },
    )
    label_all_tokens: bool = field(
        default=False,
        metadata={
            "help": "Whether to put the label for one word on all tokens of generated by that word or just on the "
            "one (in which case the other tokens will have a padding index)."
        },
    )

    def __post_init__(self):
        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
            raise ValueError("Need either a dataset name or a training/validation file.")
        else:
            if self.train_file is not None:
                extension = self.train_file.split(".")[-1]
                assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
            if self.validation_file is not None:
                extension = self.validation_file.split(".")[-1]
                assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
        self.task_name = self.task_name.lower()
125

Julien Chaumond's avatar
Julien Chaumond committed
126
127
128
129
130
131
132

def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
133
134
135
136
137
138
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
139

140
    if (
Julien Chaumond's avatar
Julien Chaumond committed
141
142
143
144
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
145
    ):
146
        raise ValueError(
147
148
            f"Output directory ({training_args.output_dir}) already exists and is not empty."
            "Use --overwrite_output_dir to overcome."
149
150
        )

151
    # Setup logging
152
153
154
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
155
        level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
156
    )
157
158

    # Log on each process the small summary:
159
    logger.warning(
160
161
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
162
    )
163
164
165
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
166
167
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
Julien Chaumond's avatar
Julien Chaumond committed
168
    logger.info("Training/evaluation parameters %s", training_args)
169

170
    # Set seed before initializing model.
Julien Chaumond's avatar
Julien Chaumond committed
171
    set_seed(training_args.seed)
172

173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
    # (the dataset will be downloaded automatically from the datasets Hub).
    #
    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
    # 'text' is found. You can easily tweak this behavior (see below).
    #
    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
    # download the dataset.
    if data_args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)
    else:
        data_files = {}
        if data_args.train_file is not None:
            data_files["train"] = data_args.train_file
        if data_args.validation_file is not None:
            data_files["validation"] = data_args.validation_file
        if data_args.test_file is not None:
            data_files["test"] = data_args.test_file
        extension = data_args.train_file.split(".")[-1]
        datasets = load_dataset(extension, data_files=data_files)
    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
    # https://huggingface.co/docs/datasets/loading_datasets.html.

    if training_args.do_train:
        column_names = datasets["train"].column_names
    else:
        column_names = datasets["validation"].column_names
    text_column_name = "words" if "words" in column_names else column_names[0]
    label_column_name = data_args.task_name if data_args.task_name in column_names else column_names[1]

    # Labeling (this part will be easier when https://github.com/huggingface/datasets/issues/797 is solved)
    def get_label_list(labels):
        unique_labels = set()
        for label in labels:
            unique_labels = unique_labels | set(label)
        label_list = list(unique_labels)
        label_list.sort()
        return label_list

    label_list = get_label_list(datasets["train"][label_column_name])
    label_to_id = {l: i for i, l in enumerate(label_list)}
    num_labels = len(label_list)
217
218

    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
219
220
221
222
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
223
    config = AutoConfig.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
224
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
225
        num_labels=num_labels,
226
        finetuning_task=data_args.task_name,
Julien Chaumond's avatar
Julien Chaumond committed
227
        cache_dir=model_args.cache_dir,
228
    )
229
    tokenizer = AutoTokenizer.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
230
231
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
232
        use_fast=True,
233
    )
234
    model = AutoModelForTokenClassification.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
235
236
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
237
        config=config,
Julien Chaumond's avatar
Julien Chaumond committed
238
        cache_dir=model_args.cache_dir,
239
    )
240

241
242
243
244
245
246
247
248
249
250
251
252
253
    # Preprocessing the dataset
    # Padding strategy
    padding = "max_length" if data_args.pad_to_max_length else False

    # Tokenize all texts and align the labels with them.
    def tokenize_and_align_labels(examples):
        tokenized_inputs = tokenizer(
            examples[text_column_name],
            padding=padding,
            truncation=True,
            # We use this argument because the texts in our dataset are lists of words (with a label for each word).
            is_split_into_words=True,
            return_offsets_mapping=True,
Julien Chaumond's avatar
Julien Chaumond committed
254
        )
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        offset_mappings = tokenized_inputs.pop("offset_mapping")
        labels = []
        for label, offset_mapping in zip(examples[label_column_name], offset_mappings):
            label_index = 0
            current_label = -100
            label_ids = []
            for offset in offset_mapping:
                # We set the label for the first token of each word. Special characters will have an offset of (0, 0)
                # so the test ignores them.
                if offset[0] == 0 and offset[1] != 0:
                    current_label = label_to_id[label[label_index]]
                    label_index += 1
                    label_ids.append(current_label)
                # For special tokens, we set the label to -100 so it's automatically ignored in the loss function.
                elif offset[0] == 0 and offset[1] == 0:
                    label_ids.append(-100)
                # For the other tokens in a word, we set the label to either the current label or -100, depending on
                # the label_all_tokens flag.
                else:
                    label_ids.append(current_label if data_args.label_all_tokens else -100)

            labels.append(label_ids)
        tokenized_inputs["labels"] = labels
        return tokenized_inputs

    tokenized_datasets = datasets.map(
        tokenize_and_align_labels,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
Julien Chaumond's avatar
Julien Chaumond committed
285
286
    )

287
288
    # Data collator
    data_collator = DataCollatorForTokenClassification(tokenizer)
Julien Chaumond's avatar
Julien Chaumond committed
289

290
291
292
293
    # Metrics
    def compute_metrics(p):
        predictions, labels = p
        predictions = np.argmax(predictions, axis=2)
Julien Chaumond's avatar
Julien Chaumond committed
294

295
296
297
298
299
300
301
302
303
        # Remove ignored index (special tokens)
        true_predictions = [
            [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        true_labels = [
            [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
Julien Chaumond's avatar
Julien Chaumond committed
304
305

        return {
306
307
308
309
            "accuracy_score": accuracy_score(true_labels, true_predictions),
            "precision": precision_score(true_labels, true_predictions),
            "recall": recall_score(true_labels, true_predictions),
            "f1": f1_score(true_labels, true_predictions),
Julien Chaumond's avatar
Julien Chaumond committed
310
311
312
313
314
315
        }

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
316
317
318
319
        train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
        eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
Julien Chaumond's avatar
Julien Chaumond committed
320
321
        compute_metrics=compute_metrics,
    )
322
323

    # Training
Julien Chaumond's avatar
Julien Chaumond committed
324
325
326
327
    if training_args.do_train:
        trainer.train(
            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
        )
328
        trainer.save_model()  # Saves the tokenizer too for easy upload
329
330
331

    # Evaluation
    results = {}
332
    if training_args.do_eval:
Julien Chaumond's avatar
Julien Chaumond committed
333
334
        logger.info("*** Evaluate ***")

335
        results = trainer.evaluate()
Julien Chaumond's avatar
Julien Chaumond committed
336

337
338
        output_eval_file = os.path.join(training_args.output_dir, "eval_results_ner.txt")
        if trainer.is_world_process_zero():
339
340
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
341
342
343
                for key, value in results.items():
                    logger.info(f"  {key} = {value}")
                    writer.write(f"{key} = {value}\n")
Julien Chaumond's avatar
Julien Chaumond committed
344
345

    # Predict
346
    if training_args.do_predict:
347
348
        logger.info("*** Predict ***")

349
        test_dataset = tokenized_datasets["test"]
350
351
        predictions, labels, metrics = trainer.predict(test_dataset)
        predictions = np.argmax(predictions, axis=2)
Julien Chaumond's avatar
Julien Chaumond committed
352

353
354
355
356
357
        # Remove ignored index (special tokens)
        true_predictions = [
            [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
Julien Chaumond's avatar
Julien Chaumond committed
358
359

        output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
360
361
362
        if trainer.is_world_master():
            with open(output_test_results_file, "w") as writer:
                for key, value in metrics.items():
363
364
                    logger.info(f"  {key} = {value}")
                    writer.write(f"{key} = {value}\n")
Julien Chaumond's avatar
Julien Chaumond committed
365

366
        # Save predictions
Julien Chaumond's avatar
Julien Chaumond committed
367
        output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
368
369
        if trainer.is_world_master():
            with open(output_test_predictions_file, "w") as writer:
370
371
                for prediction in true_predictions:
                    writer.write(" ".join(prediction) + "\n")
372

373
374
375
    return results


376
377
378
379
380
def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


381
382
if __name__ == "__main__":
    main()