Commit ddd45b81 authored by Jing Li's avatar Jing Li Committed by A. Unique TensorFlower
Browse files

Fix variable mismatch between XLNet pretrain and finetune tasks.

PiperOrigin-RevId: 274699918
parent bb7d95ab
......@@ -314,143 +314,6 @@ class EmbeddingLookup(tf.keras.layers.Layer):
return tf.nn.embedding_lookup(self.lookup_table, inputs)
class TwoStreamRelativeAttention(tf.keras.layers.Layer):
"""Two-stream attention layer with relative positional encoding."""
def __init__(self, d_model, n_head, d_head, dropout, dropout_att,
kernel_initializer, **kwargs):
super(TwoStreamRelativeAttention, self).__init__(**kwargs)
self.d_model = d_model
self.n_head = n_head
self.d_head = d_head
self.dropout = dropout
self.dropout_att = dropout_att
self.initializer = kernel_initializer
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.scale = 1.0 / (self.d_head**0.5)
self.attention_projection_layer = tf.keras.layers.Dense(
units=self.d_model,
use_bias=False,
kernel_initializer=self.initializer,
name='o')
self.attention_probs_dropout = tf.keras.layers.Dropout(
rate=self.dropout_att)
self.attention_out_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name='LayerNorm', axis=-1, epsilon=1e-12)
self.kh_projection_layer = (
self.add_weight(
'k/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.vh_projection_layer = (
self.add_weight(
'v/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.kr_projection_layer = (
self.add_weight(
'r/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.qh_projection_layer = (
self.add_weight(
'q/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.h_attention_layer = RelativeAttention(
dropout_att=self.dropout_att, scale=self.scale)
self.g_attention_layer = RelativeAttention(
dropout_att=self.dropout_att, scale=self.scale)
self.proj_o = (
self.add_weight(
'o/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.attention_dropout = tf.keras.layers.Dropout(rate=self.dropout)
super(TwoStreamRelativeAttention, self).build(unused_input_shapes)
def __call__(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
attn_mask_h, attn_mask_g, mems, target_mapping):
inputs = pack_inputs([
h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping
])
return super(TwoStreamRelativeAttention, self).__call__(inputs)
def call(self, inputs):
"""Implements call() for the layer."""
(h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping) = unpack_inputs(inputs)
if mems is not None and mems.shape.ndims > 1:
cat = tf.concat([mems, h], 0)
else:
cat = h
# content heads
k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.kh_projection_layer)
v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.vh_projection_layer)
k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer)
# positional heads
q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.qh_projection_layer)
# core attention ops
attn_vec_h = self.h_attention_layer(q_head_h, k_head_h, v_head_h, k_head_r,
seg_embed, seg_mat, r_w_bias, r_r_bias,
r_s_bias, attn_mask_h)
output_h = tf.einsum('ibnd,hnd->ibh', attn_vec_h, self.proj_o)
output_h = self.attention_dropout(output_h)
output_h = self.output_layer_norm(output_h + h)
##### g-stream
# query-stream query head
q_head_g = tf.einsum('ibh,hnd->ibnd', g, self.qh_projection_layer)
# core attention ops
if target_mapping is not None:
q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
attn_vec_g = self.g_attention_layer(q_head_g, k_head_h, v_head_h,
k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias,
attn_mask_g)
attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
else:
attn_vec_g = self.g_attention_layer(q_head_g, k_head_h, v_head_h,
k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias,
attn_mask_g)
# post processing
output_g = tf.einsum('ibnd,hnd->ibh', attn_vec_g, self.proj_o)
output_g = self.attention_dropout(output_g)
output_g = self.output_layer_norm(output_g + g)
return output_h, output_g
class RelativeMultiheadAttention(tf.keras.layers.Layer):
"""Multi-head attention with relative embedding."""
......@@ -488,7 +351,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer)
self.h_attention_layer = RelativeAttention(
self.relative_attention_layer = RelativeAttention(
dropout_att=self.dropout_att, scale=self.scale)
self.proj_o = self.add_weight(
......@@ -500,17 +363,18 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
super(RelativeMultiheadAttention, self).build(unused_input_shapes)
def __call__(self, h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
attn_mask, mems):
def __call__(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
attn_mask_h, attn_mask_g, mems, target_mapping):
inputs = pack_inputs([
h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask, mems
h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping,
])
return super(RelativeMultiheadAttention, self).__call__(inputs)
def call(self, inputs):
"""Implements call() for the layer."""
(h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask,
mems) = unpack_inputs(inputs)
(h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping) = unpack_inputs(inputs)
if mems is not None and mems.shape.ndims > 1:
cat = tf.concat([mems, h], 0)
......@@ -518,30 +382,45 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
cat = h
# content heads
q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.qh_projection_layer)
k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.kh_projection_layer)
v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.vh_projection_layer)
# positional heads
k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer)
# core attention ops
attn_vec = self.h_attention_layer(q_head_h, k_head_h, v_head_h, k_head_r,
seg_embed, seg_mat, r_w_bias, r_r_bias,
r_s_bias, attn_mask)
attn_vec_h = self.relative_attention_layer(
q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask_h)
# post processing
output_h = tf.einsum('ibnd,hnd->ibh', attn_vec_h, self.proj_o)
output_h = self.attention_dropout(output_h)
output_h = self.output_layer_norm(output_h + h)
output = tf.einsum('ibnd,hnd->ibh', attn_vec, self.proj_o)
output_g = None
if g is not None: # enable two-stream attention
# g-stream
q_head_g = tf.einsum('ibh,hnd->ibnd', g, self.qh_projection_layer)
if target_mapping is not None:
q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
attn_vec_g = self.relative_attention_layer(
q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask_g)
attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
output = self.attention_dropout(output)
else:
attn_vec_g = self.relative_attention_layer(
q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask_g)
output = self.output_layer_norm(output + h)
return output
# post processing
output_g = tf.einsum('ibnd,hnd->ibh', attn_vec_g, self.proj_o)
output_g = self.attention_dropout(output_g)
output_g = self.output_layer_norm(output_g + g)
return (output_h, output_g)
class TransformerXLModel(tf.keras.layers.Layer):
......@@ -686,20 +565,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
self.fwd_position_embedding = PositionalEmbedding(self.d_model)
self.bwd_position_embedding = PositionalEmbedding(self.d_model)
self.two_stream_layers = []
self.rel_multihead_layers = []
self.g_positionwise_ffn_layers = []
self.h_positionwise_ffn_layers = []
for i in range(self.n_layer):
self.two_stream_layers.append(
TwoStreamRelativeAttention(
d_model=self.d_model,
dropout=self.dropout,
n_head=self.n_head,
d_head=self.d_head,
dropout_att=self.dropout_att,
kernel_initializer=self.initializer,
name='layer_%d/rel_attn' % (i)))
self.rel_multihead_layers.append(
RelativeMultiheadAttention(
d_model=self.d_model,
......@@ -709,14 +577,6 @@ class TransformerXLModel(tf.keras.layers.Layer):
dropout_att=self.dropout_att,
kernel_initializer=self.initializer,
name='layer_%d/rel_attn' % (i)))
self.g_positionwise_ffn_layers.append(
PositionwiseFF(
d_model=self.d_model,
d_inner=self.d_inner,
dropout=self.dropout,
kernel_initializer=self.initializer,
activation_type=self.ff_activation,
name='layer_%d/ff' % (i)))
self.h_positionwise_ffn_layers.append(
PositionwiseFF(
d_model=self.d_model,
......@@ -825,6 +685,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_h = self.h_dropout(word_emb_k)
output_g = None
if inp_q is not None:
output_g = self.g_dropout(word_emb_q)
......@@ -912,44 +773,24 @@ class TransformerXLModel(tf.keras.layers.Layer):
r_s_bias_i = self.r_s_bias if not self.untie_r else self.r_s_bias[i]
seg_embed_i = self.seg_embed[i]
if inp_q is not None:
two_stream_layer = self.two_stream_layers[i]
g_ffn_layer = self.g_positionwise_ffn_layers[i]
h_ffn_layer = self.h_positionwise_ffn_layers[i]
rel_multihead_layer = self.rel_multihead_layers[i]
output_h, output_g = two_stream_layer(
h=output_h,
g=output_g,
r=pos_emb,
r_w_bias=self.r_w_bias if not self.untie_r else self.r_w_bias[i],
r_r_bias=self.r_r_bias if not self.untie_r else self.r_r_bias[i],
seg_mat=seg_mat,
r_s_bias=r_s_bias_i,
seg_embed=seg_embed_i,
attn_mask_h=non_tgt_mask,
attn_mask_g=attn_mask,
mems=mems[i],
target_mapping=target_mapping)
output_g = g_ffn_layer(output_g)
output_h = g_ffn_layer(output_h)
else:
rel_multihead_layer = self.rel_multihead_layers[i]
h_ffn_layer = self.h_positionwise_ffn_layers[i]
output_h = rel_multihead_layer(
h=output_h,
r=pos_emb,
r_w_bias=self.r_w_bias if not self.untie_r else self.r_w_bias[i],
r_r_bias=self.r_r_bias if not self.untie_r else self.r_r_bias[i],
seg_mat=seg_mat,
r_s_bias=r_s_bias_i,
seg_embed=seg_embed_i,
attn_mask=non_tgt_mask,
mems=mems[i])
output_h = h_ffn_layer(output_h)
ffn_layer = self.h_positionwise_ffn_layers[i]
attention_layer = self.rel_multihead_layers[i]
output_h, output_g = attention_layer(
h=output_h,
g=output_g,
r=pos_emb,
r_w_bias=self.r_w_bias if not self.untie_r else self.r_w_bias[i],
r_r_bias=self.r_r_bias if not self.untie_r else self.r_r_bias[i],
seg_mat=seg_mat,
r_s_bias=r_s_bias_i,
seg_embed=seg_embed_i,
attn_mask_h=non_tgt_mask,
attn_mask_g=attn_mask,
mems=mems[i],
target_mapping=target_mapping)
output_h = ffn_layer(output_h)
if output_g is not None:
output_g = ffn_layer(output_g)
if inp_q is not None:
output = output_g
......@@ -1156,7 +997,7 @@ class LMLossLayer(tf.keras.layers.Layer):
units=self.d_model,
kernel_initializer=self.initializer,
activation=gelu,
name='lm_projection')
name='lm_projection/dense')
self.proj_layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='lm_projection/LayerNorm')
if not self.tie_weight:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment