Commit 569ec532 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

XLNet: Remove pack/unpack hack

PiperOrigin-RevId: 291750235
parent 6a3bcef8
......@@ -23,7 +23,6 @@ import copy
import numpy as np
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.xlnet import data_utils
......@@ -115,14 +114,8 @@ class PositionalEmbedding(tf.keras.layers.Layer):
self.inv_freq = 1.0 / (10000.0**(tf.range(0, self.dim, 2.0) / self.dim))
super(PositionalEmbedding, self).build(unused_input_shapes)
def __call__(self, pos_seq, batch_size, **kwargs):
return super(PositionalEmbedding, self).__call__(
(pos_seq, batch_size), **kwargs)
def call(self, inputs):
def call(self, pos_seq, batch_size):
"""Implements call() for the layer."""
pos_seq, batch_size = inputs
sinusoid_inp = tf.einsum('i,d->id', pos_seq, self.inv_freq)
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
pos_emb = pos_emb[:, None, :]
......@@ -149,18 +142,9 @@ class RelativeAttention(tf.keras.layers.Layer):
super(RelativeAttention, self).build(unused_input_shapes)
def __call__(self, q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask, **kwargs):
inputs = tf_utils.pack_inputs([
q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask
])
return super(RelativeAttention, self).__call__(inputs, **kwargs)
def call(self, inputs):
def call(self, q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask):
"""Implements call() for the layer."""
(q_head, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask) = tf_utils.unpack_inputs(inputs)
# content based attention score
ac = tf.einsum('ibnd,jbnd->ijbn', q_head + r_w_bias, k_head_h)
......@@ -316,18 +300,9 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
super(RelativeMultiheadAttention, self).build(unused_input_shapes)
def __call__(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
attn_mask_h, attn_mask_g, mems, target_mapping, **kwargs):
inputs = tf_utils.pack_inputs([
h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping,
])
return super(RelativeMultiheadAttention, self).__call__(inputs, **kwargs)
def call(self, inputs):
def call(self, h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed,
attn_mask_h, attn_mask_g, mems, target_mapping):
"""Implements call() for the layer."""
(h, g, r, r_w_bias, r_r_bias, seg_mat, r_s_bias, seg_embed, attn_mask_h,
attn_mask_g, mems, target_mapping) = tf_utils.unpack_inputs(inputs)
if mems is not None and mems.shape.ndims > 1:
cat = tf.concat([mems, h], 0)
......@@ -343,9 +318,10 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
k_head_r = tf.einsum('ibh,hnd->ibnd', r, self.kr_projection_layer)
# core attention ops
attn_vec_h = self.relative_attention_layer(
q_head_h, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat, r_w_bias,
r_r_bias, r_s_bias, attn_mask_h)
attn_vec_h = self.relative_attention_layer(q_head_h, k_head_h, v_head_h,
k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias,
attn_mask_h)
# post processing
output_h = tf.einsum('ibnd,hnd->ibh', attn_vec_h, self.proj_o)
......@@ -358,15 +334,17 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
q_head_g = tf.einsum('ibh,hnd->ibnd', g, self.qh_projection_layer)
if target_mapping is not None:
q_head_g = tf.einsum('mbnd,mlb->lbnd', q_head_g, target_mapping)
attn_vec_g = self.relative_attention_layer(
q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask_g)
attn_vec_g = self.relative_attention_layer(q_head_g, k_head_h, v_head_h,
k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias,
attn_mask_g)
attn_vec_g = tf.einsum('lbnd,mlb->mbnd', attn_vec_g, target_mapping)
else:
attn_vec_g = self.relative_attention_layer(
q_head_g, k_head_h, v_head_h, k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias, attn_mask_g)
attn_vec_g = self.relative_attention_layer(q_head_g, k_head_h, v_head_h,
k_head_r, seg_embed, seg_mat,
r_w_bias, r_r_bias, r_s_bias,
attn_mask_g)
# post processing
output_g = tf.einsum('ibnd,hnd->ibh', attn_vec_g, self.proj_o)
......@@ -820,7 +798,7 @@ class PretrainingXLNetModel(tf.keras.Model):
mems = features.get('mems', None)
transformerxl_output, self.new_mems, self.lookup_table = self.transformerxl_model(
inp_k=input_ids,
input_ids,
seg_id=seg_ids,
input_mask=None,
mems=mems,
......@@ -898,12 +876,10 @@ class ClassificationXLNetModel(tf.keras.Model):
mems = features.get('mems', None)
transformerxl_output, new_mems, self.lookup_table = (
self.transformerxl_model(
inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask, mems=mems))
self.transformerxl_model(input_ids, seg_ids, input_mask, mems))
summary = self.summarization_layer(transformerxl_output)
per_example_loss, logits = self.cl_loss_layer(
hidden=summary, labels=label)
per_example_loss, logits = self.cl_loss_layer(hidden=summary, labels=label)
self.add_loss(tf.keras.backend.mean(per_example_loss))
return new_mems, logits
......@@ -965,13 +941,8 @@ class LMLossLayer(tf.keras.layers.Layer):
super(LMLossLayer, self).build(unused_input_shapes)
def __call__(self, hidden, target, lookup_table, target_mask, **kwargs):
inputs = tf_utils.pack_inputs([hidden, target, lookup_table, target_mask])
return super(LMLossLayer, self).__call__(inputs, **kwargs)
def call(self, inputs):
def call(self, hidden, target, lookup_table, target_mask):
"""Implements call() for the layer."""
(hidden, target, lookup_table, tgt_mask) = tf_utils.unpack_inputs(inputs)
if self.use_proj:
hidden = self.proj_layer_norm(self.proj_layer(hidden))
if self.tie_weight:
......@@ -986,7 +957,7 @@ class LMLossLayer(tf.keras.layers.Layer):
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=target, logits=logits)
total_loss = tf.reduce_sum(loss * tgt_mask) / tf.reduce_sum(tgt_mask)
total_loss = tf.reduce_sum(loss * target_mask) / tf.reduce_sum(target_mask)
return total_loss, logits
......@@ -1076,13 +1047,8 @@ class ClassificationLossLayer(tf.keras.layers.Layer):
super(ClassificationLossLayer, self).build(unused_input_shapes)
def __call__(self, hidden, labels, **kwargs):
inputs = tf_utils.pack_inputs([hidden, labels])
return super(ClassificationLossLayer, self).__call__(inputs, **kwargs)
def call(self, inputs):
def call(self, hidden, labels):
"""Implements call() for the layer."""
(hidden, labels) = tf_utils.unpack_inputs(inputs)
logits = self.proj_layer(hidden)
one_hot_target = tf.one_hot(labels, self.n_class, dtype=hidden.dtype) # pytype: disable=attribute-error
......@@ -1145,8 +1111,7 @@ class QAXLNetModel(tf.keras.Model):
p_mask = features['p_mask']
transformerxl_output, new_mems, self.lookup_table = (
self.transformerxl_model(
inp_k=input_ids, seg_id=seg_ids, input_mask=input_mask))
self.transformerxl_model(input_ids, seg_ids, input_mask))
if training:
loss, logits = self.qa_loss_layer(
......
......@@ -43,8 +43,7 @@ class PositionalEmbeddingLayerTest(tf.test.TestCase):
d_model = 4
pos_seq = tf.range(1, -1, -1.0) # [1., 0.]
pos_emb_layer = xlnet_modeling.PositionalEmbedding(d_model)
pos_emb = pos_emb_layer(
pos_seq=pos_seq, batch_size=None).numpy().astype(float)
pos_emb = pos_emb_layer(pos_seq, batch_size=None).numpy().astype(float)
logging.info(pos_emb)
self.assertAllClose(pos_emb, target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment