run_ner.py 11.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
""" Fine-tuning the library models for named entity recognition on CoNLL-2003. """
17
18
import logging
import os
19
import sys
Julien Chaumond's avatar
Julien Chaumond committed
20
from dataclasses import dataclass, field
21
from importlib import import_module
Julien Chaumond's avatar
Julien Chaumond committed
22
from typing import Dict, List, Optional, Tuple
23
24

import numpy as np
25
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
Julien Chaumond's avatar
Julien Chaumond committed
26
from torch import nn
Aymeric Augustin's avatar
Aymeric Augustin committed
27
28

from transformers import (
29
30
31
    AutoConfig,
    AutoModelForTokenClassification,
    AutoTokenizer,
Julien Chaumond's avatar
Julien Chaumond committed
32
33
34
35
36
    EvalPrediction,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
Aymeric Augustin's avatar
Aymeric Augustin committed
37
)
38
from utils_ner import Split, TokenClassificationDataset, TokenClassificationTask
Aymeric Augustin's avatar
Aymeric Augustin committed
39
40


41
42
43
logger = logging.getLogger(__name__)


Julien Chaumond's avatar
Julien Chaumond committed
44
45
46
47
48
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
49

Julien Chaumond's avatar
Julien Chaumond committed
50
51
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
52
    )
Julien Chaumond's avatar
Julien Chaumond committed
53
54
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
55
    )
56
57
58
    task_type: Optional[str] = field(
        default="NER", metadata={"help": "Task type to fine tune in training (e.g. NER, POS, etc)"}
    )
Julien Chaumond's avatar
Julien Chaumond committed
59
60
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
61
    )
Julien Chaumond's avatar
Julien Chaumond committed
62
63
64
65
66
    use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."})
    # If you want to tweak more attributes on your tokenizer, you should do it in a distinct script,
    # or just modify its tokenizer_config.json.
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
67
    )
68
69


Julien Chaumond's avatar
Julien Chaumond committed
70
71
72
73
74
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
75

Julien Chaumond's avatar
Julien Chaumond committed
76
77
    data_dir: str = field(
        metadata={"help": "The input data dir. Should contain the .txt files for a CoNLL-2003-formatted task."}
78
    )
Julien Chaumond's avatar
Julien Chaumond committed
79
    labels: Optional[str] = field(
80
81
        default=None,
        metadata={"help": "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."},
82
    )
Julien Chaumond's avatar
Julien Chaumond committed
83
    max_seq_length: int = field(
84
        default=128,
Julien Chaumond's avatar
Julien Chaumond committed
85
86
87
88
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
89
    )
Julien Chaumond's avatar
Julien Chaumond committed
90
91
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
92
93
    )

Julien Chaumond's avatar
Julien Chaumond committed
94
95
96
97
98
99
100

def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
101
102
103
104
105
106
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
107

108
    if (
Julien Chaumond's avatar
Julien Chaumond committed
109
110
111
112
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
113
    ):
114
        raise ValueError(
Julien Chaumond's avatar
Julien Chaumond committed
115
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
116
        )
117

118
119
120
121
122
123
124
125
126
127
    module = import_module("tasks")
    try:
        token_classification_task_clazz = getattr(module, model_args.task_type)
        token_classification_task: TokenClassificationTask = token_classification_task_clazz()
    except AttributeError:
        raise ValueError(
            f"Task {model_args.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
            f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
        )

128
    # Setup logging
129
130
131
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
Julien Chaumond's avatar
Julien Chaumond committed
132
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
133
134
135
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
Julien Chaumond's avatar
Julien Chaumond committed
136
137
138
139
140
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
141
    )
Julien Chaumond's avatar
Julien Chaumond committed
142
    logger.info("Training/evaluation parameters %s", training_args)
143
144

    # Set seed
Julien Chaumond's avatar
Julien Chaumond committed
145
    set_seed(training_args.seed)
146
147

    # Prepare CONLL-2003 task
148
    labels = token_classification_task.get_labels(data_args.labels)
Julien Chaumond's avatar
Julien Chaumond committed
149
    label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
150
    num_labels = len(labels)
151
152

    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
153
154
155
156
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
157

158
    config = AutoConfig.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
159
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
160
        num_labels=num_labels,
Julien Chaumond's avatar
Julien Chaumond committed
161
        id2label=label_map,
162
        label2id={label: i for i, label in enumerate(labels)},
Julien Chaumond's avatar
Julien Chaumond committed
163
        cache_dir=model_args.cache_dir,
164
    )
165
    tokenizer = AutoTokenizer.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
166
167
168
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast,
169
    )
170
    model = AutoModelForTokenClassification.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
171
172
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
173
        config=config,
Julien Chaumond's avatar
Julien Chaumond committed
174
        cache_dir=model_args.cache_dir,
175
    )
176

Julien Chaumond's avatar
Julien Chaumond committed
177
178
    # Get datasets
    train_dataset = (
179
180
        TokenClassificationDataset(
            token_classification_task=token_classification_task,
Julien Chaumond's avatar
Julien Chaumond committed
181
182
183
184
185
186
187
188
189
190
191
192
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.train,
        )
        if training_args.do_train
        else None
    )
    eval_dataset = (
193
194
        TokenClassificationDataset(
            token_classification_task=token_classification_task,
Julien Chaumond's avatar
Julien Chaumond committed
195
196
197
198
199
200
201
202
203
204
205
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.dev,
        )
        if training_args.do_eval
        else None
    )
206

Julien Chaumond's avatar
Julien Chaumond committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[List[int], List[int]]:
        preds = np.argmax(predictions, axis=2)

        batch_size, seq_len = preds.shape

        out_label_list = [[] for _ in range(batch_size)]
        preds_list = [[] for _ in range(batch_size)]

        for i in range(batch_size):
            for j in range(seq_len):
                if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
                    out_label_list[i].append(label_map[label_ids[i][j]])
                    preds_list[i].append(label_map[preds[i][j]])

        return preds_list, out_label_list

    def compute_metrics(p: EvalPrediction) -> Dict:
        preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
        return {
226
            "accuracy_score": accuracy_score(out_label_list, preds_list),
Julien Chaumond's avatar
Julien Chaumond committed
227
228
229
230
231
232
233
234
235
236
237
238
239
            "precision": precision_score(out_label_list, preds_list),
            "recall": recall_score(out_label_list, preds_list),
            "f1": f1_score(out_label_list, preds_list),
        }

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )
240
241

    # Training
Julien Chaumond's avatar
Julien Chaumond committed
242
243
244
245
    if training_args.do_train:
        trainer.train(
            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
        )
246
247
248
249
250
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)
251
252
253

    # Evaluation
    results = {}
254
    if training_args.do_eval:
Julien Chaumond's avatar
Julien Chaumond committed
255
256
257
258
259
        logger.info("*** Evaluate ***")

        result = trainer.evaluate()

        output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
260
261
262
263
264
265
        if trainer.is_world_master():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key, value in result.items():
                    logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))
Julien Chaumond's avatar
Julien Chaumond committed
266
267
268
269

            results.update(result)

    # Predict
270
    if training_args.do_predict:
271
272
        test_dataset = TokenClassificationDataset(
            token_classification_task=token_classification_task,
Julien Chaumond's avatar
Julien Chaumond committed
273
274
275
276
277
278
279
280
281
282
283
284
285
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.test,
        )

        predictions, label_ids, metrics = trainer.predict(test_dataset)
        preds_list, _ = align_predictions(predictions, label_ids)

        output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
286
287
288
289
290
        if trainer.is_world_master():
            with open(output_test_results_file, "w") as writer:
                for key, value in metrics.items():
                    logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))
Julien Chaumond's avatar
Julien Chaumond committed
291

292
        # Save predictions
Julien Chaumond's avatar
Julien Chaumond committed
293
        output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
294
295
296
        if trainer.is_world_master():
            with open(output_test_predictions_file, "w") as writer:
                with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
297
                    token_classification_task.write_predictions_to_file(writer, f, preds_list)
298

299
300
301
    return results


302
303
304
305
306
def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


307
308
if __name__ == "__main__":
    main()