Commit b045ce7d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 272777104
parent 0f176f6f
......@@ -43,6 +43,10 @@ CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]
SEG_ID_P = 0
SEG_ID_Q = 1
SEG_ID_CLS = 2
SEG_ID_PAD = 3
def file_based_input_fn_builder(input_file, name_to_features, batch_size,
......
......@@ -48,8 +48,11 @@ FLAGS = flags.FLAGS
def get_pretrainxlnet_model(model_config, run_config):
model = modeling.PretrainingXLNetModel(model_config, run_config, name="model")
return model
return modeling.PretrainingXLNetModel(
use_proj=True,
xlnet_config=model_config,
run_config=run_config,
name="model")
def main(unused_argv):
......@@ -69,8 +72,7 @@ def main(unused_argv):
if strategy:
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
logging.info("***** Number of hosts used : %d",
num_hosts)
logging.info("***** Number of hosts used : %d", num_hosts)
train_input_fn = functools.partial(
data_utils.get_pretrain_input_data, FLAGS.train_batch_size, FLAGS.seq_len,
strategy, FLAGS.train_tfrecord_path, FLAGS.reuse_len, FLAGS.perm_size,
......
......@@ -36,11 +36,6 @@ from official.nlp.xlnet import preprocess_utils
SPIECE_UNDERLINE = u"▁"
SEG_ID_P = 0
SEG_ID_Q = 1
SEG_ID_CLS = 2
SEG_ID_PAD = 3
class InputFeatures(object):
"""A single set of features of data."""
......@@ -705,28 +700,28 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(SEG_ID_P)
segment_ids.append(data_utils.SEG_ID_P)
p_mask.append(0)
paragraph_len = len(tokens)
tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_P)
segment_ids.append(data_utils.SEG_ID_P)
p_mask.append(1)
# note(zhiliny): we put P before Q
# because during pretraining, B is always shorter than A
for token in query_tokens:
tokens.append(token)
segment_ids.append(SEG_ID_Q)
segment_ids.append(data_utils.SEG_ID_Q)
p_mask.append(1)
tokens.append(data_utils.SEP_ID)
segment_ids.append(SEG_ID_Q)
segment_ids.append(data_utils.SEG_ID_Q)
p_mask.append(1)
cls_index = len(segment_ids)
tokens.append(data_utils.CLS_ID)
segment_ids.append(SEG_ID_CLS)
segment_ids.append(data_utils.SEG_ID_CLS)
p_mask.append(0)
input_ids = tokens
......@@ -739,7 +734,7 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(1)
segment_ids.append(SEG_ID_PAD)
segment_ids.append(data_utils.SEG_ID_PAD)
p_mask.append(1)
assert len(input_ids) == max_seq_length
......
......@@ -30,7 +30,6 @@ def create_run_config(is_training, is_finetune, flags):
kwargs = dict(
is_training=is_training,
use_tpu=flags.use_tpu,
use_bfloat16=flags.use_bfloat16,
dropout=flags.dropout,
dropout_att=flags.dropout_att,
init_method=flags.init_method,
......@@ -49,6 +48,7 @@ def create_run_config(is_training, is_finetune, flags):
return RunConfig(**kwargs)
# TODO(hongkuny): refactor XLNetConfig and RunConfig.
class XLNetConfig(object):
"""Configs for XLNet model.
......@@ -131,7 +131,6 @@ class RunConfig(object):
def __init__(self,
is_training,
use_tpu,
use_bfloat16,
dropout,
dropout_att,
init_method='normal',
......@@ -141,13 +140,13 @@ class RunConfig(object):
reuse_len=None,
bi_data=False,
clamp_len=-1,
same_length=False):
same_length=False,
use_cls_mask=True):
"""Initializes RunConfig.
Args:
is_training: bool, whether in training mode.
use_tpu: bool, whether TPUs are used.
use_bfloat16: bool, use bfloat16 instead of float32.
dropout: float, dropout rate.
dropout_att: float, dropout rate on attention probabilities.
init_method: str, the initialization scheme, either "normal" or "uniform".
......@@ -164,6 +163,7 @@ class RunConfig(object):
-1 means no clamping.
same_length: bool, whether to use the same attention length
for each token.
use_cls_mask: bool, whether to introduce cls mask.
"""
self.init_method = init_method
......@@ -173,9 +173,9 @@ class RunConfig(object):
self.dropout = dropout
self.dropout_att = dropout_att
self.use_tpu = use_tpu
self.use_bfloat16 = use_bfloat16
self.mem_len = mem_len
self.reuse_len = reuse_len
self.bi_data = bi_data
self.clamp_len = clamp_len
self.same_length = same_length
self.use_cls_mask = use_cls_mask
......@@ -23,6 +23,7 @@ import copy
import numpy as np
import tensorflow as tf
from official.nlp.xlnet import data_utils
def gelu(x):
......@@ -96,19 +97,6 @@ def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
return tf.keras.backend.stop_gradient(new_mem)
def embedding_lookup(lookup_table, x, use_tpu=True):
"""Looks up words embeddings for input id tensor."""
if use_tpu:
n_token = tf.shape(lookup_table)[0]
one_hot_idx = tf.one_hot(x, n_token)
if one_hot_idx.shape.ndims == 2:
return tf.einsum('nd,in->id', lookup_table, one_hot_idx)
else:
return tf.einsum('nd,ibn->ibd', lookup_table, one_hot_idx)
else:
return tf.nn.embedding_lookup(lookup_table, x)
def is_special_none_tensor(tensor):
"""Checks if a tensor is a special None Tensor."""
return tensor.shape.ndims == 0 and tensor.dtype == tf.int32
......@@ -169,7 +157,7 @@ class PositionalEmbedding(tf.keras.layers.Layer):
def build(self, unused_input_shapes):
"""Constructs inversed frequency vector for positional embedding layer."""
self.inv_freq = 1.0 / (10000.0 ** (tf.range(0, self.dim, 2.0) / self.dim))
self.inv_freq = 1.0 / (10000.0**(tf.range(0, self.dim, 2.0) / self.dim))
super(PositionalEmbedding, self).build(unused_input_shapes)
def __call__(self, pos_seq, batch_size):
......@@ -232,8 +220,12 @@ class RelativeAttention(tf.keras.layers.Layer):
if seg_mat is None:
ef = 0
else:
ef = tf.einsum('ibnd,snd->ibns', q_head + r_s_bias, seg_embed)
ef = tf.einsum('ijbs,ibns->ijbn', seg_mat, ef)
ef = tf.einsum('ibnd,snd->isbn', q_head + r_s_bias, seg_embed)
tgt_shape = tf.shape(bd)
ef = tf.where(
tf.broadcast_to(tf.expand_dims(seg_mat, 3), tgt_shape),
tf.broadcast_to(ef[:, 1:, :, :], tgt_shape),
tf.broadcast_to(ef[:, :1, :, :], tgt_shape))
# merges attention scores and performs masking
attn_score = (ac + bd + ef) * self.scale
......@@ -253,8 +245,8 @@ class RelativeAttention(tf.keras.layers.Layer):
class PositionwiseFF(tf.keras.layers.Layer):
"""Positionwise feed-forward layer."""
def __init__(self, d_model, d_inner, dropout,
kernel_initializer, activation_type, **kwargs):
def __init__(self, d_model, d_inner, dropout, kernel_initializer,
activation_type, **kwargs):
super(PositionwiseFF, self).__init__(**kwargs)
self.d_model = d_model
self.d_inner = d_inner
......@@ -282,10 +274,8 @@ class PositionwiseFF(tf.keras.layers.Layer):
units=self.d_model,
kernel_initializer=self.kernel_initializer,
name='layer_2'))
self.inner_dropout = tf.keras.layers.Dropout(rate=self.dropout,
name='drop_1')
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout,
name='drop_2')
self.output_dropout = tf.keras.layers.Dropout(
rate=self.dropout, name='drop_2')
self.output_layer_norm = (
tf.keras.layers.LayerNormalization(
name='LayerNorm', axis=-1, epsilon=1e-12))
......@@ -295,7 +285,6 @@ class PositionwiseFF(tf.keras.layers.Layer):
"""Implements call() for the layer."""
output = self.inner_projection_layer(inp)
output = self.inner_dropout(output)
output = self.output_projection_layer(output)
output = self.output_dropout(output)
output = self.output_layer_norm(output + inp)
......@@ -305,14 +294,11 @@ class PositionwiseFF(tf.keras.layers.Layer):
class EmbeddingLookup(tf.keras.layers.Layer):
"""Looks up words embeddings for id tensor."""
def __init__(self,
n_token, d_embed, initializer,
use_one_hot=False, **kwargs):
def __init__(self, n_token, d_embed, initializer, **kwargs):
super(EmbeddingLookup, self).__init__(**kwargs)
self.n_token = n_token
self.d_embed = d_embed
self.initializer = initializer
self.use_one_hot = use_one_hot
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
......@@ -325,20 +311,7 @@ class EmbeddingLookup(tf.keras.layers.Layer):
super(EmbeddingLookup, self).build(unused_input_shapes)
def call(self, inputs):
x = inputs
if self.use_one_hot:
one_hot_idx = tf.one_hot(x, self.n_token, dtype=self.dtype)
if one_hot_idx.shape.ndims == 2:
return tf.einsum('in,nd->id',
one_hot_idx,
self.lookup_table), self.lookup_table
else:
return tf.einsum('ibn,nd->ibd',
one_hot_idx,
self.lookup_table), self.lookup_table
else:
return tf.nn.embedding_lookup(self.lookup_table, x), self.lookup_table
return tf.nn.embedding_lookup(self.lookup_table, inputs)
class TwoStreamRelativeAttention(tf.keras.layers.Layer):
......@@ -356,9 +329,10 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer):
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.scale = 1.0 / (self.d_head ** 0.5)
self.scale = 1.0 / (self.d_head**0.5)
self.attention_projection_layer = tf.keras.layers.Dense(
units=self.d_model, use_bias=False,
units=self.d_model,
use_bias=False,
kernel_initializer=self.initializer,
name='o')
self.attention_probs_dropout = tf.keras.layers.Dropout(
......@@ -403,9 +377,8 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer):
super(TwoStreamRelativeAttention, self).build(unused_input_shapes)
def __call__(self, h, g, r, r_w_bias, r_r_bias,
seg_mat, r_s_bias, seg_embed, attn_mask_h, attn_mask_g,
mems, target_mapping):
def __call__(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
attn_mask_h, attn_mask_g, mems, target_mapping):
inputs = pack_inputs([
h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping
......@@ -455,15 +428,17 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer):
q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
attn_vec_g = self.g_attention_layer(
q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask_g)
attn_vec_g = self.g_attention_layer(q_head_g, k_head_h, v_head_h,
k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias,
attn_mask_g)
attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
else:
attn_vec_g = self.g_attention_layer(
q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask_g)
attn_vec_g = self.g_attention_layer(q_head_g, k_head_h, v_head_h,
k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias,
attn_mask_g)
# post processing
......@@ -491,7 +466,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.scale = 1.0 / (self.d_head ** 0.5)
self.scale = 1.0 / (self.d_head**0.5)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name='LayerNorm', axis=-1, epsilon=1e-12)
......@@ -555,9 +530,9 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer)
# core attention ops
attn_vec = self.h_attention_layer(
q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask)
attn_vec = self.h_attention_layer(q_head_h, k_head_h, v_head_h, k_head_r,
seg_embed, seg_mat, r_w_bias, r_r_bias,
r_s_bias, attn_mask)
# post processing
......@@ -592,7 +567,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
use_tpu=True,
reuse_len=None,
ff_activation='relu',
use_bfloat16=False,
use_cls_mask=False,
**kwargs):
"""Initializes TransformerXLModel.
......@@ -620,7 +595,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
reuse_len: int, the number of tokens in the currect batch to be cached and
reused in the future.
ff_activation: str, "relu" or "gelu".
use_bfloat16: bool, use bfloat16 instead of float32.
use_cls_mask: bool, whether to introduce cls mask.
**kwargs: Other parameters.
"""
......@@ -636,7 +611,6 @@ class TransformerXLModel(tf.keras.layers.Layer):
self.d_inner = d_inner
self.ff_activation = ff_activation
self.untie_r = untie_r
self.use_bfloat16 = use_bfloat16
self.use_tpu = use_tpu
self.dropout = dropout
self.dropout_att = dropout_att
......@@ -646,21 +620,21 @@ class TransformerXLModel(tf.keras.layers.Layer):
self.bi_data = bi_data
self.clamp_len = clamp_len
self.same_length = same_length
self.use_cls_mask = use_cls_mask
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.tf_float = tf.bfloat16 if self.use_bfloat16 else tf.float32
self.tf_float = tf.float32
self.embedding_lookup = EmbeddingLookup(n_token=self.n_token,
self.embedding_lookup = EmbeddingLookup(
n_token=self.n_token,
d_embed=self.d_model,
initializer=self.initializer,
use_one_hot=self.use_tpu,
dtype=self.tf_float,
name='word_embedding')
self.h_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.g_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout)
if self.untie_r:
self.r_w_bias = (
......@@ -702,11 +676,11 @@ class TransformerXLModel(tf.keras.layers.Layer):
self.seg_embed = self.add_weight(
'seg_embed', [self.n_layer, 2, self.n_head, self.d_head],
dtype=self.tf_float, initializer=self.initializer)
dtype=self.tf_float,
initializer=self.initializer)
self.mask_emb = self.add_weight('mask_emb/mask_emb',
shape=[1, 1, self.d_model],
dtype=self.tf_float)
self.mask_emb = self.add_weight(
'mask_emb/mask_emb', shape=[1, 1, self.d_model], dtype=self.tf_float)
self.emb_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.fwd_position_embedding = PositionalEmbedding(self.d_model)
......@@ -741,16 +715,16 @@ class TransformerXLModel(tf.keras.layers.Layer):
d_inner=self.d_inner,
dropout=self.dropout,
kernel_initializer=self.initializer,
activation_type=self.ff_activation, name='layer_%d/ff'%(i))
)
activation_type=self.ff_activation,
name='layer_%d/ff' % (i)))
self.h_positionwise_ffn_layers.append(
PositionwiseFF(
d_model=self.d_model,
d_inner=self.d_inner,
dropout=self.dropout,
kernel_initializer=self.initializer,
activation_type=self.ff_activation, name='layer_%d/ff'%(i))
)
activation_type=self.ff_activation,
name='layer_%d/ff' % (i)))
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout)
......@@ -766,9 +740,15 @@ class TransformerXLModel(tf.keras.layers.Layer):
inp_q=None):
# Uses dict to feed inputs into call() in order to keep mems as a python
# list.
inputs = {'inp_k': inp_k, 'seg_id': seg_id, 'input_mask': input_mask,
'mems': mems, 'perm_mask': perm_mask,
'target_mapping': target_mapping, 'inp_q': inp_q}
inputs = {
'inp_k': inp_k,
'seg_id': seg_id,
'input_mask': input_mask,
'mems': mems,
'perm_mask': perm_mask,
'target_mapping': target_mapping,
'inp_q': inp_q
}
return super(TransformerXLModel, self).__call__(inputs)
def call(self, inputs):
......@@ -827,14 +807,14 @@ class TransformerXLModel(tf.keras.layers.Layer):
if attn_mask is not None:
non_tgt_mask = -tf.eye(qlen, dtype=self.tf_float)
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen], dtype=self.tf_float),
non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0,
dtype=self.tf_float)
non_tgt_mask = tf.concat(
[tf.zeros([qlen, mlen], dtype=self.tf_float), non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast(
(attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=self.tf_float)
else:
non_tgt_mask = None
word_emb_k, _ = self.embedding_lookup(inp_k)
word_emb_k = self.embedding_lookup(inp_k)
if inp_q is not None:
if target_mapping is not None:
......@@ -855,15 +835,18 @@ class TransformerXLModel(tf.keras.layers.Layer):
mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
cat_ids = tf.concat([mem_pad, seg_id], 0)
cat_id = tf.concat([mem_pad, seg_id], 0)
if self.use_cls_mask:
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat = tf.cast(
tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])),
tf.int32)
seg_mat = tf.one_hot(seg_mat, 2, dtype=self.tf_float)
# seg_id: [qlen x bsz] & cat_id: [klen x bsz]
cls_mat = tf.logical_or(
tf.equal(seg_id, tf.constant([data_utils.SEG_ID_CLS]))[:, None],
tf.equal(cat_id, tf.constant([data_utils.SEG_ID_CLS]))[None, :])
seg_mat = tf.equal(seg_id[:, None], cat_id[None, :])
seg_mat = tf.logical_or(cls_mat, seg_mat)
else:
seg_mat = tf.logical_not(tf.equal(seg_id[:, None], cat_id[None, :]))
else:
seg_mat = None
......@@ -894,8 +877,8 @@ class TransformerXLModel(tf.keras.layers.Layer):
self.clamp_len)
if bsz is not None:
fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz//2)
bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, bsz//2)
fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz // 2)
bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, bsz // 2)
else:
fwd_pos_emb = self.fwd_position_embedding(fwd_pos_seq, None)
bwd_pos_emb = self.bwd_position_embedding(bwd_pos_seq, None)
......@@ -906,8 +889,8 @@ class TransformerXLModel(tf.keras.layers.Layer):
if dtype is not None and dtype != tf.float32:
fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq,
-self.clamp_len, self.lamp_len)
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len,
self.lamp_len)
pos_emb = self.fwd_position_embedding(fwd_pos_seq, bsz)
......@@ -969,9 +952,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
output_h = h_ffn_layer(output_h)
if inp_q is not None:
output = self.output_dropout(output_g)
output = output_g
else:
output = self.output_dropout(output_h)
output = output_h
return output, new_mems, None
......@@ -983,7 +966,7 @@ class PretrainingXLNetModel(tf.keras.Model):
"""
def __init__(self, xlnet_config, run_config, **kwargs):
def __init__(self, use_proj, xlnet_config, run_config, **kwargs):
super(PretrainingXLNetModel, self).__init__(**kwargs)
self.run_config = run_config
self.initializer = _get_initializer(run_config)
......@@ -1001,7 +984,6 @@ class PretrainingXLNetModel(tf.keras.Model):
ff_activation=self.xlnet_config.ff_activation,
untie_r=self.xlnet_config.untie_r,
is_training=self.run_config.is_training,
use_bfloat16=self.run_config.use_bfloat16,
use_tpu=self.run_config.use_tpu,
dropout=self.run_config.dropout,
dropout_att=self.run_config.dropout_att,
......@@ -1010,14 +992,16 @@ class PretrainingXLNetModel(tf.keras.Model):
bi_data=self.run_config.bi_data,
clamp_len=self.run_config.clamp_len,
same_length=self.run_config.same_length,
use_cls_mask=self.run_config.use_cls_mask,
name='transformer')
self.lmloss_layer = LMLossLayer(n_token=self.xlnet_config.n_token,
self.lmloss_layer = LMLossLayer(
n_token=self.xlnet_config.n_token,
d_model=self.xlnet_config.d_model,
initializer=self.initializer,
use_bfloat16=self.run_config.use_bfloat16,
tie_weight=True,
bi_data=self.run_config.bi_data,
use_tpu=self.run_config.use_tpu,
use_proj=use_proj,
name='lm_loss')
def call(self, features):
......@@ -1082,7 +1066,6 @@ class ClassificationXLNetModel(tf.keras.Model):
ff_activation=self.xlnet_config.ff_activation,
untie_r=self.xlnet_config.untie_r,
is_training=self.run_config.is_training,
use_bfloat16=self.run_config.use_bfloat16,
use_tpu=self.run_config.use_tpu,
dropout=self.run_config.dropout,
dropout_att=self.run_config.dropout_att,
......@@ -1133,23 +1116,28 @@ class ClassificationXLNetModel(tf.keras.Model):
class LMLossLayer(tf.keras.layers.Layer):
"""Layer computing cross entropy loss for language modeling."""
def __init__(self, n_token, d_model, initializer, use_bfloat16,
tie_weight=False, bi_data=True, use_tpu=False, **kwargs):
def __init__(self,
n_token,
d_model,
initializer,
tie_weight=False,
bi_data=True,
use_tpu=False,
use_proj=False,
**kwargs):
"""Constructs LMLoss layer.
Args:
n_token: Number of tokens in vocabulary.
d_model: The dimension of model hidden state.
initializer: Initializer used for parameters.
use_bfloat16: Whether to use bfloat16.
tie_weight: Whether to share weights between embedding lookup layer and
next-token prediction layer.
bi_data: Whether to use bidirectional input pipeline.
Usually set to True during pretraining and False during finetuning.
bi_data: Whether to use bidirectional input pipeline. Usually set to True
during pretraining and False during finetuning.
use_tpu: bool, whether to use TPU.
use_proj: bool, whether to add a projection layer before LM prediction.
**kwargs: Other parameters.
"""
super(LMLossLayer, self).__init__(**kwargs)
self.n_token = n_token
......@@ -1159,17 +1147,26 @@ class LMLossLayer(tf.keras.layers.Layer):
self.tie_weight = tie_weight
self.bi_data = bi_data
self.use_tpu = use_tpu
self.use_bfloat16 = use_bfloat16
self.use_proj = use_proj
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
if self.use_proj:
self.proj_layer = tf.keras.layers.Dense(
units=self.d_model,
kernel_initializer=self.initializer,
activation=gelu,
name='lm_projection')
self.proj_layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='lm_projection/LayerNorm')
if not self.tie_weight:
self.softmax_w = self.add_weight('weight',
self.softmax_w = self.add_weight(
'weight',
shape=[self.n_token, self.d_model],
initializer=self.initializer)
self.softmax_b = self.add_weight('bias', shape=[self.n_token],
initializer=tf.zeros_initializer())
self.softmax_b = self.add_weight(
'bias', shape=[self.n_token], initializer=tf.zeros_initializer())
super(LMLossLayer, self).build(unused_input_shapes)
......@@ -1180,6 +1177,8 @@ class LMLossLayer(tf.keras.layers.Layer):
def call(self, inputs):
"""Implements call() for the layer."""
(hidden, target, lookup_table, tgt_mask) = unpack_inputs(inputs)
if self.use_proj:
hidden = self.proj_layer_norm(self.proj_layer(hidden))
if self.tie_weight:
logits = tf.einsum('ibd,nd->ibn', hidden, lookup_table) + self.softmax_b
else:
......@@ -1189,11 +1188,8 @@ class LMLossLayer(tf.keras.layers.Layer):
one_hot_target = tf.one_hot(target, self.n_token, dtype=logits.dtype)
loss = -tf.reduce_sum(tf.nn.log_softmax(logits) * one_hot_target, -1)
else:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
logits=logits)
if self.use_bfloat16:
tgt_mask = tf.cast(tgt_mask, tf.float32)
loss = tf.cast(loss, tf.float32)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=target, logits=logits)
total_loss = tf.reduce_sum(loss * tgt_mask) / tf.reduce_sum(tgt_mask)
......@@ -1321,7 +1317,6 @@ class QAXLNetModel(tf.keras.Model):
ff_activation=self.xlnet_config.ff_activation,
untie_r=self.xlnet_config.untie_r,
is_training=self.run_config.is_training,
use_bfloat16=self.run_config.use_bfloat16,
use_tpu=self.run_config.use_tpu,
dropout=self.run_config.dropout,
dropout_att=self.run_config.dropout_att,
......@@ -1370,8 +1365,7 @@ class QAXLNetModel(tf.keras.Model):
class QALossLayer(tf.keras.layers.Layer):
"""Layer computing position and regression loss for question answering task.
"""
"""Layer computing position and regression loss for question answering task."""
def __init__(self, d_model, start_n_top, end_n_top, initializer, dropout,
**kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment