training.py 8.12 KB
Newer Older
wangsen's avatar
wangsen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
#! -*- coding:utf-8 -*-
# 搜狐2022实体情感分类Top1方案复现,https://www.biendata.xyz/competition/sohu_2022/
# 链接:https://zhuanlan.zhihu.com/p/533808475
# 复现方案:类似Prompt,拼接方案:[CLS]+sentence+[SEP]+ent1+[MASK]+ent2+[MASK]+[SEP],取[MASK]位置进行

import numpy as np
import json
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from bert4torch.snippets import sequence_padding, Callback, ListDataset, text_segmentate, seed_everything
from bert4torch.optimizers import get_linear_schedule_with_warmup
from bert4torch.tokenizers import Tokenizer, SpTokenizer
from bert4torch.models import build_transformer_model, BaseModel
from tqdm import tqdm
import transformers
import random
from sklearn.metrics import f1_score, classification_report, accuracy_score
import warnings
warnings.filterwarnings("ignore")

# 配置设置
pretrain_model = 'F:/Projects/pretrain_ckpt/xlnet/[hit_torch_base]--chinese-xlnet-base'
config_path = pretrain_model + '/bert4torch_config.json'
checkpoint_path = pretrain_model + '/pytorch_model.bin'
data_dir = 'E:/Github/Sohu2022/Sohu2022_data/nlp_data'

choice = 'train'
prefix = f'_char_512'
save_path = f'./section1{prefix}.txt'
save_path_dev = f'./dev{prefix}.txt'
ckpt_path = f'./best_model{prefix}.pt'
device = f'cuda' if torch.cuda.is_available() else 'cpu'
use_swa = False
use_adv_train = False

# 模型设置
epochs = 10
steps_per_epoch = None
total_eval_step = None
num_warmup_steps = 4000
maxlen = 900
batch_size = 6
batch_size_eval = 64
grad_accumulation_steps = 3
categories = [-2, -1, 0, 1, 2]
mask_symbol = '<mask>'

seed_everything(19260817) # 估计随机数

# 加载数据集
def load_data(filename):
    D = []
    with open(filename, encoding='utf-8') as f:
        for l in tqdm(f.readlines(), desc="Loading data"):
            taskData = json.loads(l.strip())
            text2 = ''.join([ent+mask_symbol for ent in taskData['entity'].keys()])
            D.append((taskData['content'], text2, taskData['entity']))
    return D

def search(tokens, search_token, start_idx=0):
    mask_idxs = []
    for i in range(len(tokens)):
        if tokens[i] == search_token:
            mask_idxs.append(i+start_idx)
    return mask_idxs


# 建立分词器,这里使用transformer自带的
tokenizer = transformers.XLNetTokenizerFast.from_pretrained(pretrain_model)

def collate_fn(batch):
    batch_token_ids, batch_segment_ids, batch_entity_ids, batch_entity_labels = [], [], [], []
    for text, prompt, entity in batch:
        inputs = tokenizer.__call__(text=text, text_pair=prompt, add_special_tokens=True, max_length=maxlen, truncation="only_first")
        token_ids, segment_ids = inputs['input_ids'], inputs['token_type_ids']
        ent_ids = search(token_ids, tokenizer.mask_token_id)

        batch_token_ids.append(token_ids)
        batch_segment_ids.append(segment_ids)
        batch_entity_ids.append(ent_ids)
        batch_entity_labels.append([categories.index(label) for label in entity.values()])

    batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
    batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device)
    batch_entity_ids = torch.tensor(sequence_padding(batch_entity_ids), dtype=torch.long, device=device)
    batch_entity_labels = torch.tensor(sequence_padding(batch_entity_labels, value=-1), dtype=torch.long, device=device)  # [btz, 实体个数]
    return [batch_token_ids, batch_segment_ids, batch_entity_ids], batch_entity_labels

# 转换数据集
all_data = load_data(f'{data_dir}/train.txt')
random.shuffle(all_data)
split_index = 2000 # int(len(all_data)*0.9)
train_dataloader = DataLoader(ListDataset(data=all_data[split_index:]), batch_size=batch_size, shuffle=False, collate_fn=collate_fn) 
valid_dataloader = DataLoader(ListDataset(data=all_data[:split_index]), batch_size=batch_size_eval, collate_fn=collate_fn)

# 定义bert上的模型结构
class Model(BaseModel):
    def __init__(self):
        super().__init__() 
        self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, model='xlnet')
        hidden_size = self.bert.configs['hidden_size']
        self.classifier = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.LeakyReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_size, 5)
                )

    def forward(self, inputs):
        token_ids, segment_ids, entity_ids = inputs
        last_hidden_state = self.bert([token_ids, segment_ids])  # [btz, seq_len, hdsz]

        entity_ids = entity_ids.unsqueeze(2).repeat(1, 1, last_hidden_state.shape[-1])
        entity_states = torch.gather(last_hidden_state, dim=1, index=entity_ids)
        entity_logits = self.classifier(entity_states)
        return entity_logits
model = Model().to(device)

class Loss(nn.CrossEntropyLoss):
    def forward(self, entity_logit, labels):
        loss = super().forward(entity_logit.reshape(-1, entity_logit.shape[-1]), labels.flatten())
        return loss
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps=len(train_dataloader)*epochs, last_epoch=-1)
model.compile(loss=Loss(ignore_index=-1), optimizer=optimizer, scheduler=scheduler, clip_grad_norm=1.0, adversarial_train={'name': 'fgm' if use_adv_train else ''})

# swa
if use_swa:
    def average_function(ax: torch.Tensor, x: torch.Tensor, num: int) -> torch.Tensor:
        return ax + (x - ax) / (num + 1)
    swa_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=average_function)

class Evaluator(Callback):
    """评估与保存
    """
    def __init__(self):
        self.best_val_f1 = 0.

    def on_epoch_end(self, steps, epoch, logs=None):
        f1, acc, pred_result = self.evaluate(valid_dataloader)
        if f1 > self.best_val_f1:
            self.best_val_f1 = f1
            model.save_weights(ckpt_path)
        print(f'[val-entity] f1: {f1:.5f}, acc: {acc:.5f} best_f1: {self.best_val_f1:.5f}\n')
        if use_swa:
            swa_model.update_parameters(model)

    @staticmethod
    def evaluate(data):
        valid_true, valid_pred = [], []
        eval_step = 0
        result = dict()
        for (token_ids, entity_ids), entity_labels in tqdm(data):
            if use_swa:
                swa_model.eval()
                with torch.no_grad():
                    entity_logit = F.softmax(swa_model([token_ids, entity_ids]), dim=-1)  # [btz, 实体个数, 实体类别数]
            else:
                entity_logit = F.softmax(model.predict([token_ids, entity_ids]), dim=-1)  # [btz, 实体个数, 实体类别数]
            _, entity_pred = torch.max(entity_logit, dim=-1)  # [btz, 实体个数]
            # v_pred和v_true是实体的预测结果
            valid_index = (entity_ids.flatten()>0).nonzero().squeeze(-1)
            valid_pred.extend(entity_pred.flatten()[valid_index].cpu().tolist())
            valid_true.extend(entity_labels.flatten()[valid_index].cpu().tolist())
                
            eval_step += 1
            if (total_eval_step is not None) and (eval_step >= total_eval_step):
                break
        
        valid_true = np.array(valid_true)
        valid_pred = np.array(valid_pred)
        f1 = f1_score(valid_true, valid_pred, average='macro')
        acc = accuracy_score(valid_true, valid_pred)
        print(classification_report(valid_true, valid_pred))
        # 只保留label,不需要prob
        for k, v in result.items():
            result[k] = {i: j[0] for i, j in v.items()}
        return f1, acc, result

if __name__ == '__main__':
    if choice == 'train':
        evaluator = Evaluator()
        model.fit(train_dataloader, epochs=epochs, steps_per_epoch=steps_per_epoch, grad_accumulation_steps=grad_accumulation_steps, callbacks=[evaluator])

    model.load_weights(ckpt_path)
    f1, acc, pred_result = Evaluator.evaluate(valid_dataloader)