basic_gibbs_sampling_via_mlm.py 2.01 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#! -*- coding: utf-8 -*-
# 测试代码可用性: 结合MLM的Gibbs采样

from tqdm import tqdm
import numpy as np
from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer
import torch
import torch.nn as nn


root_model_path = "F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12"
vocab_path = root_model_path + "/vocab.txt"
config_path = root_model_path + "/bert_config.json"
checkpoint_path = root_model_path + '/pytorch_model.bin'

tokenizer = Tokenizer(vocab_path, do_lower_case=True)  # 建立分词器
model = build_transformer_model(
    config_path=config_path, checkpoint_path=checkpoint_path, with_mlm='softmax'
)  # 建立模型,加载权重

sentences = []
init_sent = u'科学技术是第一生产力。'  # 给定句子或者None
minlen, maxlen = 8, 32
steps = 10000
converged_steps = 1000
vocab_size = tokenizer._vocab_size

if init_sent is None:
    length = np.random.randint(minlen, maxlen + 1)
    tokens = ['[CLS]'] + ['[MASK]'] * length + ['[SEP]']
    token_ids = tokenizer.tokens_to_ids(tokens)
    segment_ids = [0] * len(token_ids)
else:
    token_ids, segment_ids = tokenizer.encode(init_sent)
    length = len(token_ids) - 2

device='cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
with torch.no_grad():
    for _ in tqdm(range(steps), desc='Sampling'):
        # Gibbs采样流程:随机mask掉一个token,然后通过MLM模型重新采样这个token。
        i = np.random.choice(length) + 1
        token_ids[i] = tokenizer._token_mask_id
        token_ids_tensor, segment_ids_tensor = torch.tensor([token_ids], device=device), torch.tensor([segment_ids], device=device)
        _, probas = model([token_ids_tensor, segment_ids_tensor])
        probas = probas[0, i]
        token = np.random.choice(vocab_size, p=probas.cpu().numpy())
        token_ids[i] = token
        sentences.append(tokenizer.decode(token_ids))

print(u'部分随机采样结: ')
for _ in range(10):
    print(np.random.choice(sentences[converged_steps:]))