Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
7a6a8741
Commit
7a6a8741
authored
Oct 14, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Oct 14, 2020
Browse files
Implement XLNet QA model.
PiperOrigin-RevId: 337134894
parent
dd04e547
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
209 additions
and
205 deletions
+209
-205
official/nlp/modeling/models/__init__.py
official/nlp/modeling/models/__init__.py
+1
-0
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+82
-0
official/nlp/modeling/models/xlnet_test.py
official/nlp/modeling/models/xlnet_test.py
+88
-0
official/nlp/modeling/networks/span_labeling.py
official/nlp/modeling/networks/span_labeling.py
+26
-2
official/nlp/modeling/networks/span_labeling_test.py
official/nlp/modeling/networks/span_labeling_test.py
+12
-9
official/nlp/xlnet/xlnet_models.py
official/nlp/xlnet/xlnet_models.py
+0
-122
official/nlp/xlnet/xlnet_models_test.py
official/nlp/xlnet/xlnet_models_test.py
+0
-72
No files found.
official/nlp/modeling/models/__init__.py
View file @
7a6a8741
...
@@ -21,3 +21,4 @@ from official.nlp.modeling.models.dual_encoder import DualEncoder
...
@@ -21,3 +21,4 @@ from official.nlp.modeling.models.dual_encoder import DualEncoder
from
official.nlp.modeling.models.electra_pretrainer
import
ElectraPretrainer
from
official.nlp.modeling.models.electra_pretrainer
import
ElectraPretrainer
from
official.nlp.modeling.models.seq2seq_transformer
import
*
from
official.nlp.modeling.models.seq2seq_transformer
import
*
from
official.nlp.modeling.models.xlnet
import
XLNetClassifier
from
official.nlp.modeling.models.xlnet
import
XLNetClassifier
from
official.nlp.modeling.models.xlnet
import
XLNetSpanLabeler
official/nlp/modeling/models/xlnet.py
View file @
7a6a8741
...
@@ -20,6 +20,7 @@ from typing import Any, Mapping, Union
...
@@ -20,6 +20,7 @@ from typing import Any, Mapping, Union
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
layers
from
official.nlp.modeling
import
networks
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
...
@@ -98,3 +99,84 @@ class XLNetClassifier(tf.keras.Model):
...
@@ -98,3 +99,84 @@ class XLNetClassifier(tf.keras.Model):
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
return
cls
(
**
config
)
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'Text'
)
class
XLNetSpanLabeler
(
tf
.
keras
.
Model
):
"""Span labeler model based on XLNet.
This is an implementation of the network structure surrounding a
Transformer-XL encoder as described in "XLNet: Generalized Autoregressive
Pretraining for Language Understanding" (https://arxiv.org/abs/1906.08237).
Arguments:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
start_n_top: Beam size for span start.
end_n_top: Beam size for span end.
dropout_rate: The dropout rate for the span labeling layer.
span_labeling_activation
initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer.
"""
def
__init__
(
self
,
network
:
Union
[
tf
.
keras
.
layers
.
Layer
,
tf
.
keras
.
Model
],
start_n_top
:
int
,
end_n_top
:
int
,
dropout_rate
:
float
,
span_labeling_activation
:
tf
.
keras
.
initializers
.
Initializer
=
'tanh'
,
initializer
:
tf
.
keras
.
initializers
.
Initializer
=
'glorot_uniform'
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_config
=
{
'network'
:
network
,
'start_n_top'
:
start_n_top
,
'end_n_top'
:
end_n_top
,
'dropout_rate'
:
dropout_rate
,
'span_labeling_activation'
:
span_labeling_activation
,
'initializer'
:
initializer
,
}
self
.
_network
=
network
self
.
_initializer
=
initializer
self
.
_start_n_top
=
start_n_top
self
.
_end_n_top
=
end_n_top
self
.
_dropout_rate
=
dropout_rate
self
.
_activation
=
span_labeling_activation
self
.
span_labeling
=
networks
.
XLNetSpanLabeling
(
input_width
=
network
.
get_config
()[
'inner_size'
],
start_n_top
=
self
.
_start_n_top
,
end_n_top
=
self
.
_end_n_top
,
activation
=
self
.
_activation
,
dropout_rate
=
self
.
_dropout_rate
,
initializer
=
self
.
_initializer
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_ids
=
inputs
[
'input_ids'
]
segment_ids
=
inputs
[
'segment_ids'
]
input_mask
=
inputs
[
'input_mask'
]
class_index
=
tf
.
reshape
(
inputs
[
'class_index'
],
[
-
1
])
position_mask
=
inputs
[
'position_mask'
]
start_positions
=
inputs
[
'start_positions'
]
attention_output
,
new_states
=
self
.
_network
(
input_ids
=
input_ids
,
segment_ids
=
segment_ids
,
input_mask
=
input_mask
)
outputs
=
self
.
span_labeling
(
sequence_data
=
attention_output
,
class_index
=
class_index
,
position_mask
=
position_mask
,
start_positions
=
start_positions
)
return
outputs
,
new_states
def
get_config
(
self
):
return
self
.
_config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/nlp/modeling/models/xlnet_test.py
View file @
7a6a8741
...
@@ -133,5 +133,93 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
...
@@ -133,5 +133,93 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
new_xlnet_trainer_model
.
get_config
())
new_xlnet_trainer_model
.
get_config
())
@
keras_parameterized
.
run_all_keras_modes
class
XLNetSpanLabelerTest
(
keras_parameterized
.
TestCase
):
@
parameterized
.
parameters
(
1
,
2
)
def
test_xlnet_trainer
(
self
,
top_n
):
"""Validate that the Keras object can be created."""
seq_length
=
4
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base
=
_get_xlnet_base
()
# Create an XLNet trainer with the created network.
xlnet_trainer_model
=
xlnet
.
XLNetSpanLabeler
(
network
=
xlnet_base
,
start_n_top
=
top_n
,
end_n_top
=
top_n
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.1
),
span_labeling_activation
=
'tanh'
,
dropout_rate
=
0.1
)
inputs
=
dict
(
input_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
),
segment_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'segment_ids'
),
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'input_mask'
),
position_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'position_mask'
),
class_index
=
tf
.
keras
.
layers
.
Input
(
shape
=
(),
dtype
=
tf
.
int32
,
name
=
'class_index'
),
start_positions
=
tf
.
keras
.
layers
.
Input
(
shape
=
(),
dtype
=
tf
.
int32
,
name
=
'start_positions'
))
outputs
,
_
=
xlnet_trainer_model
(
inputs
)
self
.
assertIsInstance
(
outputs
,
dict
)
# Test tensor value calls for the created model.
batch_size
=
2
sequence_shape
=
(
batch_size
,
seq_length
)
inputs
=
dict
(
input_ids
=
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
segment_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_mask
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
).
astype
(
'float32'
),
position_mask
=
np
.
random
.
randint
(
1
,
size
=
(
sequence_shape
)).
astype
(
'float32'
),
class_index
=
np
.
random
.
randint
(
1
,
size
=
(
batch_size
)).
astype
(
'uint8'
),
start_positions
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,),
maxval
=
5
,
dtype
=
tf
.
int32
))
outputs
,
_
=
xlnet_trainer_model
(
inputs
)
expected_inference_keys
=
{
'start_top_log_probs'
,
'end_top_log_probs'
,
'class_logits'
,
'start_top_index'
,
'end_top_index'
,
}
self
.
assertSetEqual
(
expected_inference_keys
,
set
(
outputs
.
keys
()))
outputs
,
_
=
xlnet_trainer_model
(
inputs
,
training
=
True
)
self
.
assertIsInstance
(
outputs
,
dict
)
expected_train_keys
=
{
'start_log_probs'
,
'end_log_probs'
,
'class_logits'
}
self
.
assertSetEqual
(
expected_train_keys
,
set
(
outputs
.
keys
()))
self
.
assertIsInstance
(
outputs
,
dict
)
def
test_serialize_deserialize
(
self
):
"""Validates that the XLNet trainer can be serialized and deserialized."""
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base
=
_get_xlnet_base
()
# Create an XLNet trainer with the created network.
xlnet_trainer_model
=
xlnet
.
XLNetSpanLabeler
(
network
=
xlnet_base
,
start_n_top
=
2
,
end_n_top
=
2
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.1
),
span_labeling_activation
=
'tanh'
,
dropout_rate
=
0.1
)
# Create another XLNet trainer via serialization and deserialization.
config
=
xlnet_trainer_model
.
get_config
()
new_xlnet_trainer_model
=
xlnet
.
XLNetSpanLabeler
.
from_config
(
config
)
# Validate that the config can be forced to JSON.
_
=
new_xlnet_trainer_model
.
to_json
()
# If serialization was successful, then the new config should match the old.
self
.
assertAllEqual
(
xlnet_trainer_model
.
get_config
(),
new_xlnet_trainer_model
.
get_config
())
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/nlp/modeling/networks/span_labeling.py
View file @
7a6a8741
...
@@ -113,6 +113,9 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
...
@@ -113,6 +113,9 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
positions, and then uses either the true start positions (if training) or
positions, and then uses either the true start positions (if training) or
beam search to predict the end positions.
beam search to predict the end positions.
**Note: `compute_with_beam_search` will not work with the Functional API
(https://www.tensorflow.org/guide/keras/functional).
Arguments:
Arguments:
input_width: The innermost dimension of the input tensor to this network.
input_width: The innermost dimension of the input tensor to this network.
start_n_top: Beam size for span start.
start_n_top: Beam size for span start.
...
@@ -150,6 +153,7 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
...
@@ -150,6 +153,7 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
self
.
end_logits_inner_dense
=
tf
.
keras
.
layers
.
Dense
(
self
.
end_logits_inner_dense
=
tf
.
keras
.
layers
.
Dense
(
units
=
input_width
,
units
=
input_width
,
kernel_initializer
=
initializer
,
kernel_initializer
=
initializer
,
activation
=
activation
,
name
=
'predictions/transform/end_logits/inner'
)
name
=
'predictions/transform/end_logits/inner'
)
self
.
end_logits_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
self
.
end_logits_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
axis
=-
1
,
epsilon
=
1e-12
,
axis
=-
1
,
epsilon
=
1e-12
,
...
@@ -172,13 +176,33 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
...
@@ -172,13 +176,33 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
name
=
'predictions/transform/answer_logits/output'
)
name
=
'predictions/transform/answer_logits/output'
)
def
end_logits
(
self
,
inputs
):
def
end_logits
(
self
,
inputs
):
"""Computes the end logits."""
"""Computes the end logits.
Input shapes into the inner, layer norm, output layers should match.
During training, inputs shape should be
[batch_size, seq_length, input_width].
During inference, input shapes should be
[batch_size, seq_length, start_n_top, input_width].
Args:
inputs: The input for end logits.
Returns:
Calculated end logits.
"""
if
len
(
tf
.
shape
(
inputs
))
==
3
:
# inputs: [B, S, H] -> [B, S, 1, H]
inputs
=
tf
.
expand_dims
(
inputs
,
axis
=
2
)
end_logits
=
self
.
end_logits_inner_dense
(
inputs
)
end_logits
=
self
.
end_logits_inner_dense
(
inputs
)
end_logits
=
self
.
end_logits_layer_norm
(
end_logits
)
end_logits
=
self
.
end_logits_layer_norm
(
end_logits
)
end_logits
=
self
.
end_logits_output_dense
(
end_logits
)
end_logits
=
self
.
end_logits_output_dense
(
end_logits
)
end_logits
=
tf
.
squeeze
(
end_logits
)
end_logits
=
tf
.
squeeze
(
end_logits
)
if
tf
.
rank
(
end_logits
)
>
2
:
if
tf
.
rank
(
end_logits
)
>
2
:
# shape = [
batch_size, seq_length, start_n_top
]
# shape = [
B, S, K] -> [B, K, S
]
end_logits
=
tf
.
transpose
(
end_logits
,
[
0
,
2
,
1
])
end_logits
=
tf
.
transpose
(
end_logits
,
[
0
,
2
,
1
])
return
end_logits
return
end_logits
...
...
official/nlp/modeling/networks/span_labeling_test.py
View file @
7a6a8741
...
@@ -234,8 +234,8 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
...
@@ -234,8 +234,8 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
}
}
self
.
assertSetEqual
(
expected_keys
,
set
(
output
.
keys
()))
self
.
assertSetEqual
(
expected_keys
,
set
(
output
.
keys
()))
def
test_
functional_model
_invocation
(
self
):
def
test_
subclass
_invocation
(
self
):
"""Tests basic invocation of this layer wrapped
by
a
Functional model
."""
"""Tests basic invocation of this layer wrapped
in
a
subclass
."""
seq_length
=
8
seq_length
=
8
hidden_size
=
4
hidden_size
=
4
batch_size
=
2
batch_size
=
2
...
@@ -244,7 +244,7 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
...
@@ -244,7 +244,7 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
)
class_index
=
tf
.
keras
.
Input
(
shape
=
(),
dtype
=
tf
.
uint8
)
class_index
=
tf
.
keras
.
Input
(
shape
=
(),
dtype
=
tf
.
uint8
)
position_mask
=
tf
.
keras
.
Input
(
shape
=
(
seq_length
),
dtype
=
tf
.
float32
)
position_mask
=
tf
.
keras
.
Input
(
shape
=
(
seq_length
),
dtype
=
tf
.
float32
)
start_positions
=
tf
.
keras
.
Input
(
shape
=
(),
dtype
=
tf
.
floa
t32
)
start_positions
=
tf
.
keras
.
Input
(
shape
=
(),
dtype
=
tf
.
in
t32
)
layer
=
span_labeling
.
XLNetSpanLabeling
(
layer
=
span_labeling
.
XLNetSpanLabeling
(
input_width
=
hidden_size
,
input_width
=
hidden_size
,
...
@@ -272,7 +272,8 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
...
@@ -272,7 +272,8 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
position_mask
=
tf
.
random
.
uniform
(
position_mask
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,
seq_length
),
dtype
=
tf
.
float32
)
shape
=
(
batch_size
,
seq_length
),
dtype
=
tf
.
float32
)
class_index
=
tf
.
ones
(
shape
=
(
batch_size
,),
dtype
=
tf
.
uint8
)
class_index
=
tf
.
ones
(
shape
=
(
batch_size
,),
dtype
=
tf
.
uint8
)
start_positions
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,),
dtype
=
tf
.
float32
)
start_positions
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,),
maxval
=
5
,
dtype
=
tf
.
int32
)
inputs
=
dict
(
sequence_data
=
sequence_data
,
inputs
=
dict
(
sequence_data
=
sequence_data
,
position_mask
=
position_mask
,
position_mask
=
position_mask
,
...
@@ -282,14 +283,16 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
...
@@ -282,14 +283,16 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
output
=
model
(
inputs
)
output
=
model
(
inputs
)
self
.
assertIsInstance
(
output
,
dict
)
self
.
assertIsInstance
(
output
,
dict
)
# Test `call` with training flag.
output
=
model
.
call
(
inputs
,
training
=
True
)
self
.
assertIsInstance
(
output
,
dict
)
# Test `call` without training flag.
# Test `call` without training flag.
output
=
model
.
call
(
inputs
,
training
=
False
)
output
=
model
(
inputs
,
training
=
False
)
self
.
assertIsInstance
(
output
,
dict
)
self
.
assertIsInstance
(
output
,
dict
)
# Test `call` with training flag.
# Note: this fails due to incompatibility with the functional API.
with
self
.
assertRaisesRegexp
(
AssertionError
,
'Could not compute output KerasTensor'
):
model
(
inputs
,
training
=
True
)
def
test_serialize_deserialize
(
self
):
def
test_serialize_deserialize
(
self
):
# Create a network object that sets all of its config options.
# Create a network object that sets all of its config options.
network
=
span_labeling
.
XLNetSpanLabeling
(
network
=
span_labeling
.
XLNetSpanLabeling
(
...
...
official/nlp/xlnet/xlnet_models.py
deleted
100644 → 0
View file @
dd04e547
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet models that are compatible with TF 2.x."""
import
tensorflow
as
tf
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
networks
from
official.nlp.xlnet
import
xlnet_config
def
_get_initializer
(
initialization_method
:
str
,
initialization_range
:
float
,
initialization_std
:
float
)
->
tf
.
keras
.
initializers
.
Initializer
:
"""Gets variable initializer."""
if
initialization_method
==
'uniform'
:
initializer
=
tf
.
keras
.
initializers
.
RandomUniform
(
minval
=-
initialization_range
,
maxval
=
initialization_range
)
elif
initialization_method
==
'normal'
:
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
initialization_std
)
else
:
raise
ValueError
(
'Initializer {} not supported'
.
format
(
initialization_method
))
return
initializer
def
get_xlnet_base
(
model_config
:
xlnet_config
.
XLNetConfig
,
run_config
:
xlnet_config
.
RunConfig
,
attention_type
:
str
,
two_stream
:
bool
,
use_cls_mask
:
bool
)
->
tf
.
keras
.
Model
:
"""Gets an 'XLNetBase' object.
Args:
model_config: the config that defines the core XLNet model.
run_config: separate runtime configuration with extra parameters.
attention_type: the attention type for the base XLNet model, "uni" or "bi".
two_stream: whether or not to use two strema attention.
use_cls_mask: whether or not cls mask is included in the input sequences.
Returns:
An XLNetBase object.
"""
initializer
=
_get_initializer
(
initialization_method
=
run_config
.
init_method
,
initialization_range
=
run_config
.
init_range
,
initialization_std
=
run_config
.
init_std
)
kwargs
=
dict
(
vocab_size
=
model_config
.
n_token
,
num_layers
=
model_config
.
n_layer
,
hidden_size
=
model_config
.
d_model
,
num_attention_heads
=
model_config
.
n_head
,
head_size
=
model_config
.
d_head
,
inner_size
=
model_config
.
d_inner
,
dropout_rate
=
run_config
.
dropout
,
attention_dropout_rate
=
run_config
.
dropout_att
,
attention_type
=
attention_type
,
bi_data
=
run_config
.
bi_data
,
initializer
=
initializer
,
two_stream
=
two_stream
,
tie_attention_biases
=
not
model_config
.
untie_r
,
memory_length
=
run_config
.
mem_len
,
clamp_length
=
run_config
.
clamp_len
,
reuse_length
=
run_config
.
reuse_len
,
inner_activation
=
model_config
.
ff_activation
,
use_cls_mask
=
use_cls_mask
)
return
networks
.
XLNetBase
(
**
kwargs
)
def
classifier_model
(
model_config
:
xlnet_config
.
XLNetConfig
,
run_config
:
xlnet_config
.
RunConfig
,
num_labels
:
int
,
final_layer_initializer
:
tf
.
keras
.
initializers
.
Initializer
=
None
)
->
tf
.
keras
.
Model
:
"""Returns a TF2 Keras XLNet classifier model.
Construct a Keras model for predicting `num_labels` outputs from an input with
maximum sequence length `max_seq_length`.
Args:
model_config: the config that defines the core XLNet model.
run_config: separate runtime configuration with extra parameters.
num_labels: integer, the number of classes.
final_layer_initializer: Initializer for final dense layer. If `None`, then
it defaults to the one specified in `run_config`.
Returns:
Combined prediction model inputs -> (one-hot labels)
XLNet sub-model inputs -> (xlnet_outputs)
where inputs are:
(words, segments, mask, permutation mask,
target mapping, masked tokens)
"""
if
final_layer_initializer
is
not
None
:
initializer
=
final_layer_initializer
else
:
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
mean
=
0.
,
stddev
=
.
02
)
xlnet_base
=
get_xlnet_base
(
model_config
=
model_config
,
run_config
=
run_config
,
attention_type
=
'bi'
,
two_stream
=
False
,
use_cls_mask
=
False
)
return
models
.
XLNetClassifier
(
network
=
xlnet_base
,
num_classes
=
num_labels
,
dropout_rate
=
run_config
.
dropout
,
summary_type
=
'last'
,
initializer
=
initializer
),
xlnet_base
official/nlp/xlnet/xlnet_models_test.py
deleted
100644 → 0
View file @
dd04e547
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
official.nlp.xlnet
import
xlnet_config
from
official.nlp.xlnet
import
xlnet_models
class
XLNetModelsTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
XLNetModelsTest
,
self
).
setUp
()
self
.
_xlnet_test_config
=
xlnet_config
.
XLNetConfig
(
args_dict
=
dict
(
n_layer
=
2
,
d_model
=
4
,
n_head
=
1
,
d_head
=
2
,
d_inner
=
4
,
ff_activation
=
'gelu'
,
untie_r
=
True
,
n_token
=
32000
))
self
.
_run_config
=
xlnet_config
.
RunConfig
(
is_training
=
True
,
use_tpu
=
False
,
dropout
=
0.0
,
dropout_att
=
0.0
,
init_method
=
'normal'
,
init_range
=
0.1
,
init_std
=
0.02
,
mem_len
=
0
,
reuse_len
=
4
,
bi_data
=
False
,
clamp_len
=-
1
,
same_length
=
False
)
def
test_xlnet_base
(
self
):
xlnet_base
=
xlnet_models
.
get_xlnet_base
(
model_config
=
self
.
_xlnet_test_config
,
run_config
=
self
.
_run_config
,
attention_type
=
'bi'
,
two_stream
=
False
,
use_cls_mask
=
False
)
self
.
assertIsInstance
(
xlnet_base
,
tf
.
keras
.
layers
.
Layer
)
def
test_xlnet_classifier
(
self
):
xlnet_classifier
,
xlnet_base
=
xlnet_models
.
classifier_model
(
model_config
=
self
.
_xlnet_test_config
,
run_config
=
self
.
_run_config
,
num_labels
=
2
)
self
.
assertIsInstance
(
xlnet_classifier
,
tf
.
keras
.
Model
)
self
.
assertIsInstance
(
xlnet_base
,
tf
.
keras
.
layers
.
Layer
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment