import numpy as np import math import torch from bert4torch.snippets import sequence_padding from utils import get_bool_ids_greater_than, get_span, get_id_and_prob, cut_chinese_sent, dbc2sbc from pprint import pprint import torch.nn.functional as F class UIEPredictor(object): def __init__(self, schema, device='cpu', position_prob=0.5, max_seq_len=512, batch_size=64, split_sentence=False): self._device = device self._position_prob = position_prob self._max_seq_len = max_seq_len self._batch_size = 64 self._split_sentence = False self._schema_tree = None self.set_schema(schema) from model import uie_model, tokenizer self._tokenizer = tokenizer self.model = uie_model.to(self._device) def set_schema(self, schema): if isinstance(schema, dict) or isinstance(schema, str): schema = [schema] self._schema_tree = self._build_tree(schema) def __call__(self, inputs): texts = inputs texts = [texts] if isinstance(texts, str) else texts results = self._multi_stage_predict(texts) return results def _multi_stage_predict(self, datas): """构建schema tree和预测 """ results = [{} for _ in range(len(datas))] # input check to early return if len(datas) < 1 or self._schema_tree is None: return results # copy to stay `self._schema_tree` unchanged schema_list = self._schema_tree.children[:] while len(schema_list) > 0: node = schema_list.pop(0) examples = [] input_map = {} cnt = 0 idx = 0 if not node.prefix: for data in datas: examples.append({"text": data, "prompt": dbc2sbc(node.name)}) input_map[cnt] = [idx] idx += 1 cnt += 1 else: for pre, data in zip(node.prefix, datas): if len(pre) == 0: input_map[cnt] = [] else: for p in pre: examples.append({ "text": data, "prompt": dbc2sbc(p + node.name)}) input_map[cnt] = [i + idx for i in range(len(pre))] idx += len(pre) cnt += 1 if len(examples) == 0: result_list = [] else: result_list = self._single_stage_predict(examples) if not node.parent_relations: relations = [[] for i in range(len(datas))] for k, v in input_map.items(): for idx in v: if len(result_list[idx]) == 0: continue if node.name not in results[k].keys(): results[k][node.name] = result_list[idx] else: results[k][node.name].extend(result_list[idx]) if node.name in results[k].keys(): relations[k].extend(results[k][node.name]) else: relations = node.parent_relations for k, v in input_map.items(): for i in range(len(v)): if len(result_list[v[i]]) == 0: continue if "relations" not in relations[k][i].keys(): relations[k][i]["relations"] = { node.name: result_list[v[i]] } elif node.name not in relations[k][i]["relations"].keys( ): relations[k][i]["relations"][ node.name] = result_list[v[i]] else: relations[k][i]["relations"][node.name].extend( result_list[v[i]]) new_relations = [[] for i in range(len(datas))] for i in range(len(relations)): for j in range(len(relations[i])): if "relations" in relations[i][j].keys( ) and node.name in relations[i][j]["relations"].keys(): for k in range( len(relations[i][j]["relations"][ node.name])): new_relations[i].append(relations[i][j][ "relations"][node.name][k]) relations = new_relations prefix = [[] for _ in range(len(datas))] for k, v in input_map.items(): for idx in v: for i in range(len(result_list[idx])): prefix[k].append(result_list[idx][i]["text"] + "的") for child in node.children: child.prefix = prefix child.parent_relations = relations schema_list.append(child) return results def _convert_ids_to_results(self, examples, sentence_ids, probs): """ Convert ids to raw text in a single stage. """ results = [] for example, sentence_id, prob in zip(examples, sentence_ids, probs): if len(sentence_id) == 0: results.append([]) continue result_list = [] text = example["text"] prompt = example["prompt"] for i in range(len(sentence_id)): start, end = sentence_id[i] if start < 0 and end >= 0: continue if end < 0: start += (len(prompt) + 1) end += (len(prompt) + 1) result = {"text": prompt[start:end], "probability": prob[i]} result_list.append(result) else: result = { "text": text[start:end], "start": start, "end": end, "probability": prob[i] } result_list.append(result) results.append(result_list) return results def _auto_splitter(self, input_texts, max_text_len, split_sentence=False): ''' Split the raw texts automatically for model inference. Args: input_texts (List[str]): input raw texts. max_text_len (int): cutting length. split_sentence (bool): If True, sentence-level split will be performed. return: short_input_texts (List[str]): the short input texts for model inference. input_mapping (dict): mapping between raw text and short input texts. ''' input_mapping = {} short_input_texts = [] cnt_org = 0 cnt_short = 0 for text in input_texts: if not split_sentence: sens = [text] else: sens = cut_chinese_sent(text) for sen in sens: lens = len(sen) if lens <= max_text_len: short_input_texts.append(sen) if cnt_org not in input_mapping.keys(): input_mapping[cnt_org] = [cnt_short] else: input_mapping[cnt_org].append(cnt_short) cnt_short += 1 else: temp_text_list = [sen[i:i + max_text_len] for i in range(0, lens, max_text_len)] short_input_texts.extend(temp_text_list) short_idx = cnt_short cnt_short += math.ceil(lens / max_text_len) temp_text_id = [short_idx + i for i in range(cnt_short - short_idx)] if cnt_org not in input_mapping.keys(): input_mapping[cnt_org] = temp_text_id else: input_mapping[cnt_org].extend(temp_text_id) cnt_org += 1 return short_input_texts, input_mapping def _single_stage_predict(self, inputs): input_texts = [] prompts = [] for i in range(len(inputs)): input_texts.append(inputs[i]["text"]) prompts.append(inputs[i]["prompt"]) # max predict length should exclude the length of prompt and summary tokens max_predict_len = self._max_seq_len - len(max(prompts)) - 3 short_input_texts, self.input_mapping = self._auto_splitter(input_texts, max_predict_len, split_sentence=self._split_sentence) short_texts_prompts = [] for k, v in self.input_mapping.items(): short_texts_prompts.extend([prompts[k] for i in range(len(v))]) short_inputs = [{"text": short_input_texts[i], "prompt": short_texts_prompts[i]} for i in range(len(short_input_texts))] token_ids, segment_ids, offset_maps = self._tokenizer.encode(short_texts_prompts, short_input_texts, maxlen=self._max_seq_len, return_offsets='transformers') start_prob_concat, end_prob_concat = [], [] for batch_start in range(0, len(short_input_texts), self._batch_size): batch_token_ids = token_ids[batch_start:batch_start+self._batch_size] batch_segment_ids = segment_ids[batch_start:batch_start+self._batch_size] batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=self._device) batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=self._device) start_prob, end_prob = self.model.predict(batch_token_ids, batch_segment_ids) start_prob_concat.append(start_prob.cpu().numpy()) end_prob_concat.append(end_prob.cpu().numpy()) start_prob_concat = np.concatenate(start_prob_concat) end_prob_concat = np.concatenate(end_prob_concat) start_ids_list = get_bool_ids_greater_than(start_prob_concat, limit=self._position_prob, return_prob=True) end_ids_list = get_bool_ids_greater_than(end_prob_concat, limit=self._position_prob, return_prob=True) sentence_ids = [] probs = [] for start_ids, end_ids, ids, offset_map in zip(start_ids_list, end_ids_list, token_ids, offset_maps): for i in reversed(range(len(ids))): if ids[i] != 0: ids = ids[:i] break span_list = get_span(start_ids, end_ids, with_prob=True) sentence_id, prob = get_id_and_prob(span_list, offset_map) sentence_ids.append(sentence_id) probs.append(prob) results = self._convert_ids_to_results(short_inputs, sentence_ids, probs) results = self._auto_joiner(results, short_input_texts, self.input_mapping) return results def _auto_joiner(self, short_results, short_inputs, input_mapping): concat_results = [] is_cls_task = False for short_result in short_results: if short_result == []: continue elif 'start' not in short_result[0].keys( ) and 'end' not in short_result[0].keys(): is_cls_task = True break else: break for k, vs in input_mapping.items(): if is_cls_task: cls_options = {} single_results = [] for v in vs: if len(short_results[v]) == 0: continue if short_results[v][0]['text'] not in cls_options.keys(): cls_options[short_results[v][0][ 'text']] = [1, short_results[v][0]['probability']] else: cls_options[short_results[v][0]['text']][0] += 1 cls_options[short_results[v][0]['text']][ 1] += short_results[v][0]['probability'] if len(cls_options) != 0: cls_res, cls_info = max(cls_options.items(), key=lambda x: x[1]) concat_results.append([{ 'text': cls_res, 'probability': cls_info[1] / cls_info[0] }]) else: concat_results.append([]) else: offset = 0 single_results = [] for v in vs: if v == 0: single_results = short_results[v] offset += len(short_inputs[v]) else: for i in range(len(short_results[v])): if 'start' not in short_results[v][ i] or 'end' not in short_results[v][i]: continue short_results[v][i]['start'] += offset short_results[v][i]['end'] += offset offset += len(short_inputs[v]) single_results.extend(short_results[v]) concat_results.append(single_results) return concat_results def predict(self, input_data): results = self._multi_stage_predict(input_data) return results @classmethod def _build_tree(cls, schema, name='root'): """ Build the schema tree. """ schema_tree = SchemaTree(name) for s in schema: if isinstance(s, str): schema_tree.add_child(SchemaTree(s)) elif isinstance(s, dict): for k, v in s.items(): if isinstance(v, str): child = [v] elif isinstance(v, list): child = v else: raise TypeError("Invalid schema, value for each key:value pairs should be list or string but {} received".format(type(v))) schema_tree.add_child(cls._build_tree(child, name=k)) else: raise TypeError("Invalid schema, element should be string or dict, but {} received".format(type(s))) return schema_tree class SchemaTree(object): """SchemaTree的实现 """ def __init__(self, name='root', children=None): self.name = name self.children = [] self.prefix = None self.parent_relations = None if children is not None: for child in children: self.add_child(child) def __repr__(self): return self.name def add_child(self, node): assert isinstance(node, SchemaTree), "The children of a node should be an instacne of SchemaTree." self.children.append(node) if __name__ == '__main__': # 命名实体识别 schema = ['时间', '选手', '赛事名称'] # Define the schema for entity extraction ie = UIEPredictor(schema=schema) pprint(ie("2月8日上午北京冬奥会自由式滑雪女子大跳台决赛中中国选手谷爱凌以188.25分获得金牌!")) schema = ['肿瘤的大小', '肿瘤的个数', '肝癌级别', '脉管内癌栓分级'] ie.set_schema(schema) pprint(ie("(右肝肿瘤)肝细胞性肝癌(II-III级,梁索型和假腺管型),肿瘤包膜不完整,紧邻肝被膜,侵及周围肝组织,未见脉管内癌栓(MVI分级:M0级)及卫星子灶形成。(肿物1个,大小4.2×4.0×2.8cm)。")) # 关系抽取 schema = {'竞赛名称': ['主办方', '承办方', '已举办次数']} ie.set_schema(schema) # Reset schema pprint(ie('2022语言与智能技术竞赛由中国中文信息学会和中国计算机学会联合主办,百度公司、中国中文信息学会评测工作委员会和中国计算机学会自然语言处理专委会承办,已连续举办4届,成为全球最热门的中文NLP赛事之一。')) # 事件抽取 schema = {'地震触发词': ['地震强度', '时间', '震中位置', '震源深度']} ie.set_schema(schema) # Reset schema ie('中国地震台网正式测定:5月16日06时08分在云南临沧市凤庆县(北纬24.34度,东经99.98度)发生3.5级地震,震源深度10千米。') # 评论观点抽取 schema = {'评价维度': ['观点词', '情感倾向[正向,负向]']} ie.set_schema(schema) # Reset schema pprint(ie("店面干净,很清静,服务员服务热情,性价比很高,发现收银台有排队")) # 情感倾向分类 schema = '情感倾向[正向,负向]' ie.set_schema(schema) ie('这个产品用起来真的很流畅,我非常喜欢')