Commit ddd45b81 authored by Jing Li's avatar Jing Li Committed by A. Unique TensorFlower
Browse files

Fix variable mismatch between XLNet pretrain and finetune tasks.

PiperOrigin-RevId: 274699918
parent bb7d95ab
...@@ -314,143 +314,6 @@ class EmbeddingLookup(tf.keras.layers.Layer): ...@@ -314,143 +314,6 @@ class EmbeddingLookup(tf.keras.layers.Layer):
return tf.nn.embedding_lookup(self.lookup_table, inputs) return tf.nn.embedding_lookup(self.lookup_table, inputs)
class TwoStreamRelativeAttention(tf.keras.layers.Layer):
"""Two-stream attention layer with relative positional encoding."""
def __init__(self, d_model, n_head, d_head, dropout, dropout_att,
kernel_initializer, **kwargs):
super(TwoStreamRelativeAttention, self).__init__(**kwargs)
self.d_model = d_model
self.n_head = n_head
self.d_head = d_head
self.dropout = dropout
self.dropout_att = dropout_att
self.initializer = kernel_initializer
def build(self, unused_input_shapes):
"""Implements build() for the layer."""
self.scale = 1.0 / (self.d_head**0.5)
self.attention_projection_layer = tf.keras.layers.Dense(
units=self.d_model,
use_bias=False,
kernel_initializer=self.initializer,
name='o')
self.attention_probs_dropout = tf.keras.layers.Dropout(
rate=self.dropout_att)
self.attention_out_dropout = tf.keras.layers.Dropout(rate=self.dropout)
self.output_layer_norm = tf.keras.layers.LayerNormalization(
name='LayerNorm', axis=-1, epsilon=1e-12)
self.kh_projection_layer = (
self.add_weight(
'k/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.vh_projection_layer = (
self.add_weight(
'v/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.kr_projection_layer = (
self.add_weight(
'r/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.qh_projection_layer = (
self.add_weight(
'q/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.h_attention_layer = RelativeAttention(
dropout_att=self.dropout_att, scale=self.scale)
self.g_attention_layer = RelativeAttention(
dropout_att=self.dropout_att, scale=self.scale)
self.proj_o = (
self.add_weight(
'o/kernel',
shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer))
self.attention_dropout = tf.keras.layers.Dropout(rate=self.dropout)
super(TwoStreamRelativeAttention, self).build(unused_input_shapes)
def __call__(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
attn_mask_h, attn_mask_g, mems, target_mapping):
inputs = pack_inputs([
h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping
])
return super(TwoStreamRelativeAttention, self).__call__(inputs)
def call(self, inputs):
"""Implements call() for the layer."""
(h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping) = unpack_inputs(inputs)
if mems is not None and mems.shape.ndims > 1:
cat = tf.concat([mems, h], 0)
else:
cat = h
# content heads
k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.kh_projection_layer)
v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.vh_projection_layer)
k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer)
# positional heads
q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.qh_projection_layer)
# core attention ops
attn_vec_h = self.h_attention_layer(q_head_h, k_head_h, v_head_h, k_head_r,
seg_embed, seg_mat, r_w_bias, r_r_bias,
r_s_bias, attn_mask_h)
output_h = tf.einsum('ibnd,hnd->ibh', attn_vec_h, self.proj_o)
output_h = self.attention_dropout(output_h)
output_h = self.output_layer_norm(output_h + h)
##### g-stream
# query-stream query head
q_head_g = tf.einsum('ibh,hnd->ibnd', g, self.qh_projection_layer)
# core attention ops
if target_mapping is not None:
q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
attn_vec_g = self.g_attention_layer(q_head_g, k_head_h, v_head_h,
k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias,
attn_mask_g)
attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
else:
attn_vec_g = self.g_attention_layer(q_head_g, k_head_h, v_head_h,
k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias,
attn_mask_g)
# post processing
output_g = tf.einsum('ibnd,hnd->ibh', attn_vec_g, self.proj_o)
output_g = self.attention_dropout(output_g)
output_g = self.output_layer_norm(output_g + g)
return output_h, output_g
class RelativeMultiheadAttention(tf.keras.layers.Layer): class RelativeMultiheadAttention(tf.keras.layers.Layer):
"""Multi-head attention with relative embedding.""" """Multi-head attention with relative embedding."""
...@@ -488,7 +351,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer): ...@@ -488,7 +351,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
shape=[self.d_model, self.n_head, self.d_head], shape=[self.d_model, self.n_head, self.d_head],
initializer=self.initializer) initializer=self.initializer)
self.h_attention_layer = RelativeAttention( self.relative_attention_layer = RelativeAttention(
dropout_att=self.dropout_att, scale=self.scale) dropout_att=self.dropout_att, scale=self.scale)
self.proj_o = self.add_weight( self.proj_o = self.add_weight(
...@@ -500,17 +363,18 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer): ...@@ -500,17 +363,18 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
super(RelativeMultiheadAttention, self).build(unused_input_shapes) super(RelativeMultiheadAttention, self).build(unused_input_shapes)
def __call__(self, h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, def __call__(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
attn_mask, mems): attn_mask_h, attn_mask_g, mems, target_mapping):
inputs = pack_inputs([ inputs = pack_inputs([
h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask, mems h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping,
]) ])
return super(RelativeMultiheadAttention, self).__call__(inputs) return super(RelativeMultiheadAttention, self).__call__(inputs)
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
(h, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask, (h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
mems) = unpack_inputs(inputs) attn_mask_g, mems, target_mapping) = unpack_inputs(inputs)
if mems is not None and mems.shape.ndims > 1: if mems is not None and mems.shape.ndims > 1:
cat = tf.concat([mems, h], 0) cat = tf.concat([mems, h], 0)
...@@ -518,30 +382,45 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer): ...@@ -518,30 +382,45 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
cat = h cat = h
# content heads # content heads
q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.qh_projection_layer) q_head_h = tf.einsum('ibh,hnd->ibnd', h, self.qh_projection_layer)
k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.kh_projection_layer) k_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.kh_projection_layer)
v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.vh_projection_layer) v_head_h = tf.einsum('ibh,hnd->ibnd', cat, self.vh_projection_layer)
# positional heads # positional heads
k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer) k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer)
# core attention ops # core attention ops
attn_vec = self.h_attention_layer(q_head_h, k_head_h, v_head_h, k_head_r, attn_vec_h = self.relative_attention_layer(
seg_embed, seg_mat, r_w_bias, r_r_bias, q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_s_bias, attn_mask) r_r_bias, r_s_bias, attn_mask_h)
# post processing # post processing
output_h = tf.einsum('ibnd,hnd->ibh', attn_vec_h, self.proj_o)
output_h = self.attention_dropout(output_h)
output_h = self.output_layer_norm(output_h + h)
output = tf.einsum('ibnd,hnd->ibh', attn_vec, self.proj_o) output_g = None
if g is not None: # enable two-stream attention
# g-stream
q_head_g = tf.einsum('ibh,hnd->ibnd', g, self.qh_projection_layer)
if target_mapping is not None:
q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
attn_vec_g = self.relative_attention_layer(
q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask_g)
attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
output = self.attention_dropout(output) else:
attn_vec_g = self.relative_attention_layer(
q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask_g)
output = self.output_layer_norm(output + h) # post processing
return output output_g = tf.einsum('ibnd,hnd->ibh', attn_vec_g, self.proj_o)
output_g = self.attention_dropout(output_g)
output_g = self.output_layer_norm(output_g + g)
return (output_h, output_g)
class TransformerXLModel(tf.keras.layers.Layer): class TransformerXLModel(tf.keras.layers.Layer):
...@@ -686,20 +565,9 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -686,20 +565,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
self.fwd_position_embedding = PositionalEmbedding(self.d_model) self.fwd_position_embedding = PositionalEmbedding(self.d_model)
self.bwd_position_embedding = PositionalEmbedding(self.d_model) self.bwd_position_embedding = PositionalEmbedding(self.d_model)
self.two_stream_layers = []
self.rel_multihead_layers = [] self.rel_multihead_layers = []
self.g_positionwise_ffn_layers = []
self.h_positionwise_ffn_layers = [] self.h_positionwise_ffn_layers = []
for i in range(self.n_layer): for i in range(self.n_layer):
self.two_stream_layers.append(
TwoStreamRelativeAttention(
d_model=self.d_model,
dropout=self.dropout,
n_head=self.n_head,
d_head=self.d_head,
dropout_att=self.dropout_att,
kernel_initializer=self.initializer,
name='layer_%d/rel_attn' % (i)))
self.rel_multihead_layers.append( self.rel_multihead_layers.append(
RelativeMultiheadAttention( RelativeMultiheadAttention(
d_model=self.d_model, d_model=self.d_model,
...@@ -709,14 +577,6 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -709,14 +577,6 @@ class TransformerXLModel(tf.keras.layers.Layer):
dropout_att=self.dropout_att, dropout_att=self.dropout_att,
kernel_initializer=self.initializer, kernel_initializer=self.initializer,
name='layer_%d/rel_attn' % (i))) name='layer_%d/rel_attn' % (i)))
self.g_positionwise_ffn_layers.append(
PositionwiseFF(
d_model=self.d_model,
d_inner=self.d_inner,
dropout=self.dropout,
kernel_initializer=self.initializer,
activation_type=self.ff_activation,
name='layer_%d/ff' % (i)))
self.h_positionwise_ffn_layers.append( self.h_positionwise_ffn_layers.append(
PositionwiseFF( PositionwiseFF(
d_model=self.d_model, d_model=self.d_model,
...@@ -825,6 +685,7 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -825,6 +685,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k word_emb_q = inp_q_ext * self.mask_emb + (1 - inp_q_ext) * word_emb_k
output_h = self.h_dropout(word_emb_k) output_h = self.h_dropout(word_emb_k)
output_g = None
if inp_q is not None: if inp_q is not None:
output_g = self.g_dropout(word_emb_q) output_g = self.g_dropout(word_emb_q)
...@@ -912,13 +773,9 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -912,13 +773,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
r_s_bias_i = self.r_s_bias if not self.untie_r else self.r_s_bias[i] r_s_bias_i = self.r_s_bias if not self.untie_r else self.r_s_bias[i]
seg_embed_i = self.seg_embed[i] seg_embed_i = self.seg_embed[i]
if inp_q is not None: ffn_layer = self.h_positionwise_ffn_layers[i]
two_stream_layer = self.two_stream_layers[i] attention_layer = self.rel_multihead_layers[i]
g_ffn_layer = self.g_positionwise_ffn_layers[i] output_h, output_g = attention_layer(
h_ffn_layer = self.h_positionwise_ffn_layers[i]
rel_multihead_layer = self.rel_multihead_layers[i]
output_h, output_g = two_stream_layer(
h=output_h, h=output_h,
g=output_g, g=output_g,
r=pos_emb, r=pos_emb,
...@@ -931,25 +788,9 @@ class TransformerXLModel(tf.keras.layers.Layer): ...@@ -931,25 +788,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
attn_mask_g=attn_mask, attn_mask_g=attn_mask,
mems=mems[i], mems=mems[i],
target_mapping=target_mapping) target_mapping=target_mapping)
output_h = ffn_layer(output_h)
output_g = g_ffn_layer(output_g) if output_g is not None:
output_g = ffn_layer(output_g)
output_h = g_ffn_layer(output_h)
else:
rel_multihead_layer = self.rel_multihead_layers[i]
h_ffn_layer = self.h_positionwise_ffn_layers[i]
output_h = rel_multihead_layer(
h=output_h,
r=pos_emb,
r_w_bias=self.r_w_bias if not self.untie_r else self.r_w_bias[i],
r_r_bias=self.r_r_bias if not self.untie_r else self.r_r_bias[i],
seg_mat=seg_mat,
r_s_bias=r_s_bias_i,
seg_embed=seg_embed_i,
attn_mask=non_tgt_mask,
mems=mems[i])
output_h = h_ffn_layer(output_h)
if inp_q is not None: if inp_q is not None:
output = output_g output = output_g
...@@ -1156,7 +997,7 @@ class LMLossLayer(tf.keras.layers.Layer): ...@@ -1156,7 +997,7 @@ class LMLossLayer(tf.keras.layers.Layer):
units=self.d_model, units=self.d_model,
kernel_initializer=self.initializer, kernel_initializer=self.initializer,
activation=gelu, activation=gelu,
name='lm_projection') name='lm_projection/dense')
self.proj_layer_norm = tf.keras.layers.LayerNormalization( self.proj_layer_norm = tf.keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='lm_projection/LayerNorm') axis=-1, epsilon=1e-12, name='lm_projection/LayerNorm')
if not self.tie_weight: if not self.tie_weight:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment