Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b045ce7d
Commit
b045ce7d
authored
Oct 03, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Oct 03, 2019
Browse files
Internal change
PiperOrigin-RevId: 272777104
parent
0f176f6f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
145 additions
and
150 deletions
+145
-150
official/nlp/xlnet/data_utils.py
official/nlp/xlnet/data_utils.py
+4
-0
official/nlp/xlnet/run_pretrain.py
official/nlp/xlnet/run_pretrain.py
+6
-4
official/nlp/xlnet/squad_utils.py
official/nlp/xlnet/squad_utils.py
+6
-11
official/nlp/xlnet_config.py
official/nlp/xlnet_config.py
+5
-5
official/nlp/xlnet_modeling.py
official/nlp/xlnet_modeling.py
+124
-130
No files found.
official/nlp/xlnet/data_utils.py
View file @
b045ce7d
...
@@ -43,6 +43,10 @@ CLS_ID = special_symbols["<cls>"]
...
@@ -43,6 +43,10 @@ CLS_ID = special_symbols["<cls>"]
SEP_ID
=
special_symbols
[
"<sep>"
]
SEP_ID
=
special_symbols
[
"<sep>"
]
MASK_ID
=
special_symbols
[
"<mask>"
]
MASK_ID
=
special_symbols
[
"<mask>"
]
EOD_ID
=
special_symbols
[
"<eod>"
]
EOD_ID
=
special_symbols
[
"<eod>"
]
SEG_ID_P
=
0
SEG_ID_Q
=
1
SEG_ID_CLS
=
2
SEG_ID_PAD
=
3
def
file_based_input_fn_builder
(
input_file
,
name_to_features
,
batch_size
,
def
file_based_input_fn_builder
(
input_file
,
name_to_features
,
batch_size
,
...
...
official/nlp/xlnet/run_pretrain.py
View file @
b045ce7d
...
@@ -48,8 +48,11 @@ FLAGS = flags.FLAGS
...
@@ -48,8 +48,11 @@ FLAGS = flags.FLAGS
def
get_pretrainxlnet_model
(
model_config
,
run_config
):
def
get_pretrainxlnet_model
(
model_config
,
run_config
):
model
=
modeling
.
PretrainingXLNetModel
(
model_config
,
run_config
,
name
=
"model"
)
return
modeling
.
PretrainingXLNetModel
(
return
model
use_proj
=
True
,
xlnet_config
=
model_config
,
run_config
=
run_config
,
name
=
"model"
)
def
main
(
unused_argv
):
def
main
(
unused_argv
):
...
@@ -69,8 +72,7 @@ def main(unused_argv):
...
@@ -69,8 +72,7 @@ def main(unused_argv):
if
strategy
:
if
strategy
:
logging
.
info
(
"***** Number of cores used : %d"
,
logging
.
info
(
"***** Number of cores used : %d"
,
strategy
.
num_replicas_in_sync
)
strategy
.
num_replicas_in_sync
)
logging
.
info
(
"***** Number of hosts used : %d"
,
logging
.
info
(
"***** Number of hosts used : %d"
,
num_hosts
)
num_hosts
)
train_input_fn
=
functools
.
partial
(
train_input_fn
=
functools
.
partial
(
data_utils
.
get_pretrain_input_data
,
FLAGS
.
train_batch_size
,
FLAGS
.
seq_len
,
data_utils
.
get_pretrain_input_data
,
FLAGS
.
train_batch_size
,
FLAGS
.
seq_len
,
strategy
,
FLAGS
.
train_tfrecord_path
,
FLAGS
.
reuse_len
,
FLAGS
.
perm_size
,
strategy
,
FLAGS
.
train_tfrecord_path
,
FLAGS
.
reuse_len
,
FLAGS
.
perm_size
,
...
...
official/nlp/xlnet/squad_utils.py
View file @
b045ce7d
...
@@ -36,11 +36,6 @@ from official.nlp.xlnet import preprocess_utils
...
@@ -36,11 +36,6 @@ from official.nlp.xlnet import preprocess_utils
SPIECE_UNDERLINE
=
u
"▁"
SPIECE_UNDERLINE
=
u
"▁"
SEG_ID_P
=
0
SEG_ID_Q
=
1
SEG_ID_CLS
=
2
SEG_ID_PAD
=
3
class
InputFeatures
(
object
):
class
InputFeatures
(
object
):
"""A single set of features of data."""
"""A single set of features of data."""
...
@@ -705,28 +700,28 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
...
@@ -705,28 +700,28 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
split_token_index
)
split_token_index
)
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
segment_ids
.
append
(
SEG_ID_P
)
segment_ids
.
append
(
data_utils
.
SEG_ID_P
)
p_mask
.
append
(
0
)
p_mask
.
append
(
0
)
paragraph_len
=
len
(
tokens
)
paragraph_len
=
len
(
tokens
)
tokens
.
append
(
data_utils
.
SEP_ID
)
tokens
.
append
(
data_utils
.
SEP_ID
)
segment_ids
.
append
(
SEG_ID_P
)
segment_ids
.
append
(
data_utils
.
SEG_ID_P
)
p_mask
.
append
(
1
)
p_mask
.
append
(
1
)
# note(zhiliny): we put P before Q
# note(zhiliny): we put P before Q
# because during pretraining, B is always shorter than A
# because during pretraining, B is always shorter than A
for
token
in
query_tokens
:
for
token
in
query_tokens
:
tokens
.
append
(
token
)
tokens
.
append
(
token
)
segment_ids
.
append
(
SEG_ID_Q
)
segment_ids
.
append
(
data_utils
.
SEG_ID_Q
)
p_mask
.
append
(
1
)
p_mask
.
append
(
1
)
tokens
.
append
(
data_utils
.
SEP_ID
)
tokens
.
append
(
data_utils
.
SEP_ID
)
segment_ids
.
append
(
SEG_ID_Q
)
segment_ids
.
append
(
data_utils
.
SEG_ID_Q
)
p_mask
.
append
(
1
)
p_mask
.
append
(
1
)
cls_index
=
len
(
segment_ids
)
cls_index
=
len
(
segment_ids
)
tokens
.
append
(
data_utils
.
CLS_ID
)
tokens
.
append
(
data_utils
.
CLS_ID
)
segment_ids
.
append
(
SEG_ID_CLS
)
segment_ids
.
append
(
data_utils
.
SEG_ID_CLS
)
p_mask
.
append
(
0
)
p_mask
.
append
(
0
)
input_ids
=
tokens
input_ids
=
tokens
...
@@ -739,7 +734,7 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
...
@@ -739,7 +734,7 @@ def convert_examples_to_features(examples, sp_model, max_seq_length, doc_stride,
while
len
(
input_ids
)
<
max_seq_length
:
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_ids
.
append
(
0
)
input_mask
.
append
(
1
)
input_mask
.
append
(
1
)
segment_ids
.
append
(
SEG_ID_PAD
)
segment_ids
.
append
(
data_utils
.
SEG_ID_PAD
)
p_mask
.
append
(
1
)
p_mask
.
append
(
1
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_ids
)
==
max_seq_length
...
...
official/nlp/xlnet_config.py
View file @
b045ce7d
...
@@ -30,7 +30,6 @@ def create_run_config(is_training, is_finetune, flags):
...
@@ -30,7 +30,6 @@ def create_run_config(is_training, is_finetune, flags):
kwargs
=
dict
(
kwargs
=
dict
(
is_training
=
is_training
,
is_training
=
is_training
,
use_tpu
=
flags
.
use_tpu
,
use_tpu
=
flags
.
use_tpu
,
use_bfloat16
=
flags
.
use_bfloat16
,
dropout
=
flags
.
dropout
,
dropout
=
flags
.
dropout
,
dropout_att
=
flags
.
dropout_att
,
dropout_att
=
flags
.
dropout_att
,
init_method
=
flags
.
init_method
,
init_method
=
flags
.
init_method
,
...
@@ -49,6 +48,7 @@ def create_run_config(is_training, is_finetune, flags):
...
@@ -49,6 +48,7 @@ def create_run_config(is_training, is_finetune, flags):
return
RunConfig
(
**
kwargs
)
return
RunConfig
(
**
kwargs
)
# TODO(hongkuny): refactor XLNetConfig and RunConfig.
class
XLNetConfig
(
object
):
class
XLNetConfig
(
object
):
"""Configs for XLNet model.
"""Configs for XLNet model.
...
@@ -131,7 +131,6 @@ class RunConfig(object):
...
@@ -131,7 +131,6 @@ class RunConfig(object):
def
__init__
(
self
,
def
__init__
(
self
,
is_training
,
is_training
,
use_tpu
,
use_tpu
,
use_bfloat16
,
dropout
,
dropout
,
dropout_att
,
dropout_att
,
init_method
=
'normal'
,
init_method
=
'normal'
,
...
@@ -141,13 +140,13 @@ class RunConfig(object):
...
@@ -141,13 +140,13 @@ class RunConfig(object):
reuse_len
=
None
,
reuse_len
=
None
,
bi_data
=
False
,
bi_data
=
False
,
clamp_len
=-
1
,
clamp_len
=-
1
,
same_length
=
False
):
same_length
=
False
,
use_cls_mask
=
True
):
"""Initializes RunConfig.
"""Initializes RunConfig.
Args:
Args:
is_training: bool, whether in training mode.
is_training: bool, whether in training mode.
use_tpu: bool, whether TPUs are used.
use_tpu: bool, whether TPUs are used.
use_bfloat16: bool, use bfloat16 instead of float32.
dropout: float, dropout rate.
dropout: float, dropout rate.
dropout_att: float, dropout rate on attention probabilities.
dropout_att: float, dropout rate on attention probabilities.
init_method: str, the initialization scheme, either "normal" or "uniform".
init_method: str, the initialization scheme, either "normal" or "uniform".
...
@@ -164,6 +163,7 @@ class RunConfig(object):
...
@@ -164,6 +163,7 @@ class RunConfig(object):
-1 means no clamping.
-1 means no clamping.
same_length: bool, whether to use the same attention length
same_length: bool, whether to use the same attention length
for each token.
for each token.
use_cls_mask: bool, whether to introduce cls mask.
"""
"""
self
.
init_method
=
init_method
self
.
init_method
=
init_method
...
@@ -173,9 +173,9 @@ class RunConfig(object):
...
@@ -173,9 +173,9 @@ class RunConfig(object):
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
dropout_att
=
dropout_att
self
.
dropout_att
=
dropout_att
self
.
use_tpu
=
use_tpu
self
.
use_tpu
=
use_tpu
self
.
use_bfloat16
=
use_bfloat16
self
.
mem_len
=
mem_len
self
.
mem_len
=
mem_len
self
.
reuse_len
=
reuse_len
self
.
reuse_len
=
reuse_len
self
.
bi_data
=
bi_data
self
.
bi_data
=
bi_data
self
.
clamp_len
=
clamp_len
self
.
clamp_len
=
clamp_len
self
.
same_length
=
same_length
self
.
same_length
=
same_length
self
.
use_cls_mask
=
use_cls_mask
official/nlp/xlnet_modeling.py
View file @
b045ce7d
...
@@ -23,6 +23,7 @@ import copy
...
@@ -23,6 +23,7 @@ import copy
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.xlnet
import
data_utils
def
gelu
(
x
):
def
gelu
(
x
):
...
@@ -96,19 +97,6 @@ def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
...
@@ -96,19 +97,6 @@ def _cache_mem(curr_out, prev_mem, mem_len, reuse_len=None):
return
tf
.
keras
.
backend
.
stop_gradient
(
new_mem
)
return
tf
.
keras
.
backend
.
stop_gradient
(
new_mem
)
def
embedding_lookup
(
lookup_table
,
x
,
use_tpu
=
True
):
"""Looks up words embeddings for input id tensor."""
if
use_tpu
:
n_token
=
tf
.
shape
(
lookup_table
)[
0
]
one_hot_idx
=
tf
.
one_hot
(
x
,
n_token
)
if
one_hot_idx
.
shape
.
ndims
==
2
:
return
tf
.
einsum
(
'nd,in->id'
,
lookup_table
,
one_hot_idx
)
else
:
return
tf
.
einsum
(
'nd,ibn->ibd'
,
lookup_table
,
one_hot_idx
)
else
:
return
tf
.
nn
.
embedding_lookup
(
lookup_table
,
x
)
def
is_special_none_tensor
(
tensor
):
def
is_special_none_tensor
(
tensor
):
"""Checks if a tensor is a special None Tensor."""
"""Checks if a tensor is a special None Tensor."""
return
tensor
.
shape
.
ndims
==
0
and
tensor
.
dtype
==
tf
.
int32
return
tensor
.
shape
.
ndims
==
0
and
tensor
.
dtype
==
tf
.
int32
...
@@ -169,7 +157,7 @@ class PositionalEmbedding(tf.keras.layers.Layer):
...
@@ -169,7 +157,7 @@ class PositionalEmbedding(tf.keras.layers.Layer):
def
build
(
self
,
unused_input_shapes
):
def
build
(
self
,
unused_input_shapes
):
"""Constructs inversed frequency vector for positional embedding layer."""
"""Constructs inversed frequency vector for positional embedding layer."""
self
.
inv_freq
=
1.0
/
(
10000.0
**
(
tf
.
range
(
0
,
self
.
dim
,
2.0
)
/
self
.
dim
))
self
.
inv_freq
=
1.0
/
(
10000.0
**
(
tf
.
range
(
0
,
self
.
dim
,
2.0
)
/
self
.
dim
))
super
(
PositionalEmbedding
,
self
).
build
(
unused_input_shapes
)
super
(
PositionalEmbedding
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
pos_seq
,
batch_size
):
def
__call__
(
self
,
pos_seq
,
batch_size
):
...
@@ -232,8 +220,12 @@ class RelativeAttention(tf.keras.layers.Layer):
...
@@ -232,8 +220,12 @@ class RelativeAttention(tf.keras.layers.Layer):
if
seg_mat
is
None
:
if
seg_mat
is
None
:
ef
=
0
ef
=
0
else
:
else
:
ef
=
tf
.
einsum
(
'ibnd,snd->ibns'
,
q_head
+
r_s_bias
,
seg_embed
)
ef
=
tf
.
einsum
(
'ibnd,snd->isbn'
,
q_head
+
r_s_bias
,
seg_embed
)
ef
=
tf
.
einsum
(
'ijbs,ibns->ijbn'
,
seg_mat
,
ef
)
tgt_shape
=
tf
.
shape
(
bd
)
ef
=
tf
.
where
(
tf
.
broadcast_to
(
tf
.
expand_dims
(
seg_mat
,
3
),
tgt_shape
),
tf
.
broadcast_to
(
ef
[:,
1
:,
:,
:],
tgt_shape
),
tf
.
broadcast_to
(
ef
[:,
:
1
,
:,
:],
tgt_shape
))
# merges attention scores and performs masking
# merges attention scores and performs masking
attn_score
=
(
ac
+
bd
+
ef
)
*
self
.
scale
attn_score
=
(
ac
+
bd
+
ef
)
*
self
.
scale
...
@@ -253,8 +245,8 @@ class RelativeAttention(tf.keras.layers.Layer):
...
@@ -253,8 +245,8 @@ class RelativeAttention(tf.keras.layers.Layer):
class
PositionwiseFF
(
tf
.
keras
.
layers
.
Layer
):
class
PositionwiseFF
(
tf
.
keras
.
layers
.
Layer
):
"""Positionwise feed-forward layer."""
"""Positionwise feed-forward layer."""
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
def
__init__
(
self
,
d_model
,
d_inner
,
dropout
,
kernel_initializer
,
kernel_initializer
,
activation_type
,
**
kwargs
):
activation_type
,
**
kwargs
):
super
(
PositionwiseFF
,
self
).
__init__
(
**
kwargs
)
super
(
PositionwiseFF
,
self
).
__init__
(
**
kwargs
)
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
...
@@ -282,10 +274,8 @@ class PositionwiseFF(tf.keras.layers.Layer):
...
@@ -282,10 +274,8 @@ class PositionwiseFF(tf.keras.layers.Layer):
units
=
self
.
d_model
,
units
=
self
.
d_model
,
kernel_initializer
=
self
.
kernel_initializer
,
kernel_initializer
=
self
.
kernel_initializer
,
name
=
'layer_2'
))
name
=
'layer_2'
))
self
.
inner_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout
,
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
name
=
'drop_1'
)
rate
=
self
.
dropout
,
name
=
'drop_2'
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout
,
name
=
'drop_2'
)
self
.
output_layer_norm
=
(
self
.
output_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
'LayerNorm'
,
axis
=-
1
,
epsilon
=
1e-12
))
name
=
'LayerNorm'
,
axis
=-
1
,
epsilon
=
1e-12
))
...
@@ -295,7 +285,6 @@ class PositionwiseFF(tf.keras.layers.Layer):
...
@@ -295,7 +285,6 @@ class PositionwiseFF(tf.keras.layers.Layer):
"""Implements call() for the layer."""
"""Implements call() for the layer."""
output
=
self
.
inner_projection_layer
(
inp
)
output
=
self
.
inner_projection_layer
(
inp
)
output
=
self
.
inner_dropout
(
output
)
output
=
self
.
output_projection_layer
(
output
)
output
=
self
.
output_projection_layer
(
output
)
output
=
self
.
output_dropout
(
output
)
output
=
self
.
output_dropout
(
output
)
output
=
self
.
output_layer_norm
(
output
+
inp
)
output
=
self
.
output_layer_norm
(
output
+
inp
)
...
@@ -305,14 +294,11 @@ class PositionwiseFF(tf.keras.layers.Layer):
...
@@ -305,14 +294,11 @@ class PositionwiseFF(tf.keras.layers.Layer):
class
EmbeddingLookup
(
tf
.
keras
.
layers
.
Layer
):
class
EmbeddingLookup
(
tf
.
keras
.
layers
.
Layer
):
"""Looks up words embeddings for id tensor."""
"""Looks up words embeddings for id tensor."""
def
__init__
(
self
,
def
__init__
(
self
,
n_token
,
d_embed
,
initializer
,
**
kwargs
):
n_token
,
d_embed
,
initializer
,
use_one_hot
=
False
,
**
kwargs
):
super
(
EmbeddingLookup
,
self
).
__init__
(
**
kwargs
)
super
(
EmbeddingLookup
,
self
).
__init__
(
**
kwargs
)
self
.
n_token
=
n_token
self
.
n_token
=
n_token
self
.
d_embed
=
d_embed
self
.
d_embed
=
d_embed
self
.
initializer
=
initializer
self
.
initializer
=
initializer
self
.
use_one_hot
=
use_one_hot
def
build
(
self
,
unused_input_shapes
):
def
build
(
self
,
unused_input_shapes
):
"""Implements build() for the layer."""
"""Implements build() for the layer."""
...
@@ -325,20 +311,7 @@ class EmbeddingLookup(tf.keras.layers.Layer):
...
@@ -325,20 +311,7 @@ class EmbeddingLookup(tf.keras.layers.Layer):
super
(
EmbeddingLookup
,
self
).
build
(
unused_input_shapes
)
super
(
EmbeddingLookup
,
self
).
build
(
unused_input_shapes
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
x
=
inputs
return
tf
.
nn
.
embedding_lookup
(
self
.
lookup_table
,
inputs
)
if
self
.
use_one_hot
:
one_hot_idx
=
tf
.
one_hot
(
x
,
self
.
n_token
,
dtype
=
self
.
dtype
)
if
one_hot_idx
.
shape
.
ndims
==
2
:
return
tf
.
einsum
(
'in,nd->id'
,
one_hot_idx
,
self
.
lookup_table
),
self
.
lookup_table
else
:
return
tf
.
einsum
(
'ibn,nd->ibd'
,
one_hot_idx
,
self
.
lookup_table
),
self
.
lookup_table
else
:
return
tf
.
nn
.
embedding_lookup
(
self
.
lookup_table
,
x
),
self
.
lookup_table
class
TwoStreamRelativeAttention
(
tf
.
keras
.
layers
.
Layer
):
class
TwoStreamRelativeAttention
(
tf
.
keras
.
layers
.
Layer
):
...
@@ -356,9 +329,10 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer):
...
@@ -356,9 +329,10 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer):
def
build
(
self
,
unused_input_shapes
):
def
build
(
self
,
unused_input_shapes
):
"""Implements build() for the layer."""
"""Implements build() for the layer."""
self
.
scale
=
1.0
/
(
self
.
d_head
**
0.5
)
self
.
scale
=
1.0
/
(
self
.
d_head
**
0.5
)
self
.
attention_projection_layer
=
tf
.
keras
.
layers
.
Dense
(
self
.
attention_projection_layer
=
tf
.
keras
.
layers
.
Dense
(
units
=
self
.
d_model
,
use_bias
=
False
,
units
=
self
.
d_model
,
use_bias
=
False
,
kernel_initializer
=
self
.
initializer
,
kernel_initializer
=
self
.
initializer
,
name
=
'o'
)
name
=
'o'
)
self
.
attention_probs_dropout
=
tf
.
keras
.
layers
.
Dropout
(
self
.
attention_probs_dropout
=
tf
.
keras
.
layers
.
Dropout
(
...
@@ -403,9 +377,8 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer):
...
@@ -403,9 +377,8 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer):
super
(
TwoStreamRelativeAttention
,
self
).
build
(
unused_input_shapes
)
super
(
TwoStreamRelativeAttention
,
self
).
build
(
unused_input_shapes
)
def
__call__
(
self
,
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
def
__call__
(
self
,
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
attn_mask_g
,
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
):
mems
,
target_mapping
):
inputs
=
pack_inputs
([
inputs
=
pack_inputs
([
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
h
,
g
,
r
,
r_w_bias
,
r_r_bias
,
seg_mat
,
r_s_bias
,
seg_embed
,
attn_mask_h
,
attn_mask_g
,
mems
,
target_mapping
attn_mask_g
,
mems
,
target_mapping
...
@@ -455,15 +428,17 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer):
...
@@ -455,15 +428,17 @@ class TwoStreamRelativeAttention(tf.keras.layers.Layer):
q_head_g
=
tf
.
einsum
(
'mbnd,mlb->lbnd'
,
q_head_g
,
target_mapping
)
q_head_g
=
tf
.
einsum
(
'mbnd,mlb->lbnd'
,
q_head_g
,
target_mapping
)
attn_vec_g
=
self
.
g_attention_layer
(
attn_vec_g
=
self
.
g_attention_layer
(
q_head_g
,
k_head_h
,
v_head_h
,
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
k_head_r
,
seg_embed
,
seg_mat
,
r_r_bias
,
r_s_bias
,
attn_mask_g
)
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask_g
)
attn_vec_g
=
tf
.
einsum
(
'lbnd,mlb->mbnd'
,
attn_vec_g
,
target_mapping
)
attn_vec_g
=
tf
.
einsum
(
'lbnd,mlb->mbnd'
,
attn_vec_g
,
target_mapping
)
else
:
else
:
attn_vec_g
=
self
.
g_attention_layer
(
attn_vec_g
=
self
.
g_attention_layer
(
q_head_g
,
k_head_h
,
v_head_h
,
q_head_g
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
k_head_r
,
seg_embed
,
seg_mat
,
r_r_bias
,
r_s_bias
,
attn_mask_g
)
r_w_bias
,
r_r_bias
,
r_s_bias
,
attn_mask_g
)
# post processing
# post processing
...
@@ -491,7 +466,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
...
@@ -491,7 +466,7 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
def
build
(
self
,
unused_input_shapes
):
def
build
(
self
,
unused_input_shapes
):
"""Implements build() for the layer."""
"""Implements build() for the layer."""
self
.
scale
=
1.0
/
(
self
.
d_head
**
0.5
)
self
.
scale
=
1.0
/
(
self
.
d_head
**
0.5
)
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
'LayerNorm'
,
axis
=-
1
,
epsilon
=
1e-12
)
name
=
'LayerNorm'
,
axis
=-
1
,
epsilon
=
1e-12
)
...
@@ -555,9 +530,9 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
...
@@ -555,9 +530,9 @@ class RelativeMultiheadAttention(tf.keras.layers.Layer):
k_head_r
=
tf
.
einsum
(
'ibh,hnd->ibnd'
,
r
,
self
.
kr_projection_layer
)
k_head_r
=
tf
.
einsum
(
'ibh,hnd->ibnd'
,
r
,
self
.
kr_projection_layer
)
# core attention ops
# core attention ops
attn_vec
=
self
.
h_attention_layer
(
attn_vec
=
self
.
h_attention_layer
(
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
q_head_h
,
k_head_h
,
v_head_h
,
k_head_r
,
seg_embed
,
seg_mat
,
r_w_bias
,
seg_embed
,
seg_mat
,
r_w_bias
,
r_r_bias
,
r_r_bias
,
r_s_bias
,
attn_mask
)
r_s_bias
,
attn_mask
)
# post processing
# post processing
...
@@ -592,7 +567,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -592,7 +567,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
use_tpu
=
True
,
use_tpu
=
True
,
reuse_len
=
None
,
reuse_len
=
None
,
ff_activation
=
'relu'
,
ff_activation
=
'relu'
,
use_
bfloat16
=
False
,
use_
cls_mask
=
False
,
**
kwargs
):
**
kwargs
):
"""Initializes TransformerXLModel.
"""Initializes TransformerXLModel.
...
@@ -620,7 +595,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -620,7 +595,7 @@ class TransformerXLModel(tf.keras.layers.Layer):
reuse_len: int, the number of tokens in the currect batch to be cached and
reuse_len: int, the number of tokens in the currect batch to be cached and
reused in the future.
reused in the future.
ff_activation: str, "relu" or "gelu".
ff_activation: str, "relu" or "gelu".
use_
bfloat16: bool, use bfloat16 instead of float32
.
use_
cls_mask: bool, whether to introduce cls mask
.
**kwargs: Other parameters.
**kwargs: Other parameters.
"""
"""
...
@@ -636,7 +611,6 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -636,7 +611,6 @@ class TransformerXLModel(tf.keras.layers.Layer):
self
.
d_inner
=
d_inner
self
.
d_inner
=
d_inner
self
.
ff_activation
=
ff_activation
self
.
ff_activation
=
ff_activation
self
.
untie_r
=
untie_r
self
.
untie_r
=
untie_r
self
.
use_bfloat16
=
use_bfloat16
self
.
use_tpu
=
use_tpu
self
.
use_tpu
=
use_tpu
self
.
dropout
=
dropout
self
.
dropout
=
dropout
self
.
dropout_att
=
dropout_att
self
.
dropout_att
=
dropout_att
...
@@ -646,21 +620,21 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -646,21 +620,21 @@ class TransformerXLModel(tf.keras.layers.Layer):
self
.
bi_data
=
bi_data
self
.
bi_data
=
bi_data
self
.
clamp_len
=
clamp_len
self
.
clamp_len
=
clamp_len
self
.
same_length
=
same_length
self
.
same_length
=
same_length
self
.
use_cls_mask
=
use_cls_mask
def
build
(
self
,
unused_input_shapes
):
def
build
(
self
,
unused_input_shapes
):
"""Implements build() for the layer."""
"""Implements build() for the layer."""
self
.
tf_float
=
tf
.
bfloat16
if
self
.
use_bfloat16
else
tf
.
float32
self
.
tf_float
=
tf
.
float32
self
.
embedding_lookup
=
EmbeddingLookup
(
n_token
=
self
.
n_token
,
self
.
embedding_lookup
=
EmbeddingLookup
(
d_embed
=
self
.
d_model
,
n_token
=
self
.
n_token
,
initializer
=
self
.
initializer
,
d_embed
=
self
.
d_model
,
use_one_hot
=
self
.
use_tpu
,
initializer
=
self
.
initializer
,
dtype
=
self
.
tf_float
,
dtype
=
self
.
tf_float
,
name
=
'word_embedding'
)
name
=
'word_embedding'
)
self
.
h_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout
)
self
.
h_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout
)
self
.
g_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout
)
self
.
g_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout
)
if
self
.
untie_r
:
if
self
.
untie_r
:
self
.
r_w_bias
=
(
self
.
r_w_bias
=
(
...
@@ -702,11 +676,11 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -702,11 +676,11 @@ class TransformerXLModel(tf.keras.layers.Layer):
self
.
seg_embed
=
self
.
add_weight
(
self
.
seg_embed
=
self
.
add_weight
(
'seg_embed'
,
[
self
.
n_layer
,
2
,
self
.
n_head
,
self
.
d_head
],
'seg_embed'
,
[
self
.
n_layer
,
2
,
self
.
n_head
,
self
.
d_head
],
dtype
=
self
.
tf_float
,
initializer
=
self
.
initializer
)
dtype
=
self
.
tf_float
,
initializer
=
self
.
initializer
)
self
.
mask_emb
=
self
.
add_weight
(
'mask_emb/mask_emb'
,
self
.
mask_emb
=
self
.
add_weight
(
shape
=
[
1
,
1
,
self
.
d_model
],
'mask_emb/mask_emb'
,
shape
=
[
1
,
1
,
self
.
d_model
],
dtype
=
self
.
tf_float
)
dtype
=
self
.
tf_float
)
self
.
emb_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout
)
self
.
emb_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout
)
self
.
fwd_position_embedding
=
PositionalEmbedding
(
self
.
d_model
)
self
.
fwd_position_embedding
=
PositionalEmbedding
(
self
.
d_model
)
...
@@ -741,16 +715,16 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -741,16 +715,16 @@ class TransformerXLModel(tf.keras.layers.Layer):
d_inner
=
self
.
d_inner
,
d_inner
=
self
.
d_inner
,
dropout
=
self
.
dropout
,
dropout
=
self
.
dropout
,
kernel_initializer
=
self
.
initializer
,
kernel_initializer
=
self
.
initializer
,
activation_type
=
self
.
ff_activation
,
name
=
'layer_%d/ff'
%
(
i
))
activation_type
=
self
.
ff_activation
,
)
name
=
'layer_%d/ff'
%
(
i
))
)
self
.
h_positionwise_ffn_layers
.
append
(
self
.
h_positionwise_ffn_layers
.
append
(
PositionwiseFF
(
PositionwiseFF
(
d_model
=
self
.
d_model
,
d_model
=
self
.
d_model
,
d_inner
=
self
.
d_inner
,
d_inner
=
self
.
d_inner
,
dropout
=
self
.
dropout
,
dropout
=
self
.
dropout
,
kernel_initializer
=
self
.
initializer
,
kernel_initializer
=
self
.
initializer
,
activation_type
=
self
.
ff_activation
,
name
=
'layer_%d/ff'
%
(
i
))
activation_type
=
self
.
ff_activation
,
)
name
=
'layer_%d/ff'
%
(
i
))
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout
)
self
.
output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout
)
...
@@ -766,9 +740,15 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -766,9 +740,15 @@ class TransformerXLModel(tf.keras.layers.Layer):
inp_q
=
None
):
inp_q
=
None
):
# Uses dict to feed inputs into call() in order to keep mems as a python
# Uses dict to feed inputs into call() in order to keep mems as a python
# list.
# list.
inputs
=
{
'inp_k'
:
inp_k
,
'seg_id'
:
seg_id
,
'input_mask'
:
input_mask
,
inputs
=
{
'mems'
:
mems
,
'perm_mask'
:
perm_mask
,
'inp_k'
:
inp_k
,
'target_mapping'
:
target_mapping
,
'inp_q'
:
inp_q
}
'seg_id'
:
seg_id
,
'input_mask'
:
input_mask
,
'mems'
:
mems
,
'perm_mask'
:
perm_mask
,
'target_mapping'
:
target_mapping
,
'inp_q'
:
inp_q
}
return
super
(
TransformerXLModel
,
self
).
__call__
(
inputs
)
return
super
(
TransformerXLModel
,
self
).
__call__
(
inputs
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
...
@@ -827,14 +807,14 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -827,14 +807,14 @@ class TransformerXLModel(tf.keras.layers.Layer):
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
:
non_tgt_mask
=
-
tf
.
eye
(
qlen
,
dtype
=
self
.
tf_float
)
non_tgt_mask
=
-
tf
.
eye
(
qlen
,
dtype
=
self
.
tf_float
)
non_tgt_mask
=
tf
.
concat
(
[
tf
.
zeros
([
qlen
,
mlen
],
dtype
=
self
.
tf_float
),
non_tgt_mask
=
tf
.
concat
(
non_tgt_mask
],
axis
=-
1
)
[
tf
.
zeros
([
qlen
,
mlen
],
dtype
=
self
.
tf_float
),
non_tgt_mask
],
axis
=-
1
)
non_tgt_mask
=
tf
.
cast
(
(
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
,
non_tgt_mask
=
tf
.
cast
(
dtype
=
self
.
tf_float
)
(
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
,
dtype
=
self
.
tf_float
)
else
:
else
:
non_tgt_mask
=
None
non_tgt_mask
=
None
word_emb_k
,
_
=
self
.
embedding_lookup
(
inp_k
)
word_emb_k
=
self
.
embedding_lookup
(
inp_k
)
if
inp_q
is
not
None
:
if
inp_q
is
not
None
:
if
target_mapping
is
not
None
:
if
target_mapping
is
not
None
:
...
@@ -855,15 +835,18 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -855,15 +835,18 @@ class TransformerXLModel(tf.keras.layers.Layer):
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
tf
.
int32
)
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
tf
.
int32
)
cat_ids
=
tf
.
concat
([
mem_pad
,
seg_id
],
0
)
cat_id
=
tf
.
concat
([
mem_pad
,
seg_id
],
0
)
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
tf
.
cast
(
tf
.
logical_not
(
tf
.
equal
(
seg_id
[:,
None
],
cat_ids
[
None
,
:])),
tf
.
int32
)
seg_mat
=
tf
.
one_hot
(
seg_mat
,
2
,
dtype
=
self
.
tf_float
)
if
self
.
use_cls_mask
:
# `1` indicates not in the same segment [qlen x klen x bsz]
# seg_id: [qlen x bsz] & cat_id: [klen x bsz]
cls_mat
=
tf
.
logical_or
(
tf
.
equal
(
seg_id
,
tf
.
constant
([
data_utils
.
SEG_ID_CLS
]))[:,
None
],
tf
.
equal
(
cat_id
,
tf
.
constant
([
data_utils
.
SEG_ID_CLS
]))[
None
,
:])
seg_mat
=
tf
.
equal
(
seg_id
[:,
None
],
cat_id
[
None
,
:])
seg_mat
=
tf
.
logical_or
(
cls_mat
,
seg_mat
)
else
:
seg_mat
=
tf
.
logical_not
(
tf
.
equal
(
seg_id
[:,
None
],
cat_id
[
None
,
:]))
else
:
else
:
seg_mat
=
None
seg_mat
=
None
...
@@ -894,8 +877,8 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -894,8 +877,8 @@ class TransformerXLModel(tf.keras.layers.Layer):
self
.
clamp_len
)
self
.
clamp_len
)
if
bsz
is
not
None
:
if
bsz
is
not
None
:
fwd_pos_emb
=
self
.
fwd_position_embedding
(
fwd_pos_seq
,
bsz
//
2
)
fwd_pos_emb
=
self
.
fwd_position_embedding
(
fwd_pos_seq
,
bsz
//
2
)
bwd_pos_emb
=
self
.
bwd_position_embedding
(
bwd_pos_seq
,
bsz
//
2
)
bwd_pos_emb
=
self
.
bwd_position_embedding
(
bwd_pos_seq
,
bsz
//
2
)
else
:
else
:
fwd_pos_emb
=
self
.
fwd_position_embedding
(
fwd_pos_seq
,
None
)
fwd_pos_emb
=
self
.
fwd_position_embedding
(
fwd_pos_seq
,
None
)
bwd_pos_emb
=
self
.
bwd_position_embedding
(
bwd_pos_seq
,
None
)
bwd_pos_emb
=
self
.
bwd_position_embedding
(
bwd_pos_seq
,
None
)
...
@@ -906,8 +889,8 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -906,8 +889,8 @@ class TransformerXLModel(tf.keras.layers.Layer):
if
dtype
is
not
None
and
dtype
!=
tf
.
float32
:
if
dtype
is
not
None
and
dtype
!=
tf
.
float32
:
fwd_pos_seq
=
tf
.
cast
(
fwd_pos_seq
,
dtype
=
dtype
)
fwd_pos_seq
=
tf
.
cast
(
fwd_pos_seq
,
dtype
=
dtype
)
if
self
.
clamp_len
>
0
:
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
tf
.
clip_by_value
(
fwd_pos_seq
,
fwd_pos_seq
=
tf
.
clip_by_value
(
fwd_pos_seq
,
-
self
.
clamp_len
,
-
self
.
clamp_len
,
self
.
lamp_len
)
self
.
lamp_len
)
pos_emb
=
self
.
fwd_position_embedding
(
fwd_pos_seq
,
bsz
)
pos_emb
=
self
.
fwd_position_embedding
(
fwd_pos_seq
,
bsz
)
...
@@ -969,9 +952,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
...
@@ -969,9 +952,9 @@ class TransformerXLModel(tf.keras.layers.Layer):
output_h
=
h_ffn_layer
(
output_h
)
output_h
=
h_ffn_layer
(
output_h
)
if
inp_q
is
not
None
:
if
inp_q
is
not
None
:
output
=
self
.
output_dropout
(
output_g
)
output
=
output_g
else
:
else
:
output
=
self
.
output_dropout
(
output_h
)
output
=
output_h
return
output
,
new_mems
,
None
return
output
,
new_mems
,
None
...
@@ -983,7 +966,7 @@ class PretrainingXLNetModel(tf.keras.Model):
...
@@ -983,7 +966,7 @@ class PretrainingXLNetModel(tf.keras.Model):
"""
"""
def
__init__
(
self
,
xlnet_config
,
run_config
,
**
kwargs
):
def
__init__
(
self
,
use_proj
,
xlnet_config
,
run_config
,
**
kwargs
):
super
(
PretrainingXLNetModel
,
self
).
__init__
(
**
kwargs
)
super
(
PretrainingXLNetModel
,
self
).
__init__
(
**
kwargs
)
self
.
run_config
=
run_config
self
.
run_config
=
run_config
self
.
initializer
=
_get_initializer
(
run_config
)
self
.
initializer
=
_get_initializer
(
run_config
)
...
@@ -1001,7 +984,6 @@ class PretrainingXLNetModel(tf.keras.Model):
...
@@ -1001,7 +984,6 @@ class PretrainingXLNetModel(tf.keras.Model):
ff_activation
=
self
.
xlnet_config
.
ff_activation
,
ff_activation
=
self
.
xlnet_config
.
ff_activation
,
untie_r
=
self
.
xlnet_config
.
untie_r
,
untie_r
=
self
.
xlnet_config
.
untie_r
,
is_training
=
self
.
run_config
.
is_training
,
is_training
=
self
.
run_config
.
is_training
,
use_bfloat16
=
self
.
run_config
.
use_bfloat16
,
use_tpu
=
self
.
run_config
.
use_tpu
,
use_tpu
=
self
.
run_config
.
use_tpu
,
dropout
=
self
.
run_config
.
dropout
,
dropout
=
self
.
run_config
.
dropout
,
dropout_att
=
self
.
run_config
.
dropout_att
,
dropout_att
=
self
.
run_config
.
dropout_att
,
...
@@ -1010,15 +992,17 @@ class PretrainingXLNetModel(tf.keras.Model):
...
@@ -1010,15 +992,17 @@ class PretrainingXLNetModel(tf.keras.Model):
bi_data
=
self
.
run_config
.
bi_data
,
bi_data
=
self
.
run_config
.
bi_data
,
clamp_len
=
self
.
run_config
.
clamp_len
,
clamp_len
=
self
.
run_config
.
clamp_len
,
same_length
=
self
.
run_config
.
same_length
,
same_length
=
self
.
run_config
.
same_length
,
use_cls_mask
=
self
.
run_config
.
use_cls_mask
,
name
=
'transformer'
)
name
=
'transformer'
)
self
.
lmloss_layer
=
LMLossLayer
(
n_token
=
self
.
xlnet_config
.
n_token
,
self
.
lmloss_layer
=
LMLossLayer
(
d_model
=
self
.
xlnet_config
.
d_model
,
n_token
=
self
.
xlnet_config
.
n_token
,
initializer
=
self
.
initializer
,
d_model
=
self
.
xlnet_config
.
d_model
,
use_bfloat16
=
self
.
run_config
.
use_bfloat16
,
initializer
=
self
.
initializer
,
tie_weight
=
True
,
tie_weight
=
True
,
bi_data
=
self
.
run_config
.
bi_data
,
bi_data
=
self
.
run_config
.
bi_data
,
use_tpu
=
self
.
run_config
.
use_tpu
,
use_tpu
=
self
.
run_config
.
use_tpu
,
name
=
'lm_loss'
)
use_proj
=
use_proj
,
name
=
'lm_loss'
)
def
call
(
self
,
features
):
def
call
(
self
,
features
):
"""Implements call() for the layer."""
"""Implements call() for the layer."""
...
@@ -1082,7 +1066,6 @@ class ClassificationXLNetModel(tf.keras.Model):
...
@@ -1082,7 +1066,6 @@ class ClassificationXLNetModel(tf.keras.Model):
ff_activation
=
self
.
xlnet_config
.
ff_activation
,
ff_activation
=
self
.
xlnet_config
.
ff_activation
,
untie_r
=
self
.
xlnet_config
.
untie_r
,
untie_r
=
self
.
xlnet_config
.
untie_r
,
is_training
=
self
.
run_config
.
is_training
,
is_training
=
self
.
run_config
.
is_training
,
use_bfloat16
=
self
.
run_config
.
use_bfloat16
,
use_tpu
=
self
.
run_config
.
use_tpu
,
use_tpu
=
self
.
run_config
.
use_tpu
,
dropout
=
self
.
run_config
.
dropout
,
dropout
=
self
.
run_config
.
dropout
,
dropout_att
=
self
.
run_config
.
dropout_att
,
dropout_att
=
self
.
run_config
.
dropout_att
,
...
@@ -1133,23 +1116,28 @@ class ClassificationXLNetModel(tf.keras.Model):
...
@@ -1133,23 +1116,28 @@ class ClassificationXLNetModel(tf.keras.Model):
class
LMLossLayer
(
tf
.
keras
.
layers
.
Layer
):
class
LMLossLayer
(
tf
.
keras
.
layers
.
Layer
):
"""Layer computing cross entropy loss for language modeling."""
"""Layer computing cross entropy loss for language modeling."""
def
__init__
(
self
,
n_token
,
d_model
,
initializer
,
use_bfloat16
,
def
__init__
(
self
,
tie_weight
=
False
,
bi_data
=
True
,
use_tpu
=
False
,
**
kwargs
):
n_token
,
d_model
,
initializer
,
tie_weight
=
False
,
bi_data
=
True
,
use_tpu
=
False
,
use_proj
=
False
,
**
kwargs
):
"""Constructs LMLoss layer.
"""Constructs LMLoss layer.
Args:
Args:
n_token: Number of tokens in vocabulary.
n_token: Number of tokens in vocabulary.
d_model: The dimension of model hidden state.
d_model: The dimension of model hidden state.
initializer: Initializer used for parameters.
initializer: Initializer used for parameters.
use_bfloat16: Whether to use bfloat16.
tie_weight: Whether to share weights between embedding lookup layer and
tie_weight: Whether to share weights between embedding lookup layer and
next-token prediction layer.
next-token prediction layer.
bi_data: Whether to use bidirectional input pipeline.
bi_data: Whether to use bidirectional input pipeline.
Usually set to True
Usually set to True
during pretraining and False during finetuning.
during pretraining and False during finetuning.
use_tpu: bool, whether to use TPU.
use_tpu: bool, whether to use TPU.
use_proj: bool, whether to add a projection layer before LM prediction.
**kwargs: Other parameters.
**kwargs: Other parameters.
"""
"""
super
(
LMLossLayer
,
self
).
__init__
(
**
kwargs
)
super
(
LMLossLayer
,
self
).
__init__
(
**
kwargs
)
self
.
n_token
=
n_token
self
.
n_token
=
n_token
...
@@ -1159,17 +1147,26 @@ class LMLossLayer(tf.keras.layers.Layer):
...
@@ -1159,17 +1147,26 @@ class LMLossLayer(tf.keras.layers.Layer):
self
.
tie_weight
=
tie_weight
self
.
tie_weight
=
tie_weight
self
.
bi_data
=
bi_data
self
.
bi_data
=
bi_data
self
.
use_tpu
=
use_tpu
self
.
use_tpu
=
use_tpu
self
.
use_
bfloat16
=
use_bfloat16
self
.
use_
proj
=
use_proj
def
build
(
self
,
unused_input_shapes
):
def
build
(
self
,
unused_input_shapes
):
"""Implements build() for the layer."""
"""Implements build() for the layer."""
if
self
.
use_proj
:
self
.
proj_layer
=
tf
.
keras
.
layers
.
Dense
(
units
=
self
.
d_model
,
kernel_initializer
=
self
.
initializer
,
activation
=
gelu
,
name
=
'lm_projection'
)
self
.
proj_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
axis
=-
1
,
epsilon
=
1e-12
,
name
=
'lm_projection/LayerNorm'
)
if
not
self
.
tie_weight
:
if
not
self
.
tie_weight
:
self
.
softmax_w
=
self
.
add_weight
(
'weight'
,
self
.
softmax_w
=
self
.
add_weight
(
shape
=
[
self
.
n_token
,
self
.
d_model
],
'weight'
,
initializer
=
self
.
initializer
)
shape
=
[
self
.
n_token
,
self
.
d_model
],
initializer
=
self
.
initializer
)
self
.
softmax_b
=
self
.
add_weight
(
'bias'
,
shape
=
[
self
.
n_token
],
self
.
softmax_b
=
self
.
add_weight
(
initializer
=
tf
.
zeros_initializer
())
'bias'
,
shape
=
[
self
.
n_token
],
initializer
=
tf
.
zeros_initializer
())
super
(
LMLossLayer
,
self
).
build
(
unused_input_shapes
)
super
(
LMLossLayer
,
self
).
build
(
unused_input_shapes
)
...
@@ -1180,6 +1177,8 @@ class LMLossLayer(tf.keras.layers.Layer):
...
@@ -1180,6 +1177,8 @@ class LMLossLayer(tf.keras.layers.Layer):
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
"""Implements call() for the layer."""
"""Implements call() for the layer."""
(
hidden
,
target
,
lookup_table
,
tgt_mask
)
=
unpack_inputs
(
inputs
)
(
hidden
,
target
,
lookup_table
,
tgt_mask
)
=
unpack_inputs
(
inputs
)
if
self
.
use_proj
:
hidden
=
self
.
proj_layer_norm
(
self
.
proj_layer
(
hidden
))
if
self
.
tie_weight
:
if
self
.
tie_weight
:
logits
=
tf
.
einsum
(
'ibd,nd->ibn'
,
hidden
,
lookup_table
)
+
self
.
softmax_b
logits
=
tf
.
einsum
(
'ibd,nd->ibn'
,
hidden
,
lookup_table
)
+
self
.
softmax_b
else
:
else
:
...
@@ -1189,11 +1188,8 @@ class LMLossLayer(tf.keras.layers.Layer):
...
@@ -1189,11 +1188,8 @@ class LMLossLayer(tf.keras.layers.Layer):
one_hot_target
=
tf
.
one_hot
(
target
,
self
.
n_token
,
dtype
=
logits
.
dtype
)
one_hot_target
=
tf
.
one_hot
(
target
,
self
.
n_token
,
dtype
=
logits
.
dtype
)
loss
=
-
tf
.
reduce_sum
(
tf
.
nn
.
log_softmax
(
logits
)
*
one_hot_target
,
-
1
)
loss
=
-
tf
.
reduce_sum
(
tf
.
nn
.
log_softmax
(
logits
)
*
one_hot_target
,
-
1
)
else
:
else
:
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
labels
=
target
,
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
)
labels
=
target
,
logits
=
logits
)
if
self
.
use_bfloat16
:
tgt_mask
=
tf
.
cast
(
tgt_mask
,
tf
.
float32
)
loss
=
tf
.
cast
(
loss
,
tf
.
float32
)
total_loss
=
tf
.
reduce_sum
(
loss
*
tgt_mask
)
/
tf
.
reduce_sum
(
tgt_mask
)
total_loss
=
tf
.
reduce_sum
(
loss
*
tgt_mask
)
/
tf
.
reduce_sum
(
tgt_mask
)
...
@@ -1321,7 +1317,6 @@ class QAXLNetModel(tf.keras.Model):
...
@@ -1321,7 +1317,6 @@ class QAXLNetModel(tf.keras.Model):
ff_activation
=
self
.
xlnet_config
.
ff_activation
,
ff_activation
=
self
.
xlnet_config
.
ff_activation
,
untie_r
=
self
.
xlnet_config
.
untie_r
,
untie_r
=
self
.
xlnet_config
.
untie_r
,
is_training
=
self
.
run_config
.
is_training
,
is_training
=
self
.
run_config
.
is_training
,
use_bfloat16
=
self
.
run_config
.
use_bfloat16
,
use_tpu
=
self
.
run_config
.
use_tpu
,
use_tpu
=
self
.
run_config
.
use_tpu
,
dropout
=
self
.
run_config
.
dropout
,
dropout
=
self
.
run_config
.
dropout
,
dropout_att
=
self
.
run_config
.
dropout_att
,
dropout_att
=
self
.
run_config
.
dropout_att
,
...
@@ -1370,8 +1365,7 @@ class QAXLNetModel(tf.keras.Model):
...
@@ -1370,8 +1365,7 @@ class QAXLNetModel(tf.keras.Model):
class
QALossLayer
(
tf
.
keras
.
layers
.
Layer
):
class
QALossLayer
(
tf
.
keras
.
layers
.
Layer
):
"""Layer computing position and regression loss for question answering task.
"""Layer computing position and regression loss for question answering task."""
"""
def
__init__
(
self
,
d_model
,
start_n_top
,
end_n_top
,
initializer
,
dropout
,
def
__init__
(
self
,
d_model
,
start_n_top
,
end_n_top
,
initializer
,
dropout
,
**
kwargs
):
**
kwargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment