Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
569ec532
Commit
569ec532
authored
Jan 27, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jan 27, 2020
Browse files
XLNet: Remove pack/unpack hack
PiperOrigin-RevId: 291750235
parent
6a3bcef8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
61 deletions
+25
-61
official/nlp/xlnet/xlnet_modeling.py
official/nlp/xlnet/xlnet_modeling.py
+24
-59
official/nlp/xlnet/xlnet_modeling_test.py
official/nlp/xlnet/xlnet_modeling_test.py
+1
-2
No files found.
official/nlp/xlnet/xlnet_modeling.py
View file @
569ec532
...
...
@@ -23,7 +23,6 @@ import copy
import
numpy
as
np
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.xlnet
import
data_utils
...
...
@@ -115,14 +114,8 @@ class PositionalEmbedding(tf.keras.layers.Layer):
self
.
inv_freq
=
1.0
/
(
10000.0
**
(
tf
.
range
(
0
,
self
.
dim
,
2.0
)
/
self
.
dim
))
super
(
PositionalEmbedding
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
pos_seq
,
batch_size
,
**
kwargs
):
return
super
(
PositionalEmbedding
,
self
).
__call__
(
(
pos_seq
,
batch_size
),
**
kwargs
)
def
call
(
self
,
inputs
):
def
call
(
self
,
pos_seq
,
batch_size
):
"""Implements call() for the layer."""
pos_seq
,
batch_size
=
inputs
sinusoid_inp
=
tf
.
einsum
(
'i,d->id'
,
pos_seq
,
self
.
inv_freq
)
pos_emb
=
tf
.
concat
([
tf
.
sin
(
sinusoid_inp
),
tf
.
cos
(
sinusoid_inp
)],
-
1
)
pos_emb
=
pos_emb
[:,
None
,
:]
...
...
@@ -149,18 +142,9 @@ class RelativeAttention(tf.keras.layers.Layer):
super
(
RelativeAttention
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask
,
**
kwargs
):
inputs
=
tf_utils
.
pack_inputs
([
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask
])
return
super
(
RelativeAttention
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
def
call
(
self
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask
):
"""Implements call() for the layer."""
(
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask
)
=
tf_utils
.
unpack_inputs
(
inputs
)
# content based attention score
ac
=
tf
.
einsum
(
'ibnd,jbnd->ijbn'
,
q_head
+
r_w_bias
,
k_head_h
)
...
...
@@ -316,18 +300,9 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
super
(
RelativeMultiheadAttention
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
,
**
kwargs
):
inputs
=
tf_utils
.
pack_inputs
([
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
,
])
return
super
(
RelativeMultiheadAttention
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
def
call
(
self
,
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
):
"""Implements call() for the layer."""
(
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
)
=
tf_utils
.
unpack_inputs
(
inputs
)
if
mems
is
not
None
and
mems
.
shape
.
ndims
>
1
:
cat
=
tf
.
concat
([
mems
,
h
],
0
)
...
...
@@ -343,9 +318,10 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
k_head_r
=
tf
.
einsum
(
'ibh,hnd->ibnd'
,
r
,
self
.
kr_projection_layer
)
# core attention ops
attn_vec_h
=
self
.
relative_attention_layer
(
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask_h
)
attn_vec_h
=
self
.
relative_attention_layer
(
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask_h
)
# post processing
output_h
=
tf
.
einsum
(
'ibnd,hnd->ibh'
,
attn_vec_h
,
self
.
proj_o
)
...
...
@@ -358,15 +334,17 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
q_head_g
=
tf
.
einsum
(
'ibh,hnd->ibnd'
,
g
,
self
.
qh_projection_layer
)
if
target_mapping
is
not
None
:
q_head_g
=
tf
.
einsum
(
'mbnd,mlb->lbnd'
,
q_head_g
,
target_mapping
)
attn_vec_g
=
self
.
relative_attention_layer
(
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask_g
)
attn_vec_g
=
self
.
relative_attention_layer
(
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask_g
)
attn_vec_g
=
tf
.
einsum
(
'lbnd,mlb->mbnd'
,
attn_vec_g
,
target_mapping
)
else
:
attn_vec_g
=
self
.
relative_attention_layer
(
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask_g
)
attn_vec_g
=
self
.
relative_attention_layer
(
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask_g
)
# post processing
output_g
=
tf
.
einsum
(
'ibnd,hnd->ibh'
,
attn_vec_g
,
self
.
proj_o
)
...
...
@@ -820,7 +798,7 @@ class PretrainingXLNetModel(tf.keras.Model):
mems
=
features
.
get
(
'mems'
,
None
)
transformerxl_output
,
self
.
new_mems
,
self
.
lookup_table
=
self
.
transformerxl_model
(
inp_k
=
input_ids
,
input_ids
,
seg_id
=
seg_ids
,
input_mask
=
None
,
mems
=
mems
,
...
...
@@ -898,12 +876,10 @@ class ClassificationXLNetModel(tf.keras.Model):
mems
=
features
.
get
(
'mems'
,
None
)
transformerxl_output
,
new_mems
,
self
.
lookup_table
=
(
self
.
transformerxl_model
(
inp_k
=
input_ids
,
seg_id
=
seg_ids
,
input_mask
=
input_mask
,
mems
=
mems
))
self
.
transformerxl_model
(
input_ids
,
seg_ids
,
input_mask
,
mems
))
summary
=
self
.
summarization_layer
(
transformerxl_output
)
per_example_loss
,
logits
=
self
.
cl_loss_layer
(
hidden
=
summary
,
labels
=
label
)
per_example_loss
,
logits
=
self
.
cl_loss_layer
(
hidden
=
summary
,
labels
=
label
)
self
.
add_loss
(
tf
.
keras
.
backend
.
mean
(
per_example_loss
))
return
new_mems
,
logits
...
...
@@ -965,13 +941,8 @@ class LMLossLayer(tf.keras.layers.Layer):
super
(
LMLossLayer
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
hidden
,
target
,
lookup_table
,
target_mask
,
**
kwargs
):
inputs
=
tf_utils
.
pack_inputs
([
hidden
,
target
,
lookup_table
,
target_mask
])
return
super
(
LMLossLayer
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
def
call
(
self
,
hidden
,
target
,
lookup_table
,
target_mask
):
"""Implements call() for the layer."""
(
hidden
,
target
,
lookup_table
,
tgt_mask
)
=
tf_utils
.
unpack_inputs
(
inputs
)
if
self
.
use_proj
:
hidden
=
self
.
proj_layer_norm
(
self
.
proj_layer
(
hidden
))
if
self
.
tie_weight
:
...
...
@@ -986,7 +957,7 @@ class LMLossLayer(tf.keras.layers.Layer):
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
target
,
logits
=
logits
)
total_loss
=
tf
.
reduce_sum
(
loss
*
t
g
t_mask
)
/
tf
.
reduce_sum
(
t
g
t_mask
)
total_loss
=
tf
.
reduce_sum
(
loss
*
t
arge
t_mask
)
/
tf
.
reduce_sum
(
t
arge
t_mask
)
return
total_loss
,
logits
...
...
@@ -1076,13 +1047,8 @@ class ClassificationLossLayer(tf.keras.layers.Layer):
super
(
ClassificationLossLayer
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
hidden
,
labels
,
**
kwargs
):
inputs
=
tf_utils
.
pack_inputs
([
hidden
,
labels
])
return
super
(
ClassificationLossLayer
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
def
call
(
self
,
hidden
,
labels
):
"""Implements call() for the layer."""
(
hidden
,
labels
)
=
tf_utils
.
unpack_inputs
(
inputs
)
logits
=
self
.
proj_layer
(
hidden
)
one_hot_target
=
tf
.
one_hot
(
labels
,
self
.
n_class
,
dtype
=
hidden
.
dtype
)
# pytype: disable=attribute-error
...
...
@@ -1145,8 +1111,7 @@ class QAXLNetModel(tf.keras.Model):
p_mask
=
features
[
'p_mask'
]
transformerxl_output
,
new_mems
,
self
.
lookup_table
=
(
self
.
transformerxl_model
(
inp_k
=
input_ids
,
seg_id
=
seg_ids
,
input_mask
=
input_mask
))
self
.
transformerxl_model
(
input_ids
,
seg_ids
,
input_mask
))
if
training
:
loss
,
logits
=
self
.
qa_loss_layer
(
...
...
official/nlp/xlnet/xlnet_modeling_test.py
View file @
569ec532
...
...
@@ -43,8 +43,7 @@ class PositionalEmbeddingLayerTest(tf.test.TestCase):
d_model
=
4
pos_seq
=
tf
.
range
(
1
,
-
1
,
-
1.0
)
# [1., 0.]
pos_emb_layer
=
xlnet_modeling
.
PositionalEmbedding
(
d_model
)
pos_emb
=
pos_emb_layer
(
pos_seq
=
pos_seq
,
batch_size
=
None
).
numpy
().
astype
(
float
)
pos_emb
=
pos_emb_layer
(
pos_seq
,
batch_size
=
None
).
numpy
().
astype
(
float
)
logging
.
info
(
pos_emb
)
self
.
assertAllClose
(
pos_emb
,
target
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment