"maint/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "8f4628e0967441eaa42f6e71162ece4301e87685"
Commit 0fc002df authored by huchen's avatar huchen
Browse files

init the dlexamples new

parent 0e04b692
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# bert4keras
- Our light reimplement of bert for keras
- 更清晰、更轻量级的keras版bert
- 在线文档:http://bert4keras.spaces.ac.cn/ (还在构建中)
## 功能
目前已经实现:
- 加载bert/roberta/albert的预训练权重进行finetune;
- 实现语言模型、seq2seq所需要的attention mask;
- 丰富的examples</a>
- 从零预训练代码(支持TPU、多GPU,请看pretraining</a>);
- 兼容keras、tf.keras
## 使用
安装稳定版:
```shell
pip install bert4keras
```
安装最新版:
```shell
pip install git+https://www.github.com/bojone/bert4keras.git
```
使用例子请参考examples</a>目录。
理论上兼容Python2和Python3,兼容tensorflow 1.14+和tensorflow 2.x,实验环境是Python 2.7、Tesorflow 1.14+以及Keras 2.3.1(已经在2.2.4、2.3.0、2.3.1、tf.keras下测试通过)。
**为了获得最好的体验,建议你使用Tensorflow 1.14 + Keras 2.3.1组合。**
<blockquote><strong>关于环境组合</strong>
- 支持tf+keras和tf+tf.keras,后者需要提前传入环境变量TF_KERAS=1。
- 当使用tf+keras时,建议2.2.4 <= keras <= 2.3.1,以及 1.14 <= tf <= 2.2,不能使用tf 2.3+。
- keras 2.4+可以用,但事实上keras 2.4.x基本上已经完全等价于tf.keras了,因此如果你要用keras 2.4+,倒不如直接用tf.keras。
</blockquote>
## 权重
目前支持加载的权重:
- <strong>Google原版bert</strong>: https://github.com/google-research/bert
- <strong>brightmart版roberta</strong>: https://github.com/brightmart/roberta_zh
- <strong>哈工大版roberta</strong>: https://github.com/ymcui/Chinese-BERT-wwm
- <strong>Google原版albert</strong><sup><a href="https://github.com/bojone/bert4keras/issues/29#issuecomment-552188981">[例子]</a></sup>: https://github.com/google-research/ALBERT
- <strong>brightmart版albert</strong>: https://github.com/brightmart/albert_zh
- <strong>转换后的albert</strong>: https://github.com/bojone/albert_zh
- <strong>华为的NEZHA</strong>: https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/NEZHA-TensorFlow
- <strong>华为的NEZHA-GEN</strong>: https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/NEZHA-Gen-TensorFlow
- <strong>自研语言模型</strong>: https://github.com/ZhuiyiTechnology/pretrained-models
- <strong>T5模型</strong>: https://github.com/google-research/text-to-text-transfer-transformer
- <strong>GPT_OpenAI</strong>: https://github.com/bojone/CDial-GPT-tf
- <strong>GPT2_ML</strong>: https://github.com/imcaspar/gpt2-ml
- <strong>Google原版ELECTRA</strong>: https://github.com/google-research/electra
- <strong>哈工大版ELECTRA</strong>: https://github.com/ymcui/Chinese-ELECTRA
- <strong>CLUE版ELECTRA</strong>: https://github.com/CLUEbenchmark/ELECTRA
- <strong>LaBSE(多国语言BERT)</strong>: https://github.com/bojone/labse
- <strong>Chinese-GEN项目下的模型</strong>: https://github.com/bojone/chinese-gen
- <strong>T5.1.1</strong>: https://github.com/google-research/text-to-text-transfer-transformer/blob/master/released_checkpoints.md#t511
- <strong>Multilingual T5</strong>: https://github.com/google-research/multilingual-t5/
<strong>注意事项</strong>
- 注1:brightmart版albert的开源时间早于Google版albert,这导致早期brightmart版albert的权重与Google版的不完全一致,换言之两者不能直接相互替换。为了减少代码冗余,bert4keras的0.2.4及后续版本均只支持加载<u>Google版</u>以brightmart版中<u>带Google字眼</u>的权重。如果要加载早期版本的权重,请用<a href="https://github.com/bojone/bert4keras/releases/tag/v0.2.3">0.2.3版本</a>,或者考虑作者转换过的<a href="https://github.com/bojone/albert_zh">albert_zh</a>
- 注2:下载下来的ELECTRA权重,如果没有json配置文件的话,参考<a href="https://github.com/ymcui/Chinese-ELECTRA/issues/3">这里</a>自己改一个(需要加上`type_vocab_size`字段)。
#! -*- coding: utf-8 -*-
__version__ = '0.10.5'
# -*- coding: utf-8 -*-
# 分离后端函数,主要是为了同时兼容原生keras和tf.keras
# 通过设置环境变量TF_KERAS=1来切换tf.keras
import os, sys
from distutils.util import strtobool
import numpy as np
import tensorflow as tf
from tensorflow.python.util import nest, tf_inspect
from tensorflow.python.eager import tape
from tensorflow.python.ops.custom_gradient import _graph_mode_decorator
# 判断是tf.keras还是纯keras的标记
is_tf_keras = strtobool(os.environ.get('TF_KERAS', '0'))
if is_tf_keras:
import tensorflow.keras as keras
import tensorflow.keras.backend as K
sys.modules['keras'] = keras
else:
import keras
import keras.backend as K
# 判断是否启用重计算(通过时间换空间)
do_recompute = strtobool(os.environ.get('RECOMPUTE', '0'))
def gelu_erf(x):
"""基于Erf直接计算的gelu函数
"""
return 0.5 * x * (1.0 + tf.math.erf(x / np.sqrt(2.0)))
def gelu_tanh(x):
"""基于Tanh近似计算的gelu函数
"""
cdf = 0.5 * (
1.0 + K.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * K.pow(x, 3))))
)
return x * cdf
def set_gelu(version):
"""设置gelu版本
"""
version = version.lower()
assert version in ['erf', 'tanh'], 'gelu version must be erf or tanh'
if version == 'erf':
keras.utils.get_custom_objects()['gelu'] = gelu_erf
else:
keras.utils.get_custom_objects()['gelu'] = gelu_tanh
def piecewise_linear(t, schedule):
"""分段线性函数
其中schedule是形如{1000: 1, 2000: 0.1}的字典,
表示 t ∈ [0, 1000]时,输出从0均匀增加至1,而
t ∈ [1000, 2000]时,输出从1均匀降低到0.1,最后
t > 2000时,保持0.1不变。
"""
schedule = sorted(schedule.items())
if schedule[0][0] != 0:
schedule = [(0, 0.0)] + schedule
x = K.constant(schedule[0][1], dtype=K.floatx())
t = K.cast(t, K.floatx())
for i in range(len(schedule)):
t_begin = schedule[i][0]
x_begin = x
if i != len(schedule) - 1:
dx = schedule[i + 1][1] - schedule[i][1]
dt = schedule[i + 1][0] - schedule[i][0]
slope = 1.0 * dx / dt
x = schedule[i][1] + slope * (t - t_begin)
else:
x = K.constant(schedule[i][1], dtype=K.floatx())
x = K.switch(t >= t_begin, x, x_begin)
return x
def search_layer(inputs, name, exclude_from=None):
"""根据inputs和name来搜索层
说明:inputs为某个层或某个层的输出;name为目标层的名字。
实现:根据inputs一直往上递归搜索,直到发现名字为name的层为止;
如果找不到,那就返回None。
"""
if exclude_from is None:
exclude_from = set()
if isinstance(inputs, keras.layers.Layer):
layer = inputs
else:
layer = inputs._keras_history[0]
if layer.name == name:
return layer
elif layer in exclude_from:
return None
else:
exclude_from.add(layer)
if isinstance(layer, keras.models.Model):
model = layer
for layer in model.layers:
if layer.name == name:
return layer
inbound_layers = layer._inbound_nodes[0].inbound_layers
if not isinstance(inbound_layers, list):
inbound_layers = [inbound_layers]
if len(inbound_layers) > 0:
for layer in inbound_layers:
layer = search_layer(layer, name, exclude_from)
if layer is not None:
return layer
def sequence_masking(x, mask, value=0.0, axis=None):
"""为序列条件mask的函数
mask: 形如(batch_size, seq_len)的0-1矩阵;
value: mask部分要被替换成的值,可以是'-inf'或'inf';
axis: 序列所在轴,默认为1;
"""
if mask is None:
return x
else:
if K.dtype(mask) != K.dtype(x):
mask = K.cast(mask, K.dtype(x))
if value == '-inf':
value = -1e12
elif value == 'inf':
value = 1e12
if axis is None:
axis = 1
elif axis < 0:
axis = K.ndim(x) + axis
assert axis > 0, 'axis must be greater than 0'
for _ in range(axis - 1):
mask = K.expand_dims(mask, 1)
for _ in range(K.ndim(x) - K.ndim(mask)):
mask = K.expand_dims(mask, K.ndim(mask))
return x * mask + value * (1 - mask)
def batch_gather(params, indices):
"""同tf旧版本的batch_gather
"""
if K.dtype(indices)[:3] != 'int':
indices = K.cast(indices, 'int32')
try:
return tf.gather(params, indices, batch_dims=K.ndim(indices) - 1)
except Exception as e1:
try:
return tf.batch_gather(params, indices)
except Exception as e2:
raise ValueError('%s\n%s\n' % (e1.message, e2.message))
def pool1d(
x,
pool_size,
strides=1,
padding='valid',
data_format=None,
pool_mode='max'
):
"""向量序列的pool函数
"""
x = K.expand_dims(x, 1)
x = K.pool2d(
x,
pool_size=(1, pool_size),
strides=(1, strides),
padding=padding,
data_format=data_format,
pool_mode=pool_mode
)
return x[:, 0]
def divisible_temporal_padding(x, n):
"""将一维向量序列右padding到长度能被n整除
"""
r_len = K.shape(x)[1] % n
p_len = K.switch(r_len > 0, n - r_len, 0)
return K.temporal_padding(x, (0, p_len))
def swish(x):
"""swish函数(这样封装过后才有 __name__ 属性)
"""
return tf.nn.swish(x)
def leaky_relu(x, alpha=0.2):
"""leaky relu函数(这样封装过后才有 __name__ 属性)
"""
return tf.nn.leaky_relu(x, alpha=alpha)
class Sinusoidal(keras.initializers.Initializer):
"""Sin-Cos位置向量初始化器
来自:https://arxiv.org/abs/1706.03762
"""
def __call__(self, shape, dtype=None):
"""Sin-Cos形式的位置向量
"""
vocab_size, depth = shape
embeddings = np.zeros(shape)
for pos in range(vocab_size):
for i in range(depth // 2):
theta = pos / np.power(10000, 2. * i / depth)
embeddings[pos, 2 * i] = np.sin(theta)
embeddings[pos, 2 * i + 1] = np.cos(theta)
return embeddings
def symbolic(f):
"""恒等装饰器(兼容旧版本keras用)
"""
return f
def graph_mode_decorator(f, *args, **kwargs):
"""tf 2.1与之前版本的传参方式不一样,这里做个同步
"""
if tf.__version__ < '2.1':
return _graph_mode_decorator(f, *args, **kwargs)
else:
return _graph_mode_decorator(f, args, kwargs)
def recompute_grad(call):
"""重计算装饰器(用来装饰Keras层的call函数)
关于重计算,请参考:https://arxiv.org/abs/1604.06174
"""
if not do_recompute:
return call
def inner(self, inputs, **kwargs):
"""定义需要求梯度的函数以及重新定义求梯度过程
(参考自官方自带的tf.recompute_grad函数)
"""
flat_inputs = nest.flatten(inputs)
call_args = tf_inspect.getfullargspec(call).args
for key in ['mask', 'training']:
if key not in call_args and key in kwargs:
del kwargs[key]
def kernel_call():
"""定义前向计算
"""
return call(self, inputs, **kwargs)
def call_and_grad(*inputs):
"""定义前向计算和反向计算
"""
if is_tf_keras:
with tape.stop_recording():
outputs = kernel_call()
outputs = tf.identity(outputs)
else:
outputs = kernel_call()
def grad_fn(doutputs, variables=None):
watches = list(inputs)
if variables is not None:
watches += list(variables)
with tf.GradientTape() as t:
t.watch(watches)
with tf.control_dependencies([doutputs]):
outputs = kernel_call()
grads = t.gradient(
outputs, watches, output_gradients=[doutputs]
)
del t
return grads[:len(inputs)], grads[len(inputs):]
return outputs, grad_fn
if is_tf_keras: # 仅在tf >= 2.0下可用
outputs, grad_fn = call_and_grad(*flat_inputs)
flat_outputs = nest.flatten(outputs)
def actual_grad_fn(*doutputs):
grads = grad_fn(*doutputs, variables=self.trainable_weights)
return grads[0] + grads[1]
watches = flat_inputs + self.trainable_weights
watches = [tf.convert_to_tensor(x) for x in watches]
tape.record_operation(
call.__name__, flat_outputs, watches, actual_grad_fn
)
return outputs
else: # keras + tf >= 1.14 均可用
return graph_mode_decorator(call_and_grad, *flat_inputs)
return inner
# 给旧版本keras新增symbolic方法(装饰器),
# 以便兼容optimizers.py中的代码
K.symbolic = getattr(K, 'symbolic', None) or symbolic
custom_objects = {
'gelu_erf': gelu_erf,
'gelu_tanh': gelu_tanh,
'gelu': gelu_erf,
'swish': swish,
'leaky_relu': leaky_relu,
'Sinusoidal': Sinusoidal,
}
keras.utils.get_custom_objects().update(custom_objects)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
#! -*- coding: utf-8 -*-
# 分词函数
import unicodedata, re
from bert4keras.snippets import is_string, is_py2
from bert4keras.snippets import open
from bert4keras.snippets import convert_to_unicode
from bert4keras.snippets import truncate_sequences
def load_vocab(dict_path, encoding='utf-8', simplified=False, startswith=None):
"""从bert的词典文件中读取词典
"""
token_dict = {}
with open(dict_path, encoding=encoding) as reader:
for line in reader:
token = line.split()
token = token[0] if token else line.strip()
token_dict[token] = len(token_dict)
if simplified: # 过滤冗余部分token
new_token_dict, keep_tokens = {}, []
startswith = startswith or []
for t in startswith:
new_token_dict[t] = len(new_token_dict)
keep_tokens.append(token_dict[t])
for t, _ in sorted(token_dict.items(), key=lambda s: s[1]):
if t not in new_token_dict:
keep = True
if len(t) > 1:
for c in Tokenizer.stem(t):
if (
Tokenizer._is_cjk_character(c) or
Tokenizer._is_punctuation(c)
):
keep = False
break
if keep:
new_token_dict[t] = len(new_token_dict)
keep_tokens.append(token_dict[t])
return new_token_dict, keep_tokens
else:
return token_dict
def save_vocab(dict_path, token_dict, encoding='utf-8'):
"""将词典(比如精简过的)保存为文件
"""
with open(dict_path, 'w', encoding=encoding) as writer:
for k, v in sorted(token_dict.items(), key=lambda s: s[1]):
writer.write(k + '\n')
class TokenizerBase(object):
"""分词器基类
"""
def __init__(
self,
token_start='[CLS]',
token_end='[SEP]',
pre_tokenize=None,
token_translate=None
):
"""参数说明:
pre_tokenize:外部传入的分词函数,用作对文本进行预分词。如果传入
pre_tokenize,则先执行pre_tokenize(text),然后在它
的基础上执行原本的tokenize函数;
token_translate:映射字典,主要用在tokenize之后,将某些特殊的token
替换为对应的token。
"""
self._token_pad = '[PAD]'
self._token_unk = '[UNK]'
self._token_mask = '[MASK]'
self._token_start = token_start
self._token_end = token_end
self._pre_tokenize = pre_tokenize
self._token_translate = token_translate or {}
self._token_translate_inv = {
v: k
for k, v in self._token_translate.items()
}
def tokenize(self, text, maxlen=None):
"""分词函数
"""
tokens = [
self._token_translate.get(token) or token
for token in self._tokenize(text)
]
if self._token_start is not None:
tokens.insert(0, self._token_start)
if self._token_end is not None:
tokens.append(self._token_end)
if maxlen is not None:
index = int(self._token_end is not None) + 1
truncate_sequences(maxlen, -index, tokens)
return tokens
def token_to_id(self, token):
"""token转换为对应的id
"""
raise NotImplementedError
def tokens_to_ids(self, tokens):
"""token序列转换为对应的id序列
"""
return [self.token_to_id(token) for token in tokens]
def encode(
self,
first_text,
second_text=None,
maxlen=None,
pattern='S*E*E',
truncate_from='right'
):
"""输出文本对应token id和segment id
"""
if is_string(first_text):
first_tokens = self.tokenize(first_text)
else:
first_tokens = first_text
if second_text is None:
second_tokens = None
elif is_string(second_text):
second_tokens = self.tokenize(second_text)
else:
second_tokens = second_text
if maxlen is not None:
if truncate_from == 'right':
index = -int(self._token_end is not None) - 1
elif truncate_from == 'left':
index = int(self._token_start is not None)
else:
index = truncate_from
if second_text is not None and pattern == 'S*E*E':
maxlen += 1
truncate_sequences(maxlen, index, first_tokens, second_tokens)
first_token_ids = self.tokens_to_ids(first_tokens)
first_segment_ids = [0] * len(first_token_ids)
if second_text is not None:
if pattern == 'S*E*E':
idx = int(bool(self._token_start))
second_tokens = second_tokens[idx:]
second_token_ids = self.tokens_to_ids(second_tokens)
second_segment_ids = [1] * len(second_token_ids)
first_token_ids.extend(second_token_ids)
first_segment_ids.extend(second_segment_ids)
return first_token_ids, first_segment_ids
def id_to_token(self, i):
"""id序列为对应的token
"""
raise NotImplementedError
def ids_to_tokens(self, ids):
"""id序列转换为对应的token序列
"""
return [self.id_to_token(i) for i in ids]
def decode(self, ids):
"""转为可读文本
"""
raise NotImplementedError
def _tokenize(self, text):
"""基本分词函数
"""
raise NotImplementedError
class Tokenizer(TokenizerBase):
"""Bert原生分词器
纯Python实现,代码修改自keras_bert的tokenizer实现
"""
def __init__(
self, token_dict, do_lower_case=False, word_maxlen=200, **kwargs
):
super(Tokenizer, self).__init__(**kwargs)
if is_string(token_dict):
token_dict = load_vocab(token_dict)
self._do_lower_case = do_lower_case
self._token_dict = token_dict
self._token_dict_inv = {v: k for k, v in token_dict.items()}
self._vocab_size = len(token_dict)
self._word_maxlen = word_maxlen
for token in ['pad', 'unk', 'mask', 'start', 'end']:
try:
_token_id = token_dict[getattr(self, '_token_%s' % token)]
setattr(self, '_token_%s_id' % token, _token_id)
except:
pass
def token_to_id(self, token):
"""token转换为对应的id
"""
return self._token_dict.get(token, self._token_unk_id)
def id_to_token(self, i):
"""id转换为对应的token
"""
return self._token_dict_inv[i]
def decode(self, ids, tokens=None):
"""转为可读文本
"""
tokens = tokens or self.ids_to_tokens(ids)
tokens = [token for token in tokens if not self._is_special(token)]
text, flag = '', False
for i, token in enumerate(tokens):
if token[:2] == '##':
text += token[2:]
elif len(token) == 1 and self._is_cjk_character(token):
text += token
elif len(token) == 1 and self._is_punctuation(token):
text += token
text += ' '
elif i > 0 and self._is_cjk_character(text[-1]):
text += token
else:
text += ' '
text += token
text = re.sub(' +', ' ', text)
text = re.sub('\' (re|m|s|t|ve|d|ll) ', '\'\\1 ', text)
punctuation = self._cjk_punctuation() + '+-/={(<['
punctuation_regex = '|'.join([re.escape(p) for p in punctuation])
punctuation_regex = '(%s) ' % punctuation_regex
text = re.sub(punctuation_regex, '\\1', text)
text = re.sub('(\d\.) (\d)', '\\1\\2', text)
return text.strip()
def _tokenize(self, text, pre_tokenize=True):
"""基本分词函数
"""
if self._do_lower_case:
if is_py2:
text = unicode(text)
text = text.lower()
text = unicodedata.normalize('NFD', text)
text = ''.join([
ch for ch in text if unicodedata.category(ch) != 'Mn'
])
if pre_tokenize and self._pre_tokenize is not None:
tokens = []
for token in self._pre_tokenize(text):
if token in self._token_dict:
tokens.append(token)
else:
tokens.extend(self._tokenize(token, False))
return tokens
spaced = ''
for ch in text:
if self._is_punctuation(ch) or self._is_cjk_character(ch):
spaced += ' ' + ch + ' '
elif self._is_space(ch):
spaced += ' '
elif ord(ch) == 0 or ord(ch) == 0xfffd or self._is_control(ch):
continue
else:
spaced += ch
tokens = []
for word in spaced.strip().split():
tokens.extend(self._word_piece_tokenize(word))
return tokens
def _word_piece_tokenize(self, word):
"""word内分成subword
"""
if len(word) > self._word_maxlen:
return [word]
tokens, start, end = [], 0, 0
while start < len(word):
end = len(word)
while end > start:
sub = word[start:end]
if start > 0:
sub = '##' + sub
if sub in self._token_dict:
break
end -= 1
if start == end:
return [word]
else:
tokens.append(sub)
start = end
return tokens
@staticmethod
def stem(token):
"""获取token的“词干”(如果是##开头,则自动去掉##)
"""
if token[:2] == '##':
return token[2:]
else:
return token
@staticmethod
def _is_space(ch):
"""空格类字符判断
"""
return ch == ' ' or ch == '\n' or ch == '\r' or ch == '\t' or \
unicodedata.category(ch) == 'Zs'
@staticmethod
def _is_punctuation(ch):
"""标点符号类字符判断(全/半角均在此内)
提醒:unicodedata.category这个函数在py2和py3下的
表现可能不一样,比如u'§'字符,在py2下的结果为'So',
在py3下的结果是'Po'。
"""
code = ord(ch)
return 33 <= code <= 47 or \
58 <= code <= 64 or \
91 <= code <= 96 or \
123 <= code <= 126 or \
unicodedata.category(ch).startswith('P')
@staticmethod
def _cjk_punctuation():
return u'\uff02\uff03\uff04\uff05\uff06\uff07\uff08\uff09\uff0a\uff0b\uff0c\uff0d\uff0f\uff1a\uff1b\uff1c\uff1d\uff1e\uff20\uff3b\uff3c\uff3d\uff3e\uff3f\uff40\uff5b\uff5c\uff5d\uff5e\uff5f\uff60\uff62\uff63\uff64\u3000\u3001\u3003\u3008\u3009\u300a\u300b\u300c\u300d\u300e\u300f\u3010\u3011\u3014\u3015\u3016\u3017\u3018\u3019\u301a\u301b\u301c\u301d\u301e\u301f\u3030\u303e\u303f\u2013\u2014\u2018\u2019\u201b\u201c\u201d\u201e\u201f\u2026\u2027\ufe4f\ufe51\ufe54\u00b7\uff01\uff1f\uff61\u3002'
@staticmethod
def _is_cjk_character(ch):
"""CJK类字符判断(包括中文字符也在此列)
参考:https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
"""
code = ord(ch)
return 0x4E00 <= code <= 0x9FFF or \
0x3400 <= code <= 0x4DBF or \
0x20000 <= code <= 0x2A6DF or \
0x2A700 <= code <= 0x2B73F or \
0x2B740 <= code <= 0x2B81F or \
0x2B820 <= code <= 0x2CEAF or \
0xF900 <= code <= 0xFAFF or \
0x2F800 <= code <= 0x2FA1F
@staticmethod
def _is_control(ch):
"""控制类字符判断
"""
return unicodedata.category(ch) in ('Cc', 'Cf')
@staticmethod
def _is_special(ch):
"""判断是不是有特殊含义的符号
"""
return bool(ch) and (ch[0] == '[') and (ch[-1] == ']')
def rematch(self, text, tokens):
"""给出原始的text和tokenize后的tokens的映射关系
"""
if is_py2:
text = unicode(text)
if self._do_lower_case:
text = text.lower()
normalized_text, char_mapping = '', []
for i, ch in enumerate(text):
if self._do_lower_case:
ch = unicodedata.normalize('NFD', ch)
ch = ''.join([c for c in ch if unicodedata.category(c) != 'Mn'])
ch = ''.join([
c for c in ch
if not (ord(c) == 0 or ord(c) == 0xfffd or self._is_control(c))
])
normalized_text += ch
char_mapping.extend([i] * len(ch))
text, token_mapping, offset = normalized_text, [], 0
for token in tokens:
if self._is_special(token):
token_mapping.append([])
else:
token = self.stem(token)
start = text[offset:].index(token) + offset
end = start + len(token)
token_mapping.append(char_mapping[start:end])
offset = end
return token_mapping
class SpTokenizer(TokenizerBase):
"""基于SentencePiece模型的封装,使用上跟Tokenizer基本一致。
"""
def __init__(self, sp_model_path, **kwargs):
super(SpTokenizer, self).__init__(**kwargs)
import sentencepiece as spm
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(sp_model_path)
self._token_pad = self.sp_model.id_to_piece(self.sp_model.pad_id())
self._token_unk = self.sp_model.id_to_piece(self.sp_model.unk_id())
self._vocab_size = self.sp_model.get_piece_size()
for token in ['pad', 'unk', 'mask', 'start', 'end']:
try:
_token = getattr(self, '_token_%s' % token)
_token_id = self.sp_model.piece_to_id(_token)
setattr(self, '_token_%s_id' % token, _token_id)
except:
pass
def token_to_id(self, token):
"""token转换为对应的id
"""
return self.sp_model.piece_to_id(token)
def id_to_token(self, i):
"""id转换为对应的token
"""
if i < self._vocab_size:
return self.sp_model.id_to_piece(i)
else:
return ''
def decode(self, ids):
"""转为可读文本
"""
tokens = [
self._token_translate_inv.get(token) or token
for token in self.ids_to_tokens(ids)
]
text = self.sp_model.decode_pieces(tokens)
return convert_to_unicode(text)
def _tokenize(self, text):
"""基本分词函数
"""
if self._pre_tokenize is not None:
text = ' '.join(self._pre_tokenize(text))
tokens = self.sp_model.encode_as_pieces(text)
return tokens
def _is_special(self, i):
"""判断是不是有特殊含义的符号
"""
return self.sp_model.is_control(i) or \
self.sp_model.is_unknown(i) or \
self.sp_model.is_unused(i)
def _is_decodable(self, i):
"""判断是否应该被解码输出
"""
return (i < self._vocab_size) and not self._is_special(i)
# 例子合集
提示:Github上的examples只保证兼容Github上的最新版bert4keras,如果报错,请首先尝试升级bert4keras。
## 简介
- [basic_extract_features.py](https://github.com/bojone/bert4keras/tree/master/examples/basic_extract_features.py): 基础测试,测试BERT对句子的编码序列。
- [basic_gibbs_sampling_via_mlm.py](https://github.com/bojone/bert4keras/tree/master/examples/basic_gibbs_sampling_via_mlm.py): 基础测试,利用BERT+Gibbs采样进行文本随机生成,参考[这里](https://kexue.fm/archives/8119)
- [basic_language_model_cpm_lm.py](https://github.com/bojone/bert4keras/tree/master/examples/basic_language_model_cpm_lm.py): 基础测试,测试[CPM_LM](https://github.com/TsinghuaAI/CPM-Generate)的生成效果。
- [basic_language_model_gpt2_ml.py](https://github.com/bojone/bert4keras/tree/master/examples/basic_language_model_gpt2_ml.py): 基础测试,测试[GPT2_ML](https://github.com/imcaspar/gpt2-ml)的生成效果。
- [basic_language_model_nezha_gen_gpt.py](https://github.com/bojone/bert4keras/tree/master/examples/basic_language_model_nezha_gen_gpt.py): 基础测试,测试[GPT Base(又叫NEZHE-GEN)](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/NEZHA-Gen-TensorFlow)的生成效果。
- [basic_make_uncased_model_cased.py](https://github.com/bojone/bert4keras/tree/master/examples/basic_make_uncased_model_cased.py): 基础测试,通过简单修改词表,使得不区分大小写的模型有区分大小写的能力。
- [basic_masked_language_model.py](https://github.com/bojone/bert4keras/tree/master/examples/basic_masked_language_model.py): 基础测试,测试BERT的MLM模型效果。
- [basic_simple_web_serving_simbert.py](https://github.com/bojone/bert4keras/tree/master/examples/basic_simple_web_serving_simbert.py): 基础测试,测试自带的WebServing(将模型转化为Web接口)。
- [task_conditional_language_model.py](https://github.com/bojone/bert4keras/tree/master/examples/task_conditional_language_model.py): 任务例子,结合 BERT + [Conditional Layer Normalization](https://kexue.fm/archives/7124) 做条件语言模型。
- [task_iflytek_adversarial_training.py](https://github.com/bojone/bert4keras/tree/master/examples/task_iflytek_adversarial_training.py): 任务例子,通过[对抗训练](https://kexue.fm/archives/7234)提升分类效果。
- [task_iflytek_bert_of_theseus.py](https://github.com/bojone/bert4keras/tree/master/examples/task_iflytek_bert_of_theseus.py): 任务例子,通过[BERT-of-Theseus](https://kexue.fm/archives/7575)来进行模型压缩。
- [task_iflytek_gradient_penalty.py](https://github.com/bojone/bert4keras/tree/master/examples/task_iflytek_gradient_penalty.py): 任务例子,通过[梯度惩罚](https://kexue.fm/archives/7234)提升分类效果,可以视为另一种对抗训练。
- [task_image_caption.py](https://github.com/bojone/bert4keras/tree/master/examples/task_image_caption.py): 任务例子,BERT + [Conditional Layer Normalization](https://kexue.fm/archives/7124) + ImageNet预训练模型 来做图像描述生成。
- [task_language_model.py](https://github.com/bojone/bert4keras/tree/master/examples/task_language_model.py): 任务例子,加载BERT的预训练权重做无条件语言模型,效果上等价于GPT。
- [task_language_model_chinese_chess.py](https://github.com/bojone/bert4keras/tree/master/examples/task_language_model_chinese_chess.py): 任务例子,用GPT的方式下中国象棋,过程请参考[博客](https://kexue.fm/archives/7877)
- [task_question_answer_generation_by_seq2seq.py](https://github.com/bojone/bert4keras/tree/master/examples/task_question_answer_generation_by_seq2seq.py): 任务例子,通过[UniLM](https://kexue.fm/archives/6933)式的Seq2Seq模型来做[问答对自动构建](https://kexue.fm/archives/7630),属于自回归文本生成。
- [task_reading_comprehension_by_mlm.py](https://github.com/bojone/bert4keras/tree/master/examples/task_reading_comprehension_by_mlm.py): 任务例子,通过MLM模型来做[阅读理解问答](https://kexue.fm/archives/7148),属于简单的非自回归文本生成。
- [task_reading_comprehension_by_seq2seq.py](https://github.com/bojone/bert4keras/tree/master/examples/task_reading_comprehension_by_seq2seq.py): 任务例子,通过[UniLM](https://kexue.fm/archives/6933)式的Seq2Seq模型来做[阅读理解问答](https://kexue.fm/archives/7115),属于自回归文本生成。
- [task_relation_extraction.py](https://github.com/bojone/bert4keras/tree/master/examples/task_relation_extraction.py): 任务例子,结合BERT以及自行设计的“半指针-半标注”结构来做[关系抽取](https://kexue.fm/archives/7161)
- [task_sentence_similarity_lcqmc.py](https://github.com/bojone/bert4keras/tree/master/examples/task_sentence_similarity_lcqmc.py): 任务例子,句子对分类任务。
- [task_sentiment_albert.py](https://github.com/bojone/bert4keras/tree/master/examples/task_sentiment_albert.py): 任务例子,情感分类任务,加载ALBERT模型。
- [task_sentiment_integrated_gradients.py](https://github.com/bojone/bert4keras/tree/master/examples/task_sentiment_integrated_gradients.py): 任务例子,通过[积分梯度](https://kexue.fm/archives/7533)的方式可视化情感分类任务。
- [task_sentiment_virtual_adversarial_training.py](https://github.com/bojone/bert4keras/tree/master/examples/task_sentiment_virtual_adversarial_training.py): 任务例子,通过[虚拟对抗训练](https://kexue.fm/archives/7466)进行半监督学习,提升小样本下的情感分类性能。
- [task_seq2seq_ape210k_math_word_problem.py](https://github.com/bojone/bert4keras/tree/master/examples/task_seq2seq_ape210k_math_word_problem.py): 任务例子,通过[UniLM](https://kexue.fm/archives/6933)式的Seq2Seq模型来做小学数学应用题(数学公式生成),详情请见[这里](https://kexue.fm/archives/7809)
- [task_seq2seq_autotitle.py](https://github.com/bojone/bert4keras/tree/master/examples/task_seq2seq_autotitle.py): 任务例子,通过[UniLM](https://kexue.fm/archives/6933)式的Seq2Seq模型来做新闻标题生成。
- [task_seq2seq_autotitle_csl.py](https://github.com/bojone/bert4keras/tree/master/examples/task_seq2seq_autotitle_csl.py): 任务例子,通过[UniLM](https://kexue.fm/archives/6933)式的Seq2Seq模型来做论文标题生成,包含了评测代码。
- [task_seq2seq_autotitle_csl_mt5.py](https://github.com/bojone/bert4keras/tree/master/examples/task_seq2seq_autotitle_csl_mt5.py): 任务例子,通过[多国语言版T5](https://kexue.fm/archives/7867)式的Seq2Seq模型来做论文标题生成,包含了评测代码。
- [task_seq2seq_autotitle_multigpu.py](https://github.com/bojone/bert4keras/tree/master/examples/task_seq2seq_autotitle_multigpu.py): 任务例子,通过[UniLM](https://kexue.fm/archives/6933)式的Seq2Seq模型来做新闻标题生成,单机多卡版本。
- [task_sequence_labeling_cws_crf.py](https://github.com/bojone/bert4keras/tree/master/examples/task_sequence_labeling_cws_crf.py): 任务例子,通过 BERT + [CRF](https://kexue.fm/archives/7196) 来做中文分词。
- [task_sequence_labeling_ner_crf.py](https://github.com/bojone/bert4keras/tree/master/examples/task_sequence_labeling_ner_crf.py):
任务例子,通过 BERT + [CRF](https://kexue.fm/archives/7196) 来做中文NER。
#! -*- coding: utf-8 -*-
# 测试代码可用性: 提取特征
import numpy as np
from bert4keras.backend import keras
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import to_array
config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt'
tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器
model = build_transformer_model(config_path, checkpoint_path) # 建立模型,加载权重
# 编码测试
token_ids, segment_ids = tokenizer.encode(u'语言模型')
token_ids, segment_ids = to_array([token_ids], [segment_ids])
print('\n ===== predicting =====\n')
print(model.predict([token_ids, segment_ids]))
"""
输出:
[[[-0.63251007 0.2030236 0.07936534 ... 0.49122632 -0.20493352
0.2575253 ]
[-0.7588351 0.09651865 1.0718756 ... -0.6109694 0.04312154
0.03881441]
[ 0.5477043 -0.792117 0.44435206 ... 0.42449304 0.41105673
0.08222899]
[-0.2924238 0.6052722 0.49968526 ... 0.8604137 -0.6533166
0.5369075 ]
[-0.7473459 0.49431565 0.7185162 ... 0.3848612 -0.74090636
0.39056838]
[-0.8741375 -0.21650358 1.338839 ... 0.5816864 -0.4373226
0.56181806]]]
"""
print('\n ===== reloading and predicting =====\n')
model.save('test.model')
del model
model = keras.models.load_model('test.model')
print(model.predict([token_ids, segment_ids]))
#! -*- coding: utf-8 -*-
# 测试代码可用性: 结合MLM的Gibbs采样
from tqdm import tqdm
import numpy as np
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import to_array
config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt'
tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器
model = build_transformer_model(
config_path=config_path, checkpoint_path=checkpoint_path, with_mlm=True
) # 建立模型,加载权重
sentences = []
init_sent = u'科学技术是第一生产力。' # 给定句子或者None
minlen, maxlen = 8, 32
steps = 10000
converged_steps = 1000
vocab_size = tokenizer._vocab_size
if init_sent is None:
length = np.random.randint(minlen, maxlen + 1)
tokens = ['[CLS]'] + ['[MASK]'] * length + ['[SEP]']
token_ids = tokenizer.tokens_to_ids(tokens)
segment_ids = [0] * len(token_ids)
else:
token_ids, segment_ids = tokenizer.encode(init_sent)
length = len(token_ids) - 2
for _ in tqdm(range(steps), desc='Sampling'):
# Gibbs采样流程:随机mask掉一个token,然后通过MLM模型重新采样这个token。
i = np.random.choice(length) + 1
token_ids[i] = tokenizer._token_mask_id
probas = model.predict(to_array([token_ids], [segment_ids]))[0, i]
token = np.random.choice(vocab_size, p=probas)
token_ids[i] = token
sentences.append(tokenizer.decode(token_ids))
print(u'部分随机采样结果:')
for _ in range(10):
print(np.random.choice(sentences[converged_steps:]))
#! -*- coding: utf-8 -*-
# 基本测试:清华开源的中文GPT2模型(26亿参数)
# 项目链接:https://github.com/TsinghuaAI/CPM-Generate
# 博客介绍:https://kexue.fm/archives/7912
import numpy as np
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import SpTokenizer
from bert4keras.snippets import AutoRegressiveDecoder
from bert4keras.snippets import uniout
import jieba
jieba.initialize()
# 模型路径
config_path = '/root/kg/bert/CPM_LM_2.6B_TF/config.json'
checkpoint_path = '/root/kg/bert/CPM_LM_2.6B_TF/model.ckpt'
spm_path = '/root/kg/bert/CPM_LM_2.6B_TF/chinese_vocab.model'
def pre_tokenize(text):
"""分词前处理函数
"""
return [
w.replace(' ', u'\u2582').replace('\n', u'\u2583')
for w in jieba.cut(text, cut_all=False)
]
tokenizer = SpTokenizer(
spm_path,
token_start=None,
token_end=None,
pre_tokenize=pre_tokenize,
token_translate={u'\u2583': '<cls>'}
) # 建立分词器
model = build_transformer_model(
config_path=config_path, checkpoint_path=checkpoint_path, model='gpt2'
) # 建立模型,加载权重
class TextExpansion(AutoRegressiveDecoder):
"""基于随机采样的文本续写
"""
@AutoRegressiveDecoder.wraps(default_rtype='probas')
def predict(self, inputs, output_ids, states):
token_ids = np.concatenate([inputs[0], output_ids], 1)
return self.last_token(model).predict(token_ids)
def generate(self, text, n=1, topp=0.95, temperature=1):
"""输出结果会有一定的随机性,如果只关心Few Shot效果,
可以考虑将解码方式换为beam search。
"""
token_ids, _ = tokenizer.encode(text)
results = self.random_sample([token_ids],
n,
topp=topp,
temperature=temperature) # 基于随机采样
results = [token_ids + [int(i) for i in ids] for ids in results]
texts = [tokenizer.decode(ids) for ids in results]
return [self.post_replace(text) for text in texts]
def post_replace(self, text):
for s, t in [(' ', ''), (u'\u2582', ' '), (u'\u2583', '\n')]:
text = text.replace(s, t)
return text
text_expansion = TextExpansion(
start_id=None,
end_id=3, # 3是<cls>,也是换行符
maxlen=16,
)
# 常识推理
# 本例输出:北京
query = u"""
美国的首都是华盛顿
法国的首都是巴黎
日本的首都是东京
中国的首都是
"""
print(text_expansion.generate(query[1:-1], 1)[0])
# 单词翻译
# 本例输出:bird
query = u"""
狗 dog
猫 cat
猪 pig
"""
print(text_expansion.generate(query[1:-1], 1)[0])
# 主语抽取
# 本例输出:杨振宁
query = u"""
从1931年起,华罗庚在清华大学边学习边工作 华罗庚
在一间简陋的房间里,陈景润攻克了“哥德巴赫猜想” 陈景润
在这里,丘成桐得到IBM奖学金 丘成桐
杨振宁在粒子物理学、统计力学和凝聚态物理等领域作出里程碑性贡献
"""
print(text_expansion.generate(query[1:-1], 1)[0])
# 三元组抽取
# 本例输出:张红,体重,140斤
query = u"""
姚明的身高是211cm,是很多人心目中的偶像。 ->姚明,身高,211cm
毛泽东是绍兴人,早年在长沙读书。->毛泽东,出生地,绍兴
虽然周杰伦在欧洲办的婚礼,但是他是土生土长的中国人->周杰伦,国籍,中国
小明出生于武汉,但是却不喜欢在武汉生成,长大后去了北京。->小明,出生地,武汉
吴亦凡是很多人的偶像,但是他却是加拿大人,另很多人失望->吴亦凡,国籍,加拿大
武耀的生日在5月8号,这一天,大家都为他庆祝了生日->武耀,生日,5月8号
《青花瓷》是周杰伦最得意的一首歌。->周杰伦,作品,《青花瓷》
北京是中国的首都。->中国,首都,北京
蒋碧的家乡在盘龙城,毕业后去了深圳工作。->蒋碧,籍贯,盘龙城
上周我们和王立一起去了他的家乡云南玩昨天才回到了武汉。->王立,籍贯,云南
昨天11月17号,我和朋友一起去了海底捞,期间服务员为我的朋友刘章庆祝了生日。->刘章,生日,11月17号
张红的体重达到了140斤,她很苦恼。->
"""
print(text_expansion.generate(query[1:-1], 1)[0])
This diff is collapsed.
#! -*- coding: utf-8 -*-
# 测试代码可用性: MLM
import numpy as np
from bert4keras.models import build_transformer_model
from bert4keras.tokenizers import Tokenizer
from bert4keras.snippets import to_array
config_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_config.json'
checkpoint_path = '/root/kg/bert/chinese_L-12_H-768_A-12/bert_model.ckpt'
dict_path = '/root/kg/bert/chinese_L-12_H-768_A-12/vocab.txt'
tokenizer = Tokenizer(dict_path, do_lower_case=True) # 建立分词器
model = build_transformer_model(
config_path=config_path, checkpoint_path=checkpoint_path, with_mlm=True
) # 建立模型,加载权重
token_ids, segment_ids = tokenizer.encode(u'科学技术是第一生产力')
# mask掉“技术”
token_ids[3] = token_ids[4] = tokenizer._token_mask_id
token_ids, segment_ids = to_array([token_ids], [segment_ids])
# 用mlm模型预测被mask掉的部分
probas = model.predict([token_ids, segment_ids])[0]
print(tokenizer.decode(probas[3:5].argmax(axis=1))) # 结果正是“技术”
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment