run_pl_ner.py 9.43 KB
Newer Older
1
2
3
4
import argparse
import glob
import logging
import os
5
6
from argparse import Namespace
from importlib import import_module
7
8
9

import numpy as np
import torch
10
from seqeval.metrics import accuracy_score, f1_score, precision_score, recall_score
11
from torch.nn import CrossEntropyLoss
12
from torch.utils.data import DataLoader, TensorDataset
13

14
from lightning_base import BaseTransformer, add_generic_args, generic_train
15
from utils_ner import TokenClassificationTask
16
17
18
19
20
21
22
23
24
25


logger = logging.getLogger(__name__)


class NERTransformer(BaseTransformer):
    """
    A training module for NER. See BaseTransformer for the core options.
    """

26
27
    mode = "token-classification"

28
    def __init__(self, hparams):
29
30
31
32
33
34
35
36
37
38
39
40
        if type(hparams) == dict:
            hparams = Namespace(**hparams)
        module = import_module("tasks")
        try:
            token_classification_task_clazz = getattr(module, hparams.task_type)
            self.token_classification_task: TokenClassificationTask = token_classification_task_clazz()
        except AttributeError:
            raise ValueError(
                f"Task {hparams.task_type} needs to be defined as a TokenClassificationTask subclass in {module}. "
                f"Available tasks classes are: {TokenClassificationTask.__subclasses__()}"
            )
        self.labels = self.token_classification_task.get_labels(hparams.labels)
41
        self.pad_token_label_id = CrossEntropyLoss().ignore_index
42
        super().__init__(hparams, len(self.labels), self.mode)
43
44
45
46
47

    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_num):
48
        "Compute loss and log."
49
        inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
Julien Chaumond's avatar
Julien Chaumond committed
50
        if self.config.model_type != "distilbert":
51
            inputs["token_type_ids"] = (
Julien Chaumond's avatar
Julien Chaumond committed
52
53
                batch[2] if self.config.model_type in ["bert", "xlnet"] else None
            )  # XLM and RoBERTa don"t use token_type_ids
54

55
        outputs = self(**inputs)
56
        loss = outputs[0]
57
58
        # tensorboard_logs = {"loss": loss, "rate": self.lr_scheduler.get_last_lr()[-1]}
        return {"loss": loss}
59

60
61
62
63
64
    def prepare_data(self):
        "Called to initialize data. Use the call to construct features"
        args = self.hparams
        for mode in ["train", "dev", "test"]:
            cached_features_file = self._feature_file(mode)
65
66
67
68
            if os.path.exists(cached_features_file) and not args.overwrite_cache:
                logger.info("Loading features from cached file %s", cached_features_file)
                features = torch.load(cached_features_file)
            else:
69
                logger.info("Creating features from dataset file at %s", args.data_dir)
70
71
                examples = self.token_classification_task.read_examples_from_file(args.data_dir, mode)
                features = self.token_classification_task.convert_examples_to_features(
72
73
74
75
                    examples,
                    self.labels,
                    args.max_seq_length,
                    self.tokenizer,
Julien Chaumond's avatar
Julien Chaumond committed
76
                    cls_token_at_end=bool(self.config.model_type in ["xlnet"]),
77
                    cls_token=self.tokenizer.cls_token,
Julien Chaumond's avatar
Julien Chaumond committed
78
                    cls_token_segment_id=2 if self.config.model_type in ["xlnet"] else 0,
79
                    sep_token=self.tokenizer.sep_token,
80
                    sep_token_extra=False,
Julien Chaumond's avatar
Julien Chaumond committed
81
                    pad_on_left=bool(self.config.model_type in ["xlnet"]),
82
83
                    pad_token=self.tokenizer.pad_token_id,
                    pad_token_segment_id=self.tokenizer.pad_token_type_id,
84
85
86
87
88
                    pad_token_label_id=self.pad_token_label_id,
                )
                logger.info("Saving features into cached file %s", cached_features_file)
                torch.save(features, cached_features_file)

89
    def get_dataloader(self, mode: int, batch_size: int) -> DataLoader:
90
91
92
93
94
        "Load datasets. Called after prepare data."
        cached_features_file = self._feature_file(mode)
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
Julien Chaumond's avatar
Julien Chaumond committed
95
96
97
98
99
100
        all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
        if features[0].token_type_ids is not None:
            all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
        else:
            all_token_type_ids = torch.tensor([0 for f in features], dtype=torch.long)
            # HACK(we will not use this anymore soon)
101
102
        all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)
        return DataLoader(
Julien Chaumond's avatar
Julien Chaumond committed
103
            TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_label_ids), batch_size=batch_size
104
        )
105
106

    def validation_step(self, batch, batch_nb):
107
        """Compute validation""" ""
108
        inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
Julien Chaumond's avatar
Julien Chaumond committed
109
        if self.config.model_type != "distilbert":
110
            inputs["token_type_ids"] = (
Julien Chaumond's avatar
Julien Chaumond committed
111
112
                batch[2] if self.config.model_type in ["bert", "xlnet"] else None
            )  # XLM and RoBERTa don"t use token_type_ids
113
        outputs = self(**inputs)
114
115
116
        tmp_eval_loss, logits = outputs[:2]
        preds = logits.detach().cpu().numpy()
        out_label_ids = inputs["labels"].detach().cpu().numpy()
117
        return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids}
118
119

    def _eval_end(self, outputs):
120
        "Evaluation called for both Val and Test"
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        val_loss_mean = torch.stack([x["val_loss"] for x in outputs]).mean()
        preds = np.concatenate([x["pred"] for x in outputs], axis=0)
        preds = np.argmax(preds, axis=2)
        out_label_ids = np.concatenate([x["target"] for x in outputs], axis=0)

        label_map = {i: label for i, label in enumerate(self.labels)}
        out_label_list = [[] for _ in range(out_label_ids.shape[0])]
        preds_list = [[] for _ in range(out_label_ids.shape[0])]

        for i in range(out_label_ids.shape[0]):
            for j in range(out_label_ids.shape[1]):
                if out_label_ids[i, j] != self.pad_token_label_id:
                    out_label_list[i].append(label_map[out_label_ids[i][j]])
                    preds_list[i].append(label_map[preds[i][j]])

        results = {
            "val_loss": val_loss_mean,
138
            "accuracy_score": accuracy_score(out_label_list, preds_list),
139
140
141
142
143
144
            "precision": precision_score(out_label_list, preds_list),
            "recall": recall_score(out_label_list, preds_list),
            "f1": f1_score(out_label_list, preds_list),
        }

        ret = {k: v for k, v in results.items()}
145
        ret["log"] = results
146
147
        return ret, preds_list, out_label_list

William Falcon's avatar
William Falcon committed
148
    def validation_epoch_end(self, outputs):
Shubham Agarwal's avatar
Shubham Agarwal committed
149
        # when stable
150
        ret, preds, targets = self._eval_end(outputs)
Shubham Agarwal's avatar
Shubham Agarwal committed
151
152
        logs = ret["log"]
        return {"val_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
153

Shubham Agarwal's avatar
Shubham Agarwal committed
154
155
    def test_epoch_end(self, outputs):
        # updating to test_epoch_end instead of deprecated test_end
156
157
        ret, predictions, targets = self._eval_end(outputs)

158
        # Converting to the dict required by pl
Shubham Agarwal's avatar
Shubham Agarwal committed
159
160
161
162
163
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
        # pytorch_lightning/trainer/logging.py#L139
        logs = ret["log"]
        # `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
        return {"avg_test_loss": logs["val_loss"], "log": logs, "progress_bar": logs}
164
165
166
167
168

    @staticmethod
    def add_model_specific_args(parser, root_dir):
        # Add NER specific options
        BaseTransformer.add_model_specific_args(parser, root_dir)
169
170
171
        parser.add_argument(
            "--task_type", default="NER", type=str, help="Task type to fine tune in training (e.g. NER, POS, etc)"
        )
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        parser.add_argument(
            "--max_seq_length",
            default=128,
            type=int,
            help="The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded.",
        )

        parser.add_argument(
            "--labels",
            default="",
            type=str,
            help="Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.",
        )
186
187
188
189
190
191
        parser.add_argument(
            "--gpus",
            default=0,
            type=int,
            help="The number of GPUs allocated for this, it is by default 0 meaning none",
        )
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

        parser.add_argument(
            "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
        )

        return parser


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    add_generic_args(parser, os.getcwd())
    parser = NERTransformer.add_model_specific_args(parser, os.getcwd())
    args = parser.parse_args()
    model = NERTransformer(args)
    trainer = generic_train(model, args)

    if args.do_predict:
Shubham Agarwal's avatar
Shubham Agarwal committed
209
210
211
212
        # See https://github.com/huggingface/transformers/issues/3159
        # pl use this format to create a checkpoint:
        # https://github.com/PyTorchLightning/pytorch-lightning/blob/master\
        # /pytorch_lightning/callbacks/model_checkpoint.py#L169
213
        checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "checkpointepoch=*.ckpt"), recursive=True)))
214
        model = model.load_from_checkpoint(checkpoints[-1])
215
        trainer.test(model)