# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import json import random import re import numpy as np import paddle def set_seed(seed): paddle.seed(seed) random.seed(seed) np.random.seed(seed) def load_txt(file_path): texts = [] with open(file_path, "r", encoding="utf-8") as f: for line in f.readlines(): texts.append(line.strip()) return texts def load_json_file(path): exmaples = [] with open(path, "r", encoding="utf-8") as f: for line in f.readlines(): example = json.loads(line) exmaples.append(example) return exmaples def write_json_file(examples, save_path): with open(save_path, "w", encoding="utf-8") as f: for example in examples: line = json.dumps(example, ensure_ascii=False) f.write(line + "\n") def str2bool(v): """Support bool type for argparse.""" if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Unsupported value encountered.") def create_data_loader(dataset, mode="train", batch_size=1, trans_fn=None): """ Create dataloader. Args: dataset(obj:`paddle.io.Dataset`): Dataset instance. mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly. batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch. trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc. Returns: dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches. """ if trans_fn: dataset = dataset.map(trans_fn) shuffle = True if mode == "train" else False if mode == "train": sampler = paddle.io.DistributedBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle) else: sampler = paddle.io.BatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle) dataloader = paddle.io.DataLoader(dataset, batch_sampler=sampler, return_list=True) return dataloader def convert_example(example, tokenizer, max_seq_len): """ example: { title prompt content result_list } """ encoded_inputs = tokenizer( text=[example["prompt"]], text_pair=[example["content"]], truncation=True, max_seq_len=max_seq_len, pad_to_max_seq_len=True, return_attention_mask=True, return_position_ids=True, return_dict=False, return_offsets_mapping=True, ) encoded_inputs = encoded_inputs[0] offset_mapping = [list(x) for x in encoded_inputs["offset_mapping"]] bias = 0 for index in range(1, len(offset_mapping)): mapping = offset_mapping[index] if mapping[0] == 0 and mapping[1] == 0 and bias == 0: bias = offset_mapping[index - 1][1] + 1 # Includes [SEP] token if mapping[0] == 0 and mapping[1] == 0: continue offset_mapping[index][0] += bias offset_mapping[index][1] += bias start_ids = [0 for x in range(max_seq_len)] end_ids = [0 for x in range(max_seq_len)] for item in example["result_list"]: # Positioning at char granularity,offset_mapping indicates offset by char. start = map_offset(item["start"] + bias, offset_mapping) end = map_offset(item["end"] - 1 + bias, offset_mapping) start_ids[start] = 1.0 end_ids[end] = 1.0 tokenized_output = [ encoded_inputs["input_ids"], encoded_inputs["token_type_ids"], encoded_inputs["position_ids"], encoded_inputs["attention_mask"], start_ids, end_ids, ] tokenized_output = [np.array(x, dtype="int64") for x in tokenized_output] return tuple(tokenized_output) def map_offset(ori_offset, offset_mapping): """ map ori offset to token offset """ for index, span in enumerate(offset_mapping): if span[0] <= ori_offset < span[1]: return index return -1 def reader(data_path, max_seq_len=512): """ read json """ with open(data_path, "r", encoding="utf-8") as f: for line in f: json_line = json.loads(line) content = json_line["content"].strip() prompt = json_line["prompt"] # Model Input is aslike: [CLS] Prompt [SEP] Content [SEP] # It include three summary tokens. if max_seq_len <= len(prompt) + 3: raise ValueError("The value of max_seq_len is too small, please set a larger value") max_content_len = max_seq_len - len(prompt) - 3 if len(content) <= max_content_len: yield json_line else: result_list = json_line["result_list"] json_lines = [] accumulate = 0 while True: cur_result_list = [] for result in result_list: if result["start"] + 1 <= max_content_len < result["end"]: max_content_len = result["start"] break cur_content = content[:max_content_len] res_content = content[max_content_len:] while True: if len(result_list) == 0: break elif result_list[0]["end"] <= max_content_len: if result_list[0]["end"] > 0: cur_result = result_list.pop(0) cur_result_list.append(cur_result) else: cur_result_list = [result for result in result_list] break else: break json_line = {"content": cur_content, "result_list": cur_result_list, "prompt": prompt} json_lines.append(json_line) for result in result_list: if result["end"] <= 0: break result["start"] -= max_content_len result["end"] -= max_content_len accumulate += max_content_len max_content_len = max_seq_len - len(prompt) - 3 if len(res_content) == 0: break elif len(res_content) < max_content_len: json_line = {"content": res_content, "result_list": result_list, "prompt": prompt} json_lines.append(json_line) break else: content = res_content for json_line in json_lines: yield json_line def unify_prompt_name(prompt): # The classification labels are shuffled during finetuning, so they need # to be unified during evaluation. if re.search(r"\[.*?\]$", prompt): prompt_prefix = prompt[: prompt.find("[", 1)] cls_options = re.search(r"\[.*?\]$", prompt).group()[1:-1].split(",") cls_options = sorted(list(set(cls_options))) cls_options = ",".join(cls_options) prompt = prompt_prefix + "[" + cls_options + "]" return prompt return prompt def get_relation_type_dict(relation_data): def compare(a, b): a = a[::-1] b = b[::-1] res = "" for i in range(min(len(a), len(b))): if a[i] == b[i]: res += a[i] else: break if res == "": return res elif res[::-1][0] == "的": return res[::-1][1:] return "" relation_type_dict = {} added_list = [] for i in range(len(relation_data)): added = False if relation_data[i][0] not in added_list: for j in range(i + 1, len(relation_data)): match = compare(relation_data[i][0], relation_data[j][0]) if match != "": match = unify_prompt_name(match) if relation_data[i][0] not in added_list: added_list.append(relation_data[i][0]) relation_type_dict.setdefault(match, []).append(relation_data[i][1]) added_list.append(relation_data[j][0]) relation_type_dict.setdefault(match, []).append(relation_data[j][1]) added = True if not added: added_list.append(relation_data[i][0]) suffix = relation_data[i][0].rsplit("的", 1)[1] suffix = unify_prompt_name(suffix) relation_type_dict[suffix] = relation_data[i][1] return relation_type_dict