Commit b045ce7d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 272777104
parent 0f176f6f
...@@ -43,6 +43,10 @@ CLS_ID = special_symbols["<cls>"] ...@@ -43,6 +43,10 @@ CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"] SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"] MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"] EOD_ID = special_symbols["<eod>"]
SEG_ID_P = 0
SEG_ID_Q = 1
SEG_ID_CLS = 2
SEG_ID_PAD = 3
def file_based_input_fn_builder(input_file, name_to_features, batch_size, def file_based_input_fn_builder(input_file, name_to_features, batch_size,
......
...@@ -48,8 +48,11 @@ FLAGS = flags.FLAGS ...@@ -48,8 +48,11 @@ FLAGS = flags.FLAGS
def get_pretrainxlnet_model(model_config, run_config): def get_pretrainxlnet_model(model_config, run_config):
model = modeling.PretrainingXLNetModel(model_config, run_config, name="model") return modeling.PretrainingXLNetModel(
return model use_proj=True,
xlnet_config=model_config,
run_config=run_config,
name="model")
def main(unused_argv): def main(unused_argv):
...@@ -69,8 +72,7 @@ def main(unused_argv): ...@@ -69,8 +72,7 @@ def main(unused_argv):
if strategy: if strategy:
logging.info("***** Number of cores used : %d", logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync) strategy.num_replicas_in_sync)
logging.info("***** Number of hosts used : %d", logging.info("***** Number of hosts used : %d", num_hosts)
num_hosts)
train_input_fn = functools.partial( train_input_fn = functools.partial(
data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len, data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size, strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,
......
...@@ -36,11 +36,6 @@ from official.nlp.xlnet import preprocess_utils ...@@ -36,11 +36,6 @@ from official.nlp.xlnet import preprocess_utils
SPIECE_UNDERLINE = u"▁" SPIECE_UNDERLINE = u"▁"
SEG_ID_P = 0
SEG_ID_Q = 1
SEG_ID_CLS = 2
SEG_ID_PAD = 3
class InputFeatures(object): class InputFeatures(object):
"""A single set of features of data.""" """A single set of features of data."""
...@@ -705,28 +700,28 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride, ...@@ -705,28 +700,28 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
split_token_index) split_token_index)
token_is_max_context[len(tokens)] = is_max_context token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index]) tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(SEG_ID_P) segment_ids.append(data_utils.SEG_ID_P)
p_mask.append(0) p_mask.append(0)
paragraph_len = len(tokens) paragraph_len = len(tokens)
tokens.append(data_utils.SEP_ID) tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_P) segment_ids.append(data_utils.SEG_ID_P)
p_mask.append(1) p_mask.append(1)
# note(zhiliny): we put P before Q # note(zhiliny): we put P before Q
# because during pretraining, B is always shorter than A # because during pretraining, B is always shorter than A
for token in query_tokens: for token in query_tokens:
tokens.append(token) tokens.append(token)
segment_ids.append(SEG_ID_Q) segment_ids.append(data_utils.SEG_ID_Q)
p_mask.append(1) p_mask.append(1)
tokens.append(data_utils.SEP_ID) tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_Q) segment_ids.append(data_utils.SEG_ID_Q)
p_mask.append(1) p_mask.append(1)
cls_index = len(segment_ids) cls_index = len(segment_ids)
tokens.append(data_utils.CLS_ID) tokens.append(data_utils.CLS_ID)
segment_ids.append(SEG_ID_CLS) segment_ids.append(data_utils.SEG_ID_CLS)
p_mask.append(0) p_mask.append(0)
input_ids = tokens input_ids = tokens
...@@ -739,7 +734,7 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride, ...@@ -739,7 +734,7 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
while len(input_ids) < max_seq_length: while len(input_ids) < max_seq_length:
input_ids.append(0) input_ids.append(0)
input_mask.append(1) input_mask.append(1)
segment_ids.append(SEG_ID_PAD) segment_ids.append(data_utils.SEG_ID_PAD)
p_mask.append(1) p_mask.append(1)
assert len(input_ids) == max_seq_length assert len(input_ids) == max_seq_length
......
...@@ -30,7 +30,6 @@ def create_run_config(is_training, is_finetune, flags): ...@@ -30,7 +30,6 @@ def create_run_config(is_training, is_finetune, flags):
kwargs = dict( kwargs = dict(
is_training=is_training, is_training=is_training,
use_tpu=flags.use_tpu, use_tpu=flags.use_tpu,
use_bfloat16=flags.use_bfloat16,
dropout=flags.dropout, dropout=flags.dropout,
dropout_att=flags.dropout_att, dropout_att=flags.dropout_att,
init_method=flags.init_method, init_method=flags.init_method,
...@@ -49,6 +48,7 @@ def create_run_config(is_training, is_finetune, flags): ...@@ -49,6 +48,7 @@ def create_run_config(is_training, is_finetune, flags):
return RunConfig(**kwargs) return RunConfig(**kwargs)
# TODO(hongkuny): refactor XLNetConfig and RunConfig.
class XLNetConfig(object): class XLNetConfig(object):
"""Configs for XLNet model. """Configs for XLNet model.
...@@ -131,7 +131,6 @@ class RunConfig(object): ...@@ -131,7 +131,6 @@ class RunConfig(object):
def __init__(self, def __init__(self,
is_training, is_training,
use_tpu, use_tpu,
use_bfloat16,
dropout, dropout,
dropout_att, dropout_att,
init_method='normal', init_method='normal',
...@@ -141,13 +140,13 @@ class RunConfig(object): ...@@ -141,13 +140,13 @@ class RunConfig(object):
reuse_len=None, reuse_len=None,
bi_data=False, bi_data=False,
clamp_len=-1, clamp_len=-1,
same_length=False): same_length=False,
use_cls_mask=True):
"""Initializes RunConfig. """Initializes RunConfig.
Args: Args:
is_training: bool, whether in training mode. is_training: bool, whether in training mode.
use_tpu: bool, whether TPUs are used. use_tpu: bool, whether TPUs are used.
use_bfloat16: bool, use bfloat16 instead of float32.
dropout: float, dropout rate. dropout: float, dropout rate.
dropout_att: float, dropout rate on attention probabilities. dropout_att: float, dropout rate on attention probabilities.
init_method: str, the initialization scheme, either "normal" or "uniform". init_method: str, the initialization scheme, either "normal" or "uniform".
...@@ -164,6 +163,7 @@ class RunConfig(object): ...@@ -164,6 +163,7 @@ class RunConfig(object):
-1 means no clamping. -1 means no clamping.
same_length: bool, whether to use the same attention length same_length: bool, whether to use the same attention length
for each token. for each token.
use_cls_mask: bool, whether to introduce cls mask.
""" """
self.init_method = init_method self.init_method = init_method
...@@ -173,9 +173,9 @@ class RunConfig(object): ...@@ -173,9 +173,9 @@ class RunConfig(object):
self.dropout = dropout self.dropout = dropout
self.dropout_att = dropout_att self.dropout_att = dropout_att
self.use_tpu = use_tpu self.use_tpu = use_tpu
self.use_bfloat16 = use_bfloat16
self.mem_len = mem_len self.mem_len = mem_len
self.reuse_len = reuse_len self.reuse_len = reuse_len
self.bi_data = bi_data self.bi_data = bi_data
self.clamp_len = clamp_len self.clamp_len = clamp_len
self.same_length = same_length self.same_length = same_length
self.use_cls_mask = use_cls_mask
...@@ -23,6 +23,7 @@ import copy ...@@ -23,6 +23,7 @@ import copy
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.nlp.xlnet import data_utils
def gelu(x): def gelu(x):
...@@ -96,19 +97,6 @@ def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None): ...@@ -96,19 +97,6 @@ def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
return tf.keras.backend.stop_gradient(new_mem) return tf.keras.backend.stop_gradient(new_mem)
def embedding_lookup(lookup_table, x, use_tpu=True):
"""Looks up words embeddings for input id tensor."""
if use_tpu:
n_token = tf.shape(lookup_table)[0]
one_hot_idx = tf.one_hot(x, n_token)
if one_hot_idx.shape.ndims == 2:
return tf.einsum('nd,in->id', lookup_table, one_hot_idx)
else:
return tf.einsum('nd,ibn->ibd', lookup_table, one_hot_idx)
else:
return tf.nn.embedding_lookup(lookup_table, x)
def is_special_none_tensor(tensor): def is_special_none_tensor(tensor):
"""Checks if a tensor is a special None Tensor.""" """Checks if a tensor is a special None Tensor."""
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32 return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
...@@ -169,7 +157,7 @@ class PositionalEmbedding(tf.keras.layers.Layer): ...@@ -169,7 +157,7 @@ class PositionalEmbedding(tf.keras.layers.Layer):
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
"""Constructs inversed frequency vector for positional embedding layer.""" """Constructs inversed frequency vector for positional embedding layer."""
self.inv_freq = 1.0 / (10000.0 ** (tf.range(0, self.dim, 2.0) / self.dim)) self.inv_freq = 1.0 / (10000.0**(tf.range(0, self.dim, 2.0) / self.dim))
super(PositionalEmbedding, self).build(unused_input_shapes) super(PositionalEmbedding, self).build(unused_input_shapes)
def __call__(self, pos_seq, batch_size): def __call__(self, pos_seq, batch_size):
...@@ -232,8 +220,12 @@ class RelativeAttention(tf.keras.layers.Layer): ...@@ -232,8 +220,12 @@ class RelativeAttention(tf.keras.layers.Layer):
if seg_mat is None: if seg_mat is None:
ef = 0 ef = 0
else: else:
ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed) ef = tf.einsum('ibnd,snd->isbn', q_head + r_s_bias, seg_embed)
ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef) tgt_shape = tf.shape(bd)
ef = tf.where(
tf.broadcast_to(tf.expand_dims(seg_mat, 3), tgt_shape),
tf.broadcast_to(ef[:, 1:, :, :], tgt_shape),
tf.broadcast_to(ef[:, :1, :, :], tgt_shape))
# merges attention scores and performs masking # merges attention scores and performs masking
attn_score = (ac + bd + ef) * self.scale attn_score = (ac + bd + ef) * self.scale
...@@ -253,8 +245,8 @@ class RelativeAttention(tf.keras.layers.Layer): ...@@ -253,8 +245,8 @@ class RelativeAttention(tf.keras.layers.Layer):
class PositionwiseFF(tf.keras.layers.Layer): class PositionwiseFF(tf.keras.layers.Layer):
"""Positionwise feed-forward layer.""" """Positionwise feed-forward layer."""
def __init__(self, d_model, d_inner, dropout, def __init__(self, d_model, d_inner, dropout, kernel_initializer,
kernel_initializer, activation_type, **kwargs): activation_type, **kwargs):
super(PositionwiseFF, self).__init__(**kwargs) super(PositionwiseFF, self).__init__(**kwargs)
self.d_model = d_model self.d_model = d_model
self.d_inner = d_inner self.d_inner = d_inner
...@@ -282,10 +274,8 @@ class PositionwiseFF(tf.keras.layers.Layer): ...@@ -282,10 +274,8 @@ class PositionwiseFF(tf.keras.layers.Layer):
units=self.d_model, units=self.d_model,
kernel_initializer=self.kernel_initializer, kernel_initializer=self.kernel_initializer,
name='layer_2')) name='layer_2'))
self.inner_dropout = tf.keras.layers.Dropout(rate=self.dropout, self.output_dropout = tf.keras.layers.Dropout(
name='drop_1') rate=self.dropout, name='drop_2')
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout,
name='drop_2')
self.output_layer_norm = ( self.output_layer_norm = (
tf.keras.layers.LayerNormalization( tf.keras.layers.LayerNormalization(
name='LayerNorm', axis=-1, epsilon=1e-12)) name='LayerNorm', axis=-1, epsilon=1e-12))
...@@ -295,7 +285,6 @@ class PositionwiseFF(tf.keras.layers.Layer): ...@@ -295,7 +285,6 @@ class PositionwiseFF(tf.keras.layers.Layer):
"""Implements call() for the layer.""" """Implements call() for the layer."""
output = self.inner_projection_layer(inp) output = self.inner_projection_layer(inp)
output = self.inner_dropout(output)
output = self.output_projection_layer(output) output = self.output_projection_layer(output)
output = self.output_dropout(output) output = self.output_dropout(output)
output = self.output_layer_norm(output + inp) output = self.output_layer_norm(output + inp)
...@@ -305,14 +294,11 @@ class PositionwiseFF(tf.keras.layers.Layer): ...@@ -305,14 +294,11 @@ class PositionwiseFF(tf.keras.layers.Layer):
class EmbeddingLookup(tf.keras.layers.Layer): class EmbeddingLookup(tf.keras.layers.Layer):
"""Looks up words embeddings for id tensor.""" """Looks up words embeddings for id tensor."""
def __init__(self, def __init__(self, n_token, d_embed, initializer, **kwargs):
n_token, d_embed, initializer,
use_one_hot=False, **kwargs):
super(EmbeddingLookup, self).__init__(**kwargs) super(EmbeddingLookup, self).__init__(**kwargs)
self.n_token = n_token self.n_token = n_token
self.d_embed = d_embed self.d_embed = d_embed
self.initializer = initializer self.initializer = initializer
self.use_one_hot = use_one_hot
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
"""Implements build() for the layer.""" """Implements build() for the layer."""
...@@ -325,20 +311,7 @@ class EmbeddingLookup(tf.keras.layers.Layer): ...@@ -325,20 +311,7 @@ class EmbeddingLookup(tf.keras.layers.Layer):
super(EmbeddingLookup, self).build(unused_input_shapes) super(EmbeddingLookup, self).build(unused_input_shapes)
def call(self, inputs): def call(self, inputs):
x = inputs return tf.nn.embedding_lookup(self.lookup_table, inputs)
if self.use_one_hot:
one_hot_idx = tf.one_hot(x, self.n_token, dtype=self.dtype)
if one_hot_idx.shape.ndims == 2:
return tf.einsum('in,nd->id',
one_hot_idx,
self.lookup_table), self.lookup_table
else:
return tf.einsum('ibn,nd->ibd',
one_hot_idx,
self.lookup_table), self.lookup_table
else:
return tf.nn.embedding_lookup(self.lookup_table, x), self.lookup_table
class TwoStreamRelativeAttention(tf.keras.layers.Layer): class TwoStreamRelativeAttention(tf.keras.layers.Layer):
...@@ -356,9 +329,10 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer): ...@@ -356,9 +329,10 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer):
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
"""Implements build() for the layer.""" """Implements build() for the layer."""
self.scale = 1.0 / (self.d_head ** 0.5) self.scale = 1.0 / (self.d_head**0.5)
self.attention_projection_layer = tf.keras.layers.Dense( self.attention_projection_layer = tf.keras.layers.Dense(
units=self.d_model, use_bias=False, units=self.d_model,
use_bias=False,
kernel_initializer=self.initializer, kernel_initializer=self.initializer,
name='o') name='o')
self.attention_probs_dropout = tf.keras.layers.Dropout( self.attention_probs_dropout = tf.keras.layers.Dropout(
...@@ -403,9 +377,8 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer): ...@@ -403,9 +377,8 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer):
super(TwoStreamRelativeAttention, self).build(unused_input_shapes) super(TwoStreamRelativeAttention, self).build(unused_input_shapes)
def __call__(self, h, g, r, r_w_bias, r_r_bias, def __call__(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
seg_mat, r_s_bias, seg_embed, attn_mask_h, attn_mask_g, attn_mask_h, attn_mask_g, mems, target_mapping):
mems, target_mapping):
inputs = pack_inputs([ inputs = pack_inputs([
h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping attn_mask_g, mems, target_mapping
...@@ -455,15 +428,17 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer): ...@@ -455,15 +428,17 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer):
q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping) q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
attn_vec_g = self.g_attention_layer( attn_vec_g = self.g_attention_layer(q_head_g, k_head_h, v_head_h,
q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, k_head_r, seg_embed, seg_mat,
r_r_bias, r_s_bias, attn_mask_g) r_w_bias, r_r_bias, r_s_bias,
attn_mask_g)
attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping) attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
else: else:
attn_vec_g = self.g_attention_layer( attn_vec_g = self.g_attention_layer(q_head_g, k_head_h, v_head_h,
q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, k_head_r, seg_embed, seg_mat,
r_r_bias, r_s_bias, attn_mask_g) r_w_bias, r_r_bias, r_s_bias,
attn_mask_g)
# post processing # post processing
...@@ -491,7 +466,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer): ...@@ -491,7 +466,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
"""Implements build() for the layer.""" """Implements build() for the layer."""
self.scale = 1.0 / (self.d_head ** 0.5) self.scale = 1.0 / (self.d_head**0.5)
self.output_layer_norm = tf.keras.layers.LayerNormalization( self.output_layer_norm = tf.keras.layers.LayerNormalization(
name='LayerNorm', axis=-1, epsilon=1e-12) name='LayerNorm', axis=-1, epsilon=1e-12)
...@@ -555,9 +530,9 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer): ...@@ -555,9 +530,9 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer) k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer)
# core attention ops # core attention ops
attn_vec = self.h_attention_layer( attn_vec = self.h_attention_layer(q_head_h, k_head_h, v_head_h, k_head_r,
q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias, seg_embed, seg_mat, r_w_bias, r_r_bias,
r_r_bias, r_s_bias, attn_mask) r_s_bias, attn_mask)
# post processing # post processing
...@@ -592,7 +567,7 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -592,7 +567,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
use_tpu=True, use_tpu=True,
reuse_len=None, reuse_len=None,
ff_activation='relu', ff_activation='relu',
use_bfloat16=False, use_cls_mask=False,
**kwargs): **kwargs):
"""Initializes TransformerXLModel. """Initializes TransformerXLModel.
...@@ -620,7 +595,7 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -620,7 +595,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
reuse_len: int, the number of tokens in the currect batch to be cached and reuse_len: int, the number of tokens in the currect batch to be cached and
reused in the future. reused in the future.
ff_activation: str, "relu" or "gelu". ff_activation: str, "relu" or "gelu".
use_bfloat16: bool, use bfloat16 instead of float32. use_cls_mask: bool, whether to introduce cls mask.
**kwargs: Other parameters. **kwargs: Other parameters.
""" """
...@@ -636,7 +611,6 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -636,7 +611,6 @@ class TransformerXLModel(tf.keras.layers.Layer):
self.d_inner = d_inner self.d_inner = d_inner
self.ff_activation = ff_activation self.ff_activation = ff_activation
self.untie_r = untie_r self.untie_r = untie_r
self.use_bfloat16 = use_bfloat16
self.use_tpu = use_tpu self.use_tpu = use_tpu
self.dropout = dropout self.dropout = dropout
self.dropout_att = dropout_att self.dropout_att = dropout_att
...@@ -646,21 +620,21 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -646,21 +620,21 @@ class TransformerXLModel(tf.keras.layers.Layer):
self.bi_data = bi_data self.bi_data = bi_data
self.clamp_len = clamp_len self.clamp_len = clamp_len
self.same_length = same_length self.same_length = same_length
self.use_cls_mask = use_cls_mask
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
"""Implements build() for the layer.""" """Implements build() for the layer."""
self.tf_float = tf.bfloat16 if self.use_bfloat16 else tf.float32 self.tf_float = tf.float32
self.embedding_lookup = EmbeddingLookup(n_token=self.n_token, self.embedding_lookup = EmbeddingLookup(
d_embed=self.d_model, n_token=self.n_token,
initializer=self.initializer, d_embed=self.d_model,
use_one_hot=self.use_tpu, initializer=self.initializer,
dtype=self.tf_float, dtype=self.tf_float,
name='word_embedding') name='word_embedding')
self.h_dropout = tf.keras.layers.Dropout(rate=self.dropout) self.h_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.g_dropout = tf.keras.layers.Dropout(rate=self.dropout) self.g_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout)
if self.untie_r: if self.untie_r:
self.r_w_bias = ( self.r_w_bias = (
...@@ -702,11 +676,11 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -702,11 +676,11 @@ class TransformerXLModel(tf.keras.layers.Layer):
self.seg_embed = self.add_weight( self.seg_embed = self.add_weight(
'seg_embed', [self.n_layer, 2, self.n_head, self.d_head], 'seg_embed', [self.n_layer, 2, self.n_head, self.d_head],
dtype=self.tf_float, initializer=self.initializer) dtype=self.tf_float,
initializer=self.initializer)
self.mask_emb = self.add_weight('mask_emb/mask_emb', self.mask_emb = self.add_weight(
shape=[1, 1, self.d_model], 'mask_emb/mask_emb', shape=[1, 1, self.d_model], dtype=self.tf_float)
dtype=self.tf_float)
self.emb_dropout = tf.keras.layers.Dropout(rate=self.dropout) self.emb_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.fwd_position_embedding = PositionalEmbedding(self.d_model) self.fwd_position_embedding = PositionalEmbedding(self.d_model)
...@@ -741,16 +715,16 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -741,16 +715,16 @@ class TransformerXLModel(tf.keras.layers.Layer):
d_inner=self.d_inner, d_inner=self.d_inner,
dropout=self.dropout, dropout=self.dropout,
kernel_initializer=self.initializer, kernel_initializer=self.initializer,
activation_type=self.ff_activation, name='layer_%d/ff'%(i)) activation_type=self.ff_activation,
) name='layer_%d/ff' % (i)))
self.h_positionwise_ffn_layers.append( self.h_positionwise_ffn_layers.append(
PositionwiseFF( PositionwiseFF(
d_model=self.d_model, d_model=self.d_model,
d_inner=self.d_inner, d_inner=self.d_inner,
dropout=self.dropout, dropout=self.dropout,
kernel_initializer=self.initializer, kernel_initializer=self.initializer,
activation_type=self.ff_activation, name='layer_%d/ff'%(i)) activation_type=self.ff_activation,
) name='layer_%d/ff' % (i)))
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout)
...@@ -766,9 +740,15 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -766,9 +740,15 @@ class TransformerXLModel(tf.keras.layers.Layer):
inp_q=None): inp_q=None):
# Uses dict to feed inputs into call() in order to keep mems as a python # Uses dict to feed inputs into call() in order to keep mems as a python
# list. # list.
inputs = {'inp_k': inp_k, 'seg_id': seg_id, 'input_mask': input_mask, inputs = {
'mems': mems, 'perm_mask': perm_mask, 'inp_k': inp_k,
'target_mapping': target_mapping, 'inp_q': inp_q} 'seg_id': seg_id,
'input_mask': input_mask,
'mems': mems,
'perm_mask': perm_mask,
'target_mapping': target_mapping,
'inp_q': inp_q
}
return super(TransformerXLModel, self).__call__(inputs) return super(TransformerXLModel, self).__call__(inputs)
def call(self, inputs): def call(self, inputs):
...@@ -827,14 +807,14 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -827,14 +807,14 @@ class TransformerXLModel(tf.keras.layers.Layer):
if attn_mask is not None: if attn_mask is not None:
non_tgt_mask = -tf.eye(qlen, dtype=self.tf_float) non_tgt_mask = -tf.eye(qlen, dtype=self.tf_float)
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=self.tf_float), non_tgt_mask = tf.concat(
non_tgt_mask], axis=-1) [tf.zeros([qlen, mlen], dtype=self.tf_float), non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, non_tgt_mask = tf.cast(
dtype=self.tf_float) (attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=self.tf_float)
else: else:
non_tgt_mask = None non_tgt_mask = None
word_emb_k, _ = self.embedding_lookup(inp_k) word_emb_k = self.embedding_lookup(inp_k)
if inp_q is not None: if inp_q is not None:
if target_mapping is not None: if target_mapping is not None:
...@@ -855,15 +835,18 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -855,15 +835,18 @@ class TransformerXLModel(tf.keras.layers.Layer):
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32) mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
cat_ids = tf.concat([mem_pad, seg_id], 0) cat_id = tf.concat([mem_pad, seg_id], 0)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast(
tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])),
tf.int32)
seg_mat = tf.one_hot(seg_mat, 2, dtype=self.tf_float)
if self.use_cls_mask:
# `1` indicates not in the same segment [qlen x klen x bsz]
# seg_id: [qlen x bsz] & cat_id: [klen x bsz]
cls_mat = tf.logical_or(
tf.equal(seg_id, tf.constant([data_utils.SEG_ID_CLS]))[:, None],
tf.equal(cat_id, tf.constant([data_utils.SEG_ID_CLS]))[None, :])
seg_mat = tf.equal(seg_id[:, None], cat_id[None, :])
seg_mat = tf.logical_or(cls_mat, seg_mat)
else:
seg_mat = tf.logical_not(tf.equal(seg_id[:, None], cat_id[None, :]))
else: else:
seg_mat = None seg_mat = None
...@@ -894,8 +877,8 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -894,8 +877,8 @@ class TransformerXLModel(tf.keras.layers.Layer):
self.clamp_len) self.clamp_len)
if bsz is not None: if bsz is not None:
fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz//2) fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz // 2)
bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, bsz//2) bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, bsz // 2)
else: else:
fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, None) fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, None)
bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, None) bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, None)
...@@ -906,8 +889,8 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -906,8 +889,8 @@ class TransformerXLModel(tf.keras.layers.Layer):
if dtype is not None and dtype != tf.float32: if dtype is not None and dtype != tf.float32:
fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype) fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
if self.clamp_len > 0: if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len,
-self.clamp_len, self.lamp_len) self.lamp_len)
pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz) pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz)
...@@ -969,9 +952,9 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -969,9 +952,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
output_h = h_ffn_layer(output_h) output_h = h_ffn_layer(output_h)
if inp_q is not None: if inp_q is not None:
output = self.output_dropout(output_g) output = output_g
else: else:
output = self.output_dropout(output_h) output = output_h
return output, new_mems, None return output, new_mems, None
...@@ -983,7 +966,7 @@ class PretrainingXLNetModel(tf.keras.Model): ...@@ -983,7 +966,7 @@ class PretrainingXLNetModel(tf.keras.Model):
""" """
def __init__(self, xlnet_config, run_config, **kwargs): def __init__(self, use_proj, xlnet_config, run_config, **kwargs):
super(PretrainingXLNetModel, self).__init__(**kwargs) super(PretrainingXLNetModel, self).__init__(**kwargs)
self.run_config = run_config self.run_config = run_config
self.initializer = _get_initializer(run_config) self.initializer = _get_initializer(run_config)
...@@ -1001,7 +984,6 @@ class PretrainingXLNetModel(tf.keras.Model): ...@@ -1001,7 +984,6 @@ class PretrainingXLNetModel(tf.keras.Model):
ff_activation=self.xlnet_config.ff_activation, ff_activation=self.xlnet_config.ff_activation,
untie_r=self.xlnet_config.untie_r, untie_r=self.xlnet_config.untie_r,
is_training=self.run_config.is_training, is_training=self.run_config.is_training,
use_bfloat16=self.run_config.use_bfloat16,
use_tpu=self.run_config.use_tpu, use_tpu=self.run_config.use_tpu,
dropout=self.run_config.dropout, dropout=self.run_config.dropout,
dropout_att=self.run_config.dropout_att, dropout_att=self.run_config.dropout_att,
...@@ -1010,15 +992,17 @@ class PretrainingXLNetModel(tf.keras.Model): ...@@ -1010,15 +992,17 @@ class PretrainingXLNetModel(tf.keras.Model):
bi_data=self.run_config.bi_data, bi_data=self.run_config.bi_data,
clamp_len=self.run_config.clamp_len, clamp_len=self.run_config.clamp_len,
same_length=self.run_config.same_length, same_length=self.run_config.same_length,
use_cls_mask=self.run_config.use_cls_mask,
name='transformer') name='transformer')
self.lmloss_layer = LMLossLayer(n_token=self.xlnet_config.n_token, self.lmloss_layer = LMLossLayer(
d_model=self.xlnet_config.d_model, n_token=self.xlnet_config.n_token,
initializer=self.initializer, d_model=self.xlnet_config.d_model,
use_bfloat16=self.run_config.use_bfloat16, initializer=self.initializer,
tie_weight=True, tie_weight=True,
bi_data=self.run_config.bi_data, bi_data=self.run_config.bi_data,
use_tpu=self.run_config.use_tpu, use_tpu=self.run_config.use_tpu,
name='lm_loss') use_proj=use_proj,
name='lm_loss')
def call(self, features): def call(self, features):
"""Implements call() for the layer.""" """Implements call() for the layer."""
...@@ -1082,7 +1066,6 @@ class ClassificationXLNetModel(tf.keras.Model): ...@@ -1082,7 +1066,6 @@ class ClassificationXLNetModel(tf.keras.Model):
ff_activation=self.xlnet_config.ff_activation, ff_activation=self.xlnet_config.ff_activation,
untie_r=self.xlnet_config.untie_r, untie_r=self.xlnet_config.untie_r,
is_training=self.run_config.is_training, is_training=self.run_config.is_training,
use_bfloat16=self.run_config.use_bfloat16,
use_tpu=self.run_config.use_tpu, use_tpu=self.run_config.use_tpu,
dropout=self.run_config.dropout, dropout=self.run_config.dropout,
dropout_att=self.run_config.dropout_att, dropout_att=self.run_config.dropout_att,
...@@ -1133,23 +1116,28 @@ class ClassificationXLNetModel(tf.keras.Model): ...@@ -1133,23 +1116,28 @@ class ClassificationXLNetModel(tf.keras.Model):
class LMLossLayer(tf.keras.layers.Layer): class LMLossLayer(tf.keras.layers.Layer):
"""Layer computing cross entropy loss for language modeling.""" """Layer computing cross entropy loss for language modeling."""
def __init__(self, n_token, d_model, initializer, use_bfloat16, def __init__(self,
tie_weight=False, bi_data=True, use_tpu=False, **kwargs): n_token,
d_model,
initializer,
tie_weight=False,
bi_data=True,
use_tpu=False,
use_proj=False,
**kwargs):
"""Constructs LMLoss layer. """Constructs LMLoss layer.
Args: Args:
n_token: Number of tokens in vocabulary. n_token: Number of tokens in vocabulary.
d_model: The dimension of model hidden state. d_model: The dimension of model hidden state.
initializer: Initializer used for parameters. initializer: Initializer used for parameters.
use_bfloat16: Whether to use bfloat16.
tie_weight: Whether to share weights between embedding lookup layer and tie_weight: Whether to share weights between embedding lookup layer and
next-token prediction layer. next-token prediction layer.
bi_data: Whether to use bidirectional input pipeline. bi_data: Whether to use bidirectional input pipeline. Usually set to True
Usually set to True during pretraining and False during finetuning. during pretraining and False during finetuning.
use_tpu: bool, whether to use TPU. use_tpu: bool, whether to use TPU.
use_proj: bool, whether to add a projection layer before LM prediction.
**kwargs: Other parameters. **kwargs: Other parameters.
""" """
super(LMLossLayer, self).__init__(**kwargs) super(LMLossLayer, self).__init__(**kwargs)
self.n_token = n_token self.n_token = n_token
...@@ -1159,17 +1147,26 @@ class LMLossLayer(tf.keras.layers.Layer): ...@@ -1159,17 +1147,26 @@ class LMLossLayer(tf.keras.layers.Layer):
self.tie_weight = tie_weight self.tie_weight = tie_weight
self.bi_data = bi_data self.bi_data = bi_data
self.use_tpu = use_tpu self.use_tpu = use_tpu
self.use_bfloat16 = use_bfloat16 self.use_proj = use_proj
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
"""Implements build() for the layer.""" """Implements build() for the layer."""
if self.use_proj:
self.proj_layer = tf.keras.layers.Dense(
units=self.d_model,
kernel_initializer=self.initializer,
activation=gelu,
name='lm_projection')
self.proj_layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='lm_projection/LayerNorm')
if not self.tie_weight: if not self.tie_weight:
self.softmax_w = self.add_weight('weight', self.softmax_w = self.add_weight(
shape=[self.n_token, self.d_model], 'weight',
initializer=self.initializer) shape=[self.n_token, self.d_model],
initializer=self.initializer)
self.softmax_b = self.add_weight('bias', shape=[self.n_token], self.softmax_b = self.add_weight(
initializer=tf.zeros_initializer()) 'bias', shape=[self.n_token], initializer=tf.zeros_initializer())
super(LMLossLayer, self).build(unused_input_shapes) super(LMLossLayer, self).build(unused_input_shapes)
...@@ -1180,6 +1177,8 @@ class LMLossLayer(tf.keras.layers.Layer): ...@@ -1180,6 +1177,8 @@ class LMLossLayer(tf.keras.layers.Layer):
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
(hidden, target, lookup_table, tgt_mask) = unpack_inputs(inputs) (hidden, target, lookup_table, tgt_mask) = unpack_inputs(inputs)
if self.use_proj:
hidden = self.proj_layer_norm(self.proj_layer(hidden))
if self.tie_weight: if self.tie_weight:
logits = tf.einsum('ibd,nd->ibn', hidden, lookup_table) + self.softmax_b logits = tf.einsum('ibd,nd->ibn', hidden, lookup_table) + self.softmax_b
else: else:
...@@ -1189,11 +1188,8 @@ class LMLossLayer(tf.keras.layers.Layer): ...@@ -1189,11 +1188,8 @@ class LMLossLayer(tf.keras.layers.Layer):
one_hot_target = tf.one_hot(target, self.n_token, dtype=logits.dtype) one_hot_target = tf.one_hot(target, self.n_token, dtype=logits.dtype)
loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1) loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
else: else:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits) labels=target, logits=logits)
if self.use_bfloat16:
tgt_mask = tf.cast(tgt_mask, tf.float32)
loss = tf.cast(loss, tf.float32)
total_loss = tf.reduce_sum(loss * tgt_mask) / tf.reduce_sum(tgt_mask) total_loss = tf.reduce_sum(loss * tgt_mask) / tf.reduce_sum(tgt_mask)
...@@ -1321,7 +1317,6 @@ class QAXLNetModel(tf.keras.Model): ...@@ -1321,7 +1317,6 @@ class QAXLNetModel(tf.keras.Model):
ff_activation=self.xlnet_config.ff_activation, ff_activation=self.xlnet_config.ff_activation,
untie_r=self.xlnet_config.untie_r, untie_r=self.xlnet_config.untie_r,
is_training=self.run_config.is_training, is_training=self.run_config.is_training,
use_bfloat16=self.run_config.use_bfloat16,
use_tpu=self.run_config.use_tpu, use_tpu=self.run_config.use_tpu,
dropout=self.run_config.dropout, dropout=self.run_config.dropout,
dropout_att=self.run_config.dropout_att, dropout_att=self.run_config.dropout_att,
...@@ -1370,8 +1365,7 @@ class QAXLNetModel(tf.keras.Model): ...@@ -1370,8 +1365,7 @@ class QAXLNetModel(tf.keras.Model):
class QALossLayer(tf.keras.layers.Layer): class QALossLayer(tf.keras.layers.Layer):
"""Layer computing position and regression loss for question answering task. """Layer computing position and regression loss for question answering task."""
"""
def __init__(self, d_model, start_n_top, end_n_top, initializer, dropout, def __init__(self, d_model, start_n_top, end_n_top, initializer, dropout,
**kwargs): **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment