Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
44fc54a1
Commit
44fc54a1
authored
Oct 14, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Oct 14, 2020
Browse files
Implement XLNet QA model.
PiperOrigin-RevId: 337134894
parent
41bcd7d0
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
209 additions
and
205 deletions
+209
-205
official/nlp/modeling/models/__init__.py
official/nlp/modeling/models/__init__.py
+1
-0
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+82
-0
official/nlp/modeling/models/xlnet_test.py
official/nlp/modeling/models/xlnet_test.py
+88
-0
official/nlp/modeling/networks/span_labeling.py
official/nlp/modeling/networks/span_labeling.py
+26
-2
official/nlp/modeling/networks/span_labeling_test.py
official/nlp/modeling/networks/span_labeling_test.py
+12
-9
official/nlp/xlnet/xlnet_models.py
official/nlp/xlnet/xlnet_models.py
+0
-122
official/nlp/xlnet/xlnet_models_test.py
official/nlp/xlnet/xlnet_models_test.py
+0
-72
No files found.
official/nlp/modeling/models/__init__.py
View file @
44fc54a1
...
...
@@ -21,3 +21,4 @@ from official.nlp.modeling.models.dual_encoder import DualEncoder
from
official.nlp.modeling.models.electra_pretrainer
import
ElectraPretrainer
from
official.nlp.modeling.models.seq2seq_transformer
import
*
from
official.nlp.modeling.models.xlnet
import
XLNetClassifier
from
official.nlp.modeling.models.xlnet
import
XLNetSpanLabeler
official/nlp/modeling/models/xlnet.py
View file @
44fc54a1
...
...
@@ -20,6 +20,7 @@ from typing import Any, Mapping, Union
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
...
...
@@ -98,3 +99,84 @@ class XLNetClassifier(tf.keras.Model):
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
XLNetSpanLabeler
(
tf
.
keras
.
Model
):
"""Span labeler model based on XLNet.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Arguments:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
start_n_top: Beam size for span start.
end_n_top: Beam size for span end.
dropout_rate: The dropout rate for the span labeling layer.
span_labeling_activation
initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer.
"""
def
__init__
(
self
,
network
:
Union
[
tf
.
keras
.
layers
.
Layer
,
tf
.
keras
.
Model
],
start_n_top
:
int
,
end_n_top
:
int
,
dropout_rate
:
float
,
span_labeling_activation
:
tf
.
keras
.
initializers
.
Initializer
=
'tanh'
,
initializer
:
tf
.
keras
.
initializers
.
Initializer
=
'glorot_uniform'
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_config
=
{
'network'
:
network
,
'start_n_top'
:
start_n_top
,
'end_n_top'
:
end_n_top
,
'dropout_rate'
:
dropout_rate
,
'span_labeling_activation'
:
span_labeling_activation
,
'initializer'
:
initializer
,
}
self
.
_network
=
network
self
.
_initializer
=
initializer
self
.
_start_n_top
=
start_n_top
self
.
_end_n_top
=
end_n_top
self
.
_dropout_rate
=
dropout_rate
self
.
_activation
=
span_labeling_activation
self
.
span_labeling
=
networks
.
XLNetSpanLabeling
(
input_width
=
network
.
get_config
()[
'inner_size'
],
start_n_top
=
self
.
_start_n_top
,
end_n_top
=
self
.
_end_n_top
,
activation
=
self
.
_activation
,
dropout_rate
=
self
.
_dropout_rate
,
initializer
=
self
.
_initializer
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_ids
=
inputs
[
'input_ids'
]
segment_ids
=
inputs
[
'segment_ids'
]
input_mask
=
inputs
[
'input_mask'
]
class_index
=
tf
.
reshape
(
inputs
[
'class_index'
],
[
-
1
])
position_mask
=
inputs
[
'position_mask'
]
start_positions
=
inputs
[
'start_positions'
]
attention_output
,
new_states
=
self
.
_network
(
input_ids
=
input_ids
,
segment_ids
=
segment_ids
,
input_mask
=
input_mask
)
outputs
=
self
.
span_labeling
(
sequence_data
=
attention_output
,
class_index
=
class_index
,
position_mask
=
position_mask
,
start_positions
=
start_positions
)
return
outputs
,
new_states
def
get_config
(
self
):
return
self
.
_config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/nlp/modeling/models/xlnet_test.py
View file @
44fc54a1
...
...
@@ -133,5 +133,93 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
new_xlnet_trainer_model
.
get_config
())
@
keras_parameterized
.
run_all_keras_modes
class
XLNetSpanLabelerTest
(
keras_parameterized
.
TestCase
):
@
parameterized
.
parameters
(
1
,
2
)
def
test_xlnet_trainer
(
self
,
top_n
):
"""Validate that the Keras object can be created."""
seq_length
=
4
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base
=
_get_xlnet_base
()
# Create an XLNet trainer with the created network.
xlnet_trainer_model
=
xlnet
.
XLNetSpanLabeler
(
network
=
xlnet_base
,
start_n_top
=
top_n
,
end_n_top
=
top_n
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.1
),
span_labeling_activation
=
'tanh'
,
dropout_rate
=
0.1
)
inputs
=
dict
(
input_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
),
segment_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'segment_ids'
),
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'input_mask'
),
position_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'position_mask'
),
class_index
=
tf
.
keras
.
layers
.
Input
(
shape
=
(),
dtype
=
tf
.
int32
,
name
=
'class_index'
),
start_positions
=
tf
.
keras
.
layers
.
Input
(
shape
=
(),
dtype
=
tf
.
int32
,
name
=
'start_positions'
))
outputs
,
_
=
xlnet_trainer_model
(
inputs
)
self
.
assertIsInstance
(
outputs
,
dict
)
# Test tensor value calls for the created model.
batch_size
=
2
sequence_shape
=
(
batch_size
,
seq_length
)
inputs
=
dict
(
input_ids
=
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
segment_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_mask
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
).
astype
(
'float32'
),
position_mask
=
np
.
random
.
randint
(
1
,
size
=
(
sequence_shape
)).
astype
(
'float32'
),
class_index
=
np
.
random
.
randint
(
1
,
size
=
(
batch_size
)).
astype
(
'uint8'
),
start_positions
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,),
maxval
=
5
,
dtype
=
tf
.
int32
))
outputs
,
_
=
xlnet_trainer_model
(
inputs
)
expected_inference_keys
=
{
'start_top_log_probs'
,
'end_top_log_probs'
,
'class_logits'
,
'start_top_index'
,
'end_top_index'
,
}
self
.
assertSetEqual
(
expected_inference_keys
,
set
(
outputs
.
keys
()))
outputs
,
_
=
xlnet_trainer_model
(
inputs
,
training
=
True
)
self
.
assertIsInstance
(
outputs
,
dict
)
expected_train_keys
=
{
'start_log_probs'
,
'end_log_probs'
,
'class_logits'
}
self
.
assertSetEqual
(
expected_train_keys
,
set
(
outputs
.
keys
()))
self
.
assertIsInstance
(
outputs
,
dict
)
def
test_serialize_deserialize
(
self
):
"""Validates that the XLNet trainer can be serialized and deserialized."""
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base
=
_get_xlnet_base
()
# Create an XLNet trainer with the created network.
xlnet_trainer_model
=
xlnet
.
XLNetSpanLabeler
(
network
=
xlnet_base
,
start_n_top
=
2
,
end_n_top
=
2
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.1
),
span_labeling_activation
=
'tanh'
,
dropout_rate
=
0.1
)
# Create another XLNet trainer via serialization and deserialization.
config
=
xlnet_trainer_model
.
get_config
()
new_xlnet_trainer_model
=
xlnet
.
XLNetSpanLabeler
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
_
=
new_xlnet_trainer_model
.
to_json
()
# If serialization was successful, then the new config should match the old.
self
.
assertAllEqual
(
xlnet_trainer_model
.
get_config
(),
new_xlnet_trainer_model
.
get_config
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/modeling/networks/span_labeling.py
View file @
44fc54a1
...
...
@@ -113,6 +113,9 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
positions, and then uses either the true start positions (if training) or
beam search to predict the end positions.
**Note: `compute_with_beam_search` will not work with the Functional API
(https://www.tensorflow.org/guide/keras/functional).
Arguments:
input_width: The innermost dimension of the input tensor to this network.
start_n_top: Beam size for span start.
...
...
@@ -150,6 +153,7 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
self
.
end_logits_inner_dense
=
tf
.
keras
.
layers
.
Dense
(
units
=
input_width
,
kernel_initializer
=
initializer
,
activation
=
activation
,
name
=
'predictions/transform/end_logits/inner'
)
self
.
end_logits_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
axis
=-
1
,
epsilon
=
1e-12
,
...
...
@@ -172,13 +176,33 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
name
=
'predictions/transform/answer_logits/output'
)
def
end_logits
(
self
,
inputs
):
"""Computes the end logits."""
"""Computes the end logits.
Input shapes into the inner, layer norm, output layers should match.
During training, inputs shape should be
[batch_size, seq_length, input_width].
During inference, input shapes should be
[batch_size, seq_length, start_n_top, input_width].
Args:
inputs: The input for end logits.
Returns:
Calculated end logits.
"""
if
len
(
tf
.
shape
(
inputs
))
==
3
:
# inputs: [B, S, H] -> [B, S, 1, H]
inputs
=
tf
.
expand_dims
(
inputs
,
axis
=
2
)
end_logits
=
self
.
end_logits_inner_dense
(
inputs
)
end_logits
=
self
.
end_logits_layer_norm
(
end_logits
)
end_logits
=
self
.
end_logits_output_dense
(
end_logits
)
end_logits
=
tf
.
squeeze
(
end_logits
)
if
tf
.
rank
(
end_logits
)
>
2
:
# shape = [
batch_size, seq_length, start_n_top
]
# shape = [
B, S, K] -> [B, K, S
]
end_logits
=
tf
.
transpose
(
end_logits
,
[
0
,
2
,
1
])
return
end_logits
...
...
official/nlp/modeling/networks/span_labeling_test.py
View file @
44fc54a1
...
...
@@ -234,8 +234,8 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
}
self
.
assertSetEqual
(
expected_keys
,
set
(
output
.
keys
()))
def
test_
functional_model
_invocation
(
self
):
"""Tests basic invocation of this layer wrapped
by
a
Functional model
."""
def
test_
subclass
_invocation
(
self
):
"""Tests basic invocation of this layer wrapped
in
a
subclass
."""
seq_length
=
8
hidden_size
=
4
batch_size
=
2
...
...
@@ -244,7 +244,7 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
dtype
=
tf
.
float32
)
class_index
=
tf
.
keras
.
Input
(
shape
=
(),
dtype
=
tf
.
uint8
)
position_mask
=
tf
.
keras
.
Input
(
shape
=
(
seq_length
),
dtype
=
tf
.
float32
)
start_positions
=
tf
.
keras
.
Input
(
shape
=
(),
dtype
=
tf
.
floa
t32
)
start_positions
=
tf
.
keras
.
Input
(
shape
=
(),
dtype
=
tf
.
in
t32
)
layer
=
span_labeling
.
XLNetSpanLabeling
(
input_width
=
hidden_size
,
...
...
@@ -272,7 +272,8 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
position_mask
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,
seq_length
),
dtype
=
tf
.
float32
)
class_index
=
tf
.
ones
(
shape
=
(
batch_size
,),
dtype
=
tf
.
uint8
)
start_positions
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,),
dtype
=
tf
.
float32
)
start_positions
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,),
maxval
=
5
,
dtype
=
tf
.
int32
)
inputs
=
dict
(
sequence_data
=
sequence_data
,
position_mask
=
position_mask
,
...
...
@@ -282,14 +283,16 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
output
=
model
(
inputs
)
self
.
assertIsInstance
(
output
,
dict
)
# Test `call` with training flag.
output
=
model
.
call
(
inputs
,
training
=
True
)
self
.
assertIsInstance
(
output
,
dict
)
# Test `call` without training flag.
output
=
model
.
call
(
inputs
,
training
=
False
)
output
=
model
(
inputs
,
training
=
False
)
self
.
assertIsInstance
(
output
,
dict
)
# Test `call` with training flag.
# Note: this fails due to incompatibility with the functional API.
with
self
.
assertRaisesRegexp
(
AssertionError
,
'Could not compute output KerasTensor'
):
model
(
inputs
,
training
=
True
)
def
test_serialize_deserialize
(
self
):
# Create a network object that sets all of its config options.
network
=
span_labeling
.
XLNetSpanLabeling
(
...
...
official/nlp/xlnet/xlnet_models.py
deleted
100644 → 0
View file @
41bcd7d0
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet models that are compatible with TF 2.x."""
import
tensorflow
as
tf
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
networks
from
official.nlp.xlnet
import
xlnet_config
def
_get_initializer
(
initialization_method
:
str
,
initialization_range
:
float
,
initialization_std
:
float
)
->
tf
.
keras
.
initializers
.
Initializer
:
"""Gets variable initializer."""
if
initialization_method
==
'uniform'
:
initializer
=
tf
.
keras
.
initializers
.
RandomUniform
(
minval
=-
initialization_range
,
maxval
=
initialization_range
)
elif
initialization_method
==
'normal'
:
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
initialization_std
)
else
:
raise
ValueError
(
'Initializer {} not supported'
.
format
(
initialization_method
))
return
initializer
def
get_xlnet_base
(
model_config
:
xlnet_config
.
XLNetConfig
,
run_config
:
xlnet_config
.
RunConfig
,
attention_type
:
str
,
two_stream
:
bool
,
use_cls_mask
:
bool
)
->
tf
.
keras
.
Model
:
"""Gets an 'XLNetBase' object.
Args:
model_config: the config that defines the core XLNet model.
run_config: separate runtime configuration with extra parameters.
attention_type: the attention type for the base XLNet model, "uni" or "bi".
two_stream: whether or not to use two strema attention.
use_cls_mask: whether or not cls mask is included in the input sequences.
Returns:
An XLNetBase object.
"""
initializer
=
_get_initializer
(
initialization_method
=
run_config
.
init_method
,
initialization_range
=
run_config
.
init_range
,
initialization_std
=
run_config
.
init_std
)
kwargs
=
dict
(
vocab_size
=
model_config
.
n_token
,
num_layers
=
model_config
.
n_layer
,
hidden_size
=
model_config
.
d_model
,
num_attention_heads
=
model_config
.
n_head
,
head_size
=
model_config
.
d_head
,
inner_size
=
model_config
.
d_inner
,
dropout_rate
=
run_config
.
dropout
,
attention_dropout_rate
=
run_config
.
dropout_att
,
attention_type
=
attention_type
,
bi_data
=
run_config
.
bi_data
,
initializer
=
initializer
,
two_stream
=
two_stream
,
tie_attention_biases
=
not
model_config
.
untie_r
,
memory_length
=
run_config
.
mem_len
,
clamp_length
=
run_config
.
clamp_len
,
reuse_length
=
run_config
.
reuse_len
,
inner_activation
=
model_config
.
ff_activation
,
use_cls_mask
=
use_cls_mask
)
return
networks
.
XLNetBase
(
**
kwargs
)
def
classifier_model
(
model_config
:
xlnet_config
.
XLNetConfig
,
run_config
:
xlnet_config
.
RunConfig
,
num_labels
:
int
,
final_layer_initializer
:
tf
.
keras
.
initializers
.
Initializer
=
None
)
->
tf
.
keras
.
Model
:
"""Returns a TF2 Keras XLNet classifier model.
Construct a Keras model for predicting `num_labels` outputs from an input with
maximum sequence length `max_seq_length`.
Args:
model_config: the config that defines the core XLNet model.
run_config: separate runtime configuration with extra parameters.
num_labels: integer, the number of classes.
final_layer_initializer: Initializer for final dense layer. If `None`, then
it defaults to the one specified in `run_config`.
Returns:
Combined prediction model inputs -> (one-hot labels)
XLNet sub-model inputs -> (xlnet_outputs)
where inputs are:
(words, segments, mask, permutation mask,
target mapping, masked tokens)
"""
if
final_layer_initializer
is
not
None
:
initializer
=
final_layer_initializer
else
:
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
mean
=
0.
,
stddev
=
.
02
)
xlnet_base
=
get_xlnet_base
(
model_config
=
model_config
,
run_config
=
run_config
,
attention_type
=
'bi'
,
two_stream
=
False
,
use_cls_mask
=
False
)
return
models
.
XLNetClassifier
(
network
=
xlnet_base
,
num_classes
=
num_labels
,
dropout_rate
=
run_config
.
dropout
,
summary_type
=
'last'
,
initializer
=
initializer
),
xlnet_base
official/nlp/xlnet/xlnet_models_test.py
deleted
100644 → 0
View file @
41bcd7d0
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
official.nlp.xlnet
import
xlnet_config
from
official.nlp.xlnet
import
xlnet_models
class
XLNetModelsTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
XLNetModelsTest
,
self
).
setUp
()
self
.
_xlnet_test_config
=
xlnet_config
.
XLNetConfig
(
args_dict
=
dict
(
n_layer
=
2
,
d_model
=
4
,
n_head
=
1
,
d_head
=
2
,
d_inner
=
4
,
ff_activation
=
'gelu'
,
untie_r
=
True
,
n_token
=
32000
))
self
.
_run_config
=
xlnet_config
.
RunConfig
(
is_training
=
True
,
use_tpu
=
False
,
dropout
=
0.0
,
dropout_att
=
0.0
,
init_method
=
'normal'
,
init_range
=
0.1
,
init_std
=
0.02
,
mem_len
=
0
,
reuse_len
=
4
,
bi_data
=
False
,
clamp_len
=-
1
,
same_length
=
False
)
def
test_xlnet_base
(
self
):
xlnet_base
=
xlnet_models
.
get_xlnet_base
(
model_config
=
self
.
_xlnet_test_config
,
run_config
=
self
.
_run_config
,
attention_type
=
'bi'
,
two_stream
=
False
,
use_cls_mask
=
False
)
self
.
assertIsInstance
(
xlnet_base
,
tf
.
keras
.
layers
.
Layer
)
def
test_xlnet_classifier
(
self
):
xlnet_classifier
,
xlnet_base
=
xlnet_models
.
classifier_model
(
model_config
=
self
.
_xlnet_test_config
,
run_config
=
self
.
_run_config
,
num_labels
=
2
)
self
.
assertIsInstance
(
xlnet_classifier
,
tf
.
keras
.
Model
)
self
.
assertIsInstance
(
xlnet_base
,
tf
.
keras
.
layers
.
Layer
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment