Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e3a000ad
Commit
e3a000ad
authored
Nov 10, 2020
by
Allen Wang
Committed by
A. Unique TensorFlower
Nov 10, 2020
Browse files
Internal change
PiperOrigin-RevId: 341642152
parent
d0b78926
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
313 additions
and
175 deletions
+313
-175
official/nlp/data/squad_lib_sp.py
official/nlp/data/squad_lib_sp.py
+57
-59
official/nlp/modeling/models/xlnet.py
official/nlp/modeling/models/xlnet.py
+17
-14
official/nlp/modeling/models/xlnet_test.py
official/nlp/modeling/models/xlnet_test.py
+25
-21
official/nlp/modeling/networks/span_labeling.py
official/nlp/modeling/networks/span_labeling.py
+47
-38
official/nlp/modeling/networks/span_labeling_test.py
official/nlp/modeling/networks/span_labeling_test.py
+18
-25
official/nlp/modeling/networks/xlnet_base.py
official/nlp/modeling/networks/xlnet_base.py
+2
-1
official/nlp/tasks/question_answering.py
official/nlp/tasks/question_answering.py
+147
-17
No files found.
official/nlp/data/squad_lib_sp.py
View file @
e3a000ad
...
...
@@ -695,19 +695,13 @@ def postprocess_output(all_examples,
null_start_logit
=
result
.
start_logits
[
0
]
null_end_logit
=
result
.
end_logits
[
0
]
start_indexes_and_logits
=
_get_best_indexes_and_logits
(
result
=
result
,
n_best_size
=
n_best_size
,
start
=
True
,
xlnet_format
=
xlnet_format
)
end_indexes_and_logits
=
_get_best_indexes_and_logits
(
doc_offset
=
0
if
xlnet_format
else
feature
.
tokens
.
index
(
"[SEP]"
)
+
1
for
(
start_index
,
start_logit
,
end_index
,
end_logit
)
in
_get_best_indexes_and_logits
(
result
=
result
,
n_best_size
=
n_best_size
,
start
=
False
,
xlnet_format
=
xlnet_format
)
doc_offset
=
0
if
xlnet_format
else
feature
.
tokens
.
index
(
"[SEP]"
)
+
1
for
start_index
,
start_logit
in
start_indexes_and_logits
:
for
end_index
,
end_logit
in
end_indexes_and_logits
:
xlnet_format
=
xlnet_format
):
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
...
...
@@ -752,7 +746,7 @@ def postprocess_output(all_examples,
if
len
(
nbest
)
>=
n_best_size
:
break
feature
=
features
[
pred
.
feature_index
]
if
pred
.
start_index
>=
0
:
# this is a non-null prediction
if
pred
.
start_index
>=
0
or
xlnet_format
:
# this is a non-null prediction
tok_start_to_orig_index
=
feature
.
tok_start_to_orig_index
tok_end_to_orig_index
=
feature
.
tok_end_to_orig_index
start_orig_pos
=
tok_start_to_orig_index
[
pred
.
start_index
]
...
...
@@ -774,7 +768,7 @@ def postprocess_output(all_examples,
start_logit
=
pred
.
start_logit
,
end_logit
=
pred
.
end_logit
))
# if we didn't inlude the empty option in the n-best, in
l
cude it
# if we didn't inlude the empty option in the n-best, inc
l
ude it
if
version_2_with_negative
and
not
xlnet_format
:
if
""
not
in
seen_predictions
:
nbest
.
append
(
...
...
@@ -814,6 +808,11 @@ def postprocess_output(all_examples,
all_predictions
[
example
.
qas_id
]
=
nbest_json
[
0
][
"text"
]
else
:
assert
best_non_null_entry
is
not
None
if
xlnet_format
:
score_diff
=
score_null
scores_diff_json
[
example
.
qas_id
]
=
score_diff
all_predictions
[
example
.
qas_id
]
=
best_non_null_entry
.
text
else
:
# predict "" iff the null score - the score of best non-null > threshold
score_diff
=
score_null
-
best_non_null_entry
.
start_logit
-
(
best_non_null_entry
.
end_logit
)
...
...
@@ -835,28 +834,27 @@ def write_to_json_files(json_records, json_file):
def
_get_best_indexes_and_logits
(
result
,
n_best_size
,
start
=
False
,
xlnet_format
=
False
):
"""Generates the n-best indexes and logits from a list."""
if
xlnet_format
:
for
i
in
range
(
n_best_size
):
for
j
in
range
(
n_best_size
):
j_index
=
i
*
n_best_size
+
j
if
start
:
yield
result
.
start_indexes
[
i
],
result
.
start_logits
[
i
]
else
:
yield
result
.
end_indexes
[
j_index
],
result
.
end_logits
[
j_index
]
yield
(
result
.
start_indexes
[
i
],
result
.
start_logits
[
i
],
result
.
end_indexes
[
j_index
],
result
.
end_logits
[
j_index
])
else
:
if
start
:
logits
=
result
.
start_logits
else
:
logits
=
result
.
end_logits
index_and_score
=
sorted
(
enumerate
(
logits
),
start_index_and_score
=
sorted
(
enumerate
(
result
.
start_logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
end_index_and_score
=
sorted
(
enumerate
(
result
.
end_logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
for
i
in
range
(
len
(
index_and_score
)):
for
i
in
range
(
len
(
start_
index_and_score
)):
if
i
>=
n_best_size
:
break
yield
index_and_score
[
i
]
for
j
in
range
(
len
(
end_index_and_score
)):
if
j
>=
n_best_size
:
break
yield
(
start_index_and_score
[
i
][
0
],
start_index_and_score
[
i
][
1
],
end_index_and_score
[
j
][
0
],
end_index_and_score
[
j
][
1
])
def
_compute_softmax
(
scores
):
...
...
@@ -885,13 +883,12 @@ def _compute_softmax(scores):
class
FeatureWriter
(
object
):
"""Writes InputFeature to TF example file."""
def
__init__
(
self
,
filename
,
is_training
,
xlnet_format
=
False
):
def
__init__
(
self
,
filename
,
is_training
):
self
.
filename
=
filename
self
.
is_training
=
is_training
self
.
num_features
=
0
tf
.
io
.
gfile
.
makedirs
(
os
.
path
.
dirname
(
filename
))
self
.
_writer
=
tf
.
io
.
TFRecordWriter
(
filename
)
self
.
_xlnet_format
=
xlnet_format
def
process_feature
(
self
,
feature
):
"""Write a InputFeature to the TFRecordWriter as a tf.train.Example."""
...
...
@@ -907,8 +904,9 @@ class FeatureWriter(object):
features
[
"input_ids"
]
=
create_int_feature
(
feature
.
input_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
feature
.
input_mask
)
features
[
"segment_ids"
]
=
create_int_feature
(
feature
.
segment_ids
)
if
self
.
_xlnet_format
:
if
feature
.
paragraph_mask
:
features
[
"paragraph_mask"
]
=
create_int_feature
(
feature
.
paragraph_mask
)
if
feature
.
class_index
:
features
[
"class_index"
]
=
create_int_feature
([
feature
.
class_index
])
if
self
.
is_training
:
...
...
@@ -943,7 +941,7 @@ def generate_tf_record_from_json_file(input_file_path,
tokenizer
=
tokenization
.
FullSentencePieceTokenizer
(
sp_model_file
=
sp_model_file
)
train_writer
=
FeatureWriter
(
filename
=
output_path
,
is_training
=
True
,
xlnet_format
=
xlnet_format
)
filename
=
output_path
,
is_training
=
True
)
number_of_examples
=
convert_examples_to_features
(
examples
=
train_examples
,
tokenizer
=
tokenizer
,
...
...
official/nlp/modeling/models/xlnet.py
View file @
e3a000ad
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet
cls-token classifier
."""
"""XLNet
models
."""
# pylint: disable=g-classes-have-attributes
from
typing
import
Any
,
Mapping
,
Union
...
...
@@ -127,7 +127,7 @@ class XLNetSpanLabeler(tf.keras.Model):
start_n_top: Beam size for span start.
end_n_top: Beam size for span end.
dropout_rate: The dropout rate for the span labeling layer.
span_labeling_activation
span_labeling_activation
: The activation for the span labeling head.
initializer: The initializer (if any) to use in the span labeling network.
Defaults to a Glorot uniform initializer.
"""
...
...
@@ -135,9 +135,9 @@ class XLNetSpanLabeler(tf.keras.Model):
def
__init__
(
self
,
network
:
Union
[
tf
.
keras
.
layers
.
Layer
,
tf
.
keras
.
Model
],
start_n_top
:
int
,
end_n_top
:
int
,
dropout_rate
:
float
,
start_n_top
:
int
=
5
,
end_n_top
:
int
=
5
,
dropout_rate
:
float
=
0.1
,
span_labeling_activation
:
tf
.
keras
.
initializers
.
Initializer
=
'tanh'
,
initializer
:
tf
.
keras
.
initializers
.
Initializer
=
'glorot_uniform'
,
**
kwargs
):
...
...
@@ -165,24 +165,27 @@ class XLNetSpanLabeler(tf.keras.Model):
initializer
=
self
.
_initializer
)
def
call
(
self
,
inputs
:
Mapping
[
str
,
Any
]):
input_ids
=
inputs
[
'input_ids'
]
segment_ids
=
inputs
[
'
segment
_ids'
]
input_ids
=
inputs
[
'input_
word_
ids'
]
segment_ids
=
inputs
[
'
input_type
_ids'
]
input_mask
=
inputs
[
'input_mask'
]
class_index
=
inputs
[
'class_index'
]
paragraph_mask
=
inputs
[
'paragraph_mask'
]
start_positions
=
inputs
.
get
(
'start_positions'
,
None
)
class_index
=
tf
.
reshape
(
inputs
[
'class_index'
],
[
-
1
])
position_mask
=
inputs
[
'position_mask'
]
start_positions
=
inputs
[
'start_positions'
]
attention_output
,
new_states
=
self
.
_network
(
attention_output
,
_
=
self
.
_network
(
input_ids
=
input_ids
,
segment_ids
=
segment_ids
,
input_mask
=
input_mask
)
outputs
=
self
.
span_labeling
(
sequence_data
=
attention_output
,
class_index
=
class_index
,
p
osition_mask
=
position
_mask
,
p
aragraph_mask
=
paragraph
_mask
,
start_positions
=
start_positions
)
return
outputs
,
new_states
return
outputs
@
property
def
checkpoint_items
(
self
):
return
dict
(
encoder
=
self
.
_network
)
def
get_config
(
self
):
return
self
.
_config
...
...
official/nlp/modeling/models/xlnet_test.py
View file @
e3a000ad
...
...
@@ -137,9 +137,9 @@ class XLNetClassifierTest(keras_parameterized.TestCase):
@
keras_parameterized
.
run_all_keras_modes
class
XLNetSpanLabelerTest
(
keras_parameterized
.
TestCase
):
@
parameterized
.
parameters
(
1
,
2
)
def
test_xlnet_trainer
(
self
,
top_n
):
def
test_xlnet_trainer
(
self
):
"""Validate that the Keras object can be created."""
top_n
=
2
seq_length
=
4
# Build a simple XLNet based network to use with the XLNet trainer.
xlnet_base
=
_get_xlnet_base
()
...
...
@@ -153,46 +153,50 @@ class XLNetSpanLabelerTest(keras_parameterized.TestCase):
span_labeling_activation
=
'tanh'
,
dropout_rate
=
0.1
)
inputs
=
dict
(
input_ids
=
tf
.
keras
.
layers
.
Input
(
input_
word_
ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'input_word_ids'
),
segment
_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'
segment
_ids'
),
input_type
_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
,
name
=
'
input_type
_ids'
),
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'input_mask'
),
p
osition
_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'p
osition
_mask'
),
p
aragraph
_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
float32
,
name
=
'p
aragraph
_mask'
),
class_index
=
tf
.
keras
.
layers
.
Input
(
shape
=
(),
dtype
=
tf
.
int32
,
name
=
'class_index'
),
start_positions
=
tf
.
keras
.
layers
.
Input
(
shape
=
(),
dtype
=
tf
.
int32
,
name
=
'start_positions'
))
outputs
,
_
=
xlnet_trainer_model
(
inputs
)
outputs
=
xlnet_trainer_model
(
inputs
)
self
.
assertIsInstance
(
outputs
,
dict
)
# Test tensor value calls for the created model.
batch_size
=
2
sequence_shape
=
(
batch_size
,
seq_length
)
inputs
=
dict
(
input_ids
=
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
segment_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_word_ids
=
np
.
random
.
randint
(
10
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_type_ids
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
,
dtype
=
'int32'
),
input_mask
=
np
.
random
.
randint
(
2
,
size
=
sequence_shape
).
astype
(
'float32'
),
p
osition
_mask
=
np
.
random
.
randint
(
p
aragraph
_mask
=
np
.
random
.
randint
(
1
,
size
=
(
sequence_shape
)).
astype
(
'float32'
),
class_index
=
np
.
random
.
randint
(
1
,
size
=
(
batch_size
)).
astype
(
'uint8'
),
start_positions
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,),
maxval
=
5
,
dtype
=
tf
.
int32
))
outputs
,
_
=
xlnet_trainer_model
(
inputs
)
expected_inference
_keys
=
{
'start_
top_log_probs'
,
'end_top_log_probs'
,
'class_logit
s'
,
'
start_top_index'
,
'end_top_index
'
,
common
_keys
=
{
'start_
logits'
,
'end_logits'
,
'start_predictions'
,
'end_prediction
s'
,
'
class_logits
'
,
}
self
.
assertSetEqual
(
expected_inference_keys
,
set
(
outputs
.
keys
()))
inference_keys
=
{
'start_top_predictions'
,
'end_top_predictions'
,
'start_top_index'
,
'end_top_index'
,
}
outputs
=
xlnet_trainer_model
(
inputs
)
self
.
assertSetEqual
(
common_keys
|
inference_keys
,
set
(
outputs
.
keys
()))
outputs
,
_
=
xlnet_trainer_model
(
inputs
,
training
=
True
)
outputs
=
xlnet_trainer_model
(
inputs
,
training
=
True
)
self
.
assertIsInstance
(
outputs
,
dict
)
expected_train_keys
=
{
'start_log_probs'
,
'end_log_probs'
,
'class_logits'
}
self
.
assertSetEqual
(
expected_train_keys
,
set
(
outputs
.
keys
()))
self
.
assertSetEqual
(
common_keys
,
set
(
outputs
.
keys
()))
self
.
assertIsInstance
(
outputs
,
dict
)
def
test_serialize_deserialize
(
self
):
...
...
official/nlp/modeling/networks/span_labeling.py
View file @
e3a000ad
...
...
@@ -18,11 +18,9 @@ import collections
import
tensorflow
as
tf
def
_apply_p
osition
_mask
(
logits
,
p
osition
_mask
):
def
_apply_p
aragraph
_mask
(
logits
,
p
aragraph
_mask
):
"""Applies a position mask to calculated logits."""
if
tf
.
rank
(
logits
)
!=
tf
.
rank
(
position_mask
):
position_mask
=
position_mask
[:,
None
,
:]
masked_logits
=
logits
*
(
1
-
position_mask
)
-
1e30
*
position_mask
masked_logits
=
logits
*
(
paragraph_mask
)
-
1e30
*
(
1
-
paragraph_mask
)
return
tf
.
nn
.
log_softmax
(
masked_logits
,
-
1
),
masked_logits
...
...
@@ -137,8 +135,8 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
def
__init__
(
self
,
input_width
,
start_n_top
,
end_n_top
,
start_n_top
=
5
,
end_n_top
=
5
,
activation
=
'tanh'
,
dropout_rate
=
0.
,
initializer
=
'glorot_uniform'
,
...
...
@@ -152,6 +150,8 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
'end_n_top'
:
end_n_top
,
'dropout_rate'
:
dropout_rate
,
}
if
start_n_top
<=
1
:
raise
ValueError
(
'`start_n_top` must be greater than 1.'
)
self
.
_start_n_top
=
start_n_top
self
.
_end_n_top
=
end_n_top
self
.
start_logits_dense
=
tf
.
keras
.
layers
.
Dense
(
...
...
@@ -210,16 +210,12 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
end_logits
=
self
.
end_logits_layer_norm
(
end_logits
)
end_logits
=
self
.
end_logits_output_dense
(
end_logits
)
end_logits
=
tf
.
squeeze
(
end_logits
)
if
tf
.
rank
(
end_logits
)
>
2
:
# shape = [B, S, K] -> [B, K, S]
end_logits
=
tf
.
transpose
(
end_logits
,
[
0
,
2
,
1
])
return
end_logits
def
call
(
self
,
sequence_data
,
class_index
,
p
osition
_mask
=
None
,
p
aragraph
_mask
=
None
,
start_positions
=
None
,
training
=
False
):
"""Implements call().
...
...
@@ -234,31 +230,35 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
sequence_data: The input sequence data of shape
(batch_size, seq_length, input_width).
class_index: The class indices of the inputs of shape (batch_size,).
p
osition
_mask: Invalid position mask such as query and special symbols
p
aragraph
_mask: Invalid position mask such as query and special symbols
(e.g. PAD, SEP, CLS) of shape (batch_size,).
start_positions: The start positions of each example of shape
(batch_size,).
training: Whether or not this is the training phase.
Returns:
A dictionary with the keys 'cls_logits' and
- (if training) 'start_log_probs', 'end_log_probs'.
- (if inference/beam search) 'start_top_log_probs', 'start_top_index',
'end_top_log_probs', 'end_top_index'.
A dictionary with the keys 'start_predictions', 'end_predictions',
'start_logits', 'end_logits'.
If inference, then 'start_top_predictions', 'start_top_index',
'end_top_predictions', 'end_top_index' are also included.
"""
paragraph_mask
=
tf
.
cast
(
paragraph_mask
,
dtype
=
sequence_data
.
dtype
)
class_index
=
tf
.
reshape
(
class_index
,
[
-
1
])
seq_length
=
tf
.
shape
(
sequence_data
)[
1
]
start_logits
=
self
.
start_logits_dense
(
sequence_data
)
start_logits
=
tf
.
squeeze
(
start_logits
,
-
1
)
start_
log_prob
s
,
masked_start_logits
=
_apply_p
osition
_mask
(
start_logits
,
p
osition
_mask
)
start_
prediction
s
,
masked_start_logits
=
_apply_p
aragraph
_mask
(
start_logits
,
p
aragraph
_mask
)
compute_with_beam_search
=
not
training
or
start_positions
is
None
if
compute_with_beam_search
:
# Compute end logits using beam search.
start_top_
log_prob
s
,
start_top_index
=
tf
.
nn
.
top_k
(
start_
log_prob
s
,
k
=
self
.
_start_n_top
)
start_top_
prediction
s
,
start_top_index
=
tf
.
nn
.
top_k
(
start_
prediction
s
,
k
=
self
.
_start_n_top
)
start_index
=
tf
.
one_hot
(
start_top_index
,
depth
=
seq_length
,
axis
=-
1
,
dtype
=
tf
.
float32
)
# start_index: [batch_size, end_n_top, seq_length]
...
...
@@ -272,8 +272,13 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
[
1
,
1
,
self
.
_start_n_top
,
1
])
end_input
=
tf
.
concat
([
end_input
,
start_features
],
axis
=-
1
)
# end_input: [batch_size, seq_length, end_n_top, 2*input_width]
paragraph_mask
=
paragraph_mask
[:,
None
,
:]
end_logits
=
self
.
end_logits
(
end_input
)
# Note: this will fail if start_n_top is not >= 1.
end_logits
=
tf
.
transpose
(
end_logits
,
[
0
,
2
,
1
])
else
:
start_positions
=
tf
.
reshape
(
start_positions
,
-
1
)
start_positions
=
tf
.
reshape
(
start_positions
,
[
-
1
]
)
start_index
=
tf
.
one_hot
(
start_positions
,
depth
=
seq_length
,
axis
=-
1
,
dtype
=
tf
.
float32
)
# start_index: [batch_size, seq_length]
...
...
@@ -285,24 +290,28 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
end_input
=
tf
.
concat
([
sequence_data
,
start_features
],
axis
=-
1
)
# end_input: [batch_size, seq_length, 2*input_width]
end_logits
=
self
.
end_logits
(
end_input
)
end_log_probs
,
_
=
_apply_position_mask
(
end_logits
,
position_mask
)
output_dict
=
{}
if
training
:
output_dict
[
'start_log_probs'
]
=
start_log_probs
output_dict
[
'end_log_probs'
]
=
end_log_probs
else
:
end_top_log_probs
,
end_top_index
=
tf
.
nn
.
top_k
(
end_log_probs
,
k
=
self
.
_end_n_top
)
end_top_log_probs
=
tf
.
reshape
(
end_top_log_probs
,
end_predictions
,
masked_end_logits
=
_apply_paragraph_mask
(
end_logits
,
paragraph_mask
)
output_dict
=
dict
(
start_predictions
=
start_predictions
,
end_predictions
=
end_predictions
,
start_logits
=
masked_start_logits
,
end_logits
=
masked_end_logits
)
if
not
training
:
end_top_predictions
,
end_top_index
=
tf
.
nn
.
top_k
(
end_predictions
,
k
=
self
.
_end_n_top
)
end_top_predictions
=
tf
.
reshape
(
end_top_predictions
,
[
-
1
,
self
.
_start_n_top
*
self
.
_end_n_top
])
end_top_index
=
tf
.
reshape
(
end_top_index
,
end_top_index
=
tf
.
reshape
(
end_top_index
,
[
-
1
,
self
.
_start_n_top
*
self
.
_end_n_top
])
output_dict
[
'start_top_
log_prob
s'
]
=
start_top_
log_prob
s
output_dict
[
'start_top_
prediction
s'
]
=
start_top_
prediction
s
output_dict
[
'start_top_index'
]
=
start_top_index
output_dict
[
'end_top_
log_prob
s'
]
=
end_top_
log_prob
s
output_dict
[
'end_top_
prediction
s'
]
=
end_top_
prediction
s
output_dict
[
'end_top_index'
]
=
end_top_index
# get the representation of CLS
...
...
official/nlp/modeling/networks/span_labeling_test.py
View file @
e3a000ad
...
...
@@ -13,13 +13,6 @@
# limitations under the License.
# ==============================================================================
"""Tests for span_labeling network."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -181,39 +174,38 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
hidden_size
=
4
sequence_data
=
np
.
random
.
uniform
(
size
=
(
batch_size
,
seq_length
,
hidden_size
)).
astype
(
'float32'
)
p
osition
_mask
=
np
.
random
.
uniform
(
p
aragraph
_mask
=
np
.
random
.
uniform
(
size
=
(
batch_size
,
seq_length
)).
astype
(
'float32'
)
class_index
=
np
.
random
.
uniform
(
size
=
(
batch_size
)).
astype
(
'uint8'
)
start_positions
=
np
.
zeros
(
shape
=
(
batch_size
)).
astype
(
'uint8'
)
layer
=
span_labeling
.
XLNetSpanLabeling
(
input_width
=
hidden_size
,
start_n_top
=
1
,
end_n_top
=
1
,
start_n_top
=
2
,
end_n_top
=
2
,
activation
=
'tanh'
,
dropout_rate
=
0.
,
initializer
=
'glorot_uniform'
)
output
=
layer
(
sequence_data
=
sequence_data
,
class_index
=
class_index
,
p
osition_mask
=
position
_mask
,
p
aragraph_mask
=
paragraph
_mask
,
start_positions
=
start_positions
,
training
=
True
)
expected_keys
=
{
'start_log_probs'
,
'end_log_probs'
,
'class_logits'
,
'start_logits'
,
'end_logits'
,
'class_logits'
,
'start_predictions'
,
'end_predictions'
,
}
self
.
assertSetEqual
(
expected_keys
,
set
(
output
.
keys
()))
@
parameterized
.
named_parameters
(
(
'top_1'
,
1
),
(
'top_n'
,
5
))
def
test_basic_invocation_beam_search
(
self
,
top_n
):
def
test_basic_invocation_beam_search
(
self
):
batch_size
=
2
seq_length
=
8
hidden_size
=
4
top_n
=
5
sequence_data
=
np
.
random
.
uniform
(
size
=
(
batch_size
,
seq_length
,
hidden_size
)).
astype
(
'float32'
)
p
osition
_mask
=
np
.
random
.
uniform
(
p
aragraph
_mask
=
np
.
random
.
uniform
(
size
=
(
batch_size
,
seq_length
)).
astype
(
'float32'
)
class_index
=
np
.
random
.
uniform
(
size
=
(
batch_size
)).
astype
(
'uint8'
)
...
...
@@ -226,11 +218,12 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
initializer
=
'glorot_uniform'
)
output
=
layer
(
sequence_data
=
sequence_data
,
class_index
=
class_index
,
p
osition_mask
=
position
_mask
,
p
aragraph_mask
=
paragraph
_mask
,
training
=
False
)
expected_keys
=
{
'start_top_log_probs'
,
'end_top_log_probs'
,
'class_logits'
,
'start_top_index'
,
'end_top_index'
,
'start_top_predictions'
,
'end_top_predictions'
,
'class_logits'
,
'start_top_index'
,
'end_top_index'
,
'start_logits'
,
'end_logits'
,
'start_predictions'
,
'end_predictions'
}
self
.
assertSetEqual
(
expected_keys
,
set
(
output
.
keys
()))
...
...
@@ -243,7 +236,7 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
sequence_data
=
tf
.
keras
.
Input
(
shape
=
(
seq_length
,
hidden_size
),
dtype
=
tf
.
float32
)
class_index
=
tf
.
keras
.
Input
(
shape
=
(),
dtype
=
tf
.
uint8
)
p
osition
_mask
=
tf
.
keras
.
Input
(
shape
=
(
seq_length
),
dtype
=
tf
.
float32
)
p
aragraph
_mask
=
tf
.
keras
.
Input
(
shape
=
(
seq_length
),
dtype
=
tf
.
float32
)
start_positions
=
tf
.
keras
.
Input
(
shape
=
(),
dtype
=
tf
.
int32
)
layer
=
span_labeling
.
XLNetSpanLabeling
(
...
...
@@ -256,27 +249,27 @@ class XLNetSpanLabelingTest(keras_parameterized.TestCase):
output
=
layer
(
sequence_data
=
sequence_data
,
class_index
=
class_index
,
p
osition_mask
=
position
_mask
,
p
aragraph_mask
=
paragraph
_mask
,
start_positions
=
start_positions
)
model
=
tf
.
keras
.
Model
(
inputs
=
{
'sequence_data'
:
sequence_data
,
'class_index'
:
class_index
,
'p
osition
_mask'
:
p
osition
_mask
,
'p
aragraph
_mask'
:
p
aragraph
_mask
,
'start_positions'
:
start_positions
,
},
outputs
=
output
)
sequence_data
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,
seq_length
,
hidden_size
),
dtype
=
tf
.
float32
)
p
osition
_mask
=
tf
.
random
.
uniform
(
p
aragraph
_mask
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,
seq_length
),
dtype
=
tf
.
float32
)
class_index
=
tf
.
ones
(
shape
=
(
batch_size
,),
dtype
=
tf
.
uint8
)
start_positions
=
tf
.
random
.
uniform
(
shape
=
(
batch_size
,),
maxval
=
5
,
dtype
=
tf
.
int32
)
inputs
=
dict
(
sequence_data
=
sequence_data
,
p
osition_mask
=
position
_mask
,
p
aragraph_mask
=
paragraph
_mask
,
class_index
=
class_index
,
start_positions
=
start_positions
)
...
...
official/nlp/modeling/networks/xlnet_base.py
View file @
e3a000ad
...
...
@@ -629,6 +629,7 @@ class XLNetBase(tf.keras.layers.Layer):
"enabled. Please enable `two_stream` to enable two "
"stream attention."
)
dtype
=
input_mask
.
dtype
if
input_mask
is
not
None
else
tf
.
float32
query_attention_mask
,
content_attention_mask
=
_compute_attention_mask
(
input_mask
=
input_mask
,
permutation_mask
=
permutation_mask
,
...
...
@@ -636,7 +637,7 @@ class XLNetBase(tf.keras.layers.Layer):
seq_length
=
seq_length
,
memory_length
=
memory_length
,
batch_size
=
batch_size
,
dtype
=
tf
.
float32
)
dtype
=
dtype
)
relative_position_encoding
=
_compute_positional_encoding
(
attention_type
=
self
.
_attention_type
,
position_encoding_layer
=
self
.
position_encoding
,
...
...
official/nlp/tasks/question_answering.py
View file @
e3a000ad
...
...
@@ -14,9 +14,9 @@
# limitations under the License.
# ==============================================================================
"""Question answering task."""
import
collections
import
json
import
os
from
typing
import
List
,
Optional
from
absl
import
logging
import
dataclasses
...
...
@@ -58,6 +58,17 @@ class QuestionAnsweringConfig(cfg.TaskConfig):
validation_data
:
cfg
.
DataConfig
=
cfg
.
DataConfig
()
@
dataclasses
.
dataclass
class
RawAggregatedResult
:
"""Raw representation for SQuAD predictions."""
unique_id
:
int
start_logits
:
List
[
float
]
end_logits
:
List
[
float
]
start_indexes
:
Optional
[
List
[
int
]]
=
None
end_indexes
:
Optional
[
List
[
int
]]
=
None
class_logits
:
Optional
[
float
]
=
None
@
task_factory
.
register_task_cls
(
QuestionAnsweringConfig
)
class
QuestionAnsweringTask
(
base_task
.
Task
):
"""Task object for question answering."""
...
...
@@ -91,7 +102,6 @@ class QuestionAnsweringTask(base_task.Task):
else
:
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
encoder_cfg
=
self
.
task_config
.
model
.
encoder
.
get
()
# Currently, we only supports bert-style question answering finetuning.
return
models
.
BertSpanLabeler
(
network
=
encoder_network
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
...
...
@@ -147,6 +157,7 @@ class QuestionAnsweringTask(base_task.Task):
kwargs
[
'do_lower_case'
]
=
params
.
do_lower_case
kwargs
[
'tokenizer'
]
=
tokenization
.
FullSentencePieceTokenizer
(
sp_model_file
=
params
.
vocab_file
)
kwargs
[
'xlnet_format'
]
=
self
.
task_config
.
model
.
encoder
.
type
==
'xlnet'
elif
params
.
tokenization
==
'WordPiece'
:
kwargs
[
'tokenizer'
]
=
tokenization
.
FullTokenizer
(
vocab_file
=
params
.
vocab_file
,
do_lower_case
=
params
.
do_lower_case
)
...
...
@@ -176,7 +187,8 @@ class QuestionAnsweringTask(base_task.Task):
input_type_ids
=
dummy_ids
)
y
=
dict
(
start_positions
=
tf
.
constant
(
0
,
dtype
=
tf
.
int32
),
end_positions
=
tf
.
constant
(
1
,
dtype
=
tf
.
int32
))
end_positions
=
tf
.
constant
(
1
,
dtype
=
tf
.
int32
),
is_impossible
=
tf
.
constant
(
0
,
dtype
=
tf
.
int32
))
return
(
x
,
y
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
...
...
@@ -235,25 +247,22 @@ class QuestionAnsweringTask(base_task.Task):
}
return
logs
raw_aggregated_result
=
collections
.
namedtuple
(
'RawResult'
,
[
'unique_id'
,
'start_logits'
,
'end_logits'
])
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
assert
step_outputs
is
not
None
,
'Got no logs from self.validation_step.'
if
state
is
None
:
state
=
[]
for
unique_ids
,
start_logits
,
end_logits
in
zip
(
step_outputs
[
'unique_ids'
],
step_outputs
[
'start_logits'
],
for
outputs
in
zip
(
step_outputs
[
'unique_ids'
],
step_outputs
[
'start_logits'
],
step_outputs
[
'end_logits'
]):
u_ids
,
s_logits
,
e_logits
=
(
unique_ids
.
numpy
(),
start_logits
.
numpy
(),
end_logits
.
numpy
())
for
values
in
zip
(
u_ids
,
s_logits
,
e_logits
):
state
.
append
(
self
.
raw_a
ggregated
_r
esult
(
numpy_values
=
[
output
.
numpy
()
for
output
in
outputs
if
output
is
not
None
]
for
values
in
zip
(
*
numpy_values
):
state
.
append
(
RawA
ggregated
R
esult
(
unique_id
=
values
[
0
],
start_logits
=
values
[
1
]
.
tolist
()
,
end_logits
=
values
[
2
]
.
tolist
()
))
start_logits
=
values
[
1
],
end_logits
=
values
[
2
]))
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
...
...
@@ -299,6 +308,127 @@ class QuestionAnsweringTask(base_task.Task):
return
eval_metrics
@
dataclasses
.
dataclass
class
XLNetQuestionAnsweringConfig
(
QuestionAnsweringConfig
):
"""The config for the XLNet variation of QuestionAnswering."""
pass
@
task_factory
.
register_task_cls
(
XLNetQuestionAnsweringConfig
)
class
XLNetQuestionAnsweringTask
(
QuestionAnsweringTask
):
"""XLNet variant of the Question Answering Task.
The main differences include:
- The encoder is an `XLNetBase` class.
- The `SpanLabeling` head is an instance of `XLNetSpanLabeling` which
predicts start/end positions and impossibility score. During inference,
it predicts the top N scores and indexes.
"""
def
build_model
(
self
):
if
self
.
task_config
.
hub_module_url
and
self
.
task_config
.
init_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
if
self
.
task_config
.
hub_module_url
:
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
task_config
.
hub_module_url
)
else
:
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
encoder_cfg
=
self
.
task_config
.
model
.
encoder
.
get
()
return
models
.
XLNetSpanLabeler
(
network
=
encoder_network
,
start_n_top
=
self
.
task_config
.
n_best_size
,
end_n_top
=
self
.
task_config
.
n_best_size
,
initializer
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
encoder_cfg
.
initializer_range
))
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
start_positions
=
labels
[
'start_positions'
]
end_positions
=
labels
[
'end_positions'
]
is_impossible
=
labels
[
'is_impossible'
]
is_impossible
=
tf
.
cast
(
tf
.
reshape
(
is_impossible
,
[
-
1
]),
tf
.
float32
)
start_logits
=
model_outputs
[
'start_logits'
]
end_logits
=
model_outputs
[
'end_logits'
]
class_logits
=
model_outputs
[
'class_logits'
]
start_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
start_positions
,
start_logits
)
end_loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
end_positions
,
end_logits
)
is_impossible_loss
=
tf
.
keras
.
losses
.
binary_crossentropy
(
is_impossible
,
class_logits
,
from_logits
=
True
)
loss
=
(
tf
.
reduce_mean
(
start_loss
)
+
tf
.
reduce_mean
(
end_loss
))
/
2
loss
+=
tf
.
reduce_mean
(
is_impossible_loss
)
/
2
return
loss
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
metrics
=
dict
([(
metric
.
name
,
metric
)
for
metric
in
metrics
])
start_logits
=
model_outputs
[
'start_logits'
]
end_logits
=
model_outputs
[
'end_logits'
]
metrics
[
'start_position_accuracy'
].
update_state
(
labels
[
'start_positions'
],
start_logits
)
metrics
[
'end_position_accuracy'
].
update_state
(
labels
[
'end_positions'
],
end_logits
)
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
start_logits
=
model_outputs
[
'start_logits'
]
end_logits
=
model_outputs
[
'end_logits'
]
compiled_metrics
.
update_state
(
y_true
=
labels
,
# labels has keys 'start_positions' and 'end_positions'.
y_pred
=
{
'start_positions'
:
start_logits
,
'end_positions'
:
end_logits
,
})
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
features
,
_
=
inputs
unique_ids
=
features
.
pop
(
'unique_ids'
)
model_outputs
=
self
.
inference_step
(
features
,
model
)
start_top_predictions
=
model_outputs
[
'start_top_predictions'
]
end_top_predictions
=
model_outputs
[
'end_top_predictions'
]
start_indexes
=
model_outputs
[
'start_top_index'
]
end_indexes
=
model_outputs
[
'end_top_index'
]
class_logits
=
model_outputs
[
'class_logits'
]
logs
=
{
self
.
loss
:
0.0
,
# TODO(lehou): compute the real validation loss.
'unique_ids'
:
unique_ids
,
'start_top_predictions'
:
start_top_predictions
,
'end_top_predictions'
:
end_top_predictions
,
'start_indexes'
:
start_indexes
,
'end_indexes'
:
end_indexes
,
'class_logits'
:
class_logits
,
}
return
logs
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
):
assert
step_outputs
is
not
None
,
'Got no logs from self.validation_step.'
if
state
is
None
:
state
=
[]
for
outputs
in
zip
(
step_outputs
[
'unique_ids'
],
step_outputs
[
'start_top_predictions'
],
step_outputs
[
'end_top_predictions'
],
step_outputs
[
'start_indexes'
],
step_outputs
[
'end_indexes'
],
step_outputs
[
'class_logits'
]):
numpy_values
=
[
output
.
numpy
()
for
output
in
outputs
]
for
(
unique_id
,
start_top_predictions
,
end_top_predictions
,
start_indexes
,
end_indexes
,
class_logits
)
in
zip
(
*
numpy_values
):
state
.
append
(
RawAggregatedResult
(
unique_id
=
unique_id
,
start_logits
=
start_top_predictions
.
tolist
(),
end_logits
=
end_top_predictions
.
tolist
(),
start_indexes
=
start_indexes
.
tolist
(),
end_indexes
=
end_indexes
.
tolist
(),
class_logits
=
class_logits
))
return
state
def
predict
(
task
:
QuestionAnsweringTask
,
params
:
cfg
.
DataConfig
,
model
:
tf
.
keras
.
Model
):
"""Predicts on the input data.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment