Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e5c71d51
Commit
e5c71d51
authored
Dec 14, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Dec 14, 2019
Browse files
Refactor: use common tf_utils
PiperOrigin-RevId: 285613648
parent
4b06a97a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
54 deletions
+9
-54
official/nlp/xlnet/xlnet_modeling.py
official/nlp/xlnet/xlnet_modeling.py
+9
-54
No files found.
official/nlp/xlnet/xlnet_modeling.py
View file @
e5c71d51
...
@@ -23,6 +23,7 @@ import copy
...
@@ -23,6 +23,7 @@ import copy
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.xlnet
import
data_utils
from
official.nlp.xlnet
import
data_utils
...
@@ -102,52 +103,6 @@ def is_special_none_tensor(tensor):
...
@@ -102,52 +103,6 @@ def is_special_none_tensor(tensor):
return
tensor
.
shape
.
ndims
==
0
and
tensor
.
dtype
==
tf
.
int32
return
tensor
.
shape
.
ndims
==
0
and
tensor
.
dtype
==
tf
.
int32
def
unpack_inputs
(
inputs
):
"""Unpacks a tuple of `inputs` tensors to a tuple.
Args:
inputs: A list of tensors.
Returns:
A tuple of tensors. If any input is a special constant tensor, replace it
with None.
"""
inputs
=
tf
.
nest
.
flatten
(
inputs
)
outputs
=
[]
for
x
in
inputs
:
if
is_special_none_tensor
(
x
):
outputs
.
append
(
None
)
else
:
outputs
.
append
(
x
)
x
=
tuple
(
outputs
)
# To trick the very pointless 'unbalanced-tuple-unpacking' pylint check
# from triggering.
if
len
(
x
)
==
1
:
return
x
[
0
]
return
tuple
(
outputs
)
def
pack_inputs
(
inputs
):
"""Packs a list of `inputs` tensors to a tuple.
Args:
inputs: A list of tensors.
Returns:
A tuple of tensors. If any input is None, replace it with a special constant
tensor.
"""
inputs
=
tf
.
nest
.
flatten
(
inputs
)
outputs
=
[]
for
x
in
inputs
:
if
x
is
None
:
outputs
.
append
(
tf
.
constant
(
0
,
shape
=
[],
dtype
=
tf
.
int32
))
else
:
outputs
.
append
(
x
)
return
tuple
(
outputs
)
class
PositionalEmbedding
(
tf
.
keras
.
layers
.
Layer
):
class
PositionalEmbedding
(
tf
.
keras
.
layers
.
Layer
):
"""Generates relative positional embeddings used in Transformer-XL and XLNet."""
"""Generates relative positional embeddings used in Transformer-XL and XLNet."""
...
@@ -196,7 +151,7 @@ class RelativeAttention(tf.keras.layers.Layer):
...
@@ -196,7 +151,7 @@ class RelativeAttention(tf.keras.layers.Layer):
def
__call__
(
self
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
def
__call__
(
self
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask
,
**
kwargs
):
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask
,
**
kwargs
):
inputs
=
pack_inputs
([
inputs
=
tf_utils
.
pack_inputs
([
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask
r_r_bias
,
r_s_bias
,
attn_mask
])
])
...
@@ -205,7 +160,7 @@ class RelativeAttention(tf.keras.layers.Layer):
...
@@ -205,7 +160,7 @@ class RelativeAttention(tf.keras.layers.Layer):
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
"""Implements call() for the layer."""
(
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
(
q_head
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask
)
=
unpack_inputs
(
inputs
)
r_r_bias
,
r_s_bias
,
attn_mask
)
=
tf_utils
.
unpack_inputs
(
inputs
)
# content based attention score
# content based attention score
ac
=
tf
.
einsum
(
'ibnd,jbnd->ijbn'
,
q_head
+
r_w_bias
,
k_head_h
)
ac
=
tf
.
einsum
(
'ibnd,jbnd->ijbn'
,
q_head
+
r_w_bias
,
k_head_h
)
...
@@ -363,7 +318,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
...
@@ -363,7 +318,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
def
__call__
(
self
,
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
def
__call__
(
self
,
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
,
**
kwargs
):
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
,
**
kwargs
):
inputs
=
pack_inputs
([
inputs
=
tf_utils
.
pack_inputs
([
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
,
attn_mask_g
,
mems
,
target_mapping
,
])
])
...
@@ -372,7 +327,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
...
@@ -372,7 +327,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
"""Implements call() for the layer."""
(
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
(
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
)
=
unpack_inputs
(
inputs
)
attn_mask_g
,
mems
,
target_mapping
)
=
tf_utils
.
unpack_inputs
(
inputs
)
if
mems
is
not
None
and
mems
.
shape
.
ndims
>
1
:
if
mems
is
not
None
and
mems
.
shape
.
ndims
>
1
:
cat
=
tf
.
concat
([
mems
,
h
],
0
)
cat
=
tf
.
concat
([
mems
,
h
],
0
)
...
@@ -1011,12 +966,12 @@ class LMLossLayer(tf.keras.layers.Layer):
...
@@ -1011,12 +966,12 @@ class LMLossLayer(tf.keras.layers.Layer):
super
(
LMLossLayer
,
self
).
build
(
unused_input_shapes
)
super
(
LMLossLayer
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
hidden
,
target
,
lookup_table
,
target_mask
,
**
kwargs
):
def
__call__
(
self
,
hidden
,
target
,
lookup_table
,
target_mask
,
**
kwargs
):
inputs
=
pack_inputs
([
hidden
,
target
,
lookup_table
,
target_mask
])
inputs
=
tf_utils
.
pack_inputs
([
hidden
,
target
,
lookup_table
,
target_mask
])
return
super
(
LMLossLayer
,
self
).
__call__
(
inputs
,
**
kwargs
)
return
super
(
LMLossLayer
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
"""Implements call() for the layer."""
(
hidden
,
target
,
lookup_table
,
tgt_mask
)
=
unpack_inputs
(
inputs
)
(
hidden
,
target
,
lookup_table
,
tgt_mask
)
=
tf_utils
.
unpack_inputs
(
inputs
)
if
self
.
use_proj
:
if
self
.
use_proj
:
hidden
=
self
.
proj_layer_norm
(
self
.
proj_layer
(
hidden
))
hidden
=
self
.
proj_layer_norm
(
self
.
proj_layer
(
hidden
))
if
self
.
tie_weight
:
if
self
.
tie_weight
:
...
@@ -1122,12 +1077,12 @@ class ClassificationLossLayer(tf.keras.layers.Layer):
...
@@ -1122,12 +1077,12 @@ class ClassificationLossLayer(tf.keras.layers.Layer):
super
(
ClassificationLossLayer
,
self
).
build
(
unused_input_shapes
)
super
(
ClassificationLossLayer
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
hidden
,
labels
,
**
kwargs
):
def
__call__
(
self
,
hidden
,
labels
,
**
kwargs
):
inputs
=
pack_inputs
([
hidden
,
labels
])
inputs
=
tf_utils
.
pack_inputs
([
hidden
,
labels
])
return
super
(
ClassificationLossLayer
,
self
).
__call__
(
inputs
,
**
kwargs
)
return
super
(
ClassificationLossLayer
,
self
).
__call__
(
inputs
,
**
kwargs
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
"""Implements call() for the layer."""
(
hidden
,
labels
)
=
unpack_inputs
(
inputs
)
(
hidden
,
labels
)
=
tf_utils
.
unpack_inputs
(
inputs
)
logits
=
self
.
proj_layer
(
hidden
)
logits
=
self
.
proj_layer
(
hidden
)
one_hot_target
=
tf
.
one_hot
(
labels
,
self
.
n_class
,
dtype
=
hidden
.
dtype
)
# pytype: disable=attribute-error
one_hot_target
=
tf
.
one_hot
(
labels
,
self
.
n_class
,
dtype
=
hidden
.
dtype
)
# pytype: disable=attribute-error
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment