"online_apiserver_test/benchmarks/kernels/benchmark_moe.py" did not exist on "fba2e3b53349552607f568c17f48428c716c8c65"
Commit c007ba1a authored by sunzhq2's avatar sunzhq2 Committed by xuxo
Browse files

update

parents
Pipeline #3464 failed with stages
in 0 seconds
#! -*- coding:utf-8 -*-
# mrc阅读理解方案
# 数据集:http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
# [valid_f1]: 95.75
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from bert4torch.snippets import sequence_padding, Callback, ListDataset, seed_everything
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from tqdm import tqdm
from collections import defaultdict
max_c_len = 224
max_q_len = 32
batch_size = 6 # 真实的batch_size是 batch_size * 实体类型数
categories = ['LOC', 'PER', 'ORG']
ent2query = {"LOC": "找出下述句子中的地址名",
"PER": "找出下述句子中的人名",
"ORG": "找出下述句子中的机构名"}
# BERT base
config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 固定seed
seed_everything(42)
# 加载数据集
class MyDataset(ListDataset):
@staticmethod
def load_data(filename):
D = []
with open(filename, encoding='utf-8') as f:
f = f.read()
for l in f.split('\n\n'):
if not l:
continue
d = ['']
for i, c in enumerate(l.split('\n')):
char, flag = c.split(' ')
d[0] += char
if flag[0] == 'B':
d.append([i, i, flag[2:]])
elif flag[0] == 'I':
d[-1][1] = i
D.append(d)
return D
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
def collate_fn(batch):
batch_token_ids, batch_segment_ids, batch_start_labels, batch_end_labels = [], [], [], []
batch_ent_type = []
for d in batch:
tokens_b = tokenizer.tokenize(d[0], maxlen=max_c_len)[1:] # 不保留[CLS]
mapping = tokenizer.rematch(d[0], tokens_b)
start_mapping = {j[0]: i for i, j in enumerate(mapping) if j}
end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j}
# 按照实体类型整理实体
label_dict = defaultdict(list)
for start, end, label in d[1:]:
if start in start_mapping and end in end_mapping:
start = start_mapping[start]
end = end_mapping[end]
label_dict[label].append((start, end))
# 遍历实体类型,query为tokens_a, context为tokens_b
# 样本组成:[CLS] + tokens_a + [SEP] + tokens_b + [SEP]
for _type in categories:
start_ids = [0] * len(tokens_b)
end_ids = [0] * len(tokens_b)
text_a = ent2query[_type]
tokens_a = tokenizer.tokenize(text_a, maxlen=max_q_len)
for _label in label_dict[_type]:
start_ids[_label[0]] = 1
end_ids[_label[1]] = 1
start_ids = [0] * len(tokens_a) + start_ids
end_ids = [0] * len(tokens_a) + end_ids
token_ids = tokenizer.tokens_to_ids(tokens_a) + tokenizer.tokens_to_ids(tokens_b)
segment_ids = [0] * len(tokens_a) + [1] * len(tokens_b)
assert len(start_ids) == len(end_ids) == len(token_ids) == len(segment_ids)
batch_token_ids.append(token_ids)
batch_segment_ids.append(segment_ids)
batch_start_labels.append(start_ids)
batch_end_labels.append(end_ids)
batch_ent_type.append(_type)
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device)
batch_start_labels = torch.tensor(sequence_padding(batch_start_labels), dtype=torch.long, device=device)
batch_end_labels = torch.tensor(sequence_padding(batch_end_labels), dtype=torch.long, device=device)
return [batch_token_ids, batch_segment_ids], [batch_segment_ids, batch_start_labels, batch_end_labels, batch_ent_type]
# 转换数据集
train_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.train'), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.dev'), batch_size=batch_size, collate_fn=collate_fn)
# 定义bert上的模型结构
class Model(BaseModel):
def __init__(self):
super().__init__()
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path)
self.mid_linear = nn.Sequential(
nn.Linear(768, 128),
nn.ReLU(),
nn.Dropout(0.1)
)
self.start_fc = nn.Linear(128, 2)
self.end_fc = nn.Linear(128, 2)
def forward(self, token_ids, segment_ids):
sequence_output = self.bert([token_ids, segment_ids]) # [bts, seq_len, hdsz]
seq_out = self.mid_linear(sequence_output) # [bts, seq_len, mid_dims]
start_logits = self.start_fc(seq_out) # [bts, seq_len, 2]
end_logits = self.end_fc(seq_out) # [bts, seq_len, 2]
return start_logits, end_logits
model = Model().to(device)
class Loss(nn.CrossEntropyLoss):
def forward(self, outputs, labels):
start_logits, end_logits = outputs
mask, start_ids, end_ids = labels[:3]
start_logits = start_logits.view(-1, 2)
end_logits = end_logits.view(-1, 2)
# 去掉 text_a 和 padding 部分的标签,计算真实 loss
active_loss = mask.view(-1) == 1
active_start_logits = start_logits[active_loss]
active_end_logits = end_logits[active_loss]
active_start_labels = start_ids.view(-1)[active_loss]
active_end_labels = end_ids.view(-1)[active_loss]
start_loss = super().forward(active_start_logits, active_start_labels)
end_loss = super().forward(active_end_logits, active_end_labels)
return start_loss + end_loss
model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5))
def evaluate(data):
X, Y, Z = 0, 1e-10, 1e-10
for (token_ids, segment_ids), labels in tqdm(data, desc='Evaluation'):
start_logit, end_logit = model.predict([token_ids, segment_ids]) # [btz, seq_len, 2]
mask, start_ids, end_ids, ent_type = labels
# entity粒度
entity_pred = mrc_decode(start_logit, end_logit, ent_type, mask)
entity_true = mrc_decode(start_ids, end_ids, ent_type)
X += len(entity_pred.intersection(entity_true))
Y += len(entity_pred)
Z += len(entity_true)
f1, precision, recall = 2 * X / (Y + Z), X/ Y, X / Z
return f1, precision, recall
# 严格解码 baseline
def mrc_decode(start_preds, end_preds, ent_type, mask=None):
'''返回实体的start, end
'''
predict_entities = set()
if mask is not None: # 预测的把query和padding部分mask掉
start_preds = torch.argmax(start_preds, -1) * mask
end_preds = torch.argmax(end_preds, -1) * mask
start_preds = start_preds.cpu().numpy()
end_preds = end_preds.cpu().numpy()
for bt_i in range(start_preds.shape[0]):
start_pred = start_preds[bt_i]
end_pred = end_preds[bt_i]
# 统计每个样本的结果
for i, s_type in enumerate(start_pred):
if s_type == 0:
continue
for j, e_type in enumerate(end_pred[i:]):
if s_type == e_type:
# [样本id, 实体起点,实体终点,实体类型]
predict_entities.add((bt_i, i, i+j, ent_type[bt_i]))
break
return predict_entities
class Evaluator(Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_f1 = 0.
def on_epoch_end(self, steps, epoch, logs=None):
f1, precision, recall = evaluate(valid_dataloader)
if f1 > self.best_val_f1:
self.best_val_f1 = f1
# model.save_weights('best_model.pt')
print(f'[val] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f} best_f1: {self.best_val_f1:.5f}')
if __name__ == '__main__':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
else:
model.load_weights('best_model.pt')
#! -*- coding:utf-8 -*-
# span阅读理解方案
# 数据集:http://s3.bmio.net/kashgari/china-people-daily-ner-corpus.tar.gz
# [valid_f1]: 96.31
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from bert4torch.snippets import sequence_padding, Callback, ListDataset, seed_everything
from bert4torch.losses import FocalLoss
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel
from tqdm import tqdm
max_len = 256
batch_size = 16
categories = ['LOC', 'PER', 'ORG']
categories_id2label = {i: k for i, k in enumerate(categories, start=1)}
categories_label2id = {k: i for i, k in enumerate(categories, start=1)}
# BERT base
config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 固定seed
seed_everything(42)
# 加载数据集
class MyDataset(ListDataset):
@staticmethod
def load_data(filename):
D = []
with open(filename, encoding='utf-8') as f:
f = f.read()
for l in f.split('\n\n'):
if not l:
continue
d = ['']
for i, c in enumerate(l.split('\n')):
char, flag = c.split(' ')
d[0] += char
if flag[0] == 'B':
d.append([i, i, flag[2:]])
elif flag[0] == 'I':
d[-1][1] = i
D.append(d)
return D
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
def collate_fn(batch):
batch_token_ids, batch_start_labels, batch_end_labels = [], [], []
for d in batch:
tokens = tokenizer.tokenize(d[0], maxlen=max_len)[1:] # 不保留[CLS]
mapping = tokenizer.rematch(d[0], tokens)
start_mapping = {j[0]: i for i, j in enumerate(mapping) if j}
end_mapping = {j[-1]: i for i, j in enumerate(mapping) if j}
token_ids = tokenizer.tokens_to_ids(tokens)
start_ids = [0] * len(tokens)
end_ids = [0] * len(tokens)
for start, end, label in d[1:]:
if start in start_mapping and end in end_mapping:
start = start_mapping[start]
end = end_mapping[end]
start_ids[start] = categories_label2id[label]
end_ids[end] = categories_label2id[label]
batch_token_ids.append(token_ids)
batch_start_labels.append(start_ids)
batch_end_labels.append(end_ids)
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
batch_start_labels = torch.tensor(sequence_padding(batch_start_labels), dtype=torch.long, device=device)
batch_end_labels = torch.tensor(sequence_padding(batch_end_labels), dtype=torch.long, device=device)
batch_mask = batch_token_ids.gt(0).long()
return [batch_token_ids], [batch_mask, batch_start_labels, batch_end_labels]
# 转换数据集
train_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.train'), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.dev'), batch_size=batch_size, collate_fn=collate_fn)
# 定义bert上的模型结构
class Model(BaseModel):
def __init__(self):
super().__init__()
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, segment_vocab_size=0)
self.mid_linear = nn.Sequential(
nn.Linear(768, 128),
nn.ReLU(),
nn.Dropout(0.1)
)
self.start_fc = nn.Linear(128, len(categories)+1) # 0表示没有
self.end_fc = nn.Linear(128, len(categories)+1)
def forward(self, token_ids):
sequence_output = self.bert(token_ids) # [bts, seq_len, hdsz]
seq_out = self.mid_linear(sequence_output) # [bts, seq_len, mid_dims]
start_logits = self.start_fc(seq_out) # [bts, seq_len, num_tags]
end_logits = self.end_fc(seq_out) # [bts, seq_len, num_tags]
return start_logits, end_logits
model = Model().to(device)
class Loss(nn.CrossEntropyLoss):
def forward(self, outputs, labels):
start_logits, end_logits = outputs
mask, start_ids, end_ids = labels
start_logits = start_logits.view(-1, len(categories)+1)
end_logits = end_logits.view(-1, len(categories)+1)
# 去掉padding部分的标签,计算真实 loss
active_loss = mask.view(-1) == 1
active_start_logits = start_logits[active_loss]
active_end_logits = end_logits[active_loss]
active_start_labels = start_ids.view(-1)[active_loss]
active_end_labels = end_ids.view(-1)[active_loss]
start_loss = super().forward(active_start_logits, active_start_labels)
end_loss = super().forward(active_end_logits, active_end_labels)
return start_loss + end_loss
model.compile(loss=Loss(), optimizer=optim.Adam(model.parameters(), lr=2e-5))
def evaluate(data):
X, Y, Z = 0, 1e-10, 1e-10
for token_ids, labels in tqdm(data, desc='Evaluation'):
start_logit, end_logit = model.predict(token_ids) # [btz, seq_len, 2]
mask, start_ids, end_ids = labels
# entity粒度
entity_pred = span_decode(start_logit, end_logit, mask)
entity_true = span_decode(start_ids, end_ids)
X += len(entity_pred.intersection(entity_true))
Y += len(entity_pred)
Z += len(entity_true)
f1, precision, recall = 2 * X / (Y + Z), X/ Y, X / Z
return f1, precision, recall
# 严格解码 baseline
def span_decode(start_preds, end_preds, mask=None):
'''返回实体的start, end
'''
predict_entities = set()
if mask is not None: # 把padding部分mask掉
start_preds = torch.argmax(start_preds, -1) * mask
end_preds = torch.argmax(end_preds, -1) * mask
start_preds = start_preds.cpu().numpy()
end_preds = end_preds.cpu().numpy()
for bt_i in range(start_preds.shape[0]):
start_pred = start_preds[bt_i]
end_pred = end_preds[bt_i]
# 统计每个样本的结果
for i, s_type in enumerate(start_pred):
if s_type == 0:
continue
for j, e_type in enumerate(end_pred[i:]):
if s_type == e_type:
# [样本id, 实体起点,实体终点,实体类型]
predict_entities.add((bt_i, i, i+j, categories_id2label[s_type]))
break
return predict_entities
class Evaluator(Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_f1 = 0.
def on_epoch_end(self, steps, epoch, logs=None):
f1, precision, recall = evaluate(valid_dataloader)
if f1 > self.best_val_f1:
self.best_val_f1 = f1
# model.save_weights('best_model.pt')
print(f'[val] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f} best_f1: {self.best_val_f1:.5f}')
if __name__ == '__main__':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
else:
model.load_weights('best_model.pt')
#! -*- coding:utf-8 -*-
# tplinker_plus用来做实体识别
# [valid_f1]: 95.71
import numpy as np
from bert4torch.models import build_transformer_model, BaseModel
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from bert4torch.snippets import sequence_padding, Callback, ListDataset, seed_everything
from bert4torch.tokenizers import Tokenizer
from bert4torch.losses import MultilabelCategoricalCrossentropy
from bert4torch.layers import TplinkerHandshakingKernel
maxlen = 64
batch_size = 16
categories_label2id = {"LOC": 0, "ORG": 1, "PER": 2}
categories_id2label = dict((value, key) for key,value in categories_label2id.items())
# BERT base
config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 固定seed
seed_everything(42)
# 加载数据集
class MyDataset(ListDataset):
@staticmethod
def load_data(filename):
data = []
with open(filename, encoding='utf-8') as f:
f = f.read()
for l in f.split('\n\n'):
if not l:
continue
text, label = '', []
for i, c in enumerate(l.split('\n')):
char, flag = c.split(' ')
text += char
if flag[0] == 'B':
label.append([i, i, flag[2:]])
elif flag[0] == 'I':
label[-1][1] = i
text_list = tokenizer.tokenize(text)[1:-1] #不保留首位[CLS]和末位[SEP]
tokens = [j for i in text_list for j in i][:maxlen] # 以char为单位
data.append((tokens, label)) # label为[[start, end, entity], ...]
return data
def trans_ij2k(seq_len, i, j):
'''把第i行,第j列转化成上三角flat后的序号
'''
if (i > seq_len - 1) or (j > seq_len - 1) or (i > j):
return 0
return int(0.5*(2*seq_len-i+1)*i+(j-i))
map_ij2k = {(i, j): trans_ij2k(maxlen, i, j) for i in range(maxlen) for j in range(maxlen) if j >= i}
map_k2ij = {v: k for k, v in map_ij2k.items()}
def tran_ent_rel2id():
'''获取最后一个分类层的的映射关系
'''
tag2id = {}
for p in categories_label2id.keys():
tag2id[p] = len(tag2id)
return tag2id
tag2id = tran_ent_rel2id()
id2tag = {v: k for k, v in tag2id.items()}
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
def collate_fn(batch):
pair_len = maxlen * (maxlen+1)//2
# batch_head_labels: [btz, pair_len, tag2id_len]
batch_labels = torch.zeros((len(batch), pair_len, len(tag2id)), dtype=torch.long, device=device)
batch_token_ids = []
for i, (tokens, labels) in enumerate(batch):
batch_token_ids.append(tokenizer.tokens_to_ids(tokens)) # 前面已经限制了长度
for s_i in labels:
if s_i[1] >= len(tokens): # 实体的结尾超过文本长度,则不标记
continue
batch_labels[i, map_ij2k[s_i[0], s_i[1]], tag2id[s_i[2]]] = 1
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, length=maxlen), dtype=torch.long, device=device)
return [batch_token_ids], batch_labels
# 转换数据集
train_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.train'), batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.dev'), batch_size=batch_size, collate_fn=collate_fn)
# 定义bert上的模型结构
class Model(BaseModel):
def __init__(self):
super().__init__()
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, segment_vocab_size=0)
self.fc = nn.Linear(768, len(tag2id))
self.handshaking_kernel = TplinkerHandshakingKernel(768, shaking_type='cln_plus', inner_enc_type='lstm')
def forward(self, inputs):
last_hidden_state = self.bert(inputs) # [btz, seq_len, hdsz]
shaking_hiddens = self.handshaking_kernel(last_hidden_state)
output = self.fc(shaking_hiddens) # [btz, pair_len, tag_size]
return output
model = Model().to(device)
model.compile(loss=MultilabelCategoricalCrossentropy(), optimizer=optim.Adam(model.parameters(), lr=2e-5))
def evaluate(data, threshold=0):
X, Y, Z, threshold = 0, 1e-10, 1e-10, 0
for x_true, label in data:
scores = model.predict(x_true) # [btz, pair_len, tag_size]
for i, score in enumerate(scores):
R = set()
for pair_id, tag_id in zip(*np.where(score.cpu().numpy() > threshold)):
start, end = map_k2ij[pair_id][0], map_k2ij[pair_id][1]
R.add((start, end, tag_id))
T = set()
for pair_id, tag_id in zip(*np.where(label[i].cpu().numpy() > threshold)):
start, end = map_k2ij[pair_id][0], map_k2ij[pair_id][1]
T.add((start, end, tag_id))
X += len(R & T)
Y += len(R)
Z += len(T)
f1, precision, recall = 2 * X / (Y + Z), X / Y, X / Z
return f1, precision, recall
class Evaluator(Callback):
"""评估与保存
"""
def __init__(self):
self.best_val_f1 = 0.
def on_epoch_end(self, steps, epoch, logs=None):
f1, precision, recall = evaluate(valid_dataloader)
if f1 > self.best_val_f1:
self.best_val_f1 = f1
# model.save_weights('best_model.pt')
print(f'[val] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f} best_f1: {self.best_val_f1:.5f}')
if __name__ == '__main__':
evaluator = Evaluator()
model.fit(train_dataloader, epochs=20, steps_per_epoch=None, callbacks=[evaluator])
else:
model.load_weights('best_model.pt')
import onnx
from onnx import helper, numpy_helper
import argparse
def fuse_qkv_bias(onnx_model_path, output_model_path):
model = onnx.load(onnx_model_path)
graph = model.graph
#
print(onnx.helper.printable_graph(graph))
# MHA
mha_nodes = [node for node in graph.node if node.op_type == "MultiHeadAttention"]
if not mha_nodes:
raise ValueError("No MultiHeadAttention nodes found in the model!")
# MHA
for mha_node in mha_nodes:
# bias
inputs = list(mha_node.input)
q_input = inputs[0]
k_input = inputs[1]
v_input = inputs[2]
bias_input = inputs[3]
bias_initializer = next((init for init in graph.initializer if init.name == bias_input), None)
if bias_initializer is None:
raise ValueError(f"Bias tensor {bias_input} not found in initializers!")
bias_array = numpy_helper.to_array(bias_initializer)
hidden_size = bias_array.shape[0] // 3
# bias
q_bias = bias_array[:hidden_size]
k_bias = bias_array[hidden_size:2 * hidden_size]
v_bias = bias_array[2 * hidden_size:]
# bias
q_bias_name = f"{mha_node.name}_q_bias"
k_bias_name = f"{mha_node.name}_k_bias"
v_bias_name = f"{mha_node.name}_v_bias"
q_bias_tensor = numpy_helper.from_array(q_bias, name=q_bias_name)
k_bias_tensor = numpy_helper.from_array(k_bias, name=k_bias_name)
v_bias_tensor = numpy_helper.from_array(v_bias, name=v_bias_name)
graph.initializer.extend([q_bias_tensor, k_bias_tensor, v_bias_tensor])
# Add
q_add_output = f"{q_input}_biased"
k_add_output = f"{k_input}_biased"
v_add_output = f"{v_input}_biased"
q_add_node = helper.make_node(
"Add",
[q_input, q_bias_name],
[q_add_output],
name=f"{mha_node.name}_Add_Q"
)
k_add_node = helper.make_node(
"Add",
[k_input, k_bias_name],
[k_add_output],
name=f"{mha_node.name}_Add_K"
)
v_add_node = helper.make_node(
"Add",
[v_input, v_bias_name],
[v_add_output],
name=f"{mha_node.name}_Add_V"
)
# MHA
# mha_node.input
mha_node.input[0] = q_add_output # query
mha_node.input[1] = k_add_output # key
mha_node.input[2] = v_add_output # value
if len(inputs) == 5:
mha_node.input[3] = inputs[-1] # mask_index_0
del mha_node.input[3] # bias
# MHA
for idx, node in enumerate(graph.node):
if node == mha_node:
insert_pos = idx
break
else:
raise RuntimeError(f"Failed to find MHA node {mha_node.name} in graph")
#
graph.node.insert(insert_pos, v_add_node)
graph.node.insert(insert_pos, k_add_node)
graph.node.insert(insert_pos, q_add_node)
# bias initializer
graph.initializer.remove(bias_initializer)
#
# onnx.checker.check_model(model)
onnx.save(model, output_model_path)
print(f"Updated ONNX model saved to {output_model_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Fuse QKV bias in an ONNX model.')
parser.add_argument('input', type=str, help='Path to the input ONNX model')
parser.add_argument('output', type=str, help='Path to save the output ONNX model')
args = parser.parse_args()
fuse_qkv_bias(args.input, args.output)
\ No newline at end of file
export MIGRAPHX_ENABLE_GEMM_SOFTMAX_GEMM_FUSE=1
export MIGRAPHX_ENABLE_MHA=1
export HIP_VISIBLE_DEVICES=3
model_path=/home/sunzhq/workspace/onnx_models/bert/bert_best.onnx
output_path=/home/sunzhq/workspace/yidong-infer/bert/bert4torch_cmcc/examples/sequence_labeling/models/bert_best_mha.onnx
output_path_sim=/home/sunzhq/workspace/yidong-infer/bert/bert4torch_cmcc/examples/sequence_labeling/models/bert_best_mha_md5.onnx
output_mxr=/home/sunzhq/workspace/yidong-infer/bert/bert4torch_cmcc/examples/sequence_labeling/models/bert_best_mha_md5.mxr
python3 -m onnxruntime.transformers.optimizer \
--input ${model_path} \
--output ${output_path} \
--use_multi_head_attention \
--num_heads 12 \
--hidden_size 768 \
--model_type bert \
--disable_skip_layer_norm \
--disable_gelu \
--use_gpu \
--disable_embed_layer_norm \
--use_mask_index \
--use_raw_attention_mask
python modify_onnx.py --input ${output_path} --output ${output_path_sim}
migraphx-driver compile ${output_path_sim} \
--fp16 --binary \
--output ${output_mxr} \
--input-dim @input 64 256
\ No newline at end of file
import os
import os.path as osp
import onnx
import numpy as np
from onnx_modifier import ONNXModifier
import argparse
def change_mask_to_bias(om: ONNXModifier):
attention_nodes = om.get_nodes("MultiHeadAttention")
for attention_node in attention_nodes:
init_name = f"{attention_node.name}_qkv_bias"
init_value = om.get_initializer_value(init_name)
init_value = init_value.reshape(3, -1)
q_bias_init = om.create_initializer(f"{attention_node.name}_q_bias", init_value[0, :])
k_bias_init = om.create_initializer(f"{attention_node.name}_k_bias", init_value[1, :])
v_bias_init = om.create_initializer(f"{attention_node.name}_v_bias", init_value[2, :])
q_matmul_node = om.get_from_node(attention_node.inputs[0])
k_matmul_node = om.get_from_node(attention_node.inputs[1])
v_matmul_node = om.get_from_node(attention_node.inputs[2])
q_add_node = om.create_node("Add", f"{attention_node.name}_q_Add", [attention_node.inputs[0], q_bias_init.name], [f"{attention_node.name}_q_Add_output_0"], index=q_matmul_node.index+1)
k_add_node = om.create_node("Add", f"{attention_node.name}_k_Add", [attention_node.inputs[1], k_bias_init.name], [f"{attention_node.name}_k_Add_output_0"], index=k_matmul_node.index+1)
v_add_node = om.create_node("Add", f"{attention_node.name}_v_Add", [attention_node.inputs[2], v_bias_init.name], [f"{attention_node.name}_v_Add_output_0"], index=v_matmul_node.index+1)
attention_node.set_input(0, q_add_node.outputs[0])
attention_node.set_input(1, k_add_node.outputs[0])
attention_node.set_input(2, v_add_node.outputs[0])
assert len(attention_node.inputs) == 5
attention_node.set_input(3, attention_node.inputs.pop(4))
# if not om.get_initializer("num_heads"):
# num_heads = om.create_initializer("num_heads", np.array([12]))
# if not om.get_initializer("dimension"):
# dimension = om.create_initializer("dimension", np.array([64]))
# if not om.get_initializer("hidden_size"):
# hidden_size = om.create_initializer("hidden_size", np.array([64]))
# if not om.get_node("sequence_length_Unsqueeze"):
# sequence_length = om.create_node("Unsqueeze", "sequence_length_Unsqueeze", ["/bert/embeddings/Gather_output_0", om.create_initializer("axes_init", np.array([0])).name], ["sequence_length_Unsqueeze_output_0"], index=om.get_node("/bert/embeddings/Gather").index+1)
# if not om.get_node("BSHD_Concat"):
# bshd_concat_node = om.create_node(
# "Concat",
# "BSHD_Concat",
# ["/bert/embeddings/Unsqueeze_1_output_0", sequence_length.outputs[0], num_heads.name, dimension.name],
# [f"{attention_node.name}_shape_Concat_output_0"],
# index=attention_node.index
# )
# k_reshape_node = om.create_node("Reshape", f"{attention_node.name}_k_reshape", [k_add_node.outputs[0], bshd_concat_node.outputs[0]], [f"{attention_node.name}_k_reshape_output_0"], index=k_add_node.index+1)
# v_reshape_node = om.create_node("Reshape", f"{attention_node.name}_v_reshape", [v_add_node.outputs[0], bshd_concat_node.outputs[0]], [f"{attention_node.name}_v_reshape_output_0"], index=v_add_node.index+1)
# if not om.get_initializer("nhsd_perm"):
# nhsd_perm = om.create_initializer("nhsd_perm", np.array([0, 2, 1, 3]))
# k_transpose_node = om.create_node("Transpose", f"{attention_node.name}_k_transpose", [k_reshape_node.outputs[0]], [f"{attention_node.name}_k_transpose_output_0"], index=k_reshape_node.index+1)
# v_transpose_node = om.create_node("Transpose", f"{attention_node.name}_v_transpose", [v_reshape_node.outputs[0]], [f"{attention_node.name}_v_transpose_output_0"], index=v_reshape_node.index+1)
# attention_node.set_input(0, q_add_node.outputs[0])
# attention_node.set_input(1, k_transpose_node.outputs[0])
# attention_node.set_input(2, v_transpose_node.outputs[0])
# assert len(attention_node.inputs) == 5
# attention_node.set_input(3, attention_node.inputs.pop(4))
# attention_node.set_attribute("mask_filter_value", -1000.0)
# attention_node.set_attribute("mask_filter_value", -3.4028234663852886e+38)
print(f"{attention_node.name} changed!")
def process_mask(om):
not_node = om.get_node("/bert/Not")
cast_node = not_node.next_nodes[0]
# cast_node.name = "Cast_new"
print(f"name='{cast_node.name}', op_type='{cast_node.op_type}'")
# cast_node.set_attribute("to", 1)
next_nodes = cast_node.next_nodes
# sub_init = om.create_initializer("Sub_new_A", np.array(1.0, dtype=np.float32))
# sub_node = om.create_node("Sub", "Sub_new", [sub_init.name, cast_node.outputs[0]], ["Sub_new_output_0"], index=cast_node.index+1)
# mul_init = om.create_initializer("Mul_new_B", np.array(-1000.0, dtype=np.float32))
# mul_node = om.create_node("Mul", "Mul_new", [sub_node.outputs[0], mul_init.name], ["Mul_new_output_0"], index=sub_node.index+1)
reducesum_init = om.create_initializer("ReduceSum_for_mask_axes", np.array([1], dtype=np.int64))
reducesum_node = om.create_node("ReduceSum", "ReduceSum_for_mask", [cast_node.outputs[0], reducesum_init.name], [ "ReduceSum_for_mask_output_0"], keepdims=0, index=cast_node.index+1)
for node in next_nodes:
node.replace_input(cast_node.outputs[0], reducesum_node.outputs[0])
def main(input, output):
om = ONNXModifier(input)
change_mask_to_bias(om)
process_mask(om)
om.save(output)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Fuse QKV bias in an ONNX model.')
parser.add_argument('--input', type=str, help='Path to the input ONNX model')
parser.add_argument('--output', type=str, help='Path to save the output ONNX model')
args = parser.parse_args()
main(args.input, args.output)
"""
mdf1: 把qkv_bias放到qkv三个matmul后
mdf2: 把qkv_bias放到qkv三个matmul后 然后mask恢复 (1-x)*(-1000)
mdf3: 把qkv_bias放到qkv三个matmul后 kv处理为bhsd
mdf4: 把qkv_bias放到qkv三个matmul后 kv处理为bhsd 然后mask恢复 (1-x)*(-1000)
mdf5: 把qkv_bias放到qkv三个matmul后, mask处理维1维
mdf6: 把qkv_bias放到qkv三个matmul后, mask处理维1维 然后mask恢复 (1-x)*(-1000)
"""
"""
onnx modifier: provide a conviennt way to modify onnx model
1. add node
2. remove node
3. modify node
4. query node
"""
from collections import defaultdict, deque
import os
import os.path as osp
import shutil
import tempfile
from typing import List, Dict, Set, Tuple, Optional, Union
import warnings
import numpy as np
import onnx
from onnx import AttributeProto, numpy_helper
from onnx import shape_inference
from onnx.helper import make_attribute, make_node, make_opsetid, make_tensor, \
tensor_dtype_to_np_dtype
from onnxconverter_common import float16
import tqdm
# from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
SUPPORT_DTYPES = [
'BOOL', 'STRING', 'BFLOAT16', 'DOUBLE', 'FLOAT', 'FLOAT16',
'INT16', 'INT32', 'INT4', 'INT64', 'INT8', 'UINT16', 'UINT32', 'UINT4', 'UINT64', 'UINT8',
]
SUPPORT_DTYPES.extend([dt.lower() for dt in SUPPORT_DTYPES])
class Node:
def __init__(self, onnx_modifier=None, obj=None, index=None):
self.onnx_modifier = onnx_modifier
self.obj = obj
self.index = index
@property
def name(self):
return self.obj.name
@property
def op_type(self):
return self.obj.op_type
@property
def inputs(self):
return self.obj.input
@property
def outputs(self):
return self.obj.output
@property
def input_names(self):
return self.inputs
@property
def output_names(self):
return self.outputs
def check_modifier(self):
if self.onnx_modifier is None:
raise RuntimeError("onnx_modifier is not initialized")
@property
def prev_nodes(self):
self.check_modifier()
return self.onnx_modifier.get_prev_nodes(self)
@property
def next_nodes(self):
self.check_modifier()
return self.onnx_modifier.get_next_nodes(self)
def replace_input(self, old_name, new_name):
assert old_name in self.obj.input, \
f'"{old_name}" not in input name list of node named "{self.name}"'
for i, in_name in enumerate(self.obj.input):
if in_name == old_name:
self.set_input(i, new_name)
def set_input(self, index, name):
# assert index < len(self.obj.input), "index out of range"
# orig_name = self.obj.input[index]
# self.obj.input[index] = name
assert index < len(self.onnx_modifier.graph.node[self.index].input), "index out of range"
orig_name = self.onnx_modifier.graph.node[self.index].input[index]
self.onnx_modifier.graph.node[self.index].input[index] = name
self.check_modifier()
# Can not execute connection.pop_to_node() method directly.
# When node inputs contain multiple orig_name, need to remain the node in to_nodes.
if list(self.onnx_modifier.graph.node[self.index].input).count(orig_name) == 0:
self.onnx_modifier.connection_map[orig_name].pop_to_node(self)
if name not in self.onnx_modifier.connection_map:
self.onnx_modifier.connection_map[name] = Connection(name, self.onnx_modifier)
self.onnx_modifier.connection_map[name].add_to_node(self)
def set_inputs(self, names):
assert len(names) == len(self.obj.input), "number of inputs does not match"
assert all(isinstance(name, str) for name in names), "input names must be strings"
self.obj.input[:] = names
def set_output(self, index, name):
assert index < len(self.obj.output), "index out of range"
orig_name = self.obj.output[index]
self.obj.output[index] = name
self.check_modifier()
self.onnx_modifier.connection_map[orig_name].clear_from_node()
if name not in self.onnx_modifier.connection_map:
self.onnx_modifier.connection_map[name] = Connection(name, self.onnx_modifier)
self.onnx_modifier.connection_map[name].set_from_node(self)
def set_outputs(self, names):
assert len(names) == len(self.obj.output), "number of outputs does not match"
assert all(isinstance(name, str) for name in names), "output names must be strings"
self.obj.output[:] = names
@property
def attrs(self):
attrs = {}
for attr in self.obj.attribute:
if attr.type == AttributeProto.FLOAT: # 1
value = attr.f
elif attr.type == AttributeProto.INT: # 2
value = attr.i
elif attr.type == AttributeProto.STRING: # 3
value = attr.s.decode('utf-8')
elif attr.type == AttributeProto.TENSOR: # 4
value = numpy_helper.to_array(attr.t)
elif attr.type == AttributeProto.FLOATS: # 6
value = list(attr.floats)
elif attr.type == AttributeProto.INTS: # 7
value = list(attr.ints)
else:
value = f"Unsupported type: {attr.type}"
attrs[attr.name] = value
return attrs
def set_attribute(self, name, value, name2attr=None):
if not name2attr:
name2attr = {}
for attr in self.obj.attribute:
name2attr[attr.name] = attr
if name in name2attr:
if isinstance(value, float):
name2attr[name].f = value
name2attr[name].type = AttributeProto.FLOAT
elif isinstance(value, int):
name2attr[name].i = value
name2attr[name].type = AttributeProto.INT
elif isinstance(value, str):
name2attr[name].s = value.encode('utf-8')
name2attr[name].type = AttributeProto.STRING
elif isinstance(value, np.ndarray):
name2attr[name].ClearField("t")
name2attr[name].t.CopyFrom(numpy_helper.from_array(value))
elif isinstance(value, list):
is_all_float = all(isinstance(x, float) for x in value)
is_all_int = all(isinstance(x, int) for x in value)
assert is_all_float or is_all_int
if is_all_float:
name2attr[name].ClearField("floats")
name2attr[name].floats.extend(value)
name2attr[name].type = AttributeProto.FLOATS
else:
name2attr[name].ClearField("ints")
name2attr[name].ints.extend(value)
name2attr[name].type = AttributeProto.INTS
else:
if isinstance(value, np.ndarray):
value = numpy_helper.from_array(value)
self.obj.attribute.append(make_attribute(name, value))
def set_attributes(self, attr_dict):
name2attr = {}
for attr in self.obj.attribute:
name2attr[attr.name] = attr
for name, value in attr_dict.items():
self.set_attribute(name, value, name2attr)
class Connection:
def __init__(self, conn_name, onnx_modifier=None):
self.name = conn_name
self.onnx_modifier = onnx_modifier
self.from_node = None
self.to_nodes = []
self.to_node_names = set()
def check_modifier(self):
if self.onnx_modifier is None:
raise RuntimeError("onnx_modifier is not initialized")
def set_from_node(self, node: str | Node):
if isinstance(node, str):
self.check_modifier()
_node = self.onnx_modifier.get_node(Node)
assert node is not None, f'No node named "{node}" in onnx graph!'
elif isinstance(node, Node):
_node = node
else:
raise TypeError(f"Connection.set_from_node except input argument type" \
f" is str or Node, but received: {type(node)}")
self.from_node = _node
def clear_from_node(self):
self.from_node = None
def add_to_node(self, node: str | Node):
if isinstance(node, str):
_name = node
self.check_modifier()
_node = self.onnx_modifier.get_node(Node)
assert node is not None, f'No node named "{node}" in onnx graph!'
elif isinstance(node, Node):
_name = node.name
_node = node
else:
raise TypeError(f"Connection.add_to_node except input argument type" \
f" is str or Node, but received: {type(node)}")
if _name not in self.to_node_names:
self.to_node_names.add(_name)
self.to_nodes.append(_node)
def pop_to_node(self, node: str | Node):
if isinstance(node, str):
_name = node
self.check_modifier()
_node = self.onnx_modifier.get_node(Node)
assert node is not None, f'No node named "{node}" in onnx graph!'
elif isinstance(node, Node):
_name = node.name
_node = node
else:
raise TypeError(f"Connection.pop_to_node except input argument type" \
f" is str or Node, but received: {type(node)}")
if _name not in self.to_node_names:
raise ValueError(f'Node "{_name}" not in target nodes of connction "{self.name}"!')
self.to_node_names.remove(_name)
for i in range(len(self.to_nodes)):
if self.to_nodes[i].name == _name:
return self.to_nodes.pop(i)
else:
raise RuntimeError("to_nodes dismatch to_node_names!")
class ONNXModifier:
def __init__(self, onnx_path):
self.onnx_path = onnx_path
self.node_map = {}
self.initializer_map = {}
self.sparse_initializer_map = {}
self.connection_map = {}
self.value_info_map = {}
self.parse_onnx(self.onnx_path)
def parse_onnx(self, onnx_path):
model = onnx.load(onnx_path)
self.model = model
self.domain = model.domain
self.graph = model.graph
self.ir_version = model.ir_version
self.mdoel_version = model.model_version
self.opset_import = model.opset_import
self.update_map()
def add_opset_import(self, domain: str, version: int):
self.model.opset_import.append(make_opsetid(domain, version))
@property
def input_names(self):
return [i.name for i in self.graph.input]
@property
def output_names(self):
return [o.name for o in self.graph.output]
def add_input(self, name, dtype='float32', shape=None):
assert dtype in set(SUPPORT_DTYPES)
self.create_value_info(name, dtype=dtype, shape=shape)
new_input = self.value_info_map.pop(name)
_new_input = self.graph.value_info.pop()
assert id(new_input) == id(_new_input)
assert name == new_input.name
self.graph.input.append(new_input)
return new_input
def add_output(self, name, new_name=None, shape=None):
if name not in self.value_info_map:
raise ValueError(f"{name} not in onnx_modifier.value_info_map")
index = None
for i, v in enumerate(self.graph.value_info):
if v.name == name:
index = i
break
else:
raise ValueError(f"{name} not in model.graph.value_info")
value_info = self.value_info_map.pop(name)
assert value_info.name == name
assert id(value_info) == id(self.graph.value_info[index])
self.graph.value_info.pop(index)
if shape is not None:
tensor_type = onnx.helper.make_tensor_type_proto(
elem_type=value_info.type.tensor_type.elem_type,
shape=shape
)
value_info.type.CopyFrom(tensor_type)
if new_name is None:
self.graph.output.append(value_info)
else:
from_node = self.get_from_node(name)
to_nodes = self.get_to_nodes(name)
for i, output_name in enumerate(from_node.output_names):
if output_name == name:
from_node.set_output(i, new_name)
for node in to_nodes:
node.replace_input(name, new_name)
value_info.name = new_name
self.graph.output.append(value_info)
def remove_output(self, name):
"""根据名称删除模型输出"""
assert name in self.output_names
print("need remove output name:", name)
index = None
for i, out in enumerate(self.graph.output):
print(f"current(index={i}) output name:", out.name)
if out.name == name:
index = i
break
else:
raise RuntimeError(f"ONNX graphx not has a output named '{name}'.")
self.graph.output.pop(index)
def get_node(self, name_or_index: Union[str, int]):
"""根据节点名称或索引获取节点实例"""
if isinstance(name_or_index, str):
if name_or_index in self.node_map:
return self.node_map.get(name_or_index, None)
elif isinstance(name_or_index, int):
if name_or_index < len(self.graph.node):
return self.node_map.get(
self.graph.node[name_or_index].name, None)
else:
raise ValueError(f"Node index {name_or_index} out of range")
def get_nodes(self, op_type: str):
"""根据节点类型获取节点实例"""
node_names = [node.name for node in self.graph.node if node.op_type == op_type]
nodes = [self.node_map[name] for name in node_names]
return nodes
def get_initializer(self, name: str):
"""根据initializer名称获取initializer"""
return self.initializer_map.get(name)
def get_connection(self, name: str):
"""根据边名称获取边"""
return self.connection_map.get(name)
def get_from_node(self, conn: Union[str, Connection]):
"""获取某条边的输入节点名"""
if isinstance(conn, str):
assert conn in self.connection_map, f"Connection {conn} not in connection_map!"
return self.connection_map[conn].from_node
elif isinstance(conn, Connection):
return conn.from_node
else:
raise TypeError(f"Invalid connection type {type(conn)}")
def get_to_nodes(self, conn: Union[str, Connection]):
"""获取某条边的输出节点"""
if isinstance(conn, str):
assert conn in self.connection_map, f"Connection {conn} not in connection_map!"
return self.connection_map[conn].to_nodes
elif isinstance(conn, Connection):
return conn.to_nodes
else:
raise TypeError(f"Invalid connection type {type(conn)}")
def get_prev_nodes(self, node: Union[str, Node]):
"""获取某节点的上游输入节点"""
if isinstance(node, str):
node = self.node_map[node]
elif isinstance(node, Node):
pass
else:
raise TypeError(f"Invalid node type {type(node)}")
nodes = []
for conn_name in node.inputs:
from_node = self.get_from_node(conn_name)
if from_node:
nodes.append(from_node)
return nodes
def get_next_nodes(self, node: Union[str, Node]):
"""获取某节点的下游节点"""
if isinstance(node, str):
node = self.node_map[node]
elif isinstance(node, Node):
pass
else:
raise TypeError(f"Invalid node type {type(node)}")
nodes = []
for conn_name in node.outputs:
to_nodes = self.get_to_nodes(conn_name)
nodes.extend(to_nodes)
return nodes
def create_node(self, op_type, op_name, inputs, outputs, doc_string=None,
domain=None, index=None, **attrs):
"""创建一个新节点"""
onnx_node = make_node(op_type, inputs, outputs, name=op_name,
doc_string=doc_string, domain=domain, **attrs)
if index is None:
self.graph.node.append(onnx_node)
index = len(self.graph.node) - 1
else:
assert index <= len(self.graph.node), "index out of range"
self.graph.node.insert(index, onnx_node)
for i in range(index + 1, len(self.graph.node)):
node_name = self.graph.node[i].name
old_idx = self.node_map[node_name].index
assert old_idx == i - 1, \
f"Node {node_name} index conflict: {old_idx} != {i - 1}"
self.node_map[node_name].index = i
new_node = Node(self, self.graph.node[index], index)
self.node_map[op_name] = new_node
for in_name in new_node.input_names:
if in_name not in self.value_info_map:
self.create_value_info(in_name, dtype="float")
if in_name not in self.connection_map:
self.connection_map[in_name] = Connection(in_name, self)
self.connection_map[in_name].add_to_node(new_node)
for out_name in new_node.output_names:
if out_name not in self.value_info_map:
self.create_value_info(out_name, dtype="float")
if out_name not in self.connection_map:
self.connection_map[out_name] = Connection(out_name, self)
self.connection_map[out_name].set_from_node(new_node)
return new_node
def create_initializer(self, name, value: np.ndarray):
"""创建一个 initializer"""
assert name not in self.initializer_map, f"initializer {name} already exists!"
init_node = numpy_helper.from_array(value, name=name)
use_external_data = value.nbytes / 1024 / 1024 / 1024 > 2
if use_external_data:
print("use external data:", name)
init_node.data_location = onnx.TensorProto.EXTERNAL
location = name.replace('/', '+') + '.data'
onnx.external_data_helper.set_external_data(init_node, location)
with tempfile.TemporaryDirectory() as tmp_dir:
onnx.external_data_helper.save_external_data(init_node, tmp_dir)
init_node.ClearField("raw_data")
self.graph.initializer.append(init_node)
onnx.external_data_helper.load_external_data_for_tensor(
self.graph.initializer[-1], tmp_dir)
del self.graph.initializer[-1].external_data[:]
self.graph.initializer[-1].ClearField("data_location")
else:
self.graph.initializer.append(init_node)
self.initializer_map[name] = self.graph.initializer[-1]
return self.graph.initializer[-1]
def create_value_info(self, name, dtype=None, shape=None):
if dtype is None:
elem_type = None
else:
assert isinstance(dtype, str)
assert dtype in set(SUPPORT_DTYPES)
elem_type = getattr(onnx.TensorProto, dtype.upper())
value_info = onnx.helper.make_tensor_value_info(name=name,
elem_type=elem_type,
shape=shape)
self.graph.value_info.append(value_info)
self.value_info_map[name] = self.graph.value_info[-1]
return self.graph.value_info[-1]
def get_initializer_value(self, name):
"""获取initializer的数值"""
init = self.get_initializer(name)
return numpy_helper.to_array(init)
def set_initializer_value(self, name, value: np.ndarray):
"""为initializer设置新的数值"""
init = self.get_initializer(name)
# 检查形状和类型
old_shape = list(init.dims)
new_shape = list(value.shape)
# old_dtype = TENSOR_TYPE_TO_NP_TYPE.get(init.data_type, None)
old_dtype = tensor_dtype_to_np_dtype(init.data_type)
new_dtype = value.dtype
if old_shape != new_shape:
warn_message = f"Initailizer {name} shape changed: {old_shape} -> {new_shape}"
warnings.warn(warn_message, RuntimeWarning)
if old_dtype is not None and old_dtype != new_dtype:
warn_message = f"Initailizer {name} dtype changed: {old_dtype} -> {new_dtype}"
warnings.warn(warn_message, RuntimeWarning)
new_tensor_proto = numpy_helper.from_array(value, name=name)
init.CopyFrom(new_tensor_proto)
def connect_node(self, node, inputs_map, outputs_map):
"""将某个节点与其上下游节点连接起来
Args:
node: Node
inputs_map: [(node0, out_idx0), (node1, out_idx1), ...]
outputs_map: [(node0, in_idx0), (node1, in_idx1), ...]
"""
# 在连接 A -> B 时,若 A 的输出名与 B 的输入名冲突时,优先使用 A 的输出名,
# 即:B.input[i] = A.output[j]
for i, (n, j) in enumerate(inputs_map):
if isinstance(n, str):
n = self.node_map[n]
assert j < len(n.outputs), \
f"output index {i} out of node {n.name} outputs range"
node.set_input(i, n.outputs[j])
for name, (n, i) in zip(node.outputs, outputs_map):
if isinstance(n, str):
n = self.node_map[n]
assert i < len(n.outputs), \
f"output index {i} out of node {n.name} outputs range"
n.set_output(i, name)
# TODO: update self.connection_map
def pop_node(self, node: Union[str, Node, int], auto_connect=True):
"""根据节点名称或索引移除节点"""
if isinstance(node, str):
node = self.node_map.get(node, None)
if node is None:
return None
index = node.index
assert node.name == self.graph.node[index].name
elif isinstance(node, int):
if node >= len(self.graph.node):
raise ValueError(f"Node index {node} out of range")
index = node
elif isinstance(node, Node):
index = node.index
else:
raise ValueError(f"Invalid node name or index: {node}")
for i in range(index + 1, len(self.graph.node)):
node = self.graph.node[i]
self.node_map[node.name].index -= 1
# print(f"node_name={self.graph.node[index].name} node_index={index}")
_node_obj = self.graph.node[index]
_node = self.get_node(_node_obj.name)
next_nodes = self.get_next_nodes(_node)
self.graph.node.pop(index)
self.node_map.pop(_node_obj.name)
# automatic connecting edges
if auto_connect and len(_node.inputs) == 1 and len(_node.outputs) == 1:
# self.connection_map[_node.inputs[0]].pop_to_node(_node)
for next_node in next_nodes:
next_node.replace_input( _node.outputs[0], _node.inputs[0])
# self.connection_map[_node.inputs[0]].add_to_node(next_node)
# self.connection_map.pop(_node.outputs[0])
# update connection_map
for in_name in _node.input_names:
if _node.name in self.connection_map[in_name].to_node_names:
self.connection_map[in_name].pop_to_node(_node)
for i, out_name in enumerate(_node.output_names):
self.connection_map[out_name].clear_from_node()
return _node
def remove_nodes(self, nodes: List[str | Node], auto_connect=False):
"""同时删除多个节点"""
indices = set()
_nodes = []
invalid_nodes = set()
for node in nodes:
if isinstance(node, str):
if node in self.node_map:
node = self.node_map[node]
if node.index not in indices:
_nodes.append(node)
indices.add(node.index)
else:
invalid_nodes.add(node)
elif isinstance(node, Node):
if node.index not in indices:
_nodes.append(node)
indices.add(node.index)
else:
invalid_nodes.add(node)
_nodes.sort(key=lambda x:x.index, reverse=True)
use_progress_bar = len(_nodes) > 500
if use_progress_bar:
pbar = tqdm.tqdm(total=len(_nodes), desc="Removing nodes")
for node in _nodes:
self.pop_node(node, auto_connect=auto_connect)
if use_progress_bar:
pbar.update(1)
if use_progress_bar:
pbar.close()
# print(f"{len(nodes) - len(invalid_nodes)} nodes have been removed.")
# if len(invalid_nodes) > 0:
# print(f"find {len(invalid_nodes)} invalid nodes:\n", invalid_nodes)
def pop_initializer(self, init_name: str, update_node_inputs: bool = True):
"""根据initializer名字移除initializer"""
_init1 = self.initializer_map.pop(init_name)
init_index = None
for i in range(len(self.graph.initializer)):
if self.graph.initializer.name == init_name:
init_index = i
break
else:
raise ValueError(f"Not existing a Initializer named {init_name}")
_init2 = self.graph.initializer.pop(init_index)
assert id(_init1) == id(_init2)
# if update_node_inputs and init_name in self.connection_map:
# to_nodes = self.get_to_nodes(init_name)
# self.connection_map.pop(init_name)
# for node in to_nodes:
# num_inputs = len(node.inputs)
# for i in range(num_inputs-1, -1, -1):
# if node.inputs[i] == init_name:
# node.inputs.pop(i)
return _init1
def update_map(self):
"""更新connection_map与node_map"""
self.node_map.clear()
self.connection_map.clear()
self.initializer_map.clear()
self.sparse_initializer_map.clear()
self.value_info_map.clear()
for i, node in enumerate(self.graph.node):
new_node = Node(self, node, i)
self.node_map[node.name] = new_node
for conn_name in node.input:
if conn_name not in self.connection_map:
self.connection_map[conn_name] = Connection(conn_name, self)
self.connection_map[conn_name].add_to_node(new_node)
for conn_name in node.output:
if conn_name not in self.connection_map:
self.connection_map[conn_name] = Connection(conn_name, self)
self.connection_map[conn_name].set_from_node(new_node)
for i, node in enumerate(self.graph.initializer):
self.initializer_map[node.name] = node
for i, node in enumerate(self.graph.sparse_initializer):
self.sparse_initializer_map[node.name] = [node, i]
for i, conn in enumerate(self.graph.value_info):
self.value_info_map[conn.name] = conn
def find_unuseful_nodes(self):
"""寻找没有用到的节点"""
end_names = set()
for output_name in self.output_names:
end_names.add(self.get_from_node(output_name).name)
unuseful_names = set()
for node in self.node_map.values():
if node.name in end_names:
continue
next_nodes = self.get_next_nodes(node)
if len(next_nodes) == 0:
unuseful_names.add(node.name)
# print("A find unuseful node:", node.name)
q = deque([self.node_map[name] for name in unuseful_names])
while len(q) != 0:
node = q.popleft()
prev_nodes = self.get_prev_nodes(node)
for node1 in prev_nodes:
next_nodes = self.get_next_nodes(node1)
next_names = set([node2.name for node2 in next_nodes])
# if (next_names - end_names).issubset(unuseful_names):
if next_names.issubset(unuseful_names):
if node1.name not in unuseful_names:
q.append(node1)
unuseful_names.add(node1.name)
# print("B find unuseful node:", node1.name)
unuseful_nodes = [self.node_map[name] for name in unuseful_names]
return unuseful_nodes
def remove_trash(self):
"""
1. 移除无用的节点
2. 移除无用的initializer
3. 移除没有输入节点的connection
4. 移除没有用到的模型输入与输出
"""
self.update_map()
unuseful_nodes = self.find_unuseful_nodes()
print(f"Find unuseful {len(unuseful_nodes)} nodes!")
self.remove_nodes(unuseful_nodes)
self.update_map()
all_node_inputs = set()
for node in self.node_map.values():
all_node_inputs.update(node.input_names)
# remove unuseful initializers
for init_name in list(self.initializer_map.keys()):
if init_name in all_node_inputs:
continue
index = None
for i, init in enumerate(self.graph.initializer):
if init.name == init_name:
index = i
break
else:
raise ValueError(
f"{init_name} not in model.graph.initializer")
self.graph.initializer.pop(index)
# remove unuseful sparse_initializers
for init_name in list(self.sparse_initializer_map.keys()):
if init_name in all_node_inputs:
continue
index = None
for i, init in enumerate(self.graph.sparse_initializer):
if init.name == init_name:
index = i
break
else:
raise ValueError(
f"{init_name} not in model.graph.sparse_initializer")
self.graph.sparse_initializer.pop(index)
self.update_map()
# remove unuseful inputs and outputs
for in_name in self.input_names:
# print(in_name, [n.name for n in self.get_to_nodes(in_name)])
if len(self.get_to_nodes(in_name)) != 0:
continue
for i, _in in enumerate(self.graph.input):
if in_name == _in.name:
self.graph.input.pop(i)
break
for out_name in self.output_names:
# print(out_name, self.get_from_node(out_name).name)
if self.get_from_node(out_name) is not None:
continue
for i, _out in enumerate(self.graph.output):
if out_name == _out.name:
self.graph.output.pop(i)
break
self.update_map()
def infer_shape(self, relative_nodes=[]):
if len(relative_nodes) == 0:
relative_nodes = [node for node in self.graph.node]
for node in relative_nodes:
if isinstance(node, Node):
node = node.obj
for output_name in node.output:
if output_name in self.value_info_map:
value_info = self.value_info_map[output_name]
if value_info.type.HasField('tensor_type'):
if value_info.type.tensor_type.HasField('shape'):
value_info.type.tensor_type.ClearField('shape')
self.model = shape_inference.infer_shapes(
self.model,
strict_mode=True
)
self.update_map()
def infer_node_shpe(self, node):
input_shapes = []
input_dtypes = []
for input_name in node.inputs:
value_info = self.value_info_map[input_name]
input_shapes.append(value_info.type.tensor_type.dims)
input_dtypes.append(value_info.type.tensor_type.type)
shape_inference.infer_node_outputs(node.obj, input_shapes, input_dtypes)
def convert_float_to_float16(self):
self.model = float16.convert_float_to_float16(self.model, keep_io_types=True)
def save(self, save_path, save_as_external_data=False,
all_tensors_to_one_file=True):
self.remove_trash()
external_data_name = osp.basename(save_path) + '.data'
external_data_path = osp.join(osp.dirname(save_path), external_data_name)
if save_as_external_data and osp.isfile(external_data_path):
os.remove(external_data_path)
onnx.save(self.model,
save_path,
save_as_external_data=save_as_external_data,
all_tensors_to_one_file=all_tensors_to_one_file,
location=external_data_name,
size_threshold=1024,
convert_attribute=False)
import migraphx
# migraphx-driver compile optimized_bert_best.onnx --fp16 --binary --output new_modle_1.mxr --input-dim @input 64 256
maxInput={"input":[64,256]}
model = migraphx.parse_onnx("/home/sunzhq/workspace/yidong/bert/bert4torch_cmcc/examples/sequence_labeling/bert_best.onnx", map_input_dims=maxInput)
inputName=list(model.get_inputs().keys())[0]
migraphx.quantize_fp16(model)
model.compile(t=migraphx.get_target("gpu"), offload_copy=False, device_id=0)
migraphx.save(model, "bert_best_fp16.mxr")
export HIP_VISIBLE_DEVICES=3
export MIGRAPHX_ENABLE_GEMM_SOFTMAX_GEMM_FUSE=1 # 性能提升,影响精度
export MIGRAPHX_ENABLE_MHA=1
migraphx-driver compile /home/sunzhq/workspace/yidong-infer/bert/bert4torch_cmcc/examples/sequence_labeling/models/bert_best_mha_mfd5.onnx \
--fp16 --binary \
--output /home/sunzhq/workspace/yidong-infer/bert/bert4torch_cmcc/examples/sequence_labeling/models/bert_best_mha_mfd5.mxr \
--input-dim @input 64 256
\ No newline at end of file
export MIGRAPHX_ENABLE_MHA=1
export MIGRAPHX_ENABLE_CUTLASS=1
export MIGRAPHX_ENABLE_GEMM_SOFTMAX_GEMM_FUSE=1
python bert_bias_seg.py /models/onnx-models/bert_best.onnx /home/sunzhq/workspace/yidong-infer/bert/bert4torch_cmcc/examples/sequence_labeling/tools/models/bert_best_1.onnx
# python modify_onnx.py
\ No newline at end of file
import argparse
import collections
import json
import os
import pickle
import torch
import logging
import shutil
from tqdm import tqdm
import time
logger = logging.Logger('log')
def get_path_from_url(url, root_dir, check_exist=True, decompress=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
decompress (bool): decompress zip or tar file. Default is `True`
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
import os.path
import os
import tarfile
import zipfile
def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = os.path.split(url)[-1]
fpath = fname
return os.path.join(root_dir, fpath)
def _get_download(url, fullname):
import requests
# using requests.get method
fname = os.path.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024, unit='KB') as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _download(url, path):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not os.path.exists(path):
os.makedirs(path)
fname = os.path.split(url)[-1]
fullname = os.path.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
DOWNLOAD_RETRY_LIMIT = 3
while not os.path.exists(fullname):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if not _get_download(url, fullname):
time.sleep(1)
continue
return fullname
def _uncompress_file_zip(filepath):
with zipfile.ZipFile(filepath, 'r') as files:
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
elif _is_a_single_dir(file_list):
# `strip(os.sep)` to remove `os.sep` in the tail of path
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
files.extractall(os.path.join(file_dir, rootpath))
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < 0:
return True
return False
def _is_a_single_dir(file_list):
new_file_list = []
for file_path in file_list:
if '/' in file_path:
file_path = file_path.replace('/', os.sep)
elif '\\' in file_path:
file_path = file_path.replace('\\', os.sep)
new_file_list.append(file_path)
file_name = new_file_list[0].split(os.sep)[0]
for i in range(1, len(new_file_list)):
if file_name != new_file_list[i].split(os.sep)[0]:
return False
return True
def _uncompress_file_tar(filepath, mode="r:*"):
with tarfile.open(filepath, mode) as files:
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
files.extractall(os.path.join(file_dir, rootpath))
return uncompressed_path
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
assert is_url(url), "downloading from {} not a url".format(url)
fullpath = _map_path(url, root_dir)
if os.path.exists(fullpath) and check_exist:
logger.info("Found {}".format(fullpath))
else:
fullpath = _download(url, root_dir)
if decompress and (tarfile.is_tarfile(fullpath) or
zipfile.is_zipfile(fullpath)):
fullpath = _decompress(fullpath)
return fullpath
MODEL_MAP = {
"uie-base": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base_v0.1/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json"
}
},
"uie-medium": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium_v1.0/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medium/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-mini": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini_v1.0/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_mini/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-micro": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro_v1.0/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_micro/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-nano": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano_v1.0/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_nano/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-medical-base": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_medical_base_v0.1/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_base/tokenizer_config.json",
}
},
"uie-tiny": {
"resource_file_urls": {
"model_state.pdparams":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny_v0.1/model_state.pdparams",
"model_config.json":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/model_config.json",
"vocab_file":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/vocab.txt",
"special_tokens_map":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/special_tokens_map.json",
"tokenizer_config":
"https://bj.bcebos.com/paddlenlp/taskflow/information_extraction/uie_tiny/tokenizer_config.json"
}
}
}
def build_params_map(attention_num=12):
"""
build params map from paddle-paddle's ERNIE to transformer's BERT
:return:
"""
weight_map = collections.OrderedDict({
'encoder.embeddings.word_embeddings.weight': "bert.embeddings.word_embeddings.weight",
'encoder.embeddings.position_embeddings.weight': "bert.embeddings.position_embeddings.weight",
'encoder.embeddings.token_type_embeddings.weight': "bert.embeddings.token_type_embeddings.weight",
'encoder.embeddings.task_type_embeddings.weight': "embeddings.task_type_embeddings.weight", # 这里没有前缀bert,直接映射到bert4torch结构
'encoder.embeddings.layer_norm.weight': 'bert.embeddings.LayerNorm.weight',
'encoder.embeddings.layer_norm.bias': 'bert.embeddings.LayerNorm.bias',
})
# add attention layers
for i in range(attention_num):
weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.query.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.q_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.query.bias'
weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.key.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.k_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.key.bias'
weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.weight'] = f'bert.encoder.layer.{i}.attention.self.value.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.v_proj.bias'] = f'bert.encoder.layer.{i}.attention.self.value.bias'
weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.weight'] = f'bert.encoder.layer.{i}.attention.output.dense.weight'
weight_map[f'encoder.encoder.layers.{i}.self_attn.out_proj.bias'] = f'bert.encoder.layer.{i}.attention.output.dense.bias'
weight_map[f'encoder.encoder.layers.{i}.norm1.weight'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.weight'
weight_map[f'encoder.encoder.layers.{i}.norm1.bias'] = f'bert.encoder.layer.{i}.attention.output.LayerNorm.bias'
weight_map[f'encoder.encoder.layers.{i}.linear1.weight'] = f'bert.encoder.layer.{i}.intermediate.dense.weight'
weight_map[f'encoder.encoder.layers.{i}.linear1.bias'] = f'bert.encoder.layer.{i}.intermediate.dense.bias'
weight_map[f'encoder.encoder.layers.{i}.linear2.weight'] = f'bert.encoder.layer.{i}.output.dense.weight'
weight_map[f'encoder.encoder.layers.{i}.linear2.bias'] = f'bert.encoder.layer.{i}.output.dense.bias'
weight_map[f'encoder.encoder.layers.{i}.norm2.weight'] = f'bert.encoder.layer.{i}.output.LayerNorm.weight'
weight_map[f'encoder.encoder.layers.{i}.norm2.bias'] = f'bert.encoder.layer.{i}.output.LayerNorm.bias'
# add pooler
weight_map.update(
{
'encoder.pooler.dense.weight': 'bert.pooler.dense.weight',
'encoder.pooler.dense.bias': 'bert.pooler.dense.bias',
'linear_start.weight': 'linear_start.weight',
'linear_start.bias': 'linear_start.bias',
'linear_end.weight': 'linear_end.weight',
'linear_end.bias': 'linear_end.bias',
}
)
return weight_map
def extract_and_convert(input_dir, output_dir):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
logger.info('=' * 20 + 'save config file' + '=' * 20)
config = json.load(open(os.path.join(input_dir, 'model_config.json'), 'rt', encoding='utf-8'))
config = config['init_args'][0]
config["architectures"] = ["UIE"]
config['layer_norm_eps'] = 1e-12
del config['init_class']
if 'sent_type_vocab_size' in config:
config['type_vocab_size'] = config['sent_type_vocab_size']
config['intermediate_size'] = 4 * config['hidden_size']
json.dump(config, open(os.path.join(output_dir, 'config.json'),
'wt', encoding='utf-8'), indent=4)
logger.info('=' * 20 + 'save vocab file' + '=' * 20)
with open(os.path.join(input_dir, 'vocab.txt'), 'rt', encoding='utf-8') as f:
words = f.read().splitlines()
words_set = set()
words_duplicate_indices = []
for i in range(len(words)-1, -1, -1):
word = words[i]
if word in words_set:
words_duplicate_indices.append(i)
words_set.add(word)
for i, idx in enumerate(words_duplicate_indices):
words[idx] = chr(0x1F6A9+i) # Change duplicated word to 🚩 LOL
with open(os.path.join(output_dir, 'vocab.txt'), 'wt', encoding='utf-8') as f:
for word in words:
f.write(word+'\n')
special_tokens_map = {
"unk_token": "[UNK]",
"sep_token": "[SEP]",
"pad_token": "[PAD]",
"cls_token": "[CLS]",
"mask_token": "[MASK]"
}
json.dump(special_tokens_map, open(os.path.join(output_dir, 'special_tokens_map.json'),
'wt', encoding='utf-8'))
tokenizer_config = {
"do_lower_case": True,
"unk_token": "[UNK]",
"sep_token": "[SEP]",
"pad_token": "[PAD]",
"cls_token": "[CLS]",
"mask_token": "[MASK]",
"tokenizer_class": "BertTokenizer"
}
json.dump(tokenizer_config, open(os.path.join(output_dir, 'tokenizer_config.json'),
'wt', encoding='utf-8'))
logger.info('=' * 20 + 'extract weights' + '=' * 20)
state_dict = collections.OrderedDict()
weight_map = build_params_map(attention_num=config['num_hidden_layers'])
paddle_paddle_params = pickle.load(
open(os.path.join(input_dir, 'model_state.pdparams'), 'rb'))
del paddle_paddle_params['StructuredToParameterName@@']
for weight_name, weight_value in paddle_paddle_params.items():
if 'weight' in weight_name:
if 'encoder.encoder' in weight_name or 'pooler' in weight_name or 'linear' in weight_name:
weight_value = weight_value.transpose()
# Fix: embedding error
if 'word_embeddings.weight' in weight_name:
weight_value[0, :] = 0
if weight_name not in weight_map:
logger.info(f"{'='*20} [SKIP] {weight_name} {'='*20}")
continue
state_dict[weight_map[weight_name]] = torch.FloatTensor(weight_value)
logger.info(f"{weight_name} -> {weight_map[weight_name]} {weight_value.shape}")
torch.save(state_dict, os.path.join(output_dir, "pytorch_model.bin"))
def check_model(input_model):
if not os.path.exists(input_model):
if input_model not in MODEL_MAP:
raise ValueError('input_model not exists!')
resource_file_urls = MODEL_MAP[input_model]['resource_file_urls']
logger.info("Downloading resource files...")
for key, val in resource_file_urls.items():
file_path = os.path.join(input_model, key)
if not os.path.exists(file_path):
get_path_from_url(val, input_model)
def do_main():
check_model(args.input_model)
extract_and_convert(args.input_model, args.output_model)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_model", default="uie-base", type=str,
help="Directory of input paddle model.\n Will auto download model [uie-base/uie-tiny]")
parser.add_argument("-o", "--output_model", default="uie_base_pytorch", type=str,
help="Directory of output pytorch model")
args = parser.parse_args()
do_main()
# 数据生成1
python finetune_step1_dataprocess.py
# 数据生成2
python finetune_step2_doccano.py \
--doccano_file ./data/mid_data/train.json \
--task_type "ext" \
--splits 1.0 0.0 0.0 \
--save_dir ./data/final_data/ \
--negative_ratio 3
python finetune_step2_doccano.py \
--doccano_file ./data/mid_data/dev.json \
--task_type "ext" \
--splits 0.0 1.0 0.0 \
--save_dir ./data/final_data/ \
--negative_ratio 0
python finetune_step2_doccano.py \
--doccano_file ./data/mid_data/test.json \
--task_type "ext" \
--splits 0.0 0.0 1.0 \
--save_dir ./data/final_data/ \
--negative_ratio 0
# finetune训练
python finetune_step3_train.py
\ No newline at end of file
# 数据转换1
import os
import re
import json
en2ch = {
'ORG':'机构',
'PER':'人名',
'LOC':'籍贯'
}
def preprocess(input_path, save_path, mode):
if not os.path.exists(save_path):
os.makedirs(save_path)
data_path = os.path.join(save_path, mode + ".json")
result = []
tmp = {}
tmp['id'] = 0
tmp['text'] = ''
tmp['relations'] = []
tmp['entities'] = []
# =======先找出句子和句子中的所有实体和类型=======
with open(input_path,'r',encoding='utf-8') as fp:
lines = fp.readlines()
texts = []
entities = []
words = []
entity_tmp = []
entities_tmp = []
entity_label = ''
for line in lines:
line = line.strip().split(" ")
if len(line) == 2:
word = line[0]
label = line[1]
words.append(word)
if "B-" in label:
entity_tmp.append(word)
entity_label = en2ch[label.split("-")[-1]]
elif "I-" in label:
entity_tmp.append(word)
if (label == 'O') and entity_tmp:
if ("".join(entity_tmp), entity_label) not in entities_tmp:
entities_tmp.append(("".join(entity_tmp), entity_label))
entity_tmp, entity_label = [], ''
else:
if entity_tmp and (("".join(entity_tmp), entity_label) not in entities_tmp):
entities_tmp.append(("".join(entity_tmp), entity_label))
entity_tmp, entity_label = [], ''
texts.append("".join(words))
entities.append(entities_tmp)
words = []
entities_tmp = []
# ==========================================
# =======找出句子中实体的位置=======
i = 0
for text,entity in zip(texts, entities):
if entity:
ltmp = []
for ent,type in entity:
for span in re.finditer(ent, text):
start = span.start()
end = span.end()
ltmp.append((type, start, end, ent))
# print(ltmp)
ltmp = sorted(ltmp, key=lambda x:(x[1],x[2]))
for j in range(len(ltmp)):
# tmp['entities'].append(["".format(str(j)), ltmp[j][0], ltmp[j][1], ltmp[j][2], ltmp[j][3]])
tmp['entities'].append({"id":j, "start_offset":ltmp[j][1], "end_offset":ltmp[j][2], "label":ltmp[j][0]})
else:
tmp['entities'] = []
tmp['id'] = i
tmp['text'] = text
result.append(tmp)
tmp = {}
tmp['id'] = 0
tmp['text'] = ''
tmp['relations'] = []
tmp['entities'] = []
i += 1
with open(data_path, 'w', encoding='utf-8') as fp:
fp.write("\n".join([json.dumps(i, ensure_ascii=False) for i in result]))
preprocess("F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.train", './data/mid_data', "train")
preprocess("F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.dev", './data/mid_data', "dev")
preprocess("F:/Projects/data/corpus/ner/china-people-daily-ner-corpus/example.test", './data/mid_data', "test")
\ No newline at end of file
# 数据生成step2
import os
import time
import argparse
import json
from decimal import Decimal
import numpy as np
from bert4torch.snippets import seed_everything
from utils import convert_ext_examples, convert_cls_examples, logger
def do_convert():
seed_everything(args.seed)
tic_time = time.time()
if not os.path.exists(args.doccano_file):
raise ValueError("Please input the correct path of doccano file.")
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
if len(args.splits) != 0 and len(args.splits) != 3:
raise ValueError("Only []/ len(splits)==3 accepted for splits.")
def _check_sum(splits):
return Decimal(str(splits[0])) + Decimal(str(splits[1])) + Decimal(
str(splits[2])) == Decimal("1")
if len(args.splits) == 3 and not _check_sum(args.splits):
raise ValueError(
"Please set correct splits, sum of elements in splits should be equal to 1."
)
with open(args.doccano_file, "r", encoding="utf-8") as f:
raw_examples = f.readlines()
def _create_ext_examples(examples,
negative_ratio=0,
shuffle=False,
is_train=True):
entities, relations = convert_ext_examples(
examples, negative_ratio, is_train=is_train)
examples = entities + relations
if shuffle:
indexes = np.random.permutation(len(examples))
examples = [examples[i] for i in indexes]
return examples
def _create_cls_examples(examples, prompt_prefix, options, shuffle=False):
examples = convert_cls_examples(examples, prompt_prefix, options)
if shuffle:
indexes = np.random.permutation(len(examples))
examples = [examples[i] for i in indexes]
return examples
def _save_examples(save_dir, file_name, examples):
count = 0
save_path = os.path.join(save_dir, file_name)
if not examples:
logger.info("Skip saving %d examples to %s." % (0, save_path))
return
with open(save_path, "w", encoding="utf-8") as f:
for example in examples:
f.write(json.dumps(example, ensure_ascii=False) + "\n")
count += 1
logger.info("Save %d examples to %s." % (count, save_path))
if len(args.splits) == 0:
if args.task_type == "ext":
examples = _create_ext_examples(raw_examples, args.negative_ratio,
args.is_shuffle)
else:
examples = _create_cls_examples(raw_examples, args.prompt_prefix,
args.options, args.is_shuffle)
_save_examples(args.save_dir, "train.txt", examples)
else:
if args.is_shuffle:
indexes = np.random.permutation(len(raw_examples))
raw_examples = [raw_examples[i] for i in indexes]
i1, i2, _ = args.splits
p1 = int(len(raw_examples) * i1)
p2 = int(len(raw_examples) * (i1 + i2))
if args.task_type == "ext":
train_examples = _create_ext_examples(
raw_examples[:p1], args.negative_ratio, args.is_shuffle)
dev_examples = _create_ext_examples(
raw_examples[p1:p2], -1, is_train=False)
test_examples = _create_ext_examples(
raw_examples[p2:], -1, is_train=False)
else:
train_examples = _create_cls_examples(
raw_examples[:p1], args.prompt_prefix, args.options)
dev_examples = _create_cls_examples(
raw_examples[p1:p2], args.prompt_prefix, args.options)
test_examples = _create_cls_examples(
raw_examples[p2:], args.prompt_prefix, args.options)
_save_examples(args.save_dir, "train.txt", train_examples)
_save_examples(args.save_dir, "dev.txt", dev_examples)
_save_examples(args.save_dir, "test.txt", test_examples)
logger.info('Finished! It takes %.2f seconds' % (time.time() - tic_time))
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--doccano_file", default="./data/doccano.json",
type=str, help="The doccano file exported from doccano platform.")
parser.add_argument("-s", "--save_dir", default="./data",
type=str, help="The path of data that you wanna save.")
parser.add_argument("--negative_ratio", default=5, type=int,
help="Used only for the extraction task, the ratio of positive and negative samples, number of negtive samples = negative_ratio * number of positive samples")
parser.add_argument("--splits", default=[0.8, 0.1, 0.1], type=float, nargs="*",
help="The ratio of samples in datasets. [0.6, 0.2, 0.2] means 60%% samples used for training, 20%% for evaluation and 20%% for test.")
parser.add_argument("--task_type", choices=['ext', 'cls'], default="ext", type=str,
help="Select task type, ext for the extraction task and cls for the classification task, defaults to ext.")
parser.add_argument("--options", default=["正向", "负向"], type=str, nargs="+",
help="Used only for the classification task, the options for classification")
parser.add_argument("--prompt_prefix", default="情感倾向", type=str,
help="Used only for the classification task, the prompt prefix for classification")
parser.add_argument("--is_shuffle", default=True, type=bool,
help="Whether to shuffle the labeled dataset, defaults to True.")
parser.add_argument("--seed", type=int, default=1000,
help="random seed for initialization")
args = parser.parse_args()
do_convert()
import torch
from torch.utils.data import DataLoader
from model import uie_model, tokenizer
from bert4torch.snippets import seed_everything, sequence_padding, Callback
from torch import nn
from torch.utils.data import Dataset
import numpy as np
import json
from utils import get_bool_ids_greater_than, get_span
from random import sample
batch_size = 16
learning_rate = 1e-5
train_path = 'E:/Github/bert4torch/examples/sequence_labeling/uie/data/final_data/train.txt'
dev_path = 'E:/Github/bert4torch/examples/sequence_labeling/uie/data/final_data/dev.txt'
save_dir = './'
max_seq_len = 256
num_epochs = 10
device = 'cuda' if torch.cuda.is_available() else 'cpu'
seed_everything(42)
uie_model.to(device)
class IEDataset(Dataset):
"""信息抽取
"""
def __init__(self, file_path, tokenizer, max_seq_len, fewshot=None) -> None:
super().__init__()
self.file_path = file_path
if fewshot is None:
self.dataset = list(self.reader(file_path))
else:
assert isinstance(fewshot, int)
self.dataset = sample(list(self.reader(file_path)), fewshot)
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
return self.dataset[index]
@staticmethod
def reader(data_path, max_seq_len=512):
"""read json
"""
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
json_line = json.loads(line)
content = json_line['content']
prompt = json_line['prompt']
# Model Input is aslike: [CLS] Prompt [SEP] Content [SEP]
# It include three summary tokens.
if max_seq_len <= len(prompt) + 3:
raise ValueError("The value of max_seq_len is too small, please set a larger value")
max_content_len = max_seq_len - len(prompt) - 3
if len(content) <= max_content_len:
yield json_line
else:
result_list = json_line['result_list']
json_lines = []
accumulate = 0
while True:
cur_result_list = []
for result in result_list:
if result['start'] + 1 <= max_content_len < result['end']:
max_content_len = result['start']
break
cur_content = content[:max_content_len]
res_content = content[max_content_len:]
while True:
if len(result_list) == 0:
break
elif result_list[0]['end'] <= max_content_len:
if result_list[0]['end'] > 0:
cur_result = result_list.pop(0)
cur_result_list.append(cur_result)
else:
cur_result_list = [result for result in result_list]
break
else:
break
json_line = {'content': cur_content, 'result_list': cur_result_list, 'prompt': prompt}
json_lines.append(json_line)
for result in result_list:
if result['end'] <= 0:
break
result['start'] -= max_content_len
result['end'] -= max_content_len
accumulate += max_content_len
max_content_len = max_seq_len - len(prompt) - 3
if len(res_content) == 0:
break
elif len(res_content) < max_content_len:
json_line = {'content': res_content, 'result_list': result_list, 'prompt': prompt}
json_lines.append(json_line)
break
else:
content = res_content
for json_line in json_lines:
yield json_line
def collate_fn(batch):
"""example: {title, prompt, content, result_list}
"""
batch_token_ids, batch_token_type_ids, batch_start_ids, batch_end_ids = [], [], [], []
for example in batch:
token_ids, token_type_ids, offset_mapping = tokenizer.encode(example["prompt"], example["content"],
maxlen=max_seq_len, return_offsets='transformers')
bias = 0
for index in range(len(offset_mapping)):
if index == 0:
continue
mapping = offset_mapping[index]
if mapping[0] == 0 and mapping[1] == 0 and bias == 0:
bias = index
if mapping[0] == 0 and mapping[1] == 0:
continue
offset_mapping[index][0] += bias
offset_mapping[index][1] += bias
start_ids = [0 for _ in range(len(token_ids))]
end_ids = [0 for _ in range(len(token_ids))]
for item in example["result_list"]:
start = map_offset(item["start"] + bias, offset_mapping)
end = map_offset(item["end"] - 1 + bias, offset_mapping)
start_ids[start] = 1.0
end_ids[end] = 1.0
batch_token_ids.append(token_ids)
batch_token_type_ids.append(token_type_ids)
batch_start_ids.append(start_ids)
batch_end_ids.append(end_ids)
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
batch_token_type_ids = torch.tensor(sequence_padding(batch_token_type_ids), dtype=torch.long, device=device)
batch_start_ids = torch.tensor(sequence_padding(batch_start_ids), dtype=torch.float, device=device)
batch_end_ids = torch.tensor(sequence_padding(batch_end_ids), dtype=torch.float, device=device)
return [batch_token_ids, batch_token_type_ids], [batch_start_ids, batch_end_ids]
def map_offset(ori_offset, offset_mapping):
"""map ori offset to token offset
"""
for index, span in enumerate(offset_mapping):
if span[0] <= ori_offset < span[1]:
return index
return -1
# 数据准备
train_ds = IEDataset(train_path, tokenizer=tokenizer, max_seq_len=max_seq_len, fewshot=None)
dev_ds = IEDataset(dev_path, tokenizer=tokenizer, max_seq_len=max_seq_len)
train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(dev_ds, batch_size=batch_size, collate_fn=collate_fn)
class MyLoss(nn.Module):
def forward(self, y_pred, y_true):
start_prob, end_prob = y_pred
start_ids, end_ids = y_true
loss_start = torch.nn.functional.binary_cross_entropy(start_prob, start_ids)
loss_end = torch.nn.functional.binary_cross_entropy(end_prob, end_ids)
return loss_start + loss_end
uie_model.compile(
loss=MyLoss(),
optimizer=torch.optim.AdamW(lr=learning_rate, params=uie_model.parameters()),
)
class SpanEvaluator(Callback):
"""SpanEvaluator computes the precision, recall and F1-score for span detection.
"""
def __init__(self):
self.num_infer_spans = 0
self.num_label_spans = 0
self.num_correct_spans = 0
self.best_val_f1 = 0
def on_epoch_end(self, steps, epoch, logs=None):
f1, precision, recall = self.evaluate(valid_dataloader)
if f1 > self.best_val_f1:
self.best_val_f1 = f1
# model.save_weights('best_model.pt')
print(f'[val-entity level] f1: {f1:.5f}, p: {precision:.5f} r: {recall:.5f}')
def evaluate(self, dataloder):
self.reset()
for x_true, y_true in dataloder:
start_prob, end_prob = uie_model.predict(*x_true)
start_ids, end_ids = y_true
num_correct, num_infer, num_label = self.compute(start_prob, end_prob, start_ids, end_ids)
self.update(num_correct, num_infer, num_label)
precision, recall, f1 = self.accumulate()
return f1, precision, recall
def compute(self, start_probs, end_probs, gold_start_ids, gold_end_ids):
"""Computes the precision, recall and F1-score for span detection.
"""
start_probs = start_probs.cpu().numpy()
end_probs = end_probs.cpu().numpy()
gold_start_ids = gold_start_ids.cpu().numpy()
gold_end_ids = gold_end_ids.cpu().numpy()
pred_start_ids = get_bool_ids_greater_than(start_probs)
pred_end_ids = get_bool_ids_greater_than(end_probs)
gold_start_ids = get_bool_ids_greater_than(gold_start_ids.tolist())
gold_end_ids = get_bool_ids_greater_than(gold_end_ids.tolist())
num_correct_spans = 0
num_infer_spans = 0
num_label_spans = 0
for predict_start_ids, predict_end_ids, label_start_ids, label_end_ids in zip(
pred_start_ids, pred_end_ids, gold_start_ids, gold_end_ids):
[_correct, _infer, _label] = self.eval_span(predict_start_ids, predict_end_ids, label_start_ids, label_end_ids)
num_correct_spans += _correct
num_infer_spans += _infer
num_label_spans += _label
return num_correct_spans, num_infer_spans, num_label_spans
def update(self, num_correct_spans, num_infer_spans, num_label_spans):
"""
This function takes (num_infer_spans, num_label_spans, num_correct_spans) as input,
to accumulate and update the corresponding status of the SpanEvaluator object.
"""
self.num_infer_spans += num_infer_spans
self.num_label_spans += num_label_spans
self.num_correct_spans += num_correct_spans
def eval_span(self, predict_start_ids, predict_end_ids, label_start_ids,
label_end_ids):
"""
evaluate position extraction (start, end)
return num_correct, num_infer, num_label
input: [1, 2, 10] [4, 12] [2, 10] [4, 11]
output: (1, 2, 2)
"""
pred_set = get_span(predict_start_ids, predict_end_ids)
label_set = get_span(label_start_ids, label_end_ids)
num_correct = len(pred_set & label_set)
num_infer = len(pred_set)
num_label = len(label_set)
return (num_correct, num_infer, num_label)
def accumulate(self):
"""
This function returns the mean precision, recall and f1 score for all accumulated minibatches.
Returns:
tuple: Returns tuple (`precision, recall, f1 score`).
"""
precision = float(self.num_correct_spans / self.num_infer_spans) if self.num_infer_spans else 0.
recall = float(self.num_correct_spans / self.num_label_spans) if self.num_label_spans else 0.
f1_score = float(2 * precision * recall / (precision + recall)) if self.num_correct_spans else 0.
return precision, recall, f1_score
def reset(self):
"""
Reset function empties the evaluation memory for previous mini-batches.
"""
self.num_infer_spans = 0
self.num_label_spans = 0
self.num_correct_spans = 0
if __name__ == "__main__":
evaluator = SpanEvaluator()
print('zero_shot performance: ', evaluator.evaluate(valid_dataloader))
uie_model.fit(train_dataloader, epochs=num_epochs, steps_per_epoch=None, callbacks=[evaluator])
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from bert4torch.snippets import sequence_padding, Callback, ListDataset, seed_everything
from bert4torch.losses import FocalLoss
from bert4torch.tokenizers import Tokenizer
from bert4torch.models import build_transformer_model, BaseModel, BERT
from tqdm import tqdm
config_path = 'F:/Projects/pretrain_ckpt/uie/uie_base_pytorch/config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/uie/uie_base_pytorch/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/uie/uie_base_pytorch/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = Tokenizer(dict_path, do_lower_case=True)
class UIE(BERT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
hidden_size = self.hidden_size
self.linear_start = nn.Linear(hidden_size, 1)
self.linear_end = nn.Linear(hidden_size, 1)
self.sigmoid = nn.Sigmoid()
if kwargs.get('use_task_id') and kwargs.get('use_task_id'):
# Add task type embedding to BERT
task_type_embeddings = nn.Embedding(kwargs.get('task_type_vocab_size'), self.hidden_size)
self.embeddings.task_type_embeddings = task_type_embeddings
def hook(module, input, output):
return output+task_type_embeddings(torch.zeros(input[0].size(), dtype=torch.int64, device=input[0].device))
self.embeddings.word_embeddings.register_forward_hook(hook)
def forward(self, token_ids, token_type_ids):
outputs = super().forward([token_ids, token_type_ids])
sequence_output = outputs[0]
start_logits = self.linear_start(sequence_output)
start_logits = torch.squeeze(start_logits, -1)
start_prob = self.sigmoid(start_logits)
end_logits = self.linear_end(sequence_output)
end_logits = torch.squeeze(end_logits, -1)
end_prob = self.sigmoid(end_logits)
return start_prob, end_prob
@torch.no_grad()
def predict(self, token_ids, token_type_ids):
self.eval()
start_prob, end_prob = self.forward(token_ids, token_type_ids)
return start_prob, end_prob
uie_model = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, model=UIE, with_pool=True)
\ No newline at end of file
import numpy as np
import math
import torch
from bert4torch.snippets import sequence_padding
from utils import get_bool_ids_greater_than, get_span, get_id_and_prob, cut_chinese_sent, dbc2sbc
from pprint import pprint
import torch.nn.functional as F
class UIEPredictor(object):
def __init__(self, schema, device='cpu', position_prob=0.5, max_seq_len=512, batch_size=64, split_sentence=False):
self._device = device
self._position_prob = position_prob
self._max_seq_len = max_seq_len
self._batch_size = 64
self._split_sentence = False
self._schema_tree = None
self.set_schema(schema)
from model import uie_model, tokenizer
self._tokenizer = tokenizer
self.model = uie_model.to(self._device)
def set_schema(self, schema):
if isinstance(schema, dict) or isinstance(schema, str):
schema = [schema]
self._schema_tree = self._build_tree(schema)
def __call__(self, inputs):
texts = inputs
texts = [texts] if isinstance(texts, str) else texts
results = self._multi_stage_predict(texts)
return results
def _multi_stage_predict(self, datas):
"""构建schema tree和预测
"""
results = [{} for _ in range(len(datas))]
# input check to early return
if len(datas) < 1 or self._schema_tree is None:
return results
# copy to stay `self._schema_tree` unchanged
schema_list = self._schema_tree.children[:]
while len(schema_list) > 0:
node = schema_list.pop(0)
examples = []
input_map = {}
cnt = 0
idx = 0
if not node.prefix:
for data in datas:
examples.append({"text": data, "prompt": dbc2sbc(node.name)})
input_map[cnt] = [idx]
idx += 1
cnt += 1
else:
for pre, data in zip(node.prefix, datas):
if len(pre) == 0:
input_map[cnt] = []
else:
for p in pre:
examples.append({ "text": data, "prompt": dbc2sbc(p + node.name)})
input_map[cnt] = [i + idx for i in range(len(pre))]
idx += len(pre)
cnt += 1
if len(examples) == 0:
result_list = []
else:
result_list = self._single_stage_predict(examples)
if not node.parent_relations:
relations = [[] for i in range(len(datas))]
for k, v in input_map.items():
for idx in v:
if len(result_list[idx]) == 0:
continue
if node.name not in results[k].keys():
results[k][node.name] = result_list[idx]
else:
results[k][node.name].extend(result_list[idx])
if node.name in results[k].keys():
relations[k].extend(results[k][node.name])
else:
relations = node.parent_relations
for k, v in input_map.items():
for i in range(len(v)):
if len(result_list[v[i]]) == 0:
continue
if "relations" not in relations[k][i].keys():
relations[k][i]["relations"] = {
node.name: result_list[v[i]]
}
elif node.name not in relations[k][i]["relations"].keys(
):
relations[k][i]["relations"][
node.name] = result_list[v[i]]
else:
relations[k][i]["relations"][node.name].extend(
result_list[v[i]])
new_relations = [[] for i in range(len(datas))]
for i in range(len(relations)):
for j in range(len(relations[i])):
if "relations" in relations[i][j].keys(
) and node.name in relations[i][j]["relations"].keys():
for k in range(
len(relations[i][j]["relations"][
node.name])):
new_relations[i].append(relations[i][j][
"relations"][node.name][k])
relations = new_relations
prefix = [[] for _ in range(len(datas))]
for k, v in input_map.items():
for idx in v:
for i in range(len(result_list[idx])):
prefix[k].append(result_list[idx][i]["text"] + "的")
for child in node.children:
child.prefix = prefix
child.parent_relations = relations
schema_list.append(child)
return results
def _convert_ids_to_results(self, examples, sentence_ids, probs):
"""
Convert ids to raw text in a single stage.
"""
results = []
for example, sentence_id, prob in zip(examples, sentence_ids, probs):
if len(sentence_id) == 0:
results.append([])
continue
result_list = []
text = example["text"]
prompt = example["prompt"]
for i in range(len(sentence_id)):
start, end = sentence_id[i]
if start < 0 and end >= 0:
continue
if end < 0:
start += (len(prompt) + 1)
end += (len(prompt) + 1)
result = {"text": prompt[start:end],
"probability": prob[i]}
result_list.append(result)
else:
result = {
"text": text[start:end],
"start": start,
"end": end,
"probability": prob[i]
}
result_list.append(result)
results.append(result_list)
return results
def _auto_splitter(self, input_texts, max_text_len, split_sentence=False):
'''
Split the raw texts automatically for model inference.
Args:
input_texts (List[str]): input raw texts.
max_text_len (int): cutting length.
split_sentence (bool): If True, sentence-level split will be performed.
return:
short_input_texts (List[str]): the short input texts for model inference.
input_mapping (dict): mapping between raw text and short input texts.
'''
input_mapping = {}
short_input_texts = []
cnt_org = 0
cnt_short = 0
for text in input_texts:
if not split_sentence:
sens = [text]
else:
sens = cut_chinese_sent(text)
for sen in sens:
lens = len(sen)
if lens <= max_text_len:
short_input_texts.append(sen)
if cnt_org not in input_mapping.keys():
input_mapping[cnt_org] = [cnt_short]
else:
input_mapping[cnt_org].append(cnt_short)
cnt_short += 1
else:
temp_text_list = [sen[i:i + max_text_len] for i in range(0, lens, max_text_len)]
short_input_texts.extend(temp_text_list)
short_idx = cnt_short
cnt_short += math.ceil(lens / max_text_len)
temp_text_id = [short_idx + i for i in range(cnt_short - short_idx)]
if cnt_org not in input_mapping.keys():
input_mapping[cnt_org] = temp_text_id
else:
input_mapping[cnt_org].extend(temp_text_id)
cnt_org += 1
return short_input_texts, input_mapping
def _single_stage_predict(self, inputs):
input_texts = []
prompts = []
for i in range(len(inputs)):
input_texts.append(inputs[i]["text"])
prompts.append(inputs[i]["prompt"])
# max predict length should exclude the length of prompt and summary tokens
max_predict_len = self._max_seq_len - len(max(prompts)) - 3
short_input_texts, self.input_mapping = self._auto_splitter(input_texts, max_predict_len, split_sentence=self._split_sentence)
short_texts_prompts = []
for k, v in self.input_mapping.items():
short_texts_prompts.extend([prompts[k] for i in range(len(v))])
short_inputs = [{"text": short_input_texts[i], "prompt": short_texts_prompts[i]} for i in range(len(short_input_texts))]
token_ids, segment_ids, offset_maps = self._tokenizer.encode(short_texts_prompts, short_input_texts, maxlen=self._max_seq_len, return_offsets='transformers')
start_prob_concat, end_prob_concat = [], []
for batch_start in range(0, len(short_input_texts), self._batch_size):
batch_token_ids = token_ids[batch_start:batch_start+self._batch_size]
batch_segment_ids = segment_ids[batch_start:batch_start+self._batch_size]
batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=self._device)
batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=self._device)
start_prob, end_prob = self.model.predict(batch_token_ids, batch_segment_ids)
start_prob_concat.append(start_prob.cpu().numpy())
end_prob_concat.append(end_prob.cpu().numpy())
start_prob_concat = np.concatenate(start_prob_concat)
end_prob_concat = np.concatenate(end_prob_concat)
start_ids_list = get_bool_ids_greater_than(start_prob_concat, limit=self._position_prob, return_prob=True)
end_ids_list = get_bool_ids_greater_than(end_prob_concat, limit=self._position_prob, return_prob=True)
sentence_ids = []
probs = []
for start_ids, end_ids, ids, offset_map in zip(start_ids_list, end_ids_list, token_ids, offset_maps):
for i in reversed(range(len(ids))):
if ids[i] != 0:
ids = ids[:i]
break
span_list = get_span(start_ids, end_ids, with_prob=True)
sentence_id, prob = get_id_and_prob(span_list, offset_map)
sentence_ids.append(sentence_id)
probs.append(prob)
results = self._convert_ids_to_results(short_inputs, sentence_ids, probs)
results = self._auto_joiner(results, short_input_texts, self.input_mapping)
return results
def _auto_joiner(self, short_results, short_inputs, input_mapping):
concat_results = []
is_cls_task = False
for short_result in short_results:
if short_result == []:
continue
elif 'start' not in short_result[0].keys(
) and 'end' not in short_result[0].keys():
is_cls_task = True
break
else:
break
for k, vs in input_mapping.items():
if is_cls_task:
cls_options = {}
single_results = []
for v in vs:
if len(short_results[v]) == 0:
continue
if short_results[v][0]['text'] not in cls_options.keys():
cls_options[short_results[v][0][
'text']] = [1, short_results[v][0]['probability']]
else:
cls_options[short_results[v][0]['text']][0] += 1
cls_options[short_results[v][0]['text']][
1] += short_results[v][0]['probability']
if len(cls_options) != 0:
cls_res, cls_info = max(cls_options.items(),
key=lambda x: x[1])
concat_results.append([{
'text': cls_res,
'probability': cls_info[1] / cls_info[0]
}])
else:
concat_results.append([])
else:
offset = 0
single_results = []
for v in vs:
if v == 0:
single_results = short_results[v]
offset += len(short_inputs[v])
else:
for i in range(len(short_results[v])):
if 'start' not in short_results[v][
i] or 'end' not in short_results[v][i]:
continue
short_results[v][i]['start'] += offset
short_results[v][i]['end'] += offset
offset += len(short_inputs[v])
single_results.extend(short_results[v])
concat_results.append(single_results)
return concat_results
def predict(self, input_data):
results = self._multi_stage_predict(input_data)
return results
@classmethod
def _build_tree(cls, schema, name='root'):
"""
Build the schema tree.
"""
schema_tree = SchemaTree(name)
for s in schema:
if isinstance(s, str):
schema_tree.add_child(SchemaTree(s))
elif isinstance(s, dict):
for k, v in s.items():
if isinstance(v, str):
child = [v]
elif isinstance(v, list):
child = v
else:
raise TypeError("Invalid schema, value for each key:value pairs should be list or string but {} received".format(type(v)))
schema_tree.add_child(cls._build_tree(child, name=k))
else:
raise TypeError("Invalid schema, element should be string or dict, but {} received".format(type(s)))
return schema_tree
class SchemaTree(object):
"""SchemaTree的实现
"""
def __init__(self, name='root', children=None):
self.name = name
self.children = []
self.prefix = None
self.parent_relations = None
if children is not None:
for child in children:
self.add_child(child)
def __repr__(self):
return self.name
def add_child(self, node):
assert isinstance(node, SchemaTree), "The children of a node should be an instacne of SchemaTree."
self.children.append(node)
if __name__ == '__main__':
# 命名实体识别
schema = ['时间', '选手', '赛事名称'] # Define the schema for entity extraction
ie = UIEPredictor(schema=schema)
pprint(ie("2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!"))
schema = ['肿瘤的大小', '肿瘤的个数', '肝癌级别', '脉管内癌栓分级']
ie.set_schema(schema)
pprint(ie("(右肝肿瘤)肝细胞性肝癌(II-III级,梁索型和假腺管型),肿瘤包膜不完整,紧邻肝被膜,侵及周围肝组织,未见脉管内癌栓(MVI分级:M0级)及卫星子灶形成。(肿物1个,大小4.2×4.0×2.8cm)。"))
# 关系抽取
schema = {'竞赛名称': ['主办方', '承办方', '已举办次数']}
ie.set_schema(schema) # Reset schema
pprint(ie('2022语言与智能技术竞赛由中国中文信息学会和中国计算机学会联合主办,百度公司、中国中文信息学会评测工作委员会和中国计算机学会自然语言处理专委会承办,已连续举办4届,成为全球最热门的中文NLP赛事之一。'))
# 事件抽取
schema = {'地震触发词': ['地震强度', '时间', '震中位置', '震源深度']}
ie.set_schema(schema) # Reset schema
ie('中国地震台网正式测定:5月16日06时08分在云南临沧市凤庆县(北纬24.34度,东经99.98度)发生3.5级地震,震源深度10千米。')
# 评论观点抽取
schema = {'评价维度': ['观点词', '情感倾向[正向,负向]']}
ie.set_schema(schema) # Reset schema
pprint(ie("店面干净,很清静,服务员服务热情,性价比很高,发现收银台有排队"))
# 情感倾向分类
schema = '情感倾向[正向,负向]'
ie.set_schema(schema)
ie('这个产品用起来真的很流畅,我非常喜欢')
\ No newline at end of file
import contextlib
import functools
import json
import logging
import math
import random
import re
import shutil
import threading
import time
from functools import partial
import colorlog
import numpy as np
import torch
from colorama import Back, Fore
from tqdm import tqdm
loggers = {}
log_config = {
'DEBUG': {'level': 10, 'color': 'purple'},
'INFO': {'level': 20, 'color': 'green'},
'TRAIN': {'level': 21, 'color': 'cyan'},
'EVAL': {'level': 22, 'color': 'blue'},
'WARNING': {'level': 30, 'color': 'yellow'},
'ERROR': {'level': 40, 'color': 'red'},
'CRITICAL': {'level': 50, 'color': 'bold_red'}
}
def get_span(start_ids, end_ids, with_prob=False):
"""
Get span set from position start and end list.
Args:
start_ids (List[int]/List[tuple]): The start index list.
end_ids (List[int]/List[tuple]): The end index list.
with_prob (bool): If True, each element for start_ids and end_ids is a tuple aslike: (index, probability).
Returns:
set: The span set without overlapping, every id can only be used once .
"""
if with_prob:
start_ids = sorted(start_ids, key=lambda x: x[0])
end_ids = sorted(end_ids, key=lambda x: x[0])
else:
start_ids = sorted(start_ids)
end_ids = sorted(end_ids)
start_pointer = 0
end_pointer = 0
len_start = len(start_ids)
len_end = len(end_ids)
couple_dict = {}
while start_pointer < len_start and end_pointer < len_end:
if with_prob:
start_id = start_ids[start_pointer][0]
end_id = end_ids[end_pointer][0]
else:
start_id = start_ids[start_pointer]
end_id = end_ids[end_pointer]
if start_id == end_id:
couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
start_pointer += 1
end_pointer += 1
continue
if start_id < end_id:
couple_dict[end_ids[end_pointer]] = start_ids[start_pointer]
start_pointer += 1
continue
if start_id > end_id:
end_pointer += 1
continue
result = [(couple_dict[end], end) for end in couple_dict]
result = set(result)
return result
def get_bool_ids_greater_than(probs, limit=0.5, return_prob=False):
"""
Get idx of the last dimension in probability arrays, which is greater than a limitation.
Args:
probs (List[List[float]]): The input probability arrays.
limit (float): The limitation for probability.
return_prob (bool): Whether to return the probability
Returns:
List[List[int]]: The index of the last dimension meet the conditions.
"""
probs = np.array(probs)
dim_len = len(probs.shape)
if dim_len > 1:
result = []
for p in probs:
result.append(get_bool_ids_greater_than(p, limit, return_prob))
return result
else:
result = []
for i, p in enumerate(probs):
if p > limit:
if return_prob:
result.append((i, p))
else:
result.append(i)
return result
class Logger(object):
'''
Deafult logger in UIE
Args:
name(str) : Logger name, default is 'UIE'
'''
def __init__(self, name: str = None):
name = 'UIE' if not name else name
self.logger = logging.getLogger(name)
for key, conf in log_config.items():
logging.addLevelName(conf['level'], key)
self.__dict__[key] = functools.partial(
self.__call__, conf['level'])
self.__dict__[key.lower()] = functools.partial(
self.__call__, conf['level'])
self.format = colorlog.ColoredFormatter(
'%(log_color)s[%(asctime)-15s] [%(levelname)8s]%(reset)s - %(message)s',
log_colors={key: conf['color']
for key, conf in log_config.items()})
self.handler = logging.StreamHandler()
self.handler.setFormatter(self.format)
self.logger.addHandler(self.handler)
self.logLevel = 'DEBUG'
self.logger.setLevel(logging.DEBUG)
self.logger.propagate = False
self._is_enable = True
def disable(self):
self._is_enable = False
def enable(self):
self._is_enable = True
@property
def is_enable(self) -> bool:
return self._is_enable
def __call__(self, log_level: str, msg: str):
if not self.is_enable:
return
self.logger.log(log_level, msg)
@contextlib.contextmanager
def use_terminator(self, terminator: str):
old_terminator = self.handler.terminator
self.handler.terminator = terminator
yield
self.handler.terminator = old_terminator
@contextlib.contextmanager
def processing(self, msg: str, interval: float = 0.1):
'''
Continuously print a progress bar with rotating special effects.
Args:
msg(str): Message to be printed.
interval(float): Rotation interval. Default to 0.1.
'''
end = False
def _printer():
index = 0
flags = ['\\', '|', '/', '-']
while not end:
flag = flags[index % len(flags)]
with self.use_terminator('\r'):
self.info('{}: {}'.format(msg, flag))
time.sleep(interval)
index += 1
t = threading.Thread(target=_printer)
t.start()
yield
end = True
logger = Logger()
BAR_FORMAT = f'{{desc}}: {Fore.GREEN}{{percentage:3.0f}}%{Fore.RESET} {Fore.BLUE}{{bar}}{Fore.RESET} {Fore.GREEN}{{n_fmt}}/{{total_fmt}} {Fore.RED}{{rate_fmt}}{{postfix}}{Fore.RESET} eta {Fore.CYAN}{{remaining}}{Fore.RESET}'
BAR_FORMAT_NO_TIME = f'{{desc}}: {Fore.GREEN}{{percentage:3.0f}}%{Fore.RESET} {Fore.BLUE}{{bar}}{Fore.RESET} {Fore.GREEN}{{n_fmt}}/{{total_fmt}}{Fore.RESET}'
BAR_TYPE = [
"░▝▗▖▘▚▞▛▙█",
"░▖▘▝▗▚▞█",
" ▖▘▝▗▚▞█",
"░▒█",
" >=",
" ▏▎▍▌▋▊▉█"
"░▏▎▍▌▋▊▉█"
]
tqdm = partial(tqdm, bar_format=BAR_FORMAT, ascii=BAR_TYPE[0], leave=False)
def get_id_and_prob(spans, offset_map):
prompt_length = 0
for i in range(1, len(offset_map)):
if offset_map[i] != [0, 0]:
prompt_length += 1
else:
break
for i in range(1, prompt_length + 1):
offset_map[i][0] -= (prompt_length + 1)
offset_map[i][1] -= (prompt_length + 1)
sentence_id = []
prob = []
for start, end in spans:
prob.append(start[1] * end[1])
sentence_id.append(
(offset_map[start[0]][0], offset_map[end[0]][1]))
return sentence_id, prob
def cut_chinese_sent(para):
"""
Cut the Chinese sentences more precisely, reference to
"https://blog.csdn.net/blmoistawinde/article/details/82379256".
"""
para = re.sub(r'([。!?\?])([^”’])', r'\1\n\2', para)
para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para)
para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para)
para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para)
para = para.rstrip()
return para.split("\n")
def dbc2sbc(s):
rs = ""
for char in s:
code = ord(char)
if code == 0x3000:
code = 0x0020
else:
code -= 0xfee0
if not (0x0021 <= code and code <= 0x7e):
rs += char
continue
rs += chr(code)
return rs
def convert_cls_examples(raw_examples, prompt_prefix, options):
examples = []
logger.info(f"Converting doccano data...")
with tqdm(total=len(raw_examples)) as pbar:
for line in raw_examples:
items = json.loads(line)
# Compatible with doccano >= 1.6.2
if "data" in items.keys():
text, labels = items["data"], items["label"]
else:
text, labels = items["text"], items["label"]
random.shuffle(options)
prompt = ""
sep = ","
for option in options:
prompt += option
prompt += sep
prompt = prompt_prefix + "[" + prompt.rstrip(sep) + "]"
result_list = []
example = {
"content": text,
"result_list": result_list,
"prompt": prompt
}
for label in labels:
start = prompt.rfind(label[0]) - len(prompt) - 1
end = start + len(label)
result = {"text": label, "start": start, "end": end}
example["result_list"].append(result)
examples.append(example)
return examples
def add_negative_example(examples, texts, prompts, label_set, negative_ratio):
negative_examples = []
positive_examples = []
with tqdm(total=len(prompts)) as pbar:
for i, prompt in enumerate(prompts):
negative_sample = []
redundants_list = list(set(label_set) ^ set(prompt))
redundants_list.sort()
num_positive = len(examples[i])
if num_positive != 0:
actual_ratio = math.ceil(len(redundants_list) / num_positive)
else:
# Set num_positive to 1 for text without positive example
num_positive, actual_ratio = 1, 0
if actual_ratio <= negative_ratio or negative_ratio == -1:
idxs = [k for k in range(len(redundants_list))]
else:
idxs = random.sample(
range(0, len(redundants_list)),
negative_ratio * num_positive)
for idx in idxs:
negative_result = {
"content": texts[i],
"result_list": [],
"prompt": redundants_list[idx]
}
negative_examples.append(negative_result)
positive_examples.extend(examples[i])
pbar.update(1)
return positive_examples, negative_examples
def add_full_negative_example(examples, texts, relation_prompts, predicate_set,
subject_goldens):
with tqdm(total=len(relation_prompts)) as pbar:
for i, relation_prompt in enumerate(relation_prompts):
negative_sample = []
for subject in subject_goldens[i]:
for predicate in predicate_set:
# The relation prompt is constructed as follows:
# subject + "的" + predicate
prompt = subject + "的" + predicate
if prompt not in relation_prompt:
negative_result = {
"content": texts[i],
"result_list": [],
"prompt": prompt
}
negative_sample.append(negative_result)
examples[i].extend(negative_sample)
pbar.update(1)
return examples
def construct_relation_prompt_set(entity_name_set, predicate_set):
relation_prompt_set = set()
for entity_name in entity_name_set:
for predicate in predicate_set:
# The relation prompt is constructed as follows:
# subject + "的" + predicate
relation_prompt = entity_name + "的" + predicate
relation_prompt_set.add(relation_prompt)
return sorted(list(relation_prompt_set))
def convert_ext_examples(raw_examples, negative_ratio, is_train=True):
texts = []
entity_examples = []
relation_examples = []
entity_prompts = []
relation_prompts = []
entity_label_set = []
entity_name_set = []
predicate_set = []
subject_goldens = []
logger.info(f"Converting doccano data...")
with tqdm(total=len(raw_examples)) as pbar:
for line in raw_examples:
items = json.loads(line)
entity_id = 0
if "data" in items.keys():
relation_mode = False
if isinstance(items["label"],
dict) and "entities" in items["label"].keys():
relation_mode = True
text = items["data"]
entities = []
if not relation_mode:
# Export file in JSONL format which doccano < 1.7.0
for item in items["label"]:
entity = {
"id": entity_id,
"start_offset": item[0],
"end_offset": item[1],
"label": item[2]
}
entities.append(entity)
entity_id += 1
else:
# Export file in JSONL format for relation labeling task which doccano < 1.7.0
for item in items["label"]["entities"]:
entity = {
"id": entity_id,
"start_offset": item["start_offset"],
"end_offset": item["end_offset"],
"label": item["label"]
}
entities.append(entity)
entity_id += 1
relations = []
else:
# Export file in JSONL format which doccano >= 1.7.0
if "label" in items.keys():
text = items["text"]
entities = []
for item in items["label"]:
entity = {
"id": entity_id,
"start_offset": item[0],
"end_offset": item[1],
"label": item[2]
}
entities.append(entity)
entity_id += 1
relations = []
else:
# Export file in JSONL (relation) format
text, relations, entities = items["text"], items[
"relations"], items["entities"]
texts.append(text)
entity_example = []
entity_prompt = []
entity_example_map = {}
entity_map = {} # id to entity name
for entity in entities:
entity_name = text[entity["start_offset"]:entity["end_offset"]]
entity_map[entity["id"]] = {
"name": entity_name,
"start": entity["start_offset"],
"end": entity["end_offset"]
}
entity_label = entity["label"]
result = {
"text": entity_name,
"start": entity["start_offset"],
"end": entity["end_offset"]
}
if entity_label not in entity_example_map.keys():
entity_example_map[entity_label] = {
"content": text,
"result_list": [result],
"prompt": entity_label
}
else:
entity_example_map[entity_label]["result_list"].append(
result)
if entity_label not in entity_label_set:
entity_label_set.append(entity_label)
if entity_name not in entity_name_set:
entity_name_set.append(entity_name)
entity_prompt.append(entity_label)
for v in entity_example_map.values():
entity_example.append(v)
entity_examples.append(entity_example)
entity_prompts.append(entity_prompt)
subject_golden = []
relation_example = []
relation_prompt = []
relation_example_map = {}
for relation in relations:
predicate = relation["type"]
subject_id = relation["from_id"]
object_id = relation["to_id"]
# The relation prompt is constructed as follows:
# subject + "的" + predicate
prompt = entity_map[subject_id]["name"] + "的" + predicate
if entity_map[subject_id]["name"] not in subject_golden:
subject_golden.append(entity_map[subject_id]["name"])
result = {
"text": entity_map[object_id]["name"],
"start": entity_map[object_id]["start"],
"end": entity_map[object_id]["end"]
}
if prompt not in relation_example_map.keys():
relation_example_map[prompt] = {
"content": text,
"result_list": [result],
"prompt": prompt
}
else:
relation_example_map[prompt]["result_list"].append(result)
if predicate not in predicate_set:
predicate_set.append(predicate)
relation_prompt.append(prompt)
for v in relation_example_map.values():
relation_example.append(v)
relation_examples.append(relation_example)
relation_prompts.append(relation_prompt)
subject_goldens.append(subject_golden)
pbar.update(1)
def concat_examples(positive_examples, negative_examples, negative_ratio):
examples = []
if math.ceil(len(negative_examples) /
len(positive_examples)) <= negative_ratio:
examples = positive_examples + negative_examples
else:
# Random sampling the negative examples to ensure overall negative ratio unchanged.
idxs = random.sample(
range(0, len(negative_examples)),
negative_ratio * len(positive_examples))
negative_examples_sampled = []
for idx in idxs:
negative_examples_sampled.append(negative_examples[idx])
examples = positive_examples + negative_examples_sampled
return examples
logger.info(f"Adding negative samples for first stage prompt...")
positive_examples, negative_examples = add_negative_example(
entity_examples, texts, entity_prompts, entity_label_set,
negative_ratio)
if len(positive_examples) == 0:
all_entity_examples = []
elif is_train:
all_entity_examples = concat_examples(positive_examples,
negative_examples, negative_ratio)
else:
all_entity_examples = positive_examples + negative_examples
all_relation_examples = []
if len(predicate_set) != 0:
if is_train:
logger.info(f"Adding negative samples for second stage prompt...")
relation_prompt_set = construct_relation_prompt_set(entity_name_set,
predicate_set)
positive_examples, negative_examples = add_negative_example(
relation_examples, texts, relation_prompts, relation_prompt_set,
negative_ratio)
all_relation_examples = concat_examples(
positive_examples, negative_examples, negative_ratio)
else:
logger.info(f"Adding negative samples for second stage prompt...")
relation_examples = add_full_negative_example(
relation_examples, texts, relation_prompts, predicate_set,
subject_goldens)
all_relation_examples = [
r
for r in relation_example
for relation_example in relation_examples
]
return all_entity_examples, all_relation_examples
def get_path_from_url(url,
root_dir,
check_exist=True,
decompress=True):
""" Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
from url and decompress it, return the path.
Args:
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
decompress (bool): decompress zip or tar file. Default is `True`
Returns:
str: a local path to save downloaded models & weights & datasets.
"""
import os.path
import os
import tarfile
import zipfile
def is_url(path):
"""
Whether path is URL.
Args:
path (string): URL string or not.
"""
return path.startswith('http://') or path.startswith('https://')
def _map_path(url, root_dir):
# parse path after download under root_dir
fname = os.path.split(url)[-1]
fpath = fname
return os.path.join(root_dir, fpath)
def _get_download(url, fullname):
import requests
# using requests.get method
fname = os.path.basename(fullname)
try:
req = requests.get(url, stream=True)
except Exception as e: # requests.exceptions.ConnectionError
logger.info("Downloading {} from {} failed with exception {}".format(
fname, url, str(e)))
return False
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size:
with tqdm(total=(int(total_size) + 1023) // 1024, unit='KB') as pbar:
for chunk in req.iter_content(chunk_size=1024):
f.write(chunk)
pbar.update(1)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
def _download(url, path):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not os.path.exists(path):
os.makedirs(path)
fname = os.path.split(url)[-1]
fullname = os.path.join(path, fname)
retry_cnt = 0
logger.info("Downloading {} from {}".format(fname, url))
DOWNLOAD_RETRY_LIMIT = 3
while not os.path.exists(fullname):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
if not _get_download(url, fullname):
time.sleep(1)
continue
return fullname
def _uncompress_file_zip(filepath):
with zipfile.ZipFile(filepath, 'r') as files:
file_list = files.namelist()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
elif _is_a_single_dir(file_list):
# `strip(os.sep)` to remove `os.sep` in the tail of path
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
files.extractall(os.path.join(file_dir, rootpath))
return uncompressed_path
def _is_a_single_file(file_list):
if len(file_list) == 1 and file_list[0].find(os.sep) < 0:
return True
return False
def _is_a_single_dir(file_list):
new_file_list = []
for file_path in file_list:
if '/' in file_path:
file_path = file_path.replace('/', os.sep)
elif '\\' in file_path:
file_path = file_path.replace('\\', os.sep)
new_file_list.append(file_path)
file_name = new_file_list[0].split(os.sep)[0]
for i in range(1, len(new_file_list)):
if file_name != new_file_list[i].split(os.sep)[0]:
return False
return True
def _uncompress_file_tar(filepath, mode="r:*"):
with tarfile.open(filepath, mode) as files:
file_list = files.getnames()
file_dir = os.path.dirname(filepath)
if _is_a_single_file(file_list):
rootpath = file_list[0]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
elif _is_a_single_dir(file_list):
rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
files.extractall(file_dir)
else:
rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
uncompressed_path = os.path.join(file_dir, rootpath)
if not os.path.exists(uncompressed_path):
os.makedirs(uncompressed_path)
files.extractall(os.path.join(file_dir, rootpath))
return uncompressed_path
def _decompress(fname):
"""
Decompress for zip and tar file
"""
logger.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
if tarfile.is_tarfile(fname):
uncompressed_path = _uncompress_file_tar(fname)
elif zipfile.is_zipfile(fname):
uncompressed_path = _uncompress_file_zip(fname)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
return uncompressed_path
assert is_url(url), "downloading from {} not a url".format(url)
fullpath = _map_path(url, root_dir)
if os.path.exists(fullpath) and check_exist:
logger.info("Found {}".format(fullpath))
else:
fullpath = _download(url, root_dir)
if decompress and (tarfile.is_tarfile(fullpath) or
zipfile.is_zipfile(fullpath)):
fullpath = _decompress(fullpath)
return fullpath
#! -*- coding: utf-8 -*-
# 利用自带的接口,将SimBERT的同义句生成搭建成Web服务。
# 基于bottlepy简单封装,仅作为临时测试使用,不保证性能。
# 具体用法请看 https://github.com/bojone/bert4keras/blob/8ffb46a16a79f87aa8cdf045df7994036b4be47d/bert4keras/snippets.py#L580
import torch
from bert4torch.models import build_transformer_model, BaseModel
from bert4torch.snippets import sequence_padding, AutoRegressiveDecoder, get_pool_emb
from bert4torch.tokenizers import Tokenizer, load_vocab
from bert4torch.snippets import WebServing
# 基本信息
maxlen = 32
choice = 'simbert' # simbert simbert_v2
if choice == 'simbert':
args_model_path = "F:/Projects/pretrain_ckpt/simbert/[sushen_torch_base]--simbert_chinese_base"
args_model = 'bert'
else:
args_model_path = "F:/Projects/pretrain_ckpt/simbert/[sushen_torch_base]--roformer_chinese_sim_char_base"
args_model = 'roformer'
# 加载simbert权重或simbert_v2
root_model_path = args_model_path
dict_path = root_model_path + "/vocab.txt"
config_path = root_model_path + "/config.json"
checkpoint_path = root_model_path + '/pytorch_model.bin'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 加载并精简词表,建立分词器
token_dict, keep_tokens = load_vocab(
dict_path=dict_path,
simplified=True,
startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
)
tokenizer = Tokenizer(token_dict, do_lower_case=True)
# 建立加载模型
class Model(BaseModel):
def __init__(self, pool_method='cls'):
super().__init__()
self.bert = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, with_pool='linear', model=args_model,
application='unilm', keep_tokens=keep_tokens)
self.pool_method = pool_method
def forward(self, token_ids, segment_ids):
hidden_state, pool_cls, seq_logit = self.bert([token_ids, segment_ids])
sen_emb = get_pool_emb(hidden_state, pool_cls, token_ids.gt(0).long(), self.pool_method)
return seq_logit, sen_emb
model = Model(pool_method='cls').to(device)
class SynonymsGenerator(AutoRegressiveDecoder):
"""seq2seq解码器
"""
@AutoRegressiveDecoder.wraps('logits')
def predict(self, inputs, output_ids, states):
token_ids, segment_ids = inputs
token_ids = torch.cat([token_ids, output_ids], 1)
segment_ids = torch.cat([segment_ids, torch.ones_like(output_ids, device=device)], 1)
seq_logit, _ = model.predict([token_ids, segment_ids])
return seq_logit[:, -1, :]
def generate(self, text, n=1, topk=5):
token_ids, segment_ids = tokenizer.encode(text, maxlen=maxlen)
output_ids = self.random_sample([token_ids, segment_ids], n, topk) # 基于随机采样
return [tokenizer.decode(ids.cpu().numpy()) for ids in output_ids]
synonyms_generator = SynonymsGenerator(start_id=None, end_id=tokenizer._token_end_id, maxlen=maxlen, device=device)
def cal_sen_emb(text_list):
'''输入text的list,计算sentence的embedding
'''
X, S = [], []
for t in text_list:
x, s = tokenizer.encode(t)
X.append(x)
S.append(s)
X = torch.tensor(sequence_padding(X), dtype=torch.long, device=device)
S = torch.tensor(sequence_padding(S), dtype=torch.long, device=device)
_, Z = model.predict([X, S])
return Z
def gen_synonyms(text, n=100, k=20):
""""含义: 产生sent的n个相似句,然后返回最相似的k个。
做法:用seq2seq生成,并用encoder算相似度并排序。
"""
r = synonyms_generator.generate(text, n)
r = [i for i in set(r) if i != text] # 不和原文相同
r = [text] + r
Z = cal_sen_emb(r)
Z /= (Z**2).sum(dim=1, keepdims=True)**0.5
argsort = torch.matmul(Z[1:], -Z[0]).argsort()
return [r[i + 1] for i in argsort[:k]]
if __name__ == '__main__':
arguments = {'text': (None, True), 'n': (int, False), 'k': (int, False)}
web = WebServing(port=8864)
web.route('/gen_synonyms', gen_synonyms, arguments)
web.start()
# 现在可以测试访问 http://127.0.0.1:8864/gen_synonyms?text=苹果多少钱一斤
# 调用代码
import requests
import json
def send_msg(requestData):
url = 'http://localhost:8082/recommendinfo'
headers = {'content-type': 'application/json'}
ret = requests.post(url, json=requestData, headers=headers, stream=True)
if ret.status_code==200:
text = json.loads(ret.text)
return text
send_msg({'input': ['我的心情很好', '我很生气']})
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment