Commit 0e29b9b7 authored by xuxo's avatar xuxo
Browse files

yidong infer init

parents
Pipeline #3252 failed with stages
in 0 seconds
#! -*- coding:utf-8 -*-
# global_pointer用来做实体识别
# 数据集:http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
# 博客:https://kexue.fm/archives/8373
# [valid_f1]: 95.66
import numpy as np
from bert4torch.models import build_transformer_model, BaseModel
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from bert4torch.snippets import sequence_padding, Callback, ListDataset, seed_everything
from bert4torch.tokenizers import Tokenizer
from bert4torch.losses import MultilabelCategoricalCrossentropy
from bert4torch.layers import GlobalPointer
import random
import os
maxlen = 256
batch_size = 16
categories_label2id = {"LOC": 0, "ORG": 1, "PER": 2}
categories_id2label = dict((value, key) for key,value in categories_label2id.items())
ner_vocab_size = len(categories_label2id)
ner_head_size = 64
# BERT base
config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 固定seed
seed_everything(42)
# 加载数据集
class MyDataset(ListDataset):
@staticmethod
def load_data(filename):
data = []
with open(filename, encoding='utf-8') as f:
f = f.read()
for l in f.split('\n\n'):
if not l:
continue
text, label = '', []
for i, c in enumerate(l.split('\n')):
char, flag = c.split(' ')
text += char
if flag[0] == 'B':
label.append([i, i, flag[2:]])
elif flag[0] == 'I':
label[-1][1] = i
data.append((text, label)) # label为[[start, end, entity], ...]
return data
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
def collate_fn(batch):
batch_token_ids, batch_labels = [], []
for i, (text, text_labels) in enumerate(batch):
tokens = tokenizer.tokenize(text, maxlen=maxlen)
mapping = tokenizer.rematch(text, tokens)
start_mapping = {j[0]: i for i, j in enumerate(mapping) if j}
end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j}
token_ids = tokenizer.tokens_to_ids(tokens)
labels = np.zeros((len(categories_label2id), maxlen, maxlen))
for start, end, label in text_labels:
if start in start_mapping and end in end_mapping:
start = start_mapping[start]
end = end_mapping[end]
label = categories_label2id[label]
labels[label, start, end] = 1
batch_token_ids.append(token_ids) # 前面已经限制了长度
batch_labels.append(labels[:, :len(token_ids), :len(token_ids)])
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
batch_labels = torch.tensor(sequence_padding(batch_labels, seq_dims=3), dtype=torch.long, device=device)
return batch_token_ids, batch_labels
# 转换数据集
train_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.train'), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.dev'), batch_size=batch_size, collate_fn=collate_fn)
# 定义bert上的模型结构
class Model(BaseModel):
def __init__(self):
super().__init__()
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, segment_vocab_size=0)
self.global_pointer = GlobalPointer(hidden_size=768, heads=ner_vocab_size, head_size=ner_head_size)
def forward(self, token_ids):
sequence_output = self.bert([token_ids]) # [btz, seq_len, hdsz]
logit = self.global_pointer(sequence_output, token_ids.gt(0).long())
return logit
model = Model().to(device)
class MyLoss(MultilabelCategoricalCrossentropy):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def forward(self, y_pred, y_true):
y_true = y_true.view(y_true.shape[0]*y_true.shape[1], -1) # [btz*ner_vocab_size, seq_len*seq_len]
y_pred = y_pred.view(y_pred.shape[0]*y_pred.shape[1], -1) # [btz*ner_vocab_size, seq_len*seq_len]
return super().forward(y_pred, y_true)
model.compile(loss=MyLoss(), optimizer=optim.Adam(model.parameters(), lr=2e-5))
def evaluate(data, threshold=0):
X, Y, Z = 0, 1e-10, 1e-10
for x_true, label in data:
scores = model.predict(x_true)
for i, score in enumerate(scores):
R = set()
for l, start, end in zip(*np.where(score.cpu() > threshold)):
R.add((start, end, categories_id2label[l]))
T = set()
for l, start, end in zip(*np.where(label[i].cpu() > threshold)):
T.add((start, end, categories_id2label[l]))
X += len(R & T)
Y += len(R)
Z += len(T)
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
return f1, precision, recall
class Evaluator(Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_f1 = 0.
def on_epoch_end(self, steps, epoch, logs=None):
f1, precision, recall = evaluate(valid_dataloader)
if f1 > self.best_val_f1:
self.best_val_f1 = f1
# model.save_weights('best_model.pt')
print(f'[val] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f} best_f1: {self.best_val_f1:.5f}')
if __name__ == '__main__':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
else:
model.load_weights('best_model.pt')
#! -*- coding:utf-8 -*-
# mrc阅读理解方案
# 数据集:http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
# [valid_f1]: 95.75
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from bert4torch.snippets import sequence_padding, Callback, ListDataset, seed_everything
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from tqdm import tqdm
from collections import defaultdict
max_c_len = 224
max_q_len = 32
batch_size = 6 # 真实的batch_size是 batch_size * 实体类型数
categories = ['LOC', 'PER', 'ORG']
ent2query = {"LOC": "找出下述句子中的地址名",
"PER": "找出下述句子中的人名",
"ORG": "找出下述句子中的机构名"}
# BERT base
config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 固定seed
seed_everything(42)
# 加载数据集
class MyDataset(ListDataset):
@staticmethod
def load_data(filename):
D = []
with open(filename, encoding='utf-8') as f:
f = f.read()
for l in f.split('\n\n'):
if not l:
continue
d = ['']
for i, c in enumerate(l.split('\n')):
char, flag = c.split(' ')
d[0] += char
if flag[0] == 'B':
d.append([i, i, flag[2:]])
elif flag[0] == 'I':
d[-1][1] = i
D.append(d)
return D
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
def collate_fn(batch):
batch_token_ids, batch_segment_ids, batch_start_labels, batch_end_labels = [], [], [], []
batch_ent_type = []
for d in batch:
tokens_b = tokenizer.tokenize(d[0], maxlen=max_c_len)[1:] # 不保留[CLS]
mapping = tokenizer.rematch(d[0], tokens_b)
start_mapping = {j[0]: i for i, j in enumerate(mapping) if j}
end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j}
# 按照实体类型整理实体
label_dict = defaultdict(list)
for start, end, label in d[1:]:
if start in start_mapping and end in end_mapping:
start = start_mapping[start]
end = end_mapping[end]
label_dict[label].append((start, end))
# 遍历实体类型,query为tokens_a, context为tokens_b
# 样本组成:[CLS] + tokens_a + [SEP] + tokens_b + [SEP]
for _type in categories:
start_ids = [0] * len(tokens_b)
end_ids = [0] * len(tokens_b)
text_a = ent2query[_type]
tokens_a = tokenizer.tokenize(text_a, maxlen=max_q_len)
for _label in label_dict[_type]:
start_ids[_label[0]] = 1
end_ids[_label[1]] = 1
start_ids = [0] * len(tokens_a) + start_ids
end_ids = [0] * len(tokens_a) + end_ids
token_ids = tokenizer.tokens_to_ids(tokens_a) + tokenizer.tokens_to_ids(tokens_b)
segment_ids = [0] * len(tokens_a) + [1] * len(tokens_b)
assert len(start_ids) == len(end_ids) == len(token_ids) == len(segment_ids)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_start_labels.append(start_ids)
batch_end_labels.append(end_ids)
batch_ent_type.append(_type)
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device)
batch_start_labels = torch.tensor(sequence_padding(batch_start_labels), dtype=torch.long, device=device)
batch_end_labels = torch.tensor(sequence_padding(batch_end_labels), dtype=torch.long, device=device)
return [batch_token_ids, batch_segment_ids], [batch_segment_ids, batch_start_labels, batch_end_labels, batch_ent_type]
# 转换数据集
train_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.train'), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.dev'), batch_size=batch_size, collate_fn=collate_fn)
# 定义bert上的模型结构
class Model(BaseModel):
def __init__(self):
super().__init__()
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path)
self.mid_linear = nn.Sequential(
nn.Linear(768, 128),
nn.ReLU(),
nn.Dropout(0.1)
)
self.start_fc = nn.Linear(128, 2)
self.end_fc = nn.Linear(128, 2)
def forward(self, token_ids, segment_ids):
sequence_output = self.bert([token_ids, segment_ids]) # [bts, seq_len, hdsz]
seq_out = self.mid_linear(sequence_output) # [bts, seq_len, mid_dims]
start_logits = self.start_fc(seq_out) # [bts, seq_len, 2]
end_logits = self.end_fc(seq_out) # [bts, seq_len, 2]
return start_logits, end_logits
model = Model().to(device)
class Loss(nn.CrossEntropyLoss):
def forward(self, outputs, labels):
start_logits, end_logits = outputs
mask, start_ids, end_ids = labels[:3]
start_logits = start_logits.view(-1, 2)
end_logits = end_logits.view(-1, 2)
# 去掉 text_a 和 padding 部分的标签,计算真实 loss
active_loss = mask.view(-1) == 1
active_start_logits = start_logits[active_loss]
active_end_logits = end_logits[active_loss]
active_start_labels = start_ids.view(-1)[active_loss]
active_end_labels = end_ids.view(-1)[active_loss]
start_loss = super().forward(active_start_logits, active_start_labels)
end_loss = super().forward(active_end_logits, active_end_labels)
return start_loss + end_loss
model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5))
def evaluate(data):
X, Y, Z = 0, 1e-10, 1e-10
for (token_ids, segment_ids), labels in tqdm(data, desc='Evaluation'):
start_logit, end_logit = model.predict([token_ids, segment_ids]) # [btz, seq_len, 2]
mask, start_ids, end_ids, ent_type = labels
# entity粒度
entity_pred = mrc_decode(start_logit, end_logit, ent_type, mask)
entity_true = mrc_decode(start_ids, end_ids, ent_type)
X += len(entity_pred.intersection(entity_true))
Y += len(entity_pred)
Z += len(entity_true)
f1, precision, recall = 2 * X / (Y + Z), X/ Y, X / Z
return f1, precision, recall
# 严格解码 baseline
def mrc_decode(start_preds, end_preds, ent_type, mask=None):
'''返回实体的start, end
'''
predict_entities = set()
if mask is not None: # 预测的把query和padding部分mask掉
start_preds = torch.argmax(start_preds, -1) * mask
end_preds = torch.argmax(end_preds, -1) * mask
start_preds = start_preds.cpu().numpy()
end_preds = end_preds.cpu().numpy()
for bt_i in range(start_preds.shape[0]):
start_pred = start_preds[bt_i]
end_pred = end_preds[bt_i]
# 统计每个样本的结果
for i, s_type in enumerate(start_pred):
if s_type == 0:
continue
for j, e_type in enumerate(end_pred[i:]):
if s_type == e_type:
# [样本id, 实体起点,实体终点,实体类型]
predict_entities.add((bt_i, i, i+j, ent_type[bt_i]))
break
return predict_entities
class Evaluator(Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_f1 = 0.
def on_epoch_end(self, steps, epoch, logs=None):
f1, precision, recall = evaluate(valid_dataloader)
if f1 > self.best_val_f1:
self.best_val_f1 = f1
# model.save_weights('best_model.pt')
print(f'[val] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f} best_f1: {self.best_val_f1:.5f}')
if __name__ == '__main__':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
else:
model.load_weights('best_model.pt')
#! -*- coding:utf-8 -*-
# span阅读理解方案
# 数据集:http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
# [valid_f1]: 96.31
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from bert4torch.snippets import sequence_padding, Callback, ListDataset, seed_everything
from bert4torch.losses import FocalLoss
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from tqdm import tqdm
max_len = 256
batch_size = 16
categories = ['LOC', 'PER', 'ORG']
categories_id2label = {i: k for i, k in enumerate(categories, start=1)}
categories_label2id = {k: i for i, k in enumerate(categories, start=1)}
# BERT base
config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 固定seed
seed_everything(42)
# 加载数据集
class MyDataset(ListDataset):
@staticmethod
def load_data(filename):
D = []
with open(filename, encoding='utf-8') as f:
f = f.read()
for l in f.split('\n\n'):
if not l:
continue
d = ['']
for i, c in enumerate(l.split('\n')):
char, flag = c.split(' ')
d[0] += char
if flag[0] == 'B':
d.append([i, i, flag[2:]])
elif flag[0] == 'I':
d[-1][1] = i
D.append(d)
return D
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
def collate_fn(batch):
batch_token_ids, batch_start_labels, batch_end_labels = [], [], []
for d in batch:
tokens = tokenizer.tokenize(d[0], maxlen=max_len)[1:] # 不保留[CLS]
mapping = tokenizer.rematch(d[0], tokens)
start_mapping = {j[0]: i for i, j in enumerate(mapping) if j}
end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j}
token_ids = tokenizer.tokens_to_ids(tokens)
start_ids = [0] * len(tokens)
end_ids = [0] * len(tokens)
for start, end, label in d[1:]:
if start in start_mapping and end in end_mapping:
start = start_mapping[start]
end = end_mapping[end]
start_ids[start] = categories_label2id[label]
end_ids[end] = categories_label2id[label]
batch_token_ids.append(token_ids)
batch_start_labels.append(start_ids)
batch_end_labels.append(end_ids)
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
batch_start_labels = torch.tensor(sequence_padding(batch_start_labels), dtype=torch.long, device=device)
batch_end_labels = torch.tensor(sequence_padding(batch_end_labels), dtype=torch.long, device=device)
batch_mask = batch_token_ids.gt(0).long()
return [batch_token_ids], [batch_mask, batch_start_labels, batch_end_labels]
# 转换数据集
train_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.train'), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.dev'), batch_size=batch_size, collate_fn=collate_fn)
# 定义bert上的模型结构
class Model(BaseModel):
def __init__(self):
super().__init__()
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, segment_vocab_size=0)
self.mid_linear = nn.Sequential(
nn.Linear(768, 128),
nn.ReLU(),
nn.Dropout(0.1)
)
self.start_fc = nn.Linear(128, len(categories)+1) # 0表示没有
self.end_fc = nn.Linear(128, len(categories)+1)
def forward(self, token_ids):
sequence_output = self.bert(token_ids) # [bts, seq_len, hdsz]
seq_out = self.mid_linear(sequence_output) # [bts, seq_len, mid_dims]
start_logits = self.start_fc(seq_out) # [bts, seq_len, num_tags]
end_logits = self.end_fc(seq_out) # [bts, seq_len, num_tags]
return start_logits, end_logits
model = Model().to(device)
class Loss(nn.CrossEntropyLoss):
def forward(self, outputs, labels):
start_logits, end_logits = outputs
mask, start_ids, end_ids = labels
start_logits = start_logits.view(-1, len(categories)+1)
end_logits = end_logits.view(-1, len(categories)+1)
# 去掉padding部分的标签,计算真实 loss
active_loss = mask.view(-1) == 1
active_start_logits = start_logits[active_loss]
active_end_logits = end_logits[active_loss]
active_start_labels = start_ids.view(-1)[active_loss]
active_end_labels = end_ids.view(-1)[active_loss]
start_loss = super().forward(active_start_logits, active_start_labels)
end_loss = super().forward(active_end_logits, active_end_labels)
return start_loss + end_loss
model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5))
def evaluate(data):
X, Y, Z = 0, 1e-10, 1e-10
for token_ids, labels in tqdm(data, desc='Evaluation'):
start_logit, end_logit = model.predict(token_ids) # [btz, seq_len, 2]
mask, start_ids, end_ids = labels
# entity粒度
entity_pred = span_decode(start_logit, end_logit, mask)
entity_true = span_decode(start_ids, end_ids)
X += len(entity_pred.intersection(entity_true))
Y += len(entity_pred)
Z += len(entity_true)
f1, precision, recall = 2 * X / (Y + Z), X/ Y, X / Z
return f1, precision, recall
# 严格解码 baseline
def span_decode(start_preds, end_preds, mask=None):
'''返回实体的start, end
'''
predict_entities = set()
if mask is not None: # 把padding部分mask掉
start_preds = torch.argmax(start_preds, -1) * mask
end_preds = torch.argmax(end_preds, -1) * mask
start_preds = start_preds.cpu().numpy()
end_preds = end_preds.cpu().numpy()
for bt_i in range(start_preds.shape[0]):
start_pred = start_preds[bt_i]
end_pred = end_preds[bt_i]
# 统计每个样本的结果
for i, s_type in enumerate(start_pred):
if s_type == 0:
continue
for j, e_type in enumerate(end_pred[i:]):
if s_type == e_type:
# [样本id, 实体起点,实体终点,实体类型]
predict_entities.add((bt_i, i, i+j, categories_id2label[s_type]))
break
return predict_entities
class Evaluator(Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_f1 = 0.
def on_epoch_end(self, steps, epoch, logs=None):
f1, precision, recall = evaluate(valid_dataloader)
if f1 > self.best_val_f1:
self.best_val_f1 = f1
# model.save_weights('best_model.pt')
print(f'[val] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f} best_f1: {self.best_val_f1:.5f}')
if __name__ == '__main__':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
else:
model.load_weights('best_model.pt')
#! -*- coding:utf-8 -*-
# tplinker_plus用来做实体识别
# [valid_f1]: 95.71
import numpy as np
from bert4torch.models import build_transformer_model, BaseModel
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from bert4torch.snippets import sequence_padding, Callback, ListDataset, seed_everything
from bert4torch.tokenizers import Tokenizer
from bert4torch.losses import MultilabelCategoricalCrossentropy
from bert4torch.layers import TplinkerHandshakingKernel
maxlen = 64
batch_size = 16
categories_label2id = {"LOC": 0, "ORG": 1, "PER": 2}
categories_id2label = dict((value, key) for key,value in categories_label2id.items())
# BERT base
config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 固定seed
seed_everything(42)
# 加载数据集
class MyDataset(ListDataset):
@staticmethod
def load_data(filename):
data = []
with open(filename, encoding='utf-8') as f:
f = f.read()
for l in f.split('\n\n'):
if not l:
continue
text, label = '', []
for i, c in enumerate(l.split('\n')):
char, flag = c.split(' ')
text += char
if flag[0] == 'B':
label.append([i, i, flag[2:]])
elif flag[0] == 'I':
label[-1][1] = i
text_list = tokenizer.tokenize(text)[1:-1] #不保留首位[CLS]和末位[SEP]
tokens = [j for i in text_list for j in i][:maxlen] # 以char为单位
data.append((tokens, label)) # label为[[start, end, entity], ...]
return data
def trans_ij2k(seq_len, i, j):
'''把第i行,第j列转化成上三角flat后的序号
'''
if (i > seq_len - 1) or (j > seq_len - 1) or (i > j):
return 0
return int(0.5*(2*seq_len-i+1)*i+(j-i))
map_ij2k = {(i, j): trans_ij2k(maxlen, i, j) for i in range(maxlen) for j in range(maxlen) if j >= i}
map_k2ij = {v: k for k, v in map_ij2k.items()}
def tran_ent_rel2id():
'''获取最后一个分类层的的映射关系
'''
tag2id = {}
for p in categories_label2id.keys():
tag2id[p] = len(tag2id)
return tag2id
tag2id = tran_ent_rel2id()
id2tag = {v: k for k, v in tag2id.items()}
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
def collate_fn(batch):
pair_len = maxlen * (maxlen+1)//2
# batch_head_labels: [btz, pair_len, tag2id_len]
batch_labels = torch.zeros((len(batch), pair_len, len(tag2id)), dtype=torch.long, device=device)
batch_token_ids = []
for i, (tokens, labels) in enumerate(batch):
batch_token_ids.append(tokenizer.tokens_to_ids(tokens)) # 前面已经限制了长度
for s_i in labels:
if s_i[1] >= len(tokens): # 实体的结尾超过文本长度,则不标记
continue
batch_labels[i, map_ij2k[s_i[0], s_i[1]], tag2id[s_i[2]]] = 1
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, length=maxlen), dtype=torch.long, device=device)
return [batch_token_ids], batch_labels
# 转换数据集
train_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.train'), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.dev'), batch_size=batch_size, collate_fn=collate_fn)
# 定义bert上的模型结构
class Model(BaseModel):
def __init__(self):
super().__init__()
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, segment_vocab_size=0)
self.fc = nn.Linear(768, len(tag2id))
self.handshaking_kernel = TplinkerHandshakingKernel(768, shaking_type='cln_plus', inner_enc_type='lstm')
def forward(self, inputs):
last_hidden_state = self.bert(inputs) # [btz, seq_len, hdsz]
shaking_hiddens = self.handshaking_kernel(last_hidden_state)
output = self.fc(shaking_hiddens) # [btz, pair_len, tag_size]
return output
model = Model().to(device)
model.compile(loss=MultilabelCategoricalCrossentropy(), optimizer=optim.Adam(model.parameters(), lr=2e-5))
def evaluate(data, threshold=0):
X, Y, Z, threshold = 0, 1e-10, 1e-10, 0
for x_true, label in data:
scores = model.predict(x_true) # [btz, pair_len, tag_size]
for i, score in enumerate(scores):
R = set()
for pair_id, tag_id in zip(*np.where(score.cpu().numpy() > threshold)):
start, end = map_k2ij[pair_id][0], map_k2ij[pair_id][1]
R.add((start, end, tag_id))
T = set()
for pair_id, tag_id in zip(*np.where(label[i].cpu().numpy() > threshold)):
start, end = map_k2ij[pair_id][0], map_k2ij[pair_id][1]
T.add((start, end, tag_id))
X += len(R & T)
Y += len(R)
Z += len(T)
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
return f1, precision, recall
class Evaluator(Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_f1 = 0.
def on_epoch_end(self, steps, epoch, logs=None):
f1, precision, recall = evaluate(valid_dataloader)
if f1 > self.best_val_f1:
self.best_val_f1 = f1
# model.save_weights('best_model.pt')
print(f'[val] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f} best_f1: {self.best_val_f1:.5f}')
if __name__ == '__main__':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
else:
model.load_weights('best_model.pt')
import argparse
import collections
import json
import os
import pickle
import torch
import logging
import shutil
from tqdm import tqdm
import time
logger = logging.Logger('log')
def get_path_from_url(url, root_dir, check_exist=True, decompress=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
decompress (bool): decompress zip or tar file. Default is `True`
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
import os.path
import os
import tarfile
import zipfile
def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = os.path.split(url)[-1]
fpath = fname
return os.path.join(root_dir, fpath)
def _get_download(url, fullname):
import requests
# using requests.get method
fname = os.path.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024, unit='KB') as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _download(url, path):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not os.path.exists(path):
os.makedirs(path)
fname = os.path.split(url)[-1]
fullname = os.path.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
DOWNLOAD_RETRY_LIMIT = 3
while not os.path.exists(fullname):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if not _get_download(url, fullname):
time.sleep(1)
continue
return fullname
def _uncompress_file_zip(filepath):
with zipfile.ZipFile(filepath, 'r') as files:
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
elif _is_a_single_dir(file_list):
# `strip(os.sep)` to remove `os.sep` in the tail of path
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
files.extractall(os.path.join(file_dir, rootpath))
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < 0:
return True
return False
def _is_a_single_dir(file_list):
new_file_list = []
for file_path in file_list:
if '/' in file_path:
file_path = file_path.replace('/', os.sep)
elif '\\' in file_path:
file_path = file_path.replace('\\', os.sep)
new_file_list.append(file_path)
file_name = new_file_list[0].split(os.sep)[0]
for i in range(1, len(new_file_list)):
if file_name != new_file_list[i].split(os.sep)[0]:
return False
return True
def _uncompress_file_tar(filepath, mode="r:*"):
with tarfile.open(filepath, mode) as files:
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
files.extractall(os.path.join(file_dir, rootpath))
return uncompressed_path
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
assert is_url(url), "downloading from {} not a url".format(url)
fullpath = _map_path(url, root_dir)
if os.path.exists(fullpath) and check_exist:
logger.info("Found {}".format(fullpath))
else:
fullpath = _download(url, root_dir)
if decompress and (tarfile.is_tarfile(fullpath) or
zipfile.is_zipfile(fullpath)):
fullpath = _decompress(fullpath)
return fullpath
MODEL_MAP = {
"uie-base": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_v0.1/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json"
}
},
"uie-medium": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium_v1.0/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-mini": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini_v1.0/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-micro": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro_v1.0/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-nano": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano_v1.0/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-medical-base": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medical_base_v0.1/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-tiny": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny_v0.1/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/tokenizer_config.json"
}
}
}
def build_params_map(attention_num=12):
"""
build params map from paddle-paddle's ERNIE to transformer's BERT
:return:
"""
weight_map = collections.OrderedDict({
'encoder.embeddings.word_embeddings.weight': "bert.embeddings.word_embeddings.weight",
'encoder.embeddings.position_embeddings.weight': "bert.embeddings.position_embeddings.weight",
'encoder.embeddings.token_type_embeddings.weight': "bert.embeddings.token_type_embeddings.weight",
'encoder.embeddings.task_type_embeddings.weight': "embeddings.task_type_embeddings.weight", # 这里没有前缀bert,直接映射到bert4torch结构
'encoder.embeddings.layer_norm.weight': 'bert.embeddings.LayerNorm.weight',
'encoder.embeddings.layer_norm.bias': 'bert.embeddings.LayerNorm.bias',
})
# add attention layers
for i in range(attention_num):
weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.query.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.query.bias'
weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.key.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.key.bias'
weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.value.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.value.bias'
weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.weight'] = f'bert.encoder.layer.{i}.attention.output.dense.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.bias'] = f'bert.encoder.layer.{i}.attention.output.dense.bias'
weight_map[f'encoder.encoder.layers.{i}.norm1.weight'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.weight'
weight_map[f'encoder.encoder.layers.{i}.norm1.bias'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.bias'
weight_map[f'encoder.encoder.layers.{i}.linear1.weight'] = f'bert.encoder.layer.{i}.intermediate.dense.weight'
weight_map[f'encoder.encoder.layers.{i}.linear1.bias'] = f'bert.encoder.layer.{i}.intermediate.dense.bias'
weight_map[f'encoder.encoder.layers.{i}.linear2.weight'] = f'bert.encoder.layer.{i}.output.dense.weight'
weight_map[f'encoder.encoder.layers.{i}.linear2.bias'] = f'bert.encoder.layer.{i}.output.dense.bias'
weight_map[f'encoder.encoder.layers.{i}.norm2.weight'] = f'bert.encoder.layer.{i}.output.LayerNorm.weight'
weight_map[f'encoder.encoder.layers.{i}.norm2.bias'] = f'bert.encoder.layer.{i}.output.LayerNorm.bias'
# add pooler
weight_map.update(
{
'encoder.pooler.dense.weight': 'bert.pooler.dense.weight',
'encoder.pooler.dense.bias': 'bert.pooler.dense.bias',
'linear_start.weight': 'linear_start.weight',
'linear_start.bias': 'linear_start.bias',
'linear_end.weight': 'linear_end.weight',
'linear_end.bias': 'linear_end.bias',
}
)
return weight_map
def extract_and_convert(input_dir, output_dir):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
logger.info('=' * 20 + 'save config file' + '=' * 20)
config = json.load(open(os.path.join(input_dir, 'model_config.json'), 'rt', encoding='utf-8'))
config = config['init_args'][0]
config["architectures"] = ["UIE"]
config['layer_norm_eps'] = 1e-12
del config['init_class']
if 'sent_type_vocab_size' in config:
config['type_vocab_size'] = config['sent_type_vocab_size']
config['intermediate_size'] = 4 * config['hidden_size']
json.dump(config, open(os.path.join(output_dir, 'config.json'),
'wt', encoding='utf-8'), indent=4)
logger.info('=' * 20 + 'save vocab file' + '=' * 20)
with open(os.path.join(input_dir, 'vocab.txt'), 'rt', encoding='utf-8') as f:
words = f.read().splitlines()
words_set = set()
words_duplicate_indices = []
for i in range(len(words)-1, -1, -1):
word = words[i]
if word in words_set:
words_duplicate_indices.append(i)
words_set.add(word)
for i, idx in enumerate(words_duplicate_indices):
words[idx] = chr(0x1F6A9+i) # Change duplicated word to 🚩 LOL
with open(os.path.join(output_dir, 'vocab.txt'), 'wt', encoding='utf-8') as f:
for word in words:
f.write(word+'\n')
special_tokens_map = {
"unk_token": "[UNK]",
"sep_token": "[SEP]",
"pad_token": "[PAD]",
"cls_token": "[CLS]",
"mask_token": "[MASK]"
}
json.dump(special_tokens_map, open(os.path.join(output_dir, 'special_tokens_map.json'),
'wt', encoding='utf-8'))
tokenizer_config = {
"do_lower_case": True,
"unk_token": "[UNK]",
"sep_token": "[SEP]",
"pad_token": "[PAD]",
"cls_token": "[CLS]",
"mask_token": "[MASK]",
"tokenizer_class": "BertTokenizer"
}
json.dump(tokenizer_config, open(os.path.join(output_dir, 'tokenizer_config.json'),
'wt', encoding='utf-8'))
logger.info('=' * 20 + 'extract weights' + '=' * 20)
state_dict = collections.OrderedDict()
weight_map = build_params_map(attention_num=config['num_hidden_layers'])
paddle_paddle_params = pickle.load(
open(os.path.join(input_dir, 'model_state.pdparams'), 'rb'))
del paddle_paddle_params['StructuredToParameterName@@']
for weight_name, weight_value in paddle_paddle_params.items():
if 'weight' in weight_name:
if 'encoder.encoder' in weight_name or 'pooler' in weight_name or 'linear' in weight_name:
weight_value = weight_value.transpose()
# Fix: embedding error
if 'word_embeddings.weight' in weight_name:
weight_value[0, :] = 0
if weight_name not in weight_map:
logger.info(f"{'='*20} [SKIP] {weight_name} {'='*20}")
continue
state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value)
logger.info(f"{weight_name} -> {weight_map[weight_name]} {weight_value.shape}")
torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
def check_model(input_model):
if not os.path.exists(input_model):
if input_model not in MODEL_MAP:
raise ValueError('input_model not exists!')
resource_file_urls = MODEL_MAP[input_model]['resource_file_urls']
logger.info("Downloading resource files...")
for key, val in resource_file_urls.items():
file_path = os.path.join(input_model, key)
if not os.path.exists(file_path):
get_path_from_url(val, input_model)
def do_main():
check_model(args.input_model)
extract_and_convert(args.input_model, args.output_model)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_model", default="uie-base", type=str,
help="Directory of input paddle model.\n Will auto download model [uie-base/uie-tiny]")
parser.add_argument("-o", "--output_model", default="uie_base_pytorch", type=str,
help="Directory of output pytorch model")
args = parser.parse_args()
do_main()
# 数据生成1
python finetune_step1_dataprocess.py
# 数据生成2
python finetune_step2_doccano.py \
--doccano_file ./data/mid_data/train.json \
--task_type "ext" \
--splits 1.0 0.0 0.0 \
--save_dir ./data/final_data/ \
--negative_ratio 3
python finetune_step2_doccano.py \
--doccano_file ./data/mid_data/dev.json \
--task_type "ext" \
--splits 0.0 1.0 0.0 \
--save_dir ./data/final_data/ \
--negative_ratio 0
python finetune_step2_doccano.py \
--doccano_file ./data/mid_data/test.json \
--task_type "ext" \
--splits 0.0 0.0 1.0 \
--save_dir ./data/final_data/ \
--negative_ratio 0
# finetune训练
python finetune_step3_train.py
\ No newline at end of file
# 数据转换1
import os
import re
import json
en2ch = {
'ORG':'机构',
'PER':'人名',
'LOC':'籍贯'
}
def preprocess(input_path, save_path, mode):
if not os.path.exists(save_path):
os.makedirs(save_path)
data_path = os.path.join(save_path, mode + ".json")
result = []
tmp = {}
tmp['id'] = 0
tmp['text'] = ''
tmp['relations'] = []
tmp['entities'] = []
# =======先找出句子和句子中的所有实体和类型=======
with open(input_path,'r',encoding='utf-8') as fp:
lines = fp.readlines()
texts = []
entities = []
words = []
entity_tmp = []
entities_tmp = []
entity_label = ''
for line in lines:
line = line.strip().split(" ")
if len(line) == 2:
word = line[0]
label = line[1]
words.append(word)
if "B-" in label:
entity_tmp.append(word)
entity_label = en2ch[label.split("-")[-1]]
elif "I-" in label:
entity_tmp.append(word)
if (label == 'O') and entity_tmp:
if ("".join(entity_tmp), entity_label) not in entities_tmp:
entities_tmp.append(("".join(entity_tmp), entity_label))
entity_tmp, entity_label = [], ''
else:
if entity_tmp and (("".join(entity_tmp), entity_label) not in entities_tmp):
entities_tmp.append(("".join(entity_tmp), entity_label))
entity_tmp, entity_label = [], ''
texts.append("".join(words))
entities.append(entities_tmp)
words = []
entities_tmp = []
# ==========================================
# =======找出句子中实体的位置=======
i = 0
for text,entity in zip(texts, entities):
if entity:
ltmp = []
for ent,type in entity:
for span in re.finditer(ent, text):
start = span.start()
end = span.end()
ltmp.append((type, start, end, ent))
# print(ltmp)
ltmp = sorted(ltmp, key=lambda x:(x[1],x[2]))
for j in range(len(ltmp)):
# tmp['entities'].append(["".format(str(j)), ltmp[j][0], ltmp[j][1], ltmp[j][2], ltmp[j][3]])
tmp['entities'].append({"id":j, "start_offset":ltmp[j][1], "end_offset":ltmp[j][2], "label":ltmp[j][0]})
else:
tmp['entities'] = []
tmp['id'] = i
tmp['text'] = text
result.append(tmp)
tmp = {}
tmp['id'] = 0
tmp['text'] = ''
tmp['relations'] = []
tmp['entities'] = []
i += 1
with open(data_path, 'w', encoding='utf-8') as fp:
fp.write("\n".join([json.dumps(i, ensure_ascii=False) for i in result]))
preprocess("F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.train", './data/mid_data', "train")
preprocess("F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.dev", './data/mid_data', "dev")
preprocess("F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.test", './data/mid_data', "test")
\ No newline at end of file
# 数据生成step2
import os
import time
import argparse
import json
from decimal import Decimal
import numpy as np
from bert4torch.snippets import seed_everything
from utils import convert_ext_examples, convert_cls_examples, logger
def do_convert():
seed_everything(args.seed)
tic_time = time.time()
if not os.path.exists(args.doccano_file):
raise ValueError("Please input the correct path of doccano file.")
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
if len(args.splits) != 0 and len(args.splits) != 3:
raise ValueError("Only []/ len(splits)==3 accepted for splits.")
def _check_sum(splits):
return Decimal(str(splits[0])) + Decimal(str(splits[1])) + Decimal(
str(splits[2])) == Decimal("1")
if len(args.splits) == 3 and not _check_sum(args.splits):
raise ValueError(
"Please set correct splits, sum of elements in splits should be equal to 1."
)
with open(args.doccano_file, "r", encoding="utf-8") as f:
raw_examples = f.readlines()
def _create_ext_examples(examples,
negative_ratio=0,
shuffle=False,
is_train=True):
entities, relations = convert_ext_examples(
examples, negative_ratio, is_train=is_train)
examples = entities + relations
if shuffle:
indexes = np.random.permutation(len(examples))
examples = [examples[i] for i in indexes]
return examples
def _create_cls_examples(examples, prompt_prefix, options, shuffle=False):
examples = convert_cls_examples(examples, prompt_prefix, options)
if shuffle:
indexes = np.random.permutation(len(examples))
examples = [examples[i] for i in indexes]
return examples
def _save_examples(save_dir, file_name, examples):
count = 0
save_path = os.path.join(save_dir, file_name)
if not examples:
logger.info("Skip saving %d examples to %s." % (0, save_path))
return
with open(save_path, "w", encoding="utf-8") as f:
for example in examples:
f.write(json.dumps(example, ensure_ascii=False) + "\n")
count += 1
logger.info("Save %d examples to %s." % (count, save_path))
if len(args.splits) == 0:
if args.task_type == "ext":
examples = _create_ext_examples(raw_examples, args.negative_ratio,
args.is_shuffle)
else:
examples = _create_cls_examples(raw_examples, args.prompt_prefix,
args.options, args.is_shuffle)
_save_examples(args.save_dir, "train.txt", examples)
else:
if args.is_shuffle:
indexes = np.random.permutation(len(raw_examples))
raw_examples = [raw_examples[i] for i in indexes]
i1, i2, _ = args.splits
p1 = int(len(raw_examples) * i1)
p2 = int(len(raw_examples) * (i1 + i2))
if args.task_type == "ext":
train_examples = _create_ext_examples(
raw_examples[:p1], args.negative_ratio, args.is_shuffle)
dev_examples = _create_ext_examples(
raw_examples[p1:p2], -1, is_train=False)
test_examples = _create_ext_examples(
raw_examples[p2:], -1, is_train=False)
else:
train_examples = _create_cls_examples(
raw_examples[:p1], args.prompt_prefix, args.options)
dev_examples = _create_cls_examples(
raw_examples[p1:p2], args.prompt_prefix, args.options)
test_examples = _create_cls_examples(
raw_examples[p2:], args.prompt_prefix, args.options)
_save_examples(args.save_dir, "train.txt", train_examples)
_save_examples(args.save_dir, "dev.txt", dev_examples)
_save_examples(args.save_dir, "test.txt", test_examples)
logger.info('Finished! It takes %.2f seconds' % (time.time() - tic_time))
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--doccano_file", default="./data/doccano.json",
type=str, help="The doccano file exported from doccano platform.")
parser.add_argument("-s", "--save_dir", default="./data",
type=str, help="The path of data that you wanna save.")
parser.add_argument("--negative_ratio", default=5, type=int,
help="Used only for the extraction task, the ratio of positive and negative samples, number of negtive samples = negative_ratio * number of positive samples")
parser.add_argument("--splits", default=[0.8, 0.1, 0.1], type=float, nargs="*",
help="The ratio of samples in datasets. [0.6, 0.2, 0.2] means 60%% samples used for training, 20%% for evaluation and 20%% for test.")
parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str,
help="Select task type, ext for the extraction task and cls for the classification task, defaults to ext.")
parser.add_argument("--options", default=["正向", "负向"], type=str, nargs="+",
help="Used only for the classification task, the options for classification")
parser.add_argument("--prompt_prefix", default="情感倾向", type=str,
help="Used only for the classification task, the prompt prefix for classification")
parser.add_argument("--is_shuffle", default=True, type=bool,
help="Whether to shuffle the labeled dataset, defaults to True.")
parser.add_argument("--seed", type=int, default=1000,
help="random seed for initialization")
args = parser.parse_args()
do_convert()
import torch
from torch.utils.data import DataLoader
from model import uie_model, tokenizer
from bert4torch.snippets import seed_everything, sequence_padding, Callback
from torch import nn
from torch.utils.data import Dataset
import numpy as np
import json
from utils import get_bool_ids_greater_than, get_span
from random import sample
batch_size = 16
learning_rate = 1e-5
train_path = 'E:/Github/bert4torch/examples/sequence_labeling/uie/data/final_data/train.txt'
dev_path = 'E:/Github/bert4torch/examples/sequence_labeling/uie/data/final_data/dev.txt'
save_dir = './'
max_seq_len = 256
num_epochs = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed_everything(42)
uie_model.to(device)
class IEDataset(Dataset):
"""信息抽取
"""
def __init__(self, file_path, tokenizer, max_seq_len, fewshot=None) -> None:
super().__init__()
self.file_path = file_path
if fewshot is None:
self.dataset = list(self.reader(file_path))
else:
assert isinstance(fewshot, int)
self.dataset = sample(list(self.reader(file_path)), fewshot)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
return self.dataset[index]
@staticmethod
def reader(data_path, max_seq_len=512):
"""read json
"""
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
json_line = json.loads(line)
content = json_line['content']
prompt = json_line['prompt']
# Model Input is aslike: [CLS] Prompt [SEP] Content [SEP]
# It include three summary tokens.
if max_seq_len <= len(prompt) + 3:
raise ValueError("The value of max_seq_len is too small, please set a larger value")
max_content_len = max_seq_len - len(prompt) - 3
if len(content) <= max_content_len:
yield json_line
else:
result_list = json_line['result_list']
json_lines = []
accumulate = 0
while True:
cur_result_list = []
for result in result_list:
if result['start'] + 1 <= max_content_len < result['end']:
max_content_len = result['start']
break
cur_content = content[:max_content_len]
res_content = content[max_content_len:]
while True:
if len(result_list) == 0:
break
elif result_list[0]['end'] <= max_content_len:
if result_list[0]['end'] > 0:
cur_result = result_list.pop(0)
cur_result_list.append(cur_result)
else:
cur_result_list = [result for result in result_list]
break
else:
break
json_line = {'content': cur_content, 'result_list': cur_result_list, 'prompt': prompt}
json_lines.append(json_line)
for result in result_list:
if result['end'] <= 0:
break
result['start'] -= max_content_len
result['end'] -= max_content_len
accumulate += max_content_len
max_content_len = max_seq_len - len(prompt) - 3
if len(res_content) == 0:
break
elif len(res_content) < max_content_len:
json_line = {'content': res_content, 'result_list': result_list, 'prompt': prompt}
json_lines.append(json_line)
break
else:
content = res_content
for json_line in json_lines:
yield json_line
def collate_fn(batch):
"""example: {title, prompt, content, result_list}
"""
batch_token_ids, batch_token_type_ids, batch_start_ids, batch_end_ids = [], [], [], []
for example in batch:
token_ids, token_type_ids, offset_mapping = tokenizer.encode(example["prompt"], example["content"],
maxlen=max_seq_len, return_offsets='transformers')
bias = 0
for index in range(len(offset_mapping)):
if index == 0:
continue
mapping = offset_mapping[index]
if mapping[0] == 0 and mapping[1] == 0 and bias == 0:
bias = index
if mapping[0] == 0 and mapping[1] == 0:
continue
offset_mapping[index][0] += bias
offset_mapping[index][1] += bias
start_ids = [0 for _ in range(len(token_ids))]
end_ids = [0 for _ in range(len(token_ids))]
for item in example["result_list"]:
start = map_offset(item["start"] + bias, offset_mapping)
end = map_offset(item["end"] - 1 + bias, offset_mapping)
start_ids[start] = 1.0
end_ids[end] = 1.0
batch_token_ids.append(token_ids)
batch_token_type_ids.append(token_type_ids)
batch_start_ids.append(start_ids)
batch_end_ids.append(end_ids)
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
batch_token_type_ids = torch.tensor(sequence_padding(batch_token_type_ids), dtype=torch.long, device=device)
batch_start_ids = torch.tensor(sequence_padding(batch_start_ids), dtype=torch.float, device=device)
batch_end_ids = torch.tensor(sequence_padding(batch_end_ids), dtype=torch.float, device=device)
return [batch_token_ids, batch_token_type_ids], [batch_start_ids, batch_end_ids]
def map_offset(ori_offset, offset_mapping):
"""map ori offset to token offset
"""
for index, span in enumerate(offset_mapping):
if span[0] <= ori_offset < span[1]:
return index
return -1
# 数据准备
train_ds = IEDataset(train_path, tokenizer=tokenizer, max_seq_len=max_seq_len, fewshot=None)
dev_ds = IEDataset(dev_path, tokenizer=tokenizer, max_seq_len=max_seq_len)
train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(dev_ds, batch_size=batch_size, collate_fn=collate_fn)
class MyLoss(nn.Module):
def forward(self, y_pred, y_true):
start_prob, end_prob = y_pred
start_ids, end_ids = y_true
loss_start = torch.nn.functional.binary_cross_entropy(start_prob, start_ids)
loss_end = torch.nn.functional.binary_cross_entropy(end_prob, end_ids)
return loss_start + loss_end
uie_model.compile(
loss=MyLoss(),
optimizer=torch.optim.AdamW(lr=learning_rate, params=uie_model.parameters()),
)
class SpanEvaluator(Callback):
"""SpanEvaluator computes the precision, recall and F1-score for span detection.
"""
def __init__(self):
self.num_infer_spans = 0
self.num_label_spans = 0
self.num_correct_spans = 0
self.best_val_f1 = 0
def on_epoch_end(self, steps, epoch, logs=None):
f1, precision, recall = self.evaluate(valid_dataloader)
if f1 > self.best_val_f1:
self.best_val_f1 = f1
# model.save_weights('best_model.pt')
print(f'[val-entity level] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f}')
def evaluate(self, dataloder):
self.reset()
for x_true, y_true in dataloder:
start_prob, end_prob = uie_model.predict(*x_true)
start_ids, end_ids = y_true
num_correct, num_infer, num_label = self.compute(start_prob, end_prob, start_ids, end_ids)
self.update(num_correct, num_infer, num_label)
precision, recall, f1 = self.accumulate()
return f1, precision, recall
def compute(self, start_probs, end_probs, gold_start_ids, gold_end_ids):
"""Computes the precision, recall and F1-score for span detection.
"""
start_probs = start_probs.cpu().numpy()
end_probs = end_probs.cpu().numpy()
gold_start_ids = gold_start_ids.cpu().numpy()
gold_end_ids = gold_end_ids.cpu().numpy()
pred_start_ids = get_bool_ids_greater_than(start_probs)
pred_end_ids = get_bool_ids_greater_than(end_probs)
gold_start_ids = get_bool_ids_greater_than(gold_start_ids.tolist())
gold_end_ids = get_bool_ids_greater_than(gold_end_ids.tolist())
num_correct_spans = 0
num_infer_spans = 0
num_label_spans = 0
for predict_start_ids, predict_end_ids, label_start_ids, label_end_ids in zip(
pred_start_ids, pred_end_ids, gold_start_ids, gold_end_ids):
[_correct, _infer, _label] = self.eval_span(predict_start_ids, predict_end_ids, label_start_ids, label_end_ids)
num_correct_spans += _correct
num_infer_spans += _infer
num_label_spans += _label
return num_correct_spans, num_infer_spans, num_label_spans
def update(self, num_correct_spans, num_infer_spans, num_label_spans):
"""
This function takes (num_infer_spans, num_label_spans, num_correct_spans) as input,
to accumulate and update the corresponding status of the SpanEvaluator object.
"""
self.num_infer_spans += num_infer_spans
self.num_label_spans += num_label_spans
self.num_correct_spans += num_correct_spans
def eval_span(self, predict_start_ids, predict_end_ids, label_start_ids,
label_end_ids):
"""
evaluate position extraction (start, end)
return num_correct, num_infer, num_label
input: [1, 2, 10] [4, 12] [2, 10] [4, 11]
output: (1, 2, 2)
"""
pred_set = get_span(predict_start_ids, predict_end_ids)
label_set = get_span(label_start_ids, label_end_ids)
num_correct = len(pred_set & label_set)
num_infer = len(pred_set)
num_label = len(label_set)
return (num_correct, num_infer, num_label)
def accumulate(self):
"""
This function returns the mean precision, recall and f1 score for all accumulated minibatches.
Returns:
tuple: Returns tuple (`precision, recall, f1 score`).
"""
precision = float(self.num_correct_spans / self.num_infer_spans) if self.num_infer_spans else 0.
recall = float(self.num_correct_spans / self.num_label_spans) if self.num_label_spans else 0.
f1_score = float(2 * precision * recall / (precision + recall)) if self.num_correct_spans else 0.
return precision, recall, f1_score
def reset(self):
"""
Reset function empties the evaluation memory for previous mini-batches.
"""
self.num_infer_spans = 0
self.num_label_spans = 0
self.num_correct_spans = 0
if __name__ == "__main__":
evaluator = SpanEvaluator()
print('zero_shot performance: ', evaluator.evaluate(valid_dataloader))
uie_model.fit(train_dataloader, epochs=num_epochs, steps_per_epoch=None, callbacks=[evaluator])
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from bert4torch.snippets import sequence_padding, Callback, ListDataset, seed_everything
from bert4torch.losses import FocalLoss
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel, BERT
from tqdm import tqdm
config_path = 'F:/Projects/pretrain_ckpt/uie/uie_base_pytorch/config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/uie/uie_base_pytorch/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/uie/uie_base_pytorch/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = Tokenizer(dict_path, do_lower_case=True)
class UIE(BERT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hidden_size = self.hidden_size
self.linear_start = nn.Linear(hidden_size, 1)
self.linear_end = nn.Linear(hidden_size, 1)
self.sigmoid = nn.Sigmoid()
if kwargs.get('use_task_id') and kwargs.get('use_task_id'):
# Add task type embedding to BERT
task_type_embeddings = nn.Embedding(kwargs.get('task_type_vocab_size'), self.hidden_size)
self.embeddings.task_type_embeddings = task_type_embeddings
def hook(module, input, output):
return output+task_type_embeddings(torch.zeros(input[0].size(), dtype=torch.int64, device=input[0].device))
self.embeddings.word_embeddings.register_forward_hook(hook)
def forward(self, token_ids, token_type_ids):
outputs = super().forward([token_ids, token_type_ids])
sequence_output = outputs[0]
start_logits = self.linear_start(sequence_output)
start_logits = torch.squeeze(start_logits, -1)
start_prob = self.sigmoid(start_logits)
end_logits = self.linear_end(sequence_output)
end_logits = torch.squeeze(end_logits, -1)
end_prob = self.sigmoid(end_logits)
return start_prob, end_prob
@torch.no_grad()
def predict(self, token_ids, token_type_ids):
self.eval()
start_prob, end_prob = self.forward(token_ids, token_type_ids)
return start_prob, end_prob
uie_model = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, model=UIE, with_pool=True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment