run_ner.py 10.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Matt Maybeno's avatar
Matt Maybeno committed
16
""" Fine-tuning the library models for named entity recognition on CoNLL-2003 (Bert or Roberta). """
17
18
19
20


import logging
import os
Julien Chaumond's avatar
Julien Chaumond committed
21
22
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
23
24

import numpy as np
25
from seqeval.metrics import f1_score, precision_score, recall_score
Julien Chaumond's avatar
Julien Chaumond committed
26
from torch import nn
Aymeric Augustin's avatar
Aymeric Augustin committed
27
28

from transformers import (
29
30
31
    AutoConfig,
    AutoModelForTokenClassification,
    AutoTokenizer,
Julien Chaumond's avatar
Julien Chaumond committed
32
33
34
35
36
    EvalPrediction,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
Aymeric Augustin's avatar
Aymeric Augustin committed
37
)
Julien Chaumond's avatar
Julien Chaumond committed
38
from utils_ner import NerDataset, Split, get_labels
Aymeric Augustin's avatar
Aymeric Augustin committed
39
40


41
42
43
logger = logging.getLogger(__name__)


Julien Chaumond's avatar
Julien Chaumond committed
44
45
46
47
48
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
49

Julien Chaumond's avatar
Julien Chaumond committed
50
51
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
52
    )
Julien Chaumond's avatar
Julien Chaumond committed
53
54
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
55
    )
Julien Chaumond's avatar
Julien Chaumond committed
56
57
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
58
    )
Julien Chaumond's avatar
Julien Chaumond committed
59
60
61
62
63
    use_fast: bool = field(default=False, metadata={"help": "Set this flag to use fast tokenization."})
    # If you want to tweak more attributes on your tokenizer, you should do it in a distinct script,
    # or just modify its tokenizer_config.json.
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
64
    )
65
66


Julien Chaumond's avatar
Julien Chaumond committed
67
68
69
70
71
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
72

Julien Chaumond's avatar
Julien Chaumond committed
73
74
    data_dir: str = field(
        metadata={"help": "The input data dir. Should contain the .txt files for a CoNLL-2003-formatted task."}
75
    )
Julien Chaumond's avatar
Julien Chaumond committed
76
77
    labels: Optional[str] = field(
        metadata={"help": "Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."}
78
    )
Julien Chaumond's avatar
Julien Chaumond committed
79
    max_seq_length: int = field(
80
        default=128,
Julien Chaumond's avatar
Julien Chaumond committed
81
82
83
84
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
85
    )
Julien Chaumond's avatar
Julien Chaumond committed
86
87
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
88
89
    )

Julien Chaumond's avatar
Julien Chaumond committed
90
91
92
93
94
95
96
97

def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
98

99
    if (
Julien Chaumond's avatar
Julien Chaumond committed
100
101
102
103
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
104
    ):
105
        raise ValueError(
Julien Chaumond's avatar
Julien Chaumond committed
106
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
107
        )
108
109

    # Setup logging
110
111
112
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
Julien Chaumond's avatar
Julien Chaumond committed
113
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
114
115
116
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
Julien Chaumond's avatar
Julien Chaumond committed
117
118
119
120
121
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
122
    )
Julien Chaumond's avatar
Julien Chaumond committed
123
    logger.info("Training/evaluation parameters %s", training_args)
124
125

    # Set seed
Julien Chaumond's avatar
Julien Chaumond committed
126
    set_seed(training_args.seed)
127
128

    # Prepare CONLL-2003 task
Julien Chaumond's avatar
Julien Chaumond committed
129
130
    labels = get_labels(data_args.labels)
    label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
131
    num_labels = len(labels)
132
133

    # Load pretrained model and tokenizer
Julien Chaumond's avatar
Julien Chaumond committed
134
135
136
137
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.
138

139
    config = AutoConfig.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
140
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
141
        num_labels=num_labels,
Julien Chaumond's avatar
Julien Chaumond committed
142
        id2label=label_map,
143
        label2id={label: i for i, label in enumerate(labels)},
Julien Chaumond's avatar
Julien Chaumond committed
144
        cache_dir=model_args.cache_dir,
145
    )
146
    tokenizer = AutoTokenizer.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
147
148
149
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast,
150
    )
151
    model = AutoModelForTokenClassification.from_pretrained(
Julien Chaumond's avatar
Julien Chaumond committed
152
153
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
154
        config=config,
Julien Chaumond's avatar
Julien Chaumond committed
155
        cache_dir=model_args.cache_dir,
156
    )
157

Julien Chaumond's avatar
Julien Chaumond committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    # Get datasets
    train_dataset = (
        NerDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.train,
            local_rank=training_args.local_rank,
        )
        if training_args.do_train
        else None
    )
    eval_dataset = (
        NerDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.dev,
            local_rank=training_args.local_rank,
        )
        if training_args.do_eval
        else None
    )
187

Julien Chaumond's avatar
Julien Chaumond committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[List[int], List[int]]:
        preds = np.argmax(predictions, axis=2)

        batch_size, seq_len = preds.shape

        out_label_list = [[] for _ in range(batch_size)]
        preds_list = [[] for _ in range(batch_size)]

        for i in range(batch_size):
            for j in range(seq_len):
                if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
                    out_label_list[i].append(label_map[label_ids[i][j]])
                    preds_list[i].append(label_map[preds[i][j]])

        return preds_list, out_label_list

    def compute_metrics(p: EvalPrediction) -> Dict:
        preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
        return {
            "precision": precision_score(out_label_list, preds_list),
            "recall": recall_score(out_label_list, preds_list),
            "f1": f1_score(out_label_list, preds_list),
        }

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
    )
220
221

    # Training
Julien Chaumond's avatar
Julien Chaumond committed
222
223
224
225
    if training_args.do_train:
        trainer.train(
            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
        )
226
227
228

    # Evaluation
    results = {}
Julien Chaumond's avatar
Julien Chaumond committed
229
230
231
232
233
234
    if training_args.do_eval and training_args.local_rank in [-1, 0]:
        logger.info("*** Evaluate ***")

        result = trainer.evaluate()

        output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
235
        with open(output_eval_file, "w") as writer:
Julien Chaumond's avatar
Julien Chaumond committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            logger.info("***** Eval results *****")
            for key, value in result.items():
                logger.info("  %s = %s", key, value)
                writer.write("%s = %s\n" % (key, value))

            results.update(result)

    # Predict
    if training_args.do_predict and training_args.local_rank in [-1, 0]:
        test_dataset = NerDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            labels=labels,
            model_type=config.model_type,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.test,
            local_rank=training_args.local_rank,
        )

        predictions, label_ids, metrics = trainer.predict(test_dataset)
        preds_list, _ = align_predictions(predictions, label_ids)

        output_test_results_file = os.path.join(training_args.output_dir, "test_results.txt")
260
        with open(output_test_results_file, "w") as writer:
Julien Chaumond's avatar
Julien Chaumond committed
261
262
263
264
            for key, value in metrics.items():
                logger.info("  %s = %s", key, value)
                writer.write("%s = %s\n" % (key, value))

265
        # Save predictions
Julien Chaumond's avatar
Julien Chaumond committed
266
        output_test_predictions_file = os.path.join(training_args.output_dir, "test_predictions.txt")
267
        with open(output_test_predictions_file, "w") as writer:
Julien Chaumond's avatar
Julien Chaumond committed
268
            with open(os.path.join(data_args.data_dir, "test.txt"), "r") as f:
269
270
271
272
                example_id = 0
                for line in f:
                    if line.startswith("-DOCSTART-") or line == "" or line == "\n":
                        writer.write(line)
Julien Chaumond's avatar
Julien Chaumond committed
273
                        if not preds_list[example_id]:
274
                            example_id += 1
Julien Chaumond's avatar
Julien Chaumond committed
275
276
                    elif preds_list[example_id]:
                        output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n"
277
278
279
280
                        writer.write(output_line)
                    else:
                        logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])

281
282
283
284
285
    return results


if __name__ == "__main__":
    main()