finetune_step3_train.py 11.7 KB
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import torch
from torch.utils.data import DataLoader
from model import uie_model, tokenizer
from bert4torch.snippets import seed_everything, sequence_padding, Callback
from torch import nn
from torch.utils.data import Dataset
import numpy as np
import json
from utils import get_bool_ids_greater_than, get_span
from random import sample


batch_size = 16
learning_rate = 1e-5
train_path = 'E:/Github/bert4torch/examples/sequence_labeling/uie/data/final_data/train.txt'
dev_path = 'E:/Github/bert4torch/examples/sequence_labeling/uie/data/final_data/dev.txt'
save_dir = './'
max_seq_len = 256
num_epochs = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'

seed_everything(42)
uie_model.to(device)

class IEDataset(Dataset):
    """信息抽取
    """
    def __init__(self, file_path, tokenizer, max_seq_len, fewshot=None) -> None:
        super().__init__()
        self.file_path = file_path
        if fewshot is None:
            self.dataset = list(self.reader(file_path))
        else:
            assert isinstance(fewshot, int)
            self.dataset = sample(list(self.reader(file_path)), fewshot)
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset[index]

    @staticmethod
    def reader(data_path, max_seq_len=512):
        """read json
        """
        with open(data_path, 'r', encoding='utf-8') as f:
            for line in f:
                json_line = json.loads(line)
                content = json_line['content']
                prompt = json_line['prompt']
                # Model Input is aslike: [CLS] Prompt [SEP] Content [SEP]
                # It include three summary tokens.
                if max_seq_len <= len(prompt) + 3:
                    raise ValueError("The value of max_seq_len is too small, please set a larger value")
                max_content_len = max_seq_len - len(prompt) - 3
                if len(content) <= max_content_len:
                    yield json_line
                else:
                    result_list = json_line['result_list']
                    json_lines = []
                    accumulate = 0
                    while True:
                        cur_result_list = []

                        for result in result_list:
                            if result['start'] + 1 <= max_content_len < result['end']:
                                max_content_len = result['start']
                                break

                        cur_content = content[:max_content_len]
                        res_content = content[max_content_len:]

                        while True:
                            if len(result_list) == 0:
                                break
                            elif result_list[0]['end'] <= max_content_len:
                                if result_list[0]['end'] > 0:
                                    cur_result = result_list.pop(0)
                                    cur_result_list.append(cur_result)
                                else:
                                    cur_result_list = [result for result in result_list]
                                    break
                            else:
                                break

                        json_line = {'content': cur_content, 'result_list': cur_result_list, 'prompt': prompt}
                        json_lines.append(json_line)

                        for result in result_list:
                            if result['end'] <= 0:
                                break
                            result['start'] -= max_content_len
                            result['end'] -= max_content_len
                        accumulate += max_content_len
                        max_content_len = max_seq_len - len(prompt) - 3
                        if len(res_content) == 0:
                            break
                        elif len(res_content) < max_content_len:
                            json_line = {'content': res_content, 'result_list': result_list, 'prompt': prompt}
                            json_lines.append(json_line)
                            break
                        else:
                            content = res_content

                    for json_line in json_lines:
                        yield json_line


def collate_fn(batch):
    """example: {title, prompt, content, result_list}
    """
    batch_token_ids, batch_token_type_ids, batch_start_ids, batch_end_ids = [], [], [], []
    for example in batch:
        token_ids, token_type_ids, offset_mapping = tokenizer.encode(example["prompt"], example["content"], 
                                                                    maxlen=max_seq_len, return_offsets='transformers')
        bias = 0
        for index in range(len(offset_mapping)):
            if index == 0:
                continue
            mapping = offset_mapping[index]
            if mapping[0] == 0 and mapping[1] == 0 and bias == 0:
                bias = index
            if mapping[0] == 0 and mapping[1] == 0:
                continue
            offset_mapping[index][0] += bias
            offset_mapping[index][1] += bias
        start_ids = [0 for _ in range(len(token_ids))]
        end_ids = [0 for _ in range(len(token_ids))]
        for item in example["result_list"]:
            start = map_offset(item["start"] + bias, offset_mapping)
            end = map_offset(item["end"] - 1 + bias, offset_mapping)
            start_ids[start] = 1.0
            end_ids[end] = 1.0

        batch_token_ids.append(token_ids)
        batch_token_type_ids.append(token_type_ids)
        batch_start_ids.append(start_ids)
        batch_end_ids.append(end_ids)
    
    batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
    batch_token_type_ids = torch.tensor(sequence_padding(batch_token_type_ids), dtype=torch.long, device=device)
    batch_start_ids = torch.tensor(sequence_padding(batch_start_ids), dtype=torch.float, device=device)
    batch_end_ids = torch.tensor(sequence_padding(batch_end_ids), dtype=torch.float, device=device)
    return [batch_token_ids, batch_token_type_ids], [batch_start_ids, batch_end_ids]


def map_offset(ori_offset, offset_mapping):
    """map ori offset to token offset
    """
    for index, span in enumerate(offset_mapping):
        if span[0] <= ori_offset < span[1]:
            return index
    return -1

# 数据准备
train_ds = IEDataset(train_path, tokenizer=tokenizer, max_seq_len=max_seq_len, fewshot=None)
dev_ds = IEDataset(dev_path, tokenizer=tokenizer, max_seq_len=max_seq_len) 
train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(dev_ds, batch_size=batch_size, collate_fn=collate_fn)

class MyLoss(nn.Module):
    def forward(self, y_pred, y_true):
        start_prob, end_prob = y_pred
        start_ids, end_ids = y_true
        loss_start = torch.nn.functional.binary_cross_entropy(start_prob, start_ids)
        loss_end = torch.nn.functional.binary_cross_entropy(end_prob, end_ids)
        return loss_start + loss_end

uie_model.compile(
    loss=MyLoss(),
    optimizer=torch.optim.AdamW(lr=learning_rate, params=uie_model.parameters()),
)

class SpanEvaluator(Callback):
    """SpanEvaluator computes the precision, recall and F1-score for span detection.
    """
    def __init__(self):
        self.num_infer_spans = 0
        self.num_label_spans = 0
        self.num_correct_spans = 0
        self.best_val_f1 = 0

    def on_epoch_end(self, steps, epoch, logs=None):
        f1, precision, recall = self.evaluate(valid_dataloader)
        if f1 > self.best_val_f1:
            self.best_val_f1 = f1
            # model.save_weights('best_model.pt')
        print(f'[val-entity level] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f}')

    def evaluate(self, dataloder):
        self.reset()
        for x_true, y_true in dataloder:
            start_prob, end_prob = uie_model.predict(*x_true)
            start_ids, end_ids = y_true
            num_correct, num_infer, num_label = self.compute(start_prob, end_prob, start_ids, end_ids)
            self.update(num_correct, num_infer, num_label)
        precision, recall, f1 = self.accumulate()
        return f1, precision, recall

    def compute(self, start_probs, end_probs, gold_start_ids, gold_end_ids):
        """Computes the precision, recall and F1-score for span detection.
        """
        start_probs = start_probs.cpu().numpy()
        end_probs = end_probs.cpu().numpy()
        gold_start_ids = gold_start_ids.cpu().numpy()
        gold_end_ids = gold_end_ids.cpu().numpy()

        pred_start_ids = get_bool_ids_greater_than(start_probs)
        pred_end_ids = get_bool_ids_greater_than(end_probs)
        gold_start_ids = get_bool_ids_greater_than(gold_start_ids.tolist())
        gold_end_ids = get_bool_ids_greater_than(gold_end_ids.tolist())
        num_correct_spans = 0
        num_infer_spans = 0
        num_label_spans = 0
        for predict_start_ids, predict_end_ids, label_start_ids, label_end_ids in zip(
                pred_start_ids, pred_end_ids, gold_start_ids, gold_end_ids):
            [_correct, _infer, _label] = self.eval_span(predict_start_ids, predict_end_ids, label_start_ids, label_end_ids)
            num_correct_spans += _correct
            num_infer_spans += _infer
            num_label_spans += _label
        return num_correct_spans, num_infer_spans, num_label_spans

    def update(self, num_correct_spans, num_infer_spans, num_label_spans):
        """
        This function takes (num_infer_spans, num_label_spans, num_correct_spans) as input,
        to accumulate and update the corresponding status of the SpanEvaluator object.
        """
        self.num_infer_spans += num_infer_spans
        self.num_label_spans += num_label_spans
        self.num_correct_spans += num_correct_spans

    def eval_span(self, predict_start_ids, predict_end_ids, label_start_ids,
                  label_end_ids):
        """
        evaluate position extraction (start, end)
        return num_correct, num_infer, num_label
        input: [1, 2, 10] [4, 12] [2, 10] [4, 11]
        output: (1, 2, 2)
        """
        pred_set = get_span(predict_start_ids, predict_end_ids)
        label_set = get_span(label_start_ids, label_end_ids)
        num_correct = len(pred_set & label_set)
        num_infer = len(pred_set)
        num_label = len(label_set)
        return (num_correct, num_infer, num_label)

    def accumulate(self):
        """
        This function returns the mean precision, recall and f1 score for all accumulated minibatches.
        Returns:
            tuple: Returns tuple (`precision, recall, f1 score`).
        """
        precision = float(self.num_correct_spans / self.num_infer_spans) if self.num_infer_spans else 0.
        recall = float(self.num_correct_spans / self.num_label_spans) if self.num_label_spans else 0.
        f1_score = float(2 * precision * recall / (precision + recall)) if self.num_correct_spans else 0.
        return precision, recall, f1_score

    def reset(self):
        """
        Reset function empties the evaluation memory for previous mini-batches.
        """
        self.num_infer_spans = 0
        self.num_label_spans = 0
        self.num_correct_spans = 0


if __name__ == "__main__":
    evaluator = SpanEvaluator()
    print('zero_shot performance: ', evaluator.evaluate(valid_dataloader))
    uie_model.fit(train_dataloader, epochs=num_epochs, steps_per_epoch=None, callbacks=[evaluator])