task_language_model.py 8.36 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#! -*- coding: utf-8 -*-
# bert做language model任务,小说生成

import glob, re
from tqdm import tqdm
from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer, load_vocab
from bert4torch.snippets import sequence_padding, AutoRegressiveDecoder, Callback, ListDataset
import torch
from torch.utils.data import Dataset, DataLoader
from torchinfo import summary
import torch.nn as nn
import torch.optim as optim

maxlen = 256
batch_size = 8
epochs = 10000

# bert配置
config_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 加载并精简词表,建立分词器
token_dict, keep_tokens = load_vocab(
    dict_path=dict_path,
    simplified=True,
    startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'],
)
tokenizer = Tokenizer(token_dict, do_lower_case=True)

# 加载数据集
class MyDataset(ListDataset):
    @staticmethod
    def load_data(filenames):
        """加载数据,并尽量划分为不超过maxlen的句子
        """
        novels = []

        for txt in glob.glob(filenames):
            txt = open(txt, encoding='utf-8').read()
            txt = txt.replace('\r', '').replace('\n', '')
            txt = txt.replace(u'整理制作,并提供下载', '')
            txt = re.sub(u'www.*?com', '', txt)
            txt = txt.replace(u'\u3000', ' ')
            sents = []
            for t in txt.split('  '):
                for s in re.findall(u'.*?。', t):
                    if len(s) <= maxlen - 2:
                        sents.append(s)
            novels.append(sents)
        data = []
        pbar = tqdm(desc=u'构建语料中', total=sum(len(n) for n in novels))

        for novel in novels:
            s = u''
            for i in range(len(novel)):
                for j in range(len(novel) - i):
                    if len(s) + len(novel[i + j]) > maxlen - 2:
                        data.append(s)
                        s = u''
                        break
                    else:
                        s += novel[i + j]
                pbar.update(1)
                if i + j >= len(novel):
                    break
            if s:
                data.append(s)

        pbar.close()
        return data

def collate_fn(batch):
    batch_token_ids, batch_segment_ids = [], []
    for text in batch:
        token_ids, segment_ids = tokenizer.encode(text)
        batch_token_ids.append(token_ids)
        batch_segment_ids.append(segment_ids)

    batch_token_ids = torch.tensor(sequence_padding(batch_token_ids), dtype=torch.long, device=device)
    batch_segment_ids = torch.tensor(sequence_padding(batch_segment_ids), dtype=torch.long, device=device)
    return [batch_token_ids, batch_segment_ids], batch_token_ids

# 加载数据集
train_dataloader = DataLoader(MyDataset('F:/Projects/data/corpus/pretrain/金庸小说/*.txt'), 
                   batch_size=batch_size, shuffle=True, collate_fn=collate_fn) 

# 建模
model = build_transformer_model(
    config_path,
    checkpoint_path,
    with_mlm=True,
    application='lm',
    keep_tokens=keep_tokens,  # 只保留keep_tokens中的字,精简原字表
).to(device)
summary(model, input_data=[next(iter(train_dataloader))[0]])

class CrossEntropyLoss(nn.CrossEntropyLoss):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def forward(self, outputs, target):
        _, mlm_scores = outputs
        mlm_scores = mlm_scores[:, :-1, :].reshape(-1, mlm_scores.shape[-1])
        target = target[:, 1:].flatten()
        return super().forward(mlm_scores, target)

model.compile(loss=CrossEntropyLoss(ignore_index=0), optimizer=optim.Adam(model.parameters(), 1e-5))

# 随机采样
class StoryCompletion(AutoRegressiveDecoder):
    """基于随机采样的故事续写
    """
    @AutoRegressiveDecoder.wraps(default_rtype='logits')
    def predict(self, inputs, output_ids, states):
        token_ids = inputs[0]
        token_ids = torch.cat([token_ids, output_ids], 1)
        segment_ids = torch.zeros_like(token_ids, device=device)
        _, mlm_scores = model.predict([token_ids, segment_ids])
        return mlm_scores[:, -1, :]

    def generate(self, text, n=1, topp=0.95):
        token_ids, _ = tokenizer.encode(text)
        results = self.random_sample([token_ids[:-1]], n, topp=topp)  # 基于随机采样
        return [text + tokenizer.decode(ids.cpu().numpy()) for ids in results]

story_completion = StoryCompletion(start_id=None, end_id=tokenizer._token_end_id, maxlen=maxlen, device=device)

def just_show():
    s1 = u'当晚两人在一家小客店中宿歇。张无忌躺在炕上,越想越是担心,走到赵敏窗外,但听她呼吸调匀,正自香梦沉酣。'
    s2 = u'虚竹飞身跃上松树的枝干,只见段延庆的钢杖深深嵌在树枝之中,全凭一股内力粘劲,挂住了下面四人,内力之深厚,实是非同小可。虚竹伸左手抓住钢杖,提将上来。'
    s3 = u'杨过居住在侠客岛,是令狐冲的弟子,武器是金蛇剑。'
    for s in [s1, s2, s3]:
        t = story_completion.generate(s)
        print(u'输入: %s' % s)
        print(u'结果: %s\n' % ('\n'.join(t)))


class Evaluator(Callback):
    """评估与保存
    """
    def __init__(self):
        self.lowest = 1e10

    def on_epoch_end(self, steps, epoch, logs=None):
        # 保存最优
        if logs['loss'] <= self.lowest:
            self.lowest = logs['loss']
            # model.save_weights('./best_model.pt')
        # 演示效果
        just_show()


if __name__ == '__main__':

    evaluator = Evaluator()

    model.fit(train_dataloader, epochs=epochs, steps_per_epoch=100, callbacks=[evaluator])

else:

    model.load_weights('./best_model.weights')
"""
效果:

输入: 当晚两人在一家小客店中宿歇。张无忌躺在炕上,越想越是担心,走到赵敏窗外,但听她呼吸调匀,正自香梦沉酣。
结果: 当晚两人在一家小客店中宿歇。张无忌躺在炕上,越想越是担心,走到赵敏窗外,但听她呼吸调匀,正自香梦沉酣。次日清晨,张无忌便和赵敏去买了一匹高头大马,自己骑了随伴。那马甚有神骏,三十六斤重的身躯之中,竟无一头白马。他心中怦怦乱跳,暗想:若能将赵敏引出迷城,我决不致再和她相会,但若和赵姑娘相遇,我一生一世决计再难相见。何况我是她的私生女儿,这般亲热,岂不是好?我如何能和她相见?今后我要教训教训她才好?我教教她,教训她,要她心里快快活活的。他心如刀割,当即回到客店,将张无忌的所在说了。

输入: 虚竹飞身跃上松树的枝干,只见段延庆的钢杖深深嵌在树枝之中,全凭一股内力粘劲,挂住了下面四人,内力之深厚,实是非同小可。虚竹伸左手抓住钢杖,提将上来。
结果: 虚竹飞身跃上松树的枝干,只见段延庆的钢杖深深嵌在树枝之中,全凭一股内力粘劲,挂住了下面四人,内力之深厚,实是非同小可。虚竹伸左手抓住钢杖,提将上来。那矮子见他如此功力,大吃一惊,叫道:什么人?是谁?你干什么?我师父是谁?你们是谁?是谁?你们是谁?我师父是谁?你这矮子,便是段延庆。你们不知道我师父便是,是不是?快快说来。那矮子道:我师父便是延庆太子,他的徒弟也是段延庆。他老人家在唐朝做镇南王,你们便将他改名为延庆太子,叫做延庆太子!这名头倒怪,你们大伙儿听见了,也不知道他老人家是死是活。

输入: 杨过居住在侠客岛,是令狐冲的弟子,武器是金蛇剑。
结果: 杨过居住在侠客岛,是令狐冲的弟子,武器是金蛇剑。这时见他手中所握,竟是一柄特制的短剑,心中大喜,叫道::原来是金蛇郎君的剑!原来你便是金蛇郎君的弟子,这一下可要叫我失望了。那人哈哈一笑,说道:好啊!好啊,好啊!我的金蛇剑是我的,不过我是你的。这人道:我姓杨名过,名字叫过。你是我儿子,是我女儿,是不是?你这么大的年纪,怎地自称金刀驸马?我这就给你取个名字,叫作过儿。
"""