#! -*- coding: utf-8 -*- # bert做Seq2Seq任务,采用UNILM方案 # 介绍链接:https://kexue.fm/archives/6933 from bert4torch.models import build_transformer_model from bert4torch.tokenizers import Tokenizer, load_vocab from bert4torch.snippets import sequence_padding, text_segmentate from bert4torch.snippets import AutoRegressiveDecoder, Callback, ListDataset import torch from torchinfo import summary import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader import glob # 基本参数 maxlen = 256 batch_size = 16 epochs = 10000 # bert配置 config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json' checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin' dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt' device = 'cuda' if torch.cuda.is_available() else 'cpu' # 加载并精简词表,建立分词器 token_dict, keep_tokens = load_vocab( dict_path=dict_path, simplified=True, startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'], ) tokenizer = Tokenizer(token_dict, do_lower_case=True) def collate_fn(batch): """单条样本格式:[CLS]篇章[SEP]答案[SEP]问题[SEP] """ batch_token_ids, batch_segment_ids = [], [] for txt in batch: text = open(txt, encoding='utf-8').read() text = text.split('\n') if len(text) > 1: title = text[0] content = '\n'.join(text[1:]) token_ids, segment_ids = tokenizer.encode(content, title, maxlen=maxlen) batch_token_ids.append(token_ids) batch_segment_ids.append(segment_ids) batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device) batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device) return [batch_token_ids, batch_segment_ids], [batch_token_ids, batch_segment_ids] train_dataloader = DataLoader(ListDataset(glob.glob('F:/Projects/data/corpus/sentence_classification/THUCNews/*/*.txt')), batch_size=batch_size, shuffle=True, collate_fn=collate_fn) model = build_transformer_model( config_path, checkpoint_path, with_mlm=True, application='unilm', keep_tokens=keep_tokens, # 只保留keep_tokens中的字,精简原字表 ).to(device) summary(model, input_data=[next(iter(train_dataloader))[0]]) class CrossEntropyLoss(nn.CrossEntropyLoss): def __init__(self, **kwargs): super().__init__(**kwargs) def forward(self, outputs, target): ''' y_pred: [btz, seq_len, vocab_size] targets: y_true, y_segment unilm式样,需要手动把非seq2seq部分mask掉 ''' _, y_pred = outputs y_true, y_mask = target y_true = y_true[:, 1:]# 目标token_ids y_mask = y_mask[:, 1:] # segment_ids,刚好指示了要预测的部分 y_pred = y_pred[:, :-1, :] # 预测序列,错开一位 y_pred = y_pred.reshape(-1, y_pred.shape[-1]) y_true = (y_true*y_mask).flatten() return super().forward(y_pred, y_true) model.compile(loss=CrossEntropyLoss(ignore_index=0), optimizer=optim.Adam(model.parameters(), 1e-5)) class AutoTitle(AutoRegressiveDecoder): """seq2seq解码器 """ @AutoRegressiveDecoder.wraps(default_rtype='logits') def predict(self, inputs, output_ids, states): token_ids, segment_ids = inputs token_ids = torch.cat([token_ids, output_ids], 1) segment_ids = torch.cat([segment_ids, torch.ones_like(output_ids, device=device)], 1) _, y_pred = model.predict([token_ids, segment_ids]) return y_pred[:, -1, :] def generate(self, text, topk=1, topp=0.95): max_c_len = maxlen - self.maxlen token_ids, segment_ids = tokenizer.encode(text, maxlen=max_c_len) output_ids = self.beam_search([token_ids, segment_ids], topk=topk) # 基于beam search return tokenizer.decode(output_ids.cpu().numpy()) autotitle = AutoTitle(start_id=None, end_id=tokenizer._token_end_id, maxlen=32, device=device) def just_show(): s1 = u'夏天来临,皮肤在强烈紫外线的照射下,晒伤不可避免,因此,晒后及时修复显得尤为重要,否则可能会造成长期伤害。专家表示,选择晒后护肤品要慎重,芦荟凝胶是最安全,有效的一种选择,晒伤严重者,还请及 时 就医 。' s2 = u'8月28日,网络爆料称,华住集团旗下连锁酒店用户数据疑似发生泄露。从卖家发布的内容看,数据包含华住旗下汉庭、禧玥、桔子、宜必思等10余个品牌酒店的住客信息。泄露的信息包括华住官网注册资料、酒店入住登记的身份信息及酒店开房记录,住客姓名、手机号、邮箱、身份证号、登录账号密码等。卖家对这个约5亿条数据打包出售。第三方安全平台威胁猎人对信息出售者提供的三万条数据进行验证,认为数据真实性非常高。当天下午 ,华 住集 团发声明称,已在内部迅速开展核查,并第一时间报警。当晚,上海警方消息称,接到华住集团报案,警方已经介入调查。' for s in [s1, s2]: print(u'生成标题:', autotitle.generate(s)) class Evaluator(Callback): """评估与保存 """ def __init__(self): self.lowest = 1e10 def on_epoch_end(self, steps, epoch, logs=None): # 保存最优 if logs['loss'] <= self.lowest: self.lowest = logs['loss'] # model.save_weights('./best_model.pt') # 演示效果 just_show() if __name__ == '__main__': just_show() evaluator = Evaluator() model.fit( train_dataloader, steps_per_epoch=100, epochs=epochs, callbacks=[evaluator] ) else: model.load_weights('./best_model.pt')