Commit 19a23d09 authored by wangsen's avatar wangsen
Browse files

Initial commit

parents
Pipeline #1247 failed with stages
in 0 seconds
#! -*- coding: utf-8 -*-
# 基础测试:mlm预测
from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer
import torch
# 加载模型,请更换成自己的路径
root_model_path = "F:/Projects/pretrain_ckpt/bert/[google_tf_base]--chinese_L-12_H-768_A-12"
vocab_path = root_model_path + "/vocab.txt"
config_path = root_model_path + "/bert_config.json"
checkpoint_path = root_model_path + '/pytorch_model.bin'
# 建立分词器
tokenizer = Tokenizer(vocab_path, do_lower_case=True)
model = build_transformer_model(config_path, checkpoint_path, with_mlm='softmax') # 建立模型,加载权重
token_ids, segments_ids = tokenizer.encode("科学技术是第一生产力")
token_ids[3] = token_ids[4] = tokenizer._token_mask_id
print(''.join(tokenizer.ids_to_tokens(token_ids)))
tokens_ids_tensor = torch.tensor([token_ids])
segment_ids_tensor = torch.tensor([segments_ids])
# 需要传入参数with_mlm
model.eval()
with torch.no_grad():
_, probas = model([tokens_ids_tensor, segment_ids_tensor])
result = torch.argmax(probas[0, 3:5], dim=-1).numpy()
print(tokenizer.decode(result))
#! -*- coding: utf-8 -*-
# 基本测试:清华开源的中文GPT2模型(26亿参数)
# 项目链接:https://github.com/TsinghuaAI/CPM-Generate
# 博客介绍:https://kexue.fm/archives/7912
# 权重需转换后方可加载,转换脚本见convert_script文件夹
import numpy as np
from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import SpTokenizer
from bert4torch.snippets import AutoRegressiveDecoder
import torch
import jieba
jieba.initialize()
# 模型路径
config_path = 'F:/Projects/pretrain_ckpt/gpt2/[cpm_gpt2_torch]--cpm_lm_2.6b/bert4torch_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/gpt2/[cpm_gpt2_torch]--cpm_lm_2.6b/bert4torch_pytorch_model.bin'
spm_path = 'F:/Projects/pretrain_ckpt/gpt2/[cpm_gpt2_torch]--cpm_lm_2.6b/chinese_vocab.model'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def pre_tokenize(text):
"""分词前处理函数,'\n'替换成'▃', ' '替换成'▂'
"""
return [
w.replace(' ', u'\u2582').replace('\n', u'\u2583')
for w in jieba.cut(text, cut_all=False)
]
tokenizer = SpTokenizer(
spm_path,
token_start=None,
token_end=None,
pre_tokenize=pre_tokenize,
token_translate={u'\u2583': '<cls>'} # '\n'替换成<cls>
) # 建立分词器
model = build_transformer_model(
config_path=config_path, checkpoint_path=checkpoint_path, model='gpt2', segment_vocab_size=0
).to(device) # 建立模型,加载权重
class TextExpansion(AutoRegressiveDecoder):
"""基于随机采样的文本续写
"""
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
token_ids = torch.cat([inputs[0], output_ids], 1)
logits = model.predict([token_ids])
return logits[:, -1, :]
def generate(self, text, n=1, topp=0.95, temperature=1):
"""输出结果会有一定的随机性,如果只关心Few Shot效果,
可以考虑将解码方式换为beam search。
"""
token_ids, _ = tokenizer.encode(text)
results = self.random_sample([token_ids], n, topp=topp, temperature=temperature) # 基于随机采样
results = [token_ids + [int(i) for i in ids.cpu().numpy()] for ids in results]
texts = [tokenizer.decode(ids) for ids in results]
return [self.post_replace(text) for text in texts]
def post_replace(self, text):
for s, t in [(' ', ''), (u'\u2582', ' '), (u'\u2583', '\n')]:
text = text.replace(s, t)
return text
text_expansion = TextExpansion(
start_id=None,
end_id=3, # 3是<cls>,也是换行符
maxlen=16,
device=device
)
# 常识推理
# 本例输出:北京
query = u"""
美国的首都是华盛顿
法国的首都是巴黎
日本的首都是东京
中国的首都是
"""
print(text_expansion.generate(query[1:-1], 1)[0])
# 单词翻译
# 本例输出:bird
query = u"""
狗 dog
猫 cat
猪 pig
"""
print(text_expansion.generate(query[1:-1], 1)[0])
# 主语抽取
# 本例输出:杨振宁
query = u"""
从1931年起,华罗庚在清华大学边学习边工作 华罗庚
在一间简陋的房间里,陈景润攻克了“哥德巴赫猜想” 陈景润
在这里,丘成桐得到IBM奖学金 丘成桐
杨振宁在粒子物理学、统计力学和凝聚态物理等领域作出里程碑性贡献
"""
print(text_expansion.generate(query[1:-1], 1)[0])
# 三元组抽取
# 本例输出:张红,体重,140斤
query = u"""
姚明的身高是211cm,是很多人心目中的偶像。 ->姚明,身高,211cm
毛泽东是绍兴人,早年在长沙读书。->毛泽东,出生地,绍兴
虽然周杰伦在欧洲办的婚礼,但是他是土生土长的中国人->周杰伦,国籍,中国
小明出生于武汉,但是却不喜欢在武汉生成,长大后去了北京。->小明,出生地,武汉
吴亦凡是很多人的偶像,但是他却是加拿大人,另很多人失望->吴亦凡,国籍,加拿大
武耀的生日在5月8号,这一天,大家都为他庆祝了生日->武耀,生日,5月8号
《青花瓷》是周杰伦最得意的一首歌。->周杰伦,作品,《青花瓷》
北京是中国的首都。->中国,首都,北京
蒋碧的家乡在盘龙城,毕业后去了深圳工作。->蒋碧,籍贯,盘龙城
上周我们和王立一起去了他的家乡云南玩昨天才回到了武汉。->王立,籍贯,云南
昨天11月17号,我和朋友一起去了海底捞,期间服务员为我的朋友刘章庆祝了生日。->刘章,生日,11月17号
张红的体重达到了140斤,她很苦恼。->
"""
print(text_expansion.generate(query[1:-1], 1)[0])
#! -*- coding: utf-8 -*-
# 基础测试:ERNIE模型测试
from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer
import torch
# 加载模型,请更换成自己的路径
root_model_path = "F:/Projects/pretrain_ckpt/ernie/[baidu_torch_base]--ernie-1-base-zh"
# root_model_path = "F:/Projects/pretrain_ckpt/ernie/[baidu_torch_base]--ernie-3-base-zh"
vocab_path = root_model_path + "/vocab.txt"
config_path = root_model_path + "/config.json"
checkpoint_path = root_model_path + '/pytorch_model.bin'
# 建立分词器
tokenizer = Tokenizer(vocab_path, do_lower_case=True)
model = build_transformer_model(config_path, checkpoint_path, model='ERNIE', with_mlm='softmax') # 建立模型,加载权重
token_ids, segments_ids = tokenizer.encode("科学技术是第一生产力")
token_ids[3] = token_ids[4] = tokenizer._token_mask_id
print(''.join(tokenizer.ids_to_tokens(token_ids)))
tokens_ids_tensor = torch.tensor([token_ids])
segment_ids_tensor = torch.tensor([segments_ids])
# 需要传入参数
model.eval()
with torch.no_grad():
_, probas = model([tokens_ids_tensor, segment_ids_tensor])
result = torch.argmax(probas[0, 3:5], dim=-1).numpy()
print(tokenizer.decode(result))
#! -*- coding: utf-8 -*-
# 基本测试:gpt2_ml的效果测试
# 项目链接(tf版本):https://github.com/imcaspar/gpt2-ml
# 权重需转换后方可加载,转换脚本见convert_script文件夹
import torch
from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer
from bert4torch.snippets import AutoRegressiveDecoder
config_path = 'F:/Projects/pretrain_ckpt/gpt2/[gpt2-ml_torch_15g]/bert4torch_config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/gpt2/[gpt2-ml_torch_15g]/bert4torch_pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/gpt2/[gpt2-ml_torch_15g]/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = Tokenizer(dict_path, token_start=None, token_end=None, do_lower_case=True) # 建立分词器
model = build_transformer_model(config_path=config_path, checkpoint_path=checkpoint_path, model='gpt2_ml', segment_vocab_size=0).to(device) # 建立模型,加载权重
class ArticleCompletion(AutoRegressiveDecoder):
"""基于随机采样的文章续写
"""
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
token_ids = torch.cat([inputs[0], output_ids], 1)
logits = model.predict([token_ids])
return logits[:, -1, :]
def generate(self, text, n=1, topp=0.95):
token_ids, _ = tokenizer.encode(text)
results = self.random_sample([token_ids], n, topp=topp) # 基于随机采样
return [text + tokenizer.decode(ids.cpu().numpy()) for ids in results]
article_completion = ArticleCompletion(
start_id=None,
end_id=511, # 511是中文句号
maxlen=256,
minlen=128,
device=device
)
for text in [u'今天天气不错', u'双十一', u'科学空间']:
print(article_completion.generate(text))
"""
部分结果:
>>> article_completion.generate(u'今天天气不错')
[u'今天天气不错。昨天的天气是多云到晴的天气,今天的天气还不错,不会太冷。明后两天天气还是比较好的。不过今天的天气比较闷热,最高温度在30℃左右,明后两天天气会更加热。预计今天的最高温度为30℃,明后两天的最 高温度为32℃左右,今天的最高气温将在30℃左右。(记者李莉)。新华网重庆频道诚邀广大网友投稿,您可以用相机或手机记录下身边的感人故事,精彩瞬间。请将作者、拍摄时间、地点和简要说明连同照片发给我们,我们将精选其中的好图、美图在页面上展示,让所有新华网友共赏。[投稿] 。本报讯(记者陈敏华) 今年上半年,重庆市各级公安机关在全力抓好']
>>> article_completion.generate(u'双十一')
[u'双十一大是中国共产党在新的历史起点上召开的一次十分重要的代表大会, 是全面落实科学发展观、推进中国特色社会主义伟大事业的一次重要会议。会议的召开, 是党和政府对新世纪新阶段我国改革开放和社会主义现代化建设 事业的新的历史任务的一次重要总动员, 必将对我们党全面推进党的建']
>>> article_completion.generate(u'科学空间')
[u'科学空间站上的两个机器人在进入轨道后,一边在轨道上工作,一边用它们的身体和心脏在空间站上的一个大气层进行活动,以确保它们在进入地球之后不会因太阳风暴而受到影响;而另外一个机器人则在进入轨道的过程中,通 过机器人与地球上的大气层相互作用,使地球的大气层不断地向地球的大气层中转移,以使其能够在空间站上工作,并且使用它们的身体和心脏来完成它们的各种任务。']
"""
#! -*- coding: utf-8 -*-
# 基本测试:中文GPT模型,base版本,华为开源的
# 权重链接: https://pan.baidu.com/s/1-FB0yl1uxYDCGIRvU1XNzQ 提取码: xynn,这里使用的是转pytorch后的模型文件
# 参考项目:https://github.com/bojone/chinese-gen
import torch
from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer
from bert4torch.snippets import AutoRegressiveDecoder
config_path = 'F:/Projects/pretrain_ckpt/bert/[huawei_noah_tf_base]--chinese_nezha_gpt_L-12_H-768_A-12/config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/bert/[huawei_noah_tf_base]--chinese_nezha_gpt_L-12_H-768_A-12/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/bert/[huawei_noah_tf_base]--chinese_nezha_gpt_L-12_H-768_A-12/vocab.txt'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器
model = build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
segment_vocab_size=0, # 去掉segmeng_ids输入
application='lm',
).to(device) # 建立模型,加载权重
class ArticleCompletion(AutoRegressiveDecoder):
"""基于随机采样的文章续写
"""
@AutoRegressiveDecoder.wraps(default_rtype='logits')
def predict(self, inputs, output_ids, states):
token_ids = torch.cat([inputs[0], output_ids], 1)
_, mlm_scores = model.predict([token_ids])
return mlm_scores[:, -1, :]
def generate(self, text, n=1, topp=0.95):
token_ids = tokenizer.encode(text)[0][:-1]
results = self.random_sample([token_ids], n, topp=topp) # 基于随机采样
return [text + tokenizer.decode(ids.cpu().numpy()) for ids in results]
article_completion = ArticleCompletion(
start_id=None,
end_id=511, # 511是中文句号
maxlen=256,
minlen=128,
device=device
)
print(article_completion.generate(u'今天天气不错'))
"""
部分结果:
>>> article_completion.generate(u'今天天气不错')
[u'今天天气不错。昨天的天气是多云到晴的天气,今天的天气还不错,不会太冷。明后两天天气还是比较好的。不过今天的天气比较闷热,最高温度在30℃左右,明后两天天气会更加热。预计今天的最高温度为30℃,明后两天的最 高温度为32℃左右,今天的最高气温将在30℃左右。(记者李莉)。新华网重庆频道诚邀广大网友投稿,您可以用相机或手机记录下身边的感人故事,精彩瞬间。请将作者、拍摄时间、地点和简要说明连同照片发给我们,我们将精选其中的好图、美图在页面上展示,让所有新华网友共赏。[投稿] 。本报讯(记者陈敏华) 今年上半年,重庆市各级公安机关在全力抓好']
>>> article_completion.generate(u'双十一')
[u'双十一大是中国共产党在新的历史起点上召开的一次十分重要的代表大会, 是全面落实科学发展观、推进中国特色社会主义伟大事业的一次重要会议。会议的召开, 是党和政府对新世纪新阶段我国改革开放和社会主义现代化建设 事业的新的历史任务的一次重要总动员, 必将对我们党全面推进党的建']
>>> article_completion.generate(u'科学空间')
[u'科学空间站上的两个机器人在进入轨道后,一边在轨道上工作,一边用它们的身体和心脏在空间站上的一个大气层进行活动,以确保它们在进入地球之后不会因太阳风暴而受到影响;而另外一个机器人则在进入轨道的过程中,通 过机器人与地球上的大气层相互作用,使地球的大气层不断地向地球的大气层中转移,以使其能够在空间站上工作,并且使用它们的身体和心脏来完成它们的各种任务。']
"""
#! -*- coding: utf-8 -*-
# NEZHA模型做闲聊任务,这里只提供了测试脚本
# 源项目:https://github.com/bojone/nezha_gpt_dialog
# 权重转换脚本见:https://github.com/Tongjilibo/bert4torch/blob/master/examples/convert_script/convert_nezha_gpt_dialog.py
from bert4torch.models import build_transformer_model
from bert4torch.tokenizers import Tokenizer
from bert4torch.snippets import AutoRegressiveDecoder
import torch
# nezha配置
config_path = 'F:/Projects/pretrain_ckpt/nezha/[sushen_tf_base]--nezha_gpt_dialog/config.json'
checkpoint_path = 'F:/Projects/pretrain_ckpt/nezha/[sushen_tf_base]--nezha_gpt_dialog/pytorch_model.bin'
dict_path = 'F:/Projects/pretrain_ckpt/nezha/[sushen_tf_base]--nezha_gpt_dialog/vocab.txt'
# 建立分词器
tokenizer = Tokenizer(dict_path, do_lower_case=True)
# 建立并加载模型
model = build_transformer_model(
config_path,
checkpoint_path,
model='nezha',
application='lm',
)
class ChatBot(AutoRegressiveDecoder):
"""基于随机采样对话机器人
"""
@AutoRegressiveDecoder.wraps(default_rtype='logits')
def predict(self, inputs, output_ids, states):
token_ids, segment_ids = inputs
token_ids = torch.concat([token_ids, output_ids], 1)
curr_segment_ids = torch.ones_like(output_ids) - segment_ids[0, -1]
segment_ids = torch.concat([segment_ids, curr_segment_ids], 1)
return model.predict([token_ids, segment_ids])[-1][:, -1]
def response(self, texts, topk=5):
token_ids, segment_ids = [tokenizer._token_start_id], [0]
for i, text in enumerate(texts):
ids = tokenizer.encode(text)[0][1:]
token_ids.extend(ids)
segment_ids.extend([i % 2] * len(ids))
results = self.random_sample([token_ids, segment_ids], 1, topk)
return tokenizer.decode(results[0].cpu().numpy())
chatbot = ChatBot(start_id=None, end_id=tokenizer._token_end_id, maxlen=32)
print(chatbot.response([u'别爱我没结果', u'你这样会失去我的', u'失去了又能怎样']))
"""
回复是随机的,例如:那你还爱我吗 | 不知道 | 爱情是不是不能因为一点小事就否定了 | 我会一直爱你,你一个人会很辛苦 | 等等。
"""
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment