run_pl_ner.py 8.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
import argparse
import glob
import logging
import os

import numpy as np
import torch
from seqeval.metrics import f1_score, precision_score, recall_score
from torch.nn import CrossEntropyLoss
10
from torch.utils.data import DataLoader, TensorDataset
11
12
13
14
15
16
17
18
19
20
21
22
23

from transformer_base import BaseTransformer, add_generic_args, generic_train
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file


logger = logging.getLogger(__name__)


class NERTransformer(BaseTransformer):
    """
    A training module for NER. See BaseTransformer for the core options.
    """

24
25
    mode = "token-classification"

26
27
28
    def __init__(self, hparams):
        self.labels = get_labels(hparams.labels)
        num_labels = len(self.labels)
29
        self.pad_token_label_id = CrossEntropyLoss().ignore_index
Julien Chaumond's avatar
Julien Chaumond committed
30
        super().__init__(hparams, num_labels, self.mode)
31
32
33
34
35

    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_num):
36
        "Compute loss and log."
37
        inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
Julien Chaumond's avatar
Julien Chaumond committed
38
        if self.config.model_type != "distilbert":
39
            inputs["token_type_ids"] = (
Julien Chaumond's avatar
Julien Chaumond committed
40
41
                batch[2] if self.config.model_type in ["bert", "xlnet"] else None
            )  # XLM and RoBERTa don"t use token_type_ids
42

43
        outputs = self(**inputs)
44
45
46
47
        loss = outputs[0]
        tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
        return {"loss": loss, "log": tensorboard_logs}

48
49
50
51
52
    def prepare_data(self):
        "Called to initialize data. Use the call to construct features"
        args = self.hparams
        for mode in ["train", "dev", "test"]:
            cached_features_file = self._feature_file(mode)
53
54
55
56
            if os.path.exists(cached_features_file) and not args.overwrite_cache:
                logger.info("Loading features from cached file %s", cached_features_file)
                features = torch.load(cached_features_file)
            else:
57
58
59
60
61
62
63
                logger.info("Creating features from dataset file at %s", args.data_dir)
                examples = read_examples_from_file(args.data_dir, mode)
                features = convert_examples_to_features(
                    examples,
                    self.labels,
                    args.max_seq_length,
                    self.tokenizer,
Julien Chaumond's avatar
Julien Chaumond committed
64
                    cls_token_at_end=bool(self.config.model_type in ["xlnet"]),
65
                    cls_token=self.tokenizer.cls_token,
Julien Chaumond's avatar
Julien Chaumond committed
66
                    cls_token_segment_id=2 if self.config.model_type in ["xlnet"] else 0,
67
                    sep_token=self.tokenizer.sep_token,
Julien Chaumond's avatar
Julien Chaumond committed
68
69
                    sep_token_extra=bool(self.config.model_type in ["roberta"]),
                    pad_on_left=bool(self.config.model_type in ["xlnet"]),
70
71
                    pad_token=self.tokenizer.pad_token_id,
                    pad_token_segment_id=self.tokenizer.pad_token_type_id,
72
73
74
75
76
                    pad_token_label_id=self.pad_token_label_id,
                )
                logger.info("Saving features into cached file %s", cached_features_file)
                torch.save(features, cached_features_file)

77
    def load_dataset(self, mode, batch_size):
78
79
80
81
82
        "Load datasets. Called after prepare data."
        cached_features_file = self._feature_file(mode)
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
Julien Chaumond's avatar
Julien Chaumond committed
83
84
85
86
87
88
        all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
        if features[0].token_type_ids is not None:
            all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
        else:
            all_token_type_ids = torch.tensor([0 for f in features], dtype=torch.long)
            # HACK(we will not use this anymore soon)
89
90
        all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
91
            TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_label_ids), batch_size=batch_size
92
        )
93
94

    def validation_step(self, batch, batch_nb):
95
96
        "Compute validation"

97
        inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
Julien Chaumond's avatar
Julien Chaumond committed
98
        if self.config.model_type != "distilbert":
99
            inputs["token_type_ids"] = (
Julien Chaumond's avatar
Julien Chaumond committed
100
101
                batch[2] if self.config.model_type in ["bert", "xlnet"] else None
            )  # XLM and RoBERTa don"t use token_type_ids
102
        outputs = self(**inputs)
103
104
105
        tmp_eval_loss, logits = outputs[:2]
        preds = logits.detach().cpu().numpy()
        out_label_ids = inputs["labels"].detach().cpu().numpy()
106
        return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids}
107
108

    def _eval_end(self, outputs):
109
        "Evaluation called for both Val and Test"
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean()
        preds = np.concatenate([x["pred"] for x in outputs], axis=0)
        preds = np.argmax(preds, axis=2)
        out_label_ids = np.concatenate([x["target"] for x in outputs], axis=0)

        label_map = {i: label for i, label in enumerate(self.labels)}
        out_label_list = [[] for _ in range(out_label_ids.shape[0])]
        preds_list = [[] for _ in range(out_label_ids.shape[0])]

        for i in range(out_label_ids.shape[0]):
            for j in range(out_label_ids.shape[1]):
                if out_label_ids[i, j] != self.pad_token_label_id:
                    out_label_list[i].append(label_map[out_label_ids[i][j]])
                    preds_list[i].append(label_map[preds[i][j]])

        results = {
            "val_loss": val_loss_mean,
            "precision": precision_score(out_label_list, preds_list),
            "recall": recall_score(out_label_list, preds_list),
            "f1": f1_score(out_label_list, preds_list),
        }

        ret = {k: v for k, v in results.items()}
133
        ret["log"] = results
134
135
        return ret, preds_list, out_label_list

William Falcon's avatar
William Falcon committed
136
    def validation_epoch_end(self, outputs):
Shubham Agarwal's avatar
Shubham Agarwal committed
137
        # when stable
138
        ret, preds, targets = self._eval_end(outputs)
Shubham Agarwal's avatar
Shubham Agarwal committed
139
140
        logs = ret["log"]
        return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
141

Shubham Agarwal's avatar
Shubham Agarwal committed
142
143
    def test_epoch_end(self, outputs):
        # updating to test_epoch_end instead of deprecated test_end
144
145
        ret, predictions, targets = self._eval_end(outputs)

146
        # Converting to the dict required by pl
Shubham Agarwal's avatar
Shubham Agarwal committed
147
148
149
150
151
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
        # pytorch_lightning/trainer/logging.py#L139
        logs = ret["log"]
        # `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
        return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        # Add NER specific options
        BaseTransformer.add_model_specific_args(parser, root_dir)
        parser.add_argument(
            "--max_seq_length",
            default=128,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )

        parser.add_argument(
            "--labels",
            default="",
            type=str,
            help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.",
        )

        parser.add_argument(
            "--data_dir",
            default=None,
            type=str,
            required=True,
            help="The input data dir. Should contain the training files for the CoNLL-2003 NER task.",
        )

        parser.add_argument(
            "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
        )

        return parser


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_generic_args(parser, os.getcwd())
    parser = NERTransformer.add_model_specific_args(parser, os.getcwd())
    args = parser.parse_args()
    model = NERTransformer(args)
    trainer = generic_train(model, args)

    if args.do_predict:
Shubham Agarwal's avatar
Shubham Agarwal committed
196
197
198
199
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
200
        checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
201
        model = model.load_from_checkpoint(checkpoints[-1])
202
        trainer.test(model)