Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8163baab
"doc/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "e4ab0ebe86d679586b9cde01d0648d3bf5d65860"
Commit
8163baab
authored
Nov 01, 2018
by
Tim Rault
Browse files
Convert indentation from 2 spaces to 4 spaces
parent
555b7d66
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
3917 additions
and
3918 deletions
+3917
-3918
create_pretraining_data.py
create_pretraining_data.py
+312
-312
extract_features.py
extract_features.py
+289
-289
modeling.py
modeling.py
+840
-840
modeling_test.py
modeling_test.py
+243
-244
optimization.py
optimization.py
+131
-131
optimization_test.py
optimization_test.py
+21
-21
run_classifier.py
run_classifier.py
+527
-527
run_pretraining.py
run_pretraining.py
+354
-354
run_squad.py
run_squad.py
+900
-900
tokenization.py
tokenization.py
+227
-227
tokenization_test.py
tokenization_test.py
+73
-73
No files found.
create_pretraining_data.py
View file @
8163baab
...
@@ -63,379 +63,379 @@ flags.DEFINE_float(
...
@@ -63,379 +63,379 @@ flags.DEFINE_float(
class
TrainingInstance
(
object
):
class
TrainingInstance
(
object
):
"""A single training instance (sentence pair)."""
"""A single training instance (sentence pair)."""
def
__init__
(
self
,
tokens
,
segment_ids
,
masked_lm_positions
,
masked_lm_labels
,
def
__init__
(
self
,
tokens
,
segment_ids
,
masked_lm_positions
,
masked_lm_labels
,
is_random_next
):
is_random_next
):
self
.
tokens
=
tokens
self
.
tokens
=
tokens
self
.
segment_ids
=
segment_ids
self
.
segment_ids
=
segment_ids
self
.
is_random_next
=
is_random_next
self
.
is_random_next
=
is_random_next
self
.
masked_lm_positions
=
masked_lm_positions
self
.
masked_lm_positions
=
masked_lm_positions
self
.
masked_lm_labels
=
masked_lm_labels
self
.
masked_lm_labels
=
masked_lm_labels
def
__str__
(
self
):
def
__str__
(
self
):
s
=
""
s
=
""
s
+=
"tokens: %s
\n
"
%
(
" "
.
join
(
s
+=
"tokens: %s
\n
"
%
(
" "
.
join
(
[
tokenization
.
printable_text
(
x
)
for
x
in
self
.
tokens
]))
[
tokenization
.
printable_text
(
x
)
for
x
in
self
.
tokens
]))
s
+=
"segment_ids: %s
\n
"
%
(
" "
.
join
([
str
(
x
)
for
x
in
self
.
segment_ids
]))
s
+=
"segment_ids: %s
\n
"
%
(
" "
.
join
([
str
(
x
)
for
x
in
self
.
segment_ids
]))
s
+=
"is_random_next: %s
\n
"
%
self
.
is_random_next
s
+=
"is_random_next: %s
\n
"
%
self
.
is_random_next
s
+=
"masked_lm_positions: %s
\n
"
%
(
" "
.
join
(
s
+=
"masked_lm_positions: %s
\n
"
%
(
" "
.
join
(
[
str
(
x
)
for
x
in
self
.
masked_lm_positions
]))
[
str
(
x
)
for
x
in
self
.
masked_lm_positions
]))
s
+=
"masked_lm_labels: %s
\n
"
%
(
" "
.
join
(
s
+=
"masked_lm_labels: %s
\n
"
%
(
" "
.
join
(
[
tokenization
.
printable_text
(
x
)
for
x
in
self
.
masked_lm_labels
]))
[
tokenization
.
printable_text
(
x
)
for
x
in
self
.
masked_lm_labels
]))
s
+=
"
\n
"
s
+=
"
\n
"
return
s
return
s
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__str__
()
return
self
.
__str__
()
def
write_instance_to_example_files
(
instances
,
tokenizer
,
max_seq_length
,
def
write_instance_to_example_files
(
instances
,
tokenizer
,
max_seq_length
,
max_predictions_per_seq
,
output_files
):
max_predictions_per_seq
,
output_files
):
"""Create TF example files from `TrainingInstance`s."""
"""Create TF example files from `TrainingInstance`s."""
writers
=
[]
writers
=
[]
for
output_file
in
output_files
:
for
output_file
in
output_files
:
writers
.
append
(
tf
.
python_io
.
TFRecordWriter
(
output_file
))
writers
.
append
(
tf
.
python_io
.
TFRecordWriter
(
output_file
))
writer_index
=
0
writer_index
=
0
total_written
=
0
total_written
=
0
for
(
inst_index
,
instance
)
in
enumerate
(
instances
):
for
(
inst_index
,
instance
)
in
enumerate
(
instances
):
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
instance
.
tokens
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
instance
.
tokens
)
input_mask
=
[
1
]
*
len
(
input_ids
)
input_mask
=
[
1
]
*
len
(
input_ids
)
segment_ids
=
list
(
instance
.
segment_ids
)
segment_ids
=
list
(
instance
.
segment_ids
)
assert
len
(
input_ids
)
<=
max_seq_length
assert
len
(
input_ids
)
<=
max_seq_length
while
len
(
input_ids
)
<
max_seq_length
:
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
input_mask
.
append
(
0
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
0
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
masked_lm_positions
=
list
(
instance
.
masked_lm_positions
)
masked_lm_positions
=
list
(
instance
.
masked_lm_positions
)
masked_lm_ids
=
tokenizer
.
convert_tokens_to_ids
(
instance
.
masked_lm_labels
)
masked_lm_ids
=
tokenizer
.
convert_tokens_to_ids
(
instance
.
masked_lm_labels
)
masked_lm_weights
=
[
1.0
]
*
len
(
masked_lm_ids
)
masked_lm_weights
=
[
1.0
]
*
len
(
masked_lm_ids
)
while
len
(
masked_lm_positions
)
<
max_predictions_per_seq
:
while
len
(
masked_lm_positions
)
<
max_predictions_per_seq
:
masked_lm_positions
.
append
(
0
)
masked_lm_positions
.
append
(
0
)
masked_lm_ids
.
append
(
0
)
masked_lm_ids
.
append
(
0
)
masked_lm_weights
.
append
(
0.0
)
masked_lm_weights
.
append
(
0.0
)
next_sentence_label
=
1
if
instance
.
is_random_next
else
0
next_sentence_label
=
1
if
instance
.
is_random_next
else
0
features
=
collections
.
OrderedDict
()
features
=
collections
.
OrderedDict
()
features
[
"input_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_ids"
]
=
create_int_feature
(
input_ids
)
features
[
"input_mask"
]
=
create_int_feature
(
input_mask
)
features
[
"input_mask"
]
=
create_int_feature
(
input_mask
)
features
[
"segment_ids"
]
=
create_int_feature
(
segment_ids
)
features
[
"segment_ids"
]
=
create_int_feature
(
segment_ids
)
features
[
"masked_lm_positions"
]
=
create_int_feature
(
masked_lm_positions
)
features
[
"masked_lm_positions"
]
=
create_int_feature
(
masked_lm_positions
)
features
[
"masked_lm_ids"
]
=
create_int_feature
(
masked_lm_ids
)
features
[
"masked_lm_ids"
]
=
create_int_feature
(
masked_lm_ids
)
features
[
"masked_lm_weights"
]
=
create_float_feature
(
masked_lm_weights
)
features
[
"masked_lm_weights"
]
=
create_float_feature
(
masked_lm_weights
)
features
[
"next_sentence_labels"
]
=
create_int_feature
([
next_sentence_label
])
features
[
"next_sentence_labels"
]
=
create_int_feature
([
next_sentence_label
])
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
tf_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
features
))
writers
[
writer_index
].
write
(
tf_example
.
SerializeToString
())
writers
[
writer_index
].
write
(
tf_example
.
SerializeToString
())
writer_index
=
(
writer_index
+
1
)
%
len
(
writers
)
writer_index
=
(
writer_index
+
1
)
%
len
(
writers
)
total_written
+=
1
total_written
+=
1
if
inst_index
<
20
:
if
inst_index
<
20
:
tf
.
logging
.
info
(
"*** Example ***"
)
tf
.
logging
.
info
(
"*** Example ***"
)
tf
.
logging
.
info
(
"tokens: %s"
%
" "
.
join
(
tf
.
logging
.
info
(
"tokens: %s"
%
" "
.
join
(
[
tokenization
.
printable_text
(
x
)
for
x
in
instance
.
tokens
]))
[
tokenization
.
printable_text
(
x
)
for
x
in
instance
.
tokens
]))
for
feature_name
in
features
.
keys
():
for
feature_name
in
features
.
keys
():
feature
=
features
[
feature_name
]
feature
=
features
[
feature_name
]
values
=
[]
values
=
[]
if
feature
.
int64_list
.
value
:
if
feature
.
int64_list
.
value
:
values
=
feature
.
int64_list
.
value
values
=
feature
.
int64_list
.
value
elif
feature
.
float_list
.
value
:
elif
feature
.
float_list
.
value
:
values
=
feature
.
float_list
.
value
values
=
feature
.
float_list
.
value
tf
.
logging
.
info
(
tf
.
logging
.
info
(
"%s: %s"
%
(
feature_name
,
" "
.
join
([
str
(
x
)
for
x
in
values
])))
"%s: %s"
%
(
feature_name
,
" "
.
join
([
str
(
x
)
for
x
in
values
])))
for
writer
in
writers
:
for
writer
in
writers
:
writer
.
close
()
writer
.
close
()
tf
.
logging
.
info
(
"Wrote %d total instances"
,
total_written
)
tf
.
logging
.
info
(
"Wrote %d total instances"
,
total_written
)
def
create_int_feature
(
values
):
def
create_int_feature
(
values
):
feature
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
feature
=
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
list
(
values
)))
return
feature
return
feature
def
create_float_feature
(
values
):
def
create_float_feature
(
values
):
feature
=
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
list
(
values
)))
feature
=
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
list
(
values
)))
return
feature
return
feature
def
create_training_instances
(
input_files
,
tokenizer
,
max_seq_length
,
def
create_training_instances
(
input_files
,
tokenizer
,
max_seq_length
,
dupe_factor
,
short_seq_prob
,
masked_lm_prob
,
dupe_factor
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
rng
):
max_predictions_per_seq
,
rng
):
"""Create `TrainingInstance`s from raw text."""
"""Create `TrainingInstance`s from raw text."""
all_documents
=
[[]]
all_documents
=
[[]]
# Input file format:
# Input file format:
# (1) One sentence per line. These should ideally be actual sentences, not
# (1) One sentence per line. These should ideally be actual sentences, not
# entire paragraphs or arbitrary spans of text. (Because we use the
# entire paragraphs or arbitrary spans of text. (Because we use the
# sentence boundaries for the "next sentence prediction" task).
# sentence boundaries for the "next sentence prediction" task).
# (2) Blank lines between documents. Document boundaries are needed so
# (2) Blank lines between documents. Document boundaries are needed so
# that the "next sentence prediction" task doesn't span between documents.
# that the "next sentence prediction" task doesn't span between documents.
for
input_file
in
input_files
:
for
input_file
in
input_files
:
with
tf
.
gfile
.
GFile
(
input_file
,
"r"
)
as
reader
:
with
tf
.
gfile
.
GFile
(
input_file
,
"r"
)
as
reader
:
while
True
:
while
True
:
line
=
tokenization
.
convert_to_unicode
(
reader
.
readline
())
line
=
tokenization
.
convert_to_unicode
(
reader
.
readline
())
if
not
line
:
if
not
line
:
break
break
line
=
line
.
strip
()
line
=
line
.
strip
()
# Empty lines are used as document delimiters
# Empty lines are used as document delimiters
if
not
line
:
if
not
line
:
all_documents
.
append
([])
all_documents
.
append
([])
tokens
=
tokenizer
.
tokenize
(
line
)
tokens
=
tokenizer
.
tokenize
(
line
)
if
tokens
:
if
tokens
:
all_documents
[
-
1
].
append
(
tokens
)
all_documents
[
-
1
].
append
(
tokens
)
# Remove empty documents
# Remove empty documents
all_documents
=
[
x
for
x
in
all_documents
if
x
]
all_documents
=
[
x
for
x
in
all_documents
if
x
]
rng
.
shuffle
(
all_documents
)
rng
.
shuffle
(
all_documents
)
vocab_words
=
list
(
tokenizer
.
vocab
.
keys
())
vocab_words
=
list
(
tokenizer
.
vocab
.
keys
())
instances
=
[]
instances
=
[]
for
_
in
range
(
dupe_factor
):
for
_
in
range
(
dupe_factor
):
for
document_index
in
range
(
len
(
all_documents
)):
for
document_index
in
range
(
len
(
all_documents
)):
instances
.
extend
(
instances
.
extend
(
create_instances_from_document
(
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
))
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
))
rng
.
shuffle
(
instances
)
rng
.
shuffle
(
instances
)
return
instances
return
instances
def
create_instances_from_document
(
def
create_instances_from_document
(
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
all_documents
,
document_index
,
max_seq_length
,
short_seq_prob
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
):
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
):
"""Creates `TrainingInstance`s for a single document."""
"""Creates `TrainingInstance`s for a single document."""
document
=
all_documents
[
document_index
]
document
=
all_documents
[
document_index
]
# Account for [CLS], [SEP], [SEP]
# Account for [CLS], [SEP], [SEP]
max_num_tokens
=
max_seq_length
-
3
max_num_tokens
=
max_seq_length
-
3
# We *usually* want to fill up the entire sequence since we are padding
# We *usually* want to fill up the entire sequence since we are padding
# to `max_seq_length` anyways, so short sequences are generally wasted
# to `max_seq_length` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# The `target_seq_length` is just a rough target however, whereas
# `max_seq_length` is a hard limit.
# `max_seq_length` is a hard limit.
target_seq_length
=
max_num_tokens
target_seq_length
=
max_num_tokens
if
rng
.
random
()
<
short_seq_prob
:
if
rng
.
random
()
<
short_seq_prob
:
target_seq_length
=
rng
.
randint
(
2
,
max_num_tokens
)
target_seq_length
=
rng
.
randint
(
2
,
max_num_tokens
)
# We DON'T just concatenate all of the tokens from a document into a long
# We DON'T just concatenate all of the tokens from a document into a long
# sequence and choose an arbitrary split point because this would make the
# sequence and choose an arbitrary split point because this would make the
# next sentence prediction task too easy. Instead, we split the input into
# next sentence prediction task too easy. Instead, we split the input into
# segments "A" and "B" based on the actual "sentences" provided by the user
# segments "A" and "B" based on the actual "sentences" provided by the user
# input.
# input.
instances
=
[]
instances
=
[]
current_chunk
=
[]
current_chunk
=
[]
current_length
=
0
current_length
=
0
i
=
0
i
=
0
while
i
<
len
(
document
):
while
i
<
len
(
document
):
segment
=
document
[
i
]
segment
=
document
[
i
]
current_chunk
.
append
(
segment
)
current_chunk
.
append
(
segment
)
current_length
+=
len
(
segment
)
current_length
+=
len
(
segment
)
if
i
==
len
(
document
)
-
1
or
current_length
>=
target_seq_length
:
if
i
==
len
(
document
)
-
1
or
current_length
>=
target_seq_length
:
if
current_chunk
:
if
current_chunk
:
# `a_end` is how many segments from `current_chunk` go into the `A`
# `a_end` is how many segments from `current_chunk` go into the `A`
# (first) sentence.
# (first) sentence.
a_end
=
1
a_end
=
1
if
len
(
current_chunk
)
>=
2
:
if
len
(
current_chunk
)
>=
2
:
a_end
=
rng
.
randint
(
1
,
len
(
current_chunk
)
-
1
)
a_end
=
rng
.
randint
(
1
,
len
(
current_chunk
)
-
1
)
tokens_a
=
[]
tokens_a
=
[]
for
j
in
range
(
a_end
):
for
j
in
range
(
a_end
):
tokens_a
.
extend
(
current_chunk
[
j
])
tokens_a
.
extend
(
current_chunk
[
j
])
tokens_b
=
[]
tokens_b
=
[]
# Random next
# Random next
is_random_next
=
False
is_random_next
=
False
if
len
(
current_chunk
)
==
1
or
rng
.
random
()
<
0.5
:
if
len
(
current_chunk
)
==
1
or
rng
.
random
()
<
0.5
:
is_random_next
=
True
is_random_next
=
True
target_b_length
=
target_seq_length
-
len
(
tokens_a
)
target_b_length
=
target_seq_length
-
len
(
tokens_a
)
# This should rarely go for more than one iteration for large
# This should rarely go for more than one iteration for large
# corpora. However, just to be careful, we try to make sure that
# corpora. However, just to be careful, we try to make sure that
# the random document is not the same as the document
# the random document is not the same as the document
# we're processing.
# we're processing.
for
_
in
range
(
10
):
for
_
in
range
(
10
):
random_document_index
=
rng
.
randint
(
0
,
len
(
all_documents
)
-
1
)
random_document_index
=
rng
.
randint
(
0
,
len
(
all_documents
)
-
1
)
if
random_document_index
!=
document_index
:
if
random_document_index
!=
document_index
:
break
break
random_document
=
all_documents
[
random_document_index
]
random_document
=
all_documents
[
random_document_index
]
random_start
=
rng
.
randint
(
0
,
len
(
random_document
)
-
1
)
random_start
=
rng
.
randint
(
0
,
len
(
random_document
)
-
1
)
for
j
in
range
(
random_start
,
len
(
random_document
)):
for
j
in
range
(
random_start
,
len
(
random_document
)):
tokens_b
.
extend
(
random_document
[
j
])
tokens_b
.
extend
(
random_document
[
j
])
if
len
(
tokens_b
)
>=
target_b_length
:
if
len
(
tokens_b
)
>=
target_b_length
:
break
break
# We didn't actually use these segments so we "put them back" so
# We didn't actually use these segments so we "put them back" so
# they don't go to waste.
# they don't go to waste.
num_unused_segments
=
len
(
current_chunk
)
-
a_end
num_unused_segments
=
len
(
current_chunk
)
-
a_end
i
-=
num_unused_segments
i
-=
num_unused_segments
# Actual next
# Actual next
else
:
else
:
is_random_next
=
False
is_random_next
=
False
for
j
in
range
(
a_end
,
len
(
current_chunk
)):
for
j
in
range
(
a_end
,
len
(
current_chunk
)):
tokens_b
.
extend
(
current_chunk
[
j
])
tokens_b
.
extend
(
current_chunk
[
j
])
truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_num_tokens
,
rng
)
truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_num_tokens
,
rng
)
assert
len
(
tokens_a
)
>=
1
assert
len
(
tokens_a
)
>=
1
assert
len
(
tokens_b
)
>=
1
assert
len
(
tokens_b
)
>=
1
tokens
=
[]
tokens
=
[]
segment_ids
=
[]
segment_ids
=
[]
tokens
.
append
(
"[CLS]"
)
tokens
.
append
(
"[CLS]"
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
0
)
for
token
in
tokens_a
:
for
token
in
tokens_a
:
tokens
.
append
(
token
)
tokens
.
append
(
token
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
0
)
tokens
.
append
(
"[SEP]"
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
0
)
for
token
in
tokens_b
:
for
token
in
tokens_b
:
tokens
.
append
(
token
)
tokens
.
append
(
token
)
segment_ids
.
append
(
1
)
segment_ids
.
append
(
1
)
tokens
.
append
(
"[SEP]"
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
1
)
segment_ids
.
append
(
1
)
(
tokens
,
masked_lm_positions
,
(
tokens
,
masked_lm_positions
,
masked_lm_labels
)
=
create_masked_lm_predictions
(
masked_lm_labels
)
=
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
)
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
)
instance
=
TrainingInstance
(
instance
=
TrainingInstance
(
tokens
=
tokens
,
tokens
=
tokens
,
segment_ids
=
segment_ids
,
segment_ids
=
segment_ids
,
is_random_next
=
is_random_next
,
is_random_next
=
is_random_next
,
masked_lm_positions
=
masked_lm_positions
,
masked_lm_positions
=
masked_lm_positions
,
masked_lm_labels
=
masked_lm_labels
)
masked_lm_labels
=
masked_lm_labels
)
instances
.
append
(
instance
)
instances
.
append
(
instance
)
current_chunk
=
[]
current_chunk
=
[]
current_length
=
0
current_length
=
0
i
+=
1
i
+=
1
return
instances
return
instances
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
def
create_masked_lm_predictions
(
tokens
,
masked_lm_prob
,
max_predictions_per_seq
,
vocab_words
,
rng
):
max_predictions_per_seq
,
vocab_words
,
rng
):
"""Creates the predictis for the masked LM objective."""
"""Creates the predictis for the masked LM objective."""
cand_indexes
=
[]
cand_indexes
=
[]
for
(
i
,
token
)
in
enumerate
(
tokens
):
for
(
i
,
token
)
in
enumerate
(
tokens
):
if
token
==
"[CLS]"
or
token
==
"[SEP]"
:
if
token
==
"[CLS]"
or
token
==
"[SEP]"
:
continue
continue
cand_indexes
.
append
(
i
)
cand_indexes
.
append
(
i
)
rng
.
shuffle
(
cand_indexes
)
rng
.
shuffle
(
cand_indexes
)
output_tokens
=
list
(
tokens
)
output_tokens
=
list
(
tokens
)
masked_lm
=
collections
.
namedtuple
(
"masked_lm"
,
[
"index"
,
"label"
])
# pylint: disable=invalid-name
masked_lm
=
collections
.
namedtuple
(
"masked_lm"
,
[
"index"
,
"label"
])
# pylint: disable=invalid-name
num_to_predict
=
min
(
max_predictions_per_seq
,
num_to_predict
=
min
(
max_predictions_per_seq
,
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
max
(
1
,
int
(
round
(
len
(
tokens
)
*
masked_lm_prob
))))
masked_lms
=
[]
masked_lms
=
[]
covered_indexes
=
set
()
covered_indexes
=
set
()
for
index
in
cand_indexes
:
for
index
in
cand_indexes
:
if
len
(
masked_lms
)
>=
num_to_predict
:
if
len
(
masked_lms
)
>=
num_to_predict
:
break
break
if
index
in
covered_indexes
:
if
index
in
covered_indexes
:
continue
continue
covered_indexes
.
add
(
index
)
covered_indexes
.
add
(
index
)
masked_token
=
None
masked_token
=
None
# 80% of the time, replace with [MASK]
# 80% of the time, replace with [MASK]
if
rng
.
random
()
<
0.8
:
if
rng
.
random
()
<
0.8
:
masked_token
=
"[MASK]"
masked_token
=
"[MASK]"
else
:
else
:
# 10% of the time, keep original
# 10% of the time, keep original
if
rng
.
random
()
<
0.5
:
if
rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
index
]
masked_token
=
tokens
[
index
]
# 10% of the time, replace with random word
# 10% of the time, replace with random word
else
:
else
:
masked_token
=
vocab_words
[
rng
.
randint
(
0
,
len
(
vocab_words
)
-
1
)]
masked_token
=
vocab_words
[
rng
.
randint
(
0
,
len
(
vocab_words
)
-
1
)]
output_tokens
[
index
]
=
masked_token
output_tokens
[
index
]
=
masked_token
masked_lms
.
append
(
masked_lm
(
index
=
index
,
label
=
tokens
[
index
]))
masked_lms
.
append
(
masked_lm
(
index
=
index
,
label
=
tokens
[
index
]))
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
masked_lm_positions
=
[]
masked_lm_positions
=
[]
masked_lm_labels
=
[]
masked_lm_labels
=
[]
for
p
in
masked_lms
:
for
p
in
masked_lms
:
masked_lm_positions
.
append
(
p
.
index
)
masked_lm_positions
.
append
(
p
.
index
)
masked_lm_labels
.
append
(
p
.
label
)
masked_lm_labels
.
append
(
p
.
label
)
return
(
output_tokens
,
masked_lm_positions
,
masked_lm_labels
)
return
(
output_tokens
,
masked_lm_positions
,
masked_lm_labels
)
def
truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_num_tokens
,
rng
):
def
truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_num_tokens
,
rng
):
"""Truncates a pair of sequences to a maximum sequence length."""
"""Truncates a pair of sequences to a maximum sequence length."""
while
True
:
while
True
:
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
if
total_length
<=
max_num_tokens
:
if
total_length
<=
max_num_tokens
:
break
break
trunc_tokens
=
tokens_a
if
len
(
tokens_a
)
>
len
(
tokens_b
)
else
tokens_b
trunc_tokens
=
tokens_a
if
len
(
tokens_a
)
>
len
(
tokens_b
)
else
tokens_b
assert
len
(
trunc_tokens
)
>=
1
assert
len
(
trunc_tokens
)
>=
1
# We want to sometimes truncate from the front and sometimes from the
# We want to sometimes truncate from the front and sometimes from the
# back to add more randomness and avoid biases.
# back to add more randomness and avoid biases.
if
rng
.
random
()
<
0.5
:
if
rng
.
random
()
<
0.5
:
del
trunc_tokens
[
0
]
del
trunc_tokens
[
0
]
else
:
else
:
trunc_tokens
.
pop
()
trunc_tokens
.
pop
()
def
main
(
_
):
def
main
(
_
):
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tokenizer
=
tokenization
.
FullTokenizer
(
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
input_files
=
[]
input_files
=
[]
for
input_pattern
in
FLAGS
.
input_file
.
split
(
","
):
for
input_pattern
in
FLAGS
.
input_file
.
split
(
","
):
input_files
.
extend
(
tf
.
gfile
.
Glob
(
input_pattern
))
input_files
.
extend
(
tf
.
gfile
.
Glob
(
input_pattern
))
tf
.
logging
.
info
(
"*** Reading from input files ***"
)
tf
.
logging
.
info
(
"*** Reading from input files ***"
)
for
input_file
in
input_files
:
for
input_file
in
input_files
:
tf
.
logging
.
info
(
" %s"
,
input_file
)
tf
.
logging
.
info
(
" %s"
,
input_file
)
rng
=
random
.
Random
(
FLAGS
.
random_seed
)
rng
=
random
.
Random
(
FLAGS
.
random_seed
)
instances
=
create_training_instances
(
instances
=
create_training_instances
(
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
input_files
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
dupe_factor
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
FLAGS
.
short_seq_prob
,
FLAGS
.
masked_lm_prob
,
FLAGS
.
max_predictions_per_seq
,
rng
)
rng
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
output_files
=
FLAGS
.
output_file
.
split
(
","
)
tf
.
logging
.
info
(
"*** Writing to output files ***"
)
tf
.
logging
.
info
(
"*** Writing to output files ***"
)
for
output_file
in
output_files
:
for
output_file
in
output_files
:
tf
.
logging
.
info
(
" %s"
,
output_file
)
tf
.
logging
.
info
(
" %s"
,
output_file
)
write_instance_to_example_files
(
instances
,
tokenizer
,
FLAGS
.
max_seq_length
,
write_instance_to_example_files
(
instances
,
tokenizer
,
FLAGS
.
max_seq_length
,
FLAGS
.
max_predictions_per_seq
,
output_files
)
FLAGS
.
max_predictions_per_seq
,
output_files
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
flags
.
mark_flag_as_required
(
"input_file"
)
flags
.
mark_flag_as_required
(
"input_file"
)
flags
.
mark_flag_as_required
(
"output_file"
)
flags
.
mark_flag_as_required
(
"output_file"
)
flags
.
mark_flag_as_required
(
"vocab_file"
)
flags
.
mark_flag_as_required
(
"vocab_file"
)
tf
.
app
.
run
()
tf
.
app
.
run
()
extract_features.py
View file @
8163baab
...
@@ -80,330 +80,330 @@ flags.DEFINE_bool(
...
@@ -80,330 +80,330 @@ flags.DEFINE_bool(
class
InputExample
(
object
):
class
InputExample
(
object
):
def
__init__
(
self
,
unique_id
,
text_a
,
text_b
):
def
__init__
(
self
,
unique_id
,
text_a
,
text_b
):
self
.
unique_id
=
unique_id
self
.
unique_id
=
unique_id
self
.
text_a
=
text_a
self
.
text_a
=
text_a
self
.
text_b
=
text_b
self
.
text_b
=
text_b
class
InputFeatures
(
object
):
class
InputFeatures
(
object
):
"""A single set of features of data."""
"""A single set of features of data."""
def
__init__
(
self
,
unique_id
,
tokens
,
input_ids
,
input_mask
,
input_type_ids
):
def
__init__
(
self
,
unique_id
,
tokens
,
input_ids
,
input_mask
,
input_type_ids
):
self
.
unique_id
=
unique_id
self
.
unique_id
=
unique_id
self
.
tokens
=
tokens
self
.
tokens
=
tokens
self
.
input_ids
=
input_ids
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
input_mask
=
input_mask
self
.
input_type_ids
=
input_type_ids
self
.
input_type_ids
=
input_type_ids
def
input_fn_builder
(
features
,
seq_length
):
def
input_fn_builder
(
features
,
seq_length
):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
all_unique_ids
=
[]
all_unique_ids
=
[]
all_input_ids
=
[]
all_input_ids
=
[]
all_input_mask
=
[]
all_input_mask
=
[]
all_input_type_ids
=
[]
all_input_type_ids
=
[]
for
feature
in
features
:
for
feature
in
features
:
all_unique_ids
.
append
(
feature
.
unique_id
)
all_unique_ids
.
append
(
feature
.
unique_id
)
all_input_ids
.
append
(
feature
.
input_ids
)
all_input_ids
.
append
(
feature
.
input_ids
)
all_input_mask
.
append
(
feature
.
input_mask
)
all_input_mask
.
append
(
feature
.
input_mask
)
all_input_type_ids
.
append
(
feature
.
input_type_ids
)
all_input_type_ids
.
append
(
feature
.
input_type_ids
)
def
input_fn
(
params
):
def
input_fn
(
params
):
"""The actual input function."""
"""The actual input function."""
batch_size
=
params
[
"batch_size"
]
batch_size
=
params
[
"batch_size"
]
num_examples
=
len
(
features
)
num_examples
=
len
(
features
)
# This is for demo purposes and does NOT scale to large data sets. We do
# This is for demo purposes and does NOT scale to large data sets. We do
# not use Dataset.from_generator() because that uses tf.py_func which is
# not use Dataset.from_generator() because that uses tf.py_func which is
# not TPU compatible. The right way to load data is with TFRecordReader.
# not TPU compatible. The right way to load data is with TFRecordReader.
d
=
tf
.
data
.
Dataset
.
from_tensor_slices
({
d
=
tf
.
data
.
Dataset
.
from_tensor_slices
({
"unique_ids"
:
"unique_ids"
:
tf
.
constant
(
all_unique_ids
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
),
tf
.
constant
(
all_unique_ids
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
),
"input_ids"
:
"input_ids"
:
tf
.
constant
(
tf
.
constant
(
all_input_ids
,
shape
=
[
num_examples
,
seq_length
],
all_input_ids
,
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
"input_mask"
:
"input_mask"
:
tf
.
constant
(
tf
.
constant
(
all_input_mask
,
all_input_mask
,
shape
=
[
num_examples
,
seq_length
],
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
"input_type_ids"
:
"input_type_ids"
:
tf
.
constant
(
tf
.
constant
(
all_input_type_ids
,
all_input_type_ids
,
shape
=
[
num_examples
,
seq_length
],
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
})
})
d
=
d
.
batch
(
batch_size
=
batch_size
,
drop_remainder
=
False
)
d
=
d
.
batch
(
batch_size
=
batch_size
,
drop_remainder
=
False
)
return
d
return
d
return
input_fn
return
input_fn
def
model_fn_builder
(
bert_config
,
init_checkpoint
,
layer_indexes
,
use_tpu
,
def
model_fn_builder
(
bert_config
,
init_checkpoint
,
layer_indexes
,
use_tpu
,
use_one_hot_embeddings
):
use_one_hot_embeddings
):
"""Returns `model_fn` closure for TPUEstimator."""
"""Returns `model_fn` closure for TPUEstimator."""
def
model_fn
(
features
,
labels
,
mode
,
params
):
# pylint: disable=unused-argument
def
model_fn
(
features
,
labels
,
mode
,
params
):
# pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
"""The `model_fn` for TPUEstimator."""
unique_ids
=
features
[
"unique_ids"
]
unique_ids
=
features
[
"unique_ids"
]
input_ids
=
features
[
"input_ids"
]
input_ids
=
features
[
"input_ids"
]
input_mask
=
features
[
"input_mask"
]
input_mask
=
features
[
"input_mask"
]
input_type_ids
=
features
[
"input_type_ids"
]
input_type_ids
=
features
[
"input_type_ids"
]
model
=
modeling
.
BertModel
(
model
=
modeling
.
BertModel
(
config
=
bert_config
,
config
=
bert_config
,
is_training
=
False
,
is_training
=
False
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
token_type_ids
=
input_type_ids
,
token_type_ids
=
input_type_ids
,
use_one_hot_embeddings
=
use_one_hot_embeddings
)
use_one_hot_embeddings
=
use_one_hot_embeddings
)
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
if
mode
!=
tf
.
estimator
.
ModeKeys
.
PREDICT
:
raise
ValueError
(
"Only PREDICT modes are supported: %s"
%
(
mode
))
raise
ValueError
(
"Only PREDICT modes are supported: %s"
%
(
mode
))
tvars
=
tf
.
trainable_variables
()
tvars
=
tf
.
trainable_variables
()
scaffold_fn
=
None
scaffold_fn
=
None
(
assignment_map
,
_
)
=
modeling
.
get_assigment_map_from_checkpoint
(
(
assignment_map
,
_
)
=
modeling
.
get_assigment_map_from_checkpoint
(
tvars
,
init_checkpoint
)
tvars
,
init_checkpoint
)
if
use_tpu
:
if
use_tpu
:
def
tpu_scaffold
():
def
tpu_scaffold
():
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
return
tf
.
train
.
Scaffold
()
return
tf
.
train
.
Scaffold
()
scaffold_fn
=
tpu_scaffold
scaffold_fn
=
tpu_scaffold
else
:
else
:
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
all_layers
=
model
.
get_all_encoder_layers
()
all_layers
=
model
.
get_all_encoder_layers
()
predictions
=
{
predictions
=
{
"unique_id"
:
unique_ids
,
"unique_id"
:
unique_ids
,
}
}
for
(
i
,
layer_index
)
in
enumerate
(
layer_indexes
):
for
(
i
,
layer_index
)
in
enumerate
(
layer_indexes
):
predictions
[
"layer_output_%d"
%
i
]
=
all_layers
[
layer_index
]
predictions
[
"layer_output_%d"
%
i
]
=
all_layers
[
layer_index
]
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
,
scaffold_fn
=
scaffold_fn
)
mode
=
mode
,
predictions
=
predictions
,
scaffold_fn
=
scaffold_fn
)
return
output_spec
return
output_spec
return
model_fn
return
model_fn
def
convert_examples_to_features
(
examples
,
seq_length
,
tokenizer
):
def
convert_examples_to_features
(
examples
,
seq_length
,
tokenizer
):
"""Loads a data file into a list of `InputBatch`s."""
"""Loads a data file into a list of `InputBatch`s."""
features
=
[]
features
=
[]
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
tokens_a
=
tokenizer
.
tokenize
(
example
.
text_a
)
tokens_a
=
tokenizer
.
tokenize
(
example
.
text_a
)
tokens_b
=
None
tokens_b
=
None
if
example
.
text_b
:
if
example
.
text_b
:
tokens_b
=
tokenizer
.
tokenize
(
example
.
text_b
)
tokens_b
=
tokenizer
.
tokenize
(
example
.
text_b
)
if
tokens_b
:
if
tokens_b
:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
seq_length
-
3
)
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
seq_length
-
3
)
else
:
else
:
# Account for [CLS] and [SEP] with "- 2"
# Account for [CLS] and [SEP] with "- 2"
if
len
(
tokens_a
)
>
seq_length
-
2
:
if
len
(
tokens_a
)
>
seq_length
-
2
:
tokens_a
=
tokens_a
[
0
:(
seq_length
-
2
)]
tokens_a
=
tokens_a
[
0
:(
seq_length
-
2
)]
# The convention in BERT is:
# The convention in BERT is:
# (a) For sequence pairs:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
# type_ids: 0 0 0 0 0 0 0
#
#
# Where "type_ids" are used to indicate whether this is the first
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambigiously separates the sequences, but it makes
# since the [SEP] token unambigiously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
# it easier for the model to learn the concept of sequences.
#
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
# the entire model is fine-tuned.
tokens
=
[]
tokens
=
[]
input_type_ids
=
[]
input_type_ids
=
[]
tokens
.
append
(
"[CLS]"
)
tokens
.
append
(
"[CLS]"
)
input_type_ids
.
append
(
0
)
input_type_ids
.
append
(
0
)
for
token
in
tokens_a
:
for
token
in
tokens_a
:
tokens
.
append
(
token
)
tokens
.
append
(
token
)
input_type_ids
.
append
(
0
)
input_type_ids
.
append
(
0
)
tokens
.
append
(
"[SEP]"
)
tokens
.
append
(
"[SEP]"
)
input_type_ids
.
append
(
0
)
input_type_ids
.
append
(
0
)
if
tokens_b
:
if
tokens_b
:
for
token
in
tokens_b
:
for
token
in
tokens_b
:
tokens
.
append
(
token
)
tokens
.
append
(
token
)
input_type_ids
.
append
(
1
)
input_type_ids
.
append
(
1
)
tokens
.
append
(
"[SEP]"
)
tokens
.
append
(
"[SEP]"
)
input_type_ids
.
append
(
1
)
input_type_ids
.
append
(
1
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
# tokens are attended to.
input_mask
=
[
1
]
*
len
(
input_ids
)
input_mask
=
[
1
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
# Zero-pad up to the sequence length.
while
len
(
input_ids
)
<
seq_length
:
while
len
(
input_ids
)
<
seq_length
:
input_ids
.
append
(
0
)
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
input_mask
.
append
(
0
)
input_type_ids
.
append
(
0
)
input_type_ids
.
append
(
0
)
assert
len
(
input_ids
)
==
seq_length
assert
len
(
input_ids
)
==
seq_length
assert
len
(
input_mask
)
==
seq_length
assert
len
(
input_mask
)
==
seq_length
assert
len
(
input_type_ids
)
==
seq_length
assert
len
(
input_type_ids
)
==
seq_length
if
ex_index
<
5
:
if
ex_index
<
5
:
tf
.
logging
.
info
(
"*** Example ***"
)
tf
.
logging
.
info
(
"*** Example ***"
)
tf
.
logging
.
info
(
"unique_id: %s"
%
(
example
.
unique_id
))
tf
.
logging
.
info
(
"unique_id: %s"
%
(
example
.
unique_id
))
tf
.
logging
.
info
(
"tokens: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
tokens
]))
tf
.
logging
.
info
(
"tokens: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
tokens
]))
tf
.
logging
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
tf
.
logging
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
tf
.
logging
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
tf
.
logging
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
tf
.
logging
.
info
(
tf
.
logging
.
info
(
"input_type_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_type_ids
]))
"input_type_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_type_ids
]))
features
.
append
(
features
.
append
(
InputFeatures
(
InputFeatures
(
unique_id
=
example
.
unique_id
,
unique_id
=
example
.
unique_id
,
tokens
=
tokens
,
tokens
=
tokens
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
input_type_ids
=
input_type_ids
))
input_type_ids
=
input_type_ids
))
return
features
return
features
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
"""Truncates a sequence pair in place to the maximum length."""
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
# that's truncated likely contains more information than a longer sequence.
while
True
:
while
True
:
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
if
total_length
<=
max_length
:
if
total_length
<=
max_length
:
break
break
if
len
(
tokens_a
)
>
len
(
tokens_b
):
if
len
(
tokens_a
)
>
len
(
tokens_b
):
tokens_a
.
pop
()
tokens_a
.
pop
()
else
:
else
:
tokens_b
.
pop
()
tokens_b
.
pop
()
def
read_examples
(
input_file
):
def
read_examples
(
input_file
):
"""Read a list of `InputExample`s from an input file."""
"""Read a list of `InputExample`s from an input file."""
examples
=
[]
examples
=
[]
unique_id
=
0
unique_id
=
0
with
tf
.
gfile
.
GFile
(
input_file
,
"r"
)
as
reader
:
with
tf
.
gfile
.
GFile
(
input_file
,
"r"
)
as
reader
:
while
True
:
while
True
:
line
=
tokenization
.
convert_to_unicode
(
reader
.
readline
())
line
=
tokenization
.
convert_to_unicode
(
reader
.
readline
())
if
not
line
:
if
not
line
:
break
break
line
=
line
.
strip
()
line
=
line
.
strip
()
text_a
=
None
text_a
=
None
text_b
=
None
text_b
=
None
m
=
re
.
match
(
r
"^(.*) \|\|\| (.*)$"
,
line
)
m
=
re
.
match
(
r
"^(.*) \|\|\| (.*)$"
,
line
)
if
m
is
None
:
if
m
is
None
:
text_a
=
line
text_a
=
line
else
:
else
:
text_a
=
m
.
group
(
1
)
text_a
=
m
.
group
(
1
)
text_b
=
m
.
group
(
2
)
text_b
=
m
.
group
(
2
)
examples
.
append
(
examples
.
append
(
InputExample
(
unique_id
=
unique_id
,
text_a
=
text_a
,
text_b
=
text_b
))
InputExample
(
unique_id
=
unique_id
,
text_a
=
text_a
,
text_b
=
text_b
))
unique_id
+=
1
unique_id
+=
1
return
examples
return
examples
def
main
(
_
):
def
main
(
_
):
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
layer_indexes
=
[
int
(
x
)
for
x
in
FLAGS
.
layers
.
split
(
","
)]
layer_indexes
=
[
int
(
x
)
for
x
in
FLAGS
.
layers
.
split
(
","
)]
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
tokenizer
=
tokenization
.
FullTokenizer
(
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
is_per_host
=
tf
.
contrib
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
is_per_host
=
tf
.
contrib
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
master
=
FLAGS
.
master
,
master
=
FLAGS
.
master
,
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
num_shards
=
FLAGS
.
num_tpu_cores
,
num_shards
=
FLAGS
.
num_tpu_cores
,
per_host_input_for_training
=
is_per_host
))
per_host_input_for_training
=
is_per_host
))
examples
=
read_examples
(
FLAGS
.
input_file
)
examples
=
read_examples
(
FLAGS
.
input_file
)
features
=
convert_examples_to_features
(
features
=
convert_examples_to_features
(
examples
=
examples
,
seq_length
=
FLAGS
.
max_seq_length
,
tokenizer
=
tokenizer
)
examples
=
examples
,
seq_length
=
FLAGS
.
max_seq_length
,
tokenizer
=
tokenizer
)
unique_id_to_feature
=
{}
unique_id_to_feature
=
{}
for
feature
in
features
:
for
feature
in
features
:
unique_id_to_feature
[
feature
.
unique_id
]
=
feature
unique_id_to_feature
[
feature
.
unique_id
]
=
feature
model_fn
=
model_fn_builder
(
model_fn
=
model_fn_builder
(
bert_config
=
bert_config
,
bert_config
=
bert_config
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
layer_indexes
=
layer_indexes
,
layer_indexes
=
layer_indexes
,
use_tpu
=
FLAGS
.
use_tpu
,
use_tpu
=
FLAGS
.
use_tpu
,
use_one_hot_embeddings
=
FLAGS
.
use_one_hot_embeddings
)
use_one_hot_embeddings
=
FLAGS
.
use_one_hot_embeddings
)
# If TPU is not available, this will fall back to normal Estimator on CPU
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
# or GPU.
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
use_tpu
=
FLAGS
.
use_tpu
,
use_tpu
=
FLAGS
.
use_tpu
,
model_fn
=
model_fn
,
model_fn
=
model_fn
,
config
=
run_config
,
config
=
run_config
,
predict_batch_size
=
FLAGS
.
batch_size
)
predict_batch_size
=
FLAGS
.
batch_size
)
input_fn
=
input_fn_builder
(
input_fn
=
input_fn_builder
(
features
=
features
,
seq_length
=
FLAGS
.
max_seq_length
)
features
=
features
,
seq_length
=
FLAGS
.
max_seq_length
)
with
codecs
.
getwriter
(
"utf-8"
)(
tf
.
gfile
.
Open
(
FLAGS
.
output_file
,
with
codecs
.
getwriter
(
"utf-8"
)(
tf
.
gfile
.
Open
(
FLAGS
.
output_file
,
"w"
))
as
writer
:
"w"
))
as
writer
:
for
result
in
estimator
.
predict
(
input_fn
,
yield_single_examples
=
True
):
for
result
in
estimator
.
predict
(
input_fn
,
yield_single_examples
=
True
):
unique_id
=
int
(
result
[
"unique_id"
])
unique_id
=
int
(
result
[
"unique_id"
])
feature
=
unique_id_to_feature
[
unique_id
]
feature
=
unique_id_to_feature
[
unique_id
]
output_json
=
collections
.
OrderedDict
()
output_json
=
collections
.
OrderedDict
()
output_json
[
"linex_index"
]
=
unique_id
output_json
[
"linex_index"
]
=
unique_id
all_features
=
[]
all_features
=
[]
for
(
i
,
token
)
in
enumerate
(
feature
.
tokens
):
for
(
i
,
token
)
in
enumerate
(
feature
.
tokens
):
all_layers
=
[]
all_layers
=
[]
for
(
j
,
layer_index
)
in
enumerate
(
layer_indexes
):
for
(
j
,
layer_index
)
in
enumerate
(
layer_indexes
):
layer_output
=
result
[
"layer_output_%d"
%
j
]
layer_output
=
result
[
"layer_output_%d"
%
j
]
layers
=
collections
.
OrderedDict
()
layers
=
collections
.
OrderedDict
()
layers
[
"index"
]
=
layer_index
layers
[
"index"
]
=
layer_index
layers
[
"values"
]
=
[
layers
[
"values"
]
=
[
round
(
float
(
x
),
6
)
for
x
in
layer_output
[
i
:(
i
+
1
)].
flat
round
(
float
(
x
),
6
)
for
x
in
layer_output
[
i
:(
i
+
1
)].
flat
]
]
all_layers
.
append
(
layers
)
all_layers
.
append
(
layers
)
features
=
collections
.
OrderedDict
()
features
=
collections
.
OrderedDict
()
features
[
"token"
]
=
token
features
[
"token"
]
=
token
features
[
"layers"
]
=
all_layers
features
[
"layers"
]
=
all_layers
all_features
.
append
(
features
)
all_features
.
append
(
features
)
output_json
[
"features"
]
=
all_features
output_json
[
"features"
]
=
all_features
writer
.
write
(
json
.
dumps
(
output_json
)
+
"
\n
"
)
writer
.
write
(
json
.
dumps
(
output_json
)
+
"
\n
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
flags
.
mark_flag_as_required
(
"input_file"
)
flags
.
mark_flag_as_required
(
"input_file"
)
flags
.
mark_flag_as_required
(
"vocab_file"
)
flags
.
mark_flag_as_required
(
"vocab_file"
)
flags
.
mark_flag_as_required
(
"bert_config_file"
)
flags
.
mark_flag_as_required
(
"bert_config_file"
)
flags
.
mark_flag_as_required
(
"init_checkpoint"
)
flags
.
mark_flag_as_required
(
"init_checkpoint"
)
flags
.
mark_flag_as_required
(
"output_file"
)
flags
.
mark_flag_as_required
(
"output_file"
)
tf
.
app
.
run
()
tf
.
app
.
run
()
modeling.py
View file @
8163baab
...
@@ -28,354 +28,354 @@ import tensorflow as tf
...
@@ -28,354 +28,354 @@ import tensorflow as tf
class
BertConfig
(
object
):
class
BertConfig
(
object
):
"""Configuration for `BertModel`."""
"""Configuration for `BertModel`."""
def
__init__
(
self
,
def
__init__
(
self
,
vocab_size
,
vocab_size
,
hidden_size
=
768
,
hidden_size
=
768
,
num_hidden_layers
=
12
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
num_attention_heads
=
12
,
intermediate_size
=
3072
,
intermediate_size
=
3072
,
hidden_act
=
"gelu"
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_vocab_size
=
16
,
initializer_range
=
0.02
):
initializer_range
=
0.02
):
"""Constructs BertConfig.
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_act
=
hidden_act
self
.
intermediate_size
=
intermediate_size
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
@
classmethod
def
from_dict
(
cls
,
json_object
):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config
=
BertConfig
(
vocab_size
=
None
)
for
(
key
,
value
)
in
six
.
iteritems
(
json_object
):
config
.
__dict__
[
key
]
=
value
return
config
@
classmethod
def
from_json_file
(
cls
,
json_file
):
"""Constructs a `BertConfig` from a json file of parameters."""
with
tf
.
gfile
.
GFile
(
json_file
,
"r"
)
as
reader
:
text
=
reader
.
read
()
return
cls
.
from_dict
(
json
.
loads
(
text
))
def
to_dict
(
self
):
"""Serializes this instance to a Python dictionary."""
output
=
copy
.
deepcopy
(
self
.
__dict__
)
return
output
def
to_json_string
(
self
):
"""Serializes this instance to a JSON string."""
return
json
.
dumps
(
self
.
to_dict
(),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
hidden_act
=
hidden_act
self
.
intermediate_size
=
intermediate_size
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
@
classmethod
def
from_dict
(
cls
,
json_object
):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config
=
BertConfig
(
vocab_size
=
None
)
for
(
key
,
value
)
in
six
.
iteritems
(
json_object
):
config
.
__dict__
[
key
]
=
value
return
config
@
classmethod
def
from_json_file
(
cls
,
json_file
):
"""Constructs a `BertConfig` from a json file of parameters."""
with
tf
.
gfile
.
GFile
(
json_file
,
"r"
)
as
reader
:
text
=
reader
.
read
()
return
cls
.
from_dict
(
json
.
loads
(
text
))
def
to_dict
(
self
):
"""Serializes this instance to a Python dictionary."""
output
=
copy
.
deepcopy
(
self
.
__dict__
)
return
output
def
to_json_string
(
self
):
class
BertModel
(
object
):
"""Serializes this instance to a JSON string."""
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
return
json
.
dumps
(
self
.
to_dict
(),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
Example usage:
class
BertModel
(
object
):
```python
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
# Already been converted into WordPiece token ids
input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
Example usage:
input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
```python
# Already been converted into WordPiece token ids
input_ids = tf.constant([[31, 51, 99], [15, 5, 0]])
input_mask = tf.constant([[1, 1, 1], [1, 1, 0]])
token_type_ids = tf.constant([[0, 0, 1], [0, 2, 0]])
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = modeling.BertModel(config=config, is_training=True,
input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
label_embeddings = tf.get_variable(...)
pooled_output = model.get_pooled_output()
logits = tf.matmul(pooled_output, label_embeddings)
...
```
"""
def
__init__
(
self
,
config
,
is_training
,
input_ids
,
input_mask
=
None
,
token_type_ids
=
None
,
use_one_hot_embeddings
=
True
,
scope
=
None
):
"""Constructor for BertModel.
Args:
config = modeling.BertConfig(vocab_size=32000, hidden_size=512,
config: `BertConfig` instance.
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
is_training: bool. rue for training model, false for eval model. Controls
whether dropout will be applied.
input_ids: int32 Tensor of shape [batch_size, seq_length].
input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
embeddings or tf.embedding_lookup() for the word embeddings. On the TPU,
it is must faster if this is True, on the CPU or GPU, it is faster if
this is False.
scope: (optional) variable scope. Defaults to "bert".
Raises:
model = modeling.BertModel(config=config, is_training=True,
ValueError: The config is invalid or one of the input tensor shapes
input_ids=input_ids, input_mask=input_mask, token_type_ids=token_type_ids)
is invalid.
label_embeddings = tf.get_variable(...)
pooled_output = model.get_pooled_output()
logits = tf.matmul(pooled_output, label_embeddings)
...
```
"""
"""
config
=
copy
.
deepcopy
(
config
)
if
not
is_training
:
config
.
hidden_dropout_prob
=
0.0
config
.
attention_probs_dropout_prob
=
0.0
input_shape
=
get_shape_list
(
input_ids
,
expected_rank
=
2
)
def
__init__
(
self
,
batch_size
=
input_shape
[
0
]
config
,
seq_length
=
input_shape
[
1
]
is_training
,
input_ids
,
input_mask
=
None
,
token_type_ids
=
None
,
use_one_hot_embeddings
=
True
,
scope
=
None
):
"""Constructor for BertModel.
Args:
config: `BertConfig` instance.
is_training: bool. rue for training model, false for eval model. Controls
whether dropout will be applied.
input_ids: int32 Tensor of shape [batch_size, seq_length].
input_mask: (optional) int32 Tensor of shape [batch_size, seq_length].
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
use_one_hot_embeddings: (optional) bool. Whether to use one-hot word
embeddings or tf.embedding_lookup() for the word embeddings. On the TPU,
it is must faster if this is True, on the CPU or GPU, it is faster if
this is False.
scope: (optional) variable scope. Defaults to "bert".
Raises:
ValueError: The config is invalid or one of the input tensor shapes
is invalid.
"""
config
=
copy
.
deepcopy
(
config
)
if
not
is_training
:
config
.
hidden_dropout_prob
=
0.0
config
.
attention_probs_dropout_prob
=
0.0
input_shape
=
get_shape_list
(
input_ids
,
expected_rank
=
2
)
batch_size
=
input_shape
[
0
]
seq_length
=
input_shape
[
1
]
if
input_mask
is
None
:
input_mask
=
tf
.
ones
(
shape
=
[
batch_size
,
seq_length
],
dtype
=
tf
.
int32
)
if
token_type_ids
is
None
:
token_type_ids
=
tf
.
zeros
(
shape
=
[
batch_size
,
seq_length
],
dtype
=
tf
.
int32
)
with
tf
.
variable_scope
(
"bert"
,
scope
):
with
tf
.
variable_scope
(
"embeddings"
):
# Perform embedding lookup on the word ids.
(
self
.
embedding_output
,
self
.
embedding_table
)
=
embedding_lookup
(
input_ids
=
input_ids
,
vocab_size
=
config
.
vocab_size
,
embedding_size
=
config
.
hidden_size
,
initializer_range
=
config
.
initializer_range
,
word_embedding_name
=
"word_embeddings"
,
use_one_hot_embeddings
=
use_one_hot_embeddings
)
# Add positional embeddings and token type embeddings, then layer
# normalize and perform dropout.
self
.
embedding_output
=
embedding_postprocessor
(
input_tensor
=
self
.
embedding_output
,
use_token_type
=
True
,
token_type_ids
=
token_type_ids
,
token_type_vocab_size
=
config
.
type_vocab_size
,
token_type_embedding_name
=
"token_type_embeddings"
,
use_position_embeddings
=
True
,
position_embedding_name
=
"position_embeddings"
,
initializer_range
=
config
.
initializer_range
,
max_position_embeddings
=
config
.
max_position_embeddings
,
dropout_prob
=
config
.
hidden_dropout_prob
)
with
tf
.
variable_scope
(
"encoder"
):
# This converts a 2D mask of shape [batch_size, seq_length] to a 3D
# mask of shape [batch_size, seq_length, seq_length] which is used
# for the attention scores.
attention_mask
=
create_attention_mask_from_input_mask
(
input_ids
,
input_mask
)
# Run the stacked transformer.
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
self
.
all_encoder_layers
=
transformer_model
(
input_tensor
=
self
.
embedding_output
,
attention_mask
=
attention_mask
,
hidden_size
=
config
.
hidden_size
,
num_hidden_layers
=
config
.
num_hidden_layers
,
num_attention_heads
=
config
.
num_attention_heads
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_act_fn
=
get_activation
(
config
.
hidden_act
),
hidden_dropout_prob
=
config
.
hidden_dropout_prob
,
attention_probs_dropout_prob
=
config
.
attention_probs_dropout_prob
,
initializer_range
=
config
.
initializer_range
,
do_return_all_layers
=
True
)
self
.
sequence_output
=
self
.
all_encoder_layers
[
-
1
]
# The "pooler" converts the encoded sequence tensor of shape
# [batch_size, seq_length, hidden_size] to a tensor of shape
# [batch_size, hidden_size]. This is necessary for segment-level
# (or segment-pair-level) classification tasks where we need a fixed
# dimensional representation of the segment.
with
tf
.
variable_scope
(
"pooler"
):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token. We assume that this has been pre-trained
first_token_tensor
=
tf
.
squeeze
(
self
.
sequence_output
[:,
0
:
1
,
:],
axis
=
1
)
self
.
pooled_output
=
tf
.
layers
.
dense
(
first_token_tensor
,
config
.
hidden_size
,
activation
=
tf
.
tanh
,
kernel_initializer
=
create_initializer
(
config
.
initializer_range
))
def
get_pooled_output
(
self
):
return
self
.
pooled_output
def
get_sequence_output
(
self
):
"""Gets final hidden layer of encoder.
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
to the final hidden of the transformer encoder.
"""
return
self
.
sequence_output
def
get_all_encoder_layers
(
self
):
return
self
.
all_encoder_layers
def
get_embedding_output
(
self
):
"""Gets output of the embedding lookup (i.e., input to the transformer).
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
to the output of the embedding layer, after summing the word
embeddings with the positional embeddings and the token type embeddings,
then performing layer normalization. This is the input to the transformer.
"""
return
self
.
embedding_output
def
get_embedding_table
(
self
):
return
self
.
embedding_table
if
input_mask
is
None
:
input_mask
=
tf
.
ones
(
shape
=
[
batch_size
,
seq_length
],
dtype
=
tf
.
int32
)
if
token_type_ids
is
None
:
token_type_ids
=
tf
.
zeros
(
shape
=
[
batch_size
,
seq_length
],
dtype
=
tf
.
int32
)
with
tf
.
variable_scope
(
"bert"
,
scope
):
with
tf
.
variable_scope
(
"embeddings"
):
# Perform embedding lookup on the word ids.
(
self
.
embedding_output
,
self
.
embedding_table
)
=
embedding_lookup
(
input_ids
=
input_ids
,
vocab_size
=
config
.
vocab_size
,
embedding_size
=
config
.
hidden_size
,
initializer_range
=
config
.
initializer_range
,
word_embedding_name
=
"word_embeddings"
,
use_one_hot_embeddings
=
use_one_hot_embeddings
)
# Add positional embeddings and token type embeddings, then layer
# normalize and perform dropout.
self
.
embedding_output
=
embedding_postprocessor
(
input_tensor
=
self
.
embedding_output
,
use_token_type
=
True
,
token_type_ids
=
token_type_ids
,
token_type_vocab_size
=
config
.
type_vocab_size
,
token_type_embedding_name
=
"token_type_embeddings"
,
use_position_embeddings
=
True
,
position_embedding_name
=
"position_embeddings"
,
initializer_range
=
config
.
initializer_range
,
max_position_embeddings
=
config
.
max_position_embeddings
,
dropout_prob
=
config
.
hidden_dropout_prob
)
with
tf
.
variable_scope
(
"encoder"
):
# This converts a 2D mask of shape [batch_size, seq_length] to a 3D
# mask of shape [batch_size, seq_length, seq_length] which is used
# for the attention scores.
attention_mask
=
create_attention_mask_from_input_mask
(
input_ids
,
input_mask
)
# Run the stacked transformer.
# `sequence_output` shape = [batch_size, seq_length, hidden_size].
self
.
all_encoder_layers
=
transformer_model
(
input_tensor
=
self
.
embedding_output
,
attention_mask
=
attention_mask
,
hidden_size
=
config
.
hidden_size
,
num_hidden_layers
=
config
.
num_hidden_layers
,
num_attention_heads
=
config
.
num_attention_heads
,
intermediate_size
=
config
.
intermediate_size
,
intermediate_act_fn
=
get_activation
(
config
.
hidden_act
),
hidden_dropout_prob
=
config
.
hidden_dropout_prob
,
attention_probs_dropout_prob
=
config
.
attention_probs_dropout_prob
,
initializer_range
=
config
.
initializer_range
,
do_return_all_layers
=
True
)
self
.
sequence_output
=
self
.
all_encoder_layers
[
-
1
]
# The "pooler" converts the encoded sequence tensor of shape
# [batch_size, seq_length, hidden_size] to a tensor of shape
# [batch_size, hidden_size]. This is necessary for segment-level
# (or segment-pair-level) classification tasks where we need a fixed
# dimensional representation of the segment.
with
tf
.
variable_scope
(
"pooler"
):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token. We assume that this has been pre-trained
first_token_tensor
=
tf
.
squeeze
(
self
.
sequence_output
[:,
0
:
1
,
:],
axis
=
1
)
self
.
pooled_output
=
tf
.
layers
.
dense
(
first_token_tensor
,
config
.
hidden_size
,
activation
=
tf
.
tanh
,
kernel_initializer
=
create_initializer
(
config
.
initializer_range
))
def
get_pooled_output
(
self
):
return
self
.
pooled_output
def
get_sequence_output
(
self
):
"""Gets final hidden layer of encoder.
Returns:
def
gelu
(
input_tensor
):
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
"""Gaussian Error Linear Unit.
to the final hidden of the transformer encoder.
"""
return
self
.
sequence_output
def
get_all_encoder_layers
(
self
):
This is a smoother version of the RELU.
return
self
.
all_encoder_layers
Original paper: https://arxiv.org/abs/1606.08415
def
get_embedding_output
(
self
)
:
Args
:
"""Gets output of the embedding lookup (i.e., input to the transformer)
.
input_tensor: float Tensor to perform activation
.
Returns:
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size] corresponding
`input_tensor` with the GELU activation applied.
to the output of the embedding layer, after summing the word
embeddings with the positional embeddings and the token type embeddings,
then performing layer normalization. This is the input to the transformer.
"""
"""
return
self
.
embedding_output
cdf
=
0.5
*
(
1.0
+
tf
.
erf
(
input_tensor
/
tf
.
sqrt
(
2.0
)))
return
input_tensor
*
cdf
def
get_embedding_table
(
self
):
return
self
.
embedding_table
def
gelu
(
input_tensor
):
def
get_activation
(
activation_string
):
"""Gaussian Error Linear Unit.
"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
Args:
input_tensor: float Tensor to perform activa
tion.
activation_string: String name of the activation func
tion.
Returns:
Returns:
`input_tensor` with the GELU activation applied.
A Python function corresponding to the activation function. If
"""
`activation_string` is None, empty, or "linear", this will return None.
cdf
=
0.5
*
(
1.0
+
tf
.
erf
(
input_tensor
/
tf
.
sqrt
(
2.0
)))
If `activation_string` is not a string, it will return `activation_string`.
return
input_tensor
*
cdf
Raises:
ValueError: The `activation_string` does not correspond to a known
activation.
"""
def
get_activation
(
activation_string
):
# We assume that anything that"s not a string is already an activation
"""Maps a string to a Python function, e.g., "relu" => `tf.nn.relu`.
# function, so we just return it.
if
not
isinstance
(
activation_string
,
six
.
string_types
):
Args:
return
activation_string
activation_string: String name of the activation function.
if
not
activation_string
:
Returns:
return
None
A Python function corresponding to the activation function. If
`activation_string` is None, empty, or "linear", this will return None.
act
=
activation_string
.
lower
()
If `activation_string` is not a string, it will return `activation_string`.
if
act
==
"linear"
:
return
None
Raises:
elif
act
==
"relu"
:
ValueError: The `activation_string` does not correspond to a known
return
tf
.
nn
.
relu
activation.
elif
act
==
"gelu"
:
"""
return
gelu
elif
act
==
"tanh"
:
# We assume that anything that"s not a string is already an activation
return
tf
.
tanh
# function, so we just return it.
else
:
if
not
isinstance
(
activation_string
,
six
.
string_types
):
raise
ValueError
(
"Unsupported activation: %s"
%
act
)
return
activation_string
if
not
activation_string
:
return
None
act
=
activation_string
.
lower
()
if
act
==
"linear"
:
return
None
elif
act
==
"relu"
:
return
tf
.
nn
.
relu
elif
act
==
"gelu"
:
return
gelu
elif
act
==
"tanh"
:
return
tf
.
tanh
else
:
raise
ValueError
(
"Unsupported activation: %s"
%
act
)
def
get_assigment_map_from_checkpoint
(
tvars
,
init_checkpoint
):
def
get_assigment_map_from_checkpoint
(
tvars
,
init_checkpoint
):
"""Compute the union of the current variables and checkpoint variables."""
"""Compute the union of the current variables and checkpoint variables."""
assignment_map
=
{}
assignment_map
=
{}
initialized_variable_names
=
{}
initialized_variable_names
=
{}
name_to_variable
=
collections
.
OrderedDict
()
name_to_variable
=
collections
.
OrderedDict
()
for
var
in
tvars
:
for
var
in
tvars
:
name
=
var
.
name
name
=
var
.
name
m
=
re
.
match
(
"^(.*):
\\
d+$"
,
name
)
m
=
re
.
match
(
"^(.*):
\\
d+$"
,
name
)
if
m
is
not
None
:
if
m
is
not
None
:
name
=
m
.
group
(
1
)
name
=
m
.
group
(
1
)
name_to_variable
[
name
]
=
var
name_to_variable
[
name
]
=
var
init_vars
=
tf
.
train
.
list_variables
(
init_checkpoint
)
init_vars
=
tf
.
train
.
list_variables
(
init_checkpoint
)
assignment_map
=
collections
.
OrderedDict
()
assignment_map
=
collections
.
OrderedDict
()
for
x
in
init_vars
:
for
x
in
init_vars
:
(
name
,
var
)
=
(
x
[
0
],
x
[
1
])
(
name
,
var
)
=
(
x
[
0
],
x
[
1
])
if
name
not
in
name_to_variable
:
if
name
not
in
name_to_variable
:
continue
continue
assignment_map
[
name
]
=
name
assignment_map
[
name
]
=
name
initialized_variable_names
[
name
]
=
1
initialized_variable_names
[
name
]
=
1
initialized_variable_names
[
name
+
":0"
]
=
1
initialized_variable_names
[
name
+
":0"
]
=
1
return
(
assignment_map
,
initialized_variable_names
)
return
(
assignment_map
,
initialized_variable_names
)
def
dropout
(
input_tensor
,
dropout_prob
):
def
dropout
(
input_tensor
,
dropout_prob
):
"""Perform dropout.
"""Perform dropout.
Args:
Args:
input_tensor: float Tensor.
input_tensor: float Tensor.
dropout_prob: Python float. The probabiltiy of dropping out a value (NOT of
dropout_prob: Python float. The probabiltiy of dropping out a value (NOT of
*keeping* a dimension as in `tf.nn.dropout`).
*keeping* a dimension as in `tf.nn.dropout`).
Returns:
Returns:
A version of `input_tensor` with dropout applied.
A version of `input_tensor` with dropout applied.
"""
"""
if
dropout_prob
is
None
or
dropout_prob
==
0.0
:
if
dropout_prob
is
None
or
dropout_prob
==
0.0
:
return
input_tensor
return
input_tensor
output
=
tf
.
nn
.
dropout
(
input_tensor
,
1.0
-
dropout_prob
)
output
=
tf
.
nn
.
dropout
(
input_tensor
,
1.0
-
dropout_prob
)
return
output
return
output
def
layer_norm
(
input_tensor
,
name
=
None
):
def
layer_norm
(
input_tensor
,
name
=
None
):
"""Run layer normalization on the last dimension of the tensor."""
"""Run layer normalization on the last dimension of the tensor."""
return
tf
.
contrib
.
layers
.
layer_norm
(
return
tf
.
contrib
.
layers
.
layer_norm
(
inputs
=
input_tensor
,
begin_norm_axis
=-
1
,
begin_params_axis
=-
1
,
scope
=
name
)
inputs
=
input_tensor
,
begin_norm_axis
=-
1
,
begin_params_axis
=-
1
,
scope
=
name
)
def
layer_norm_and_dropout
(
input_tensor
,
dropout_prob
,
name
=
None
):
def
layer_norm_and_dropout
(
input_tensor
,
dropout_prob
,
name
=
None
):
"""Runs layer normalization followed by dropout."""
"""Runs layer normalization followed by dropout."""
output_tensor
=
layer_norm
(
input_tensor
,
name
)
output_tensor
=
layer_norm
(
input_tensor
,
name
)
output_tensor
=
dropout
(
output_tensor
,
dropout_prob
)
output_tensor
=
dropout
(
output_tensor
,
dropout_prob
)
return
output_tensor
return
output_tensor
def
create_initializer
(
initializer_range
=
0.02
):
def
create_initializer
(
initializer_range
=
0.02
):
"""Creates a `truncated_normal_initializer` with the given range."""
"""Creates a `truncated_normal_initializer` with the given range."""
return
tf
.
truncated_normal_initializer
(
stddev
=
initializer_range
)
return
tf
.
truncated_normal_initializer
(
stddev
=
initializer_range
)
def
embedding_lookup
(
input_ids
,
def
embedding_lookup
(
input_ids
,
...
@@ -384,47 +384,47 @@ def embedding_lookup(input_ids,
...
@@ -384,47 +384,47 @@ def embedding_lookup(input_ids,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
word_embedding_name
=
"word_embeddings"
,
word_embedding_name
=
"word_embeddings"
,
use_one_hot_embeddings
=
False
):
use_one_hot_embeddings
=
False
):
"""Looks up words embeddings for id tensor.
"""Looks up words embeddings for id tensor.
Args:
Args:
input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
input_ids: int32 Tensor of shape [batch_size, seq_length] containing word
ids.
ids.
vocab_size: int. Size of the embedding vocabulary.
vocab_size: int. Size of the embedding vocabulary.
embedding_size: int. Width of the word embeddings.
embedding_size: int. Width of the word embeddings.
initializer_range: float. Embedding initialization range.
initializer_range: float. Embedding initialization range.
word_embedding_name: string. Name of the embedding table.
word_embedding_name: string. Name of the embedding table.
use_one_hot_embeddings: bool. If True, use one-hot method for word
use_one_hot_embeddings: bool. If True, use one-hot method for word
embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better
embeddings. If False, use `tf.nn.embedding_lookup()`. One hot is better
for TPUs.
for TPUs.
Returns:
Returns:
float Tensor of shape [batch_size, seq_length, embedding_size].
float Tensor of shape [batch_size, seq_length, embedding_size].
"""
"""
# This function assumes that the input is of shape [batch_size, seq_length,
# This function assumes that the input is of shape [batch_size, seq_length,
# num_inputs].
# num_inputs].
#
#
# If the input is a 2D tensor of shape [batch_size, seq_length], we
# If the input is a 2D tensor of shape [batch_size, seq_length], we
# reshape to [batch_size, seq_length, 1].
# reshape to [batch_size, seq_length, 1].
if
input_ids
.
shape
.
ndims
==
2
:
if
input_ids
.
shape
.
ndims
==
2
:
input_ids
=
tf
.
expand_dims
(
input_ids
,
axis
=
[
-
1
])
input_ids
=
tf
.
expand_dims
(
input_ids
,
axis
=
[
-
1
])
embedding_table
=
tf
.
get_variable
(
embedding_table
=
tf
.
get_variable
(
name
=
word_embedding_name
,
name
=
word_embedding_name
,
shape
=
[
vocab_size
,
embedding_size
],
shape
=
[
vocab_size
,
embedding_size
],
initializer
=
create_initializer
(
initializer_range
))
initializer
=
create_initializer
(
initializer_range
))
if
use_one_hot_embeddings
:
if
use_one_hot_embeddings
:
flat_input_ids
=
tf
.
reshape
(
input_ids
,
[
-
1
])
flat_input_ids
=
tf
.
reshape
(
input_ids
,
[
-
1
])
one_hot_input_ids
=
tf
.
one_hot
(
flat_input_ids
,
depth
=
vocab_size
)
one_hot_input_ids
=
tf
.
one_hot
(
flat_input_ids
,
depth
=
vocab_size
)
output
=
tf
.
matmul
(
one_hot_input_ids
,
embedding_table
)
output
=
tf
.
matmul
(
one_hot_input_ids
,
embedding_table
)
else
:
else
:
output
=
tf
.
nn
.
embedding_lookup
(
embedding_table
,
input_ids
)
output
=
tf
.
nn
.
embedding_lookup
(
embedding_table
,
input_ids
)
input_shape
=
get_shape_list
(
input_ids
)
input_shape
=
get_shape_list
(
input_ids
)
output
=
tf
.
reshape
(
output
,
output
=
tf
.
reshape
(
output
,
input_shape
[
0
:
-
1
]
+
[
input_shape
[
-
1
]
*
embedding_size
])
input_shape
[
0
:
-
1
]
+
[
input_shape
[
-
1
]
*
embedding_size
])
return
(
output
,
embedding_table
)
return
(
output
,
embedding_table
)
def
embedding_postprocessor
(
input_tensor
,
def
embedding_postprocessor
(
input_tensor
,
...
@@ -437,131 +437,131 @@ def embedding_postprocessor(input_tensor,
...
@@ -437,131 +437,131 @@ def embedding_postprocessor(input_tensor,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
max_position_embeddings
=
512
,
max_position_embeddings
=
512
,
dropout_prob
=
0.1
):
dropout_prob
=
0.1
):
"""Performs various post-processing on a word embedding tensor.
"""Performs various post-processing on a word embedding tensor.
Args:
input_tensor: float Tensor of shape [batch_size, seq_length,
embedding_size].
use_token_type: bool. Whether to add embeddings for `token_type_ids`.
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
Must be specified if `use_token_type` is True.
token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
token_type_embedding_name: string. The name of the embedding table variable
for token type ids.
use_position_embeddings: bool. Whether to add position embeddings for the
position of each token in the sequence.
position_embedding_name: string. The name of the embedding table variable
for positional embeddings.
initializer_range: float. Range of the weight initialization.
max_position_embeddings: int. Maximum sequence length that might ever be
used with this model. This can be longer than the sequence length of
input_tensor, but cannot be shorter.
dropout_prob: float. Dropout probability applied to the final output tensor.
Returns:
float tensor with same shape as `input_tensor`.
Raises:
ValueError: One of the tensor shapes or input values is invalid.
"""
input_shape
=
get_shape_list
(
input_tensor
,
expected_rank
=
3
)
batch_size
=
input_shape
[
0
]
seq_length
=
input_shape
[
1
]
width
=
input_shape
[
2
]
if
seq_length
>
max_position_embeddings
:
raise
ValueError
(
"The seq length (%d) cannot be greater than "
"`max_position_embeddings` (%d)"
%
(
seq_length
,
max_position_embeddings
))
output
=
input_tensor
if
use_token_type
:
if
token_type_ids
is
None
:
raise
ValueError
(
"`token_type_ids` must be specified if"
"`use_token_type` is True."
)
token_type_table
=
tf
.
get_variable
(
name
=
token_type_embedding_name
,
shape
=
[
token_type_vocab_size
,
width
],
initializer
=
create_initializer
(
initializer_range
))
# This vocab will be small so we always do one-hot here, since it is always
# faster for a small vocabulary.
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
[
-
1
])
one_hot_ids
=
tf
.
one_hot
(
flat_token_type_ids
,
depth
=
token_type_vocab_size
)
token_type_embeddings
=
tf
.
matmul
(
one_hot_ids
,
token_type_table
)
token_type_embeddings
=
tf
.
reshape
(
token_type_embeddings
,
[
batch_size
,
seq_length
,
width
])
output
+=
token_type_embeddings
if
use_position_embeddings
:
full_position_embeddings
=
tf
.
get_variable
(
name
=
position_embedding_name
,
shape
=
[
max_position_embeddings
,
width
],
initializer
=
create_initializer
(
initializer_range
))
# Since the position embedding table is a learned variable, we create it
# using a (long) sequence length `max_position_embeddings`. The actual
# sequence length might be shorter than this, for faster training of
# tasks that do not have long sequences.
#
# So `full_position_embeddings` is effectively an embedding table
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
# perform a slice.
if
seq_length
<
max_position_embeddings
:
position_embeddings
=
tf
.
slice
(
full_position_embeddings
,
[
0
,
0
],
[
seq_length
,
-
1
])
else
:
position_embeddings
=
full_position_embeddings
num_dims
=
len
(
output
.
shape
.
as_list
())
Args:
input_tensor: float Tensor of shape [batch_size, seq_length,
embedding_size].
use_token_type: bool. Whether to add embeddings for `token_type_ids`.
token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
Must be specified if `use_token_type` is True.
token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
token_type_embedding_name: string. The name of the embedding table variable
for token type ids.
use_position_embeddings: bool. Whether to add position embeddings for the
position of each token in the sequence.
position_embedding_name: string. The name of the embedding table variable
for positional embeddings.
initializer_range: float. Range of the weight initialization.
max_position_embeddings: int. Maximum sequence length that might ever be
used with this model. This can be longer than the sequence length of
input_tensor, but cannot be shorter.
dropout_prob: float. Dropout probability applied to the final output tensor.
# Only the last two dimensions are relevant (`seq_length` and `width`), so
Returns:
# we broadcast among the first dimensions, which is typically just
float tensor with same shape as `input_tensor`.
# the batch size.
position_broadcast_shape
=
[]
for
_
in
range
(
num_dims
-
2
):
position_broadcast_shape
.
append
(
1
)
position_broadcast_shape
.
extend
([
seq_length
,
width
])
position_embeddings
=
tf
.
reshape
(
position_embeddings
,
position_broadcast_shape
)
output
+=
position_embeddings
output
=
layer_norm_and_dropout
(
output
,
dropout_prob
)
Raises:
return
output
ValueError: One of the tensor shapes or input values is invalid.
"""
input_shape
=
get_shape_list
(
input_tensor
,
expected_rank
=
3
)
batch_size
=
input_shape
[
0
]
seq_length
=
input_shape
[
1
]
width
=
input_shape
[
2
]
if
seq_length
>
max_position_embeddings
:
raise
ValueError
(
"The seq length (%d) cannot be greater than "
"`max_position_embeddings` (%d)"
%
(
seq_length
,
max_position_embeddings
))
output
=
input_tensor
if
use_token_type
:
if
token_type_ids
is
None
:
raise
ValueError
(
"`token_type_ids` must be specified if"
"`use_token_type` is True."
)
token_type_table
=
tf
.
get_variable
(
name
=
token_type_embedding_name
,
shape
=
[
token_type_vocab_size
,
width
],
initializer
=
create_initializer
(
initializer_range
))
# This vocab will be small so we always do one-hot here, since it is always
# faster for a small vocabulary.
flat_token_type_ids
=
tf
.
reshape
(
token_type_ids
,
[
-
1
])
one_hot_ids
=
tf
.
one_hot
(
flat_token_type_ids
,
depth
=
token_type_vocab_size
)
token_type_embeddings
=
tf
.
matmul
(
one_hot_ids
,
token_type_table
)
token_type_embeddings
=
tf
.
reshape
(
token_type_embeddings
,
[
batch_size
,
seq_length
,
width
])
output
+=
token_type_embeddings
if
use_position_embeddings
:
full_position_embeddings
=
tf
.
get_variable
(
name
=
position_embedding_name
,
shape
=
[
max_position_embeddings
,
width
],
initializer
=
create_initializer
(
initializer_range
))
# Since the position embedding table is a learned variable, we create it
# using a (long) sequence length `max_position_embeddings`. The actual
# sequence length might be shorter than this, for faster training of
# tasks that do not have long sequences.
#
# So `full_position_embeddings` is effectively an embedding table
# for position [0, 1, 2, ..., max_position_embeddings-1], and the current
# sequence has positions [0, 1, 2, ... seq_length-1], so we can just
# perform a slice.
if
seq_length
<
max_position_embeddings
:
position_embeddings
=
tf
.
slice
(
full_position_embeddings
,
[
0
,
0
],
[
seq_length
,
-
1
])
else
:
position_embeddings
=
full_position_embeddings
num_dims
=
len
(
output
.
shape
.
as_list
())
# Only the last two dimensions are relevant (`seq_length` and `width`), so
# we broadcast among the first dimensions, which is typically just
# the batch size.
position_broadcast_shape
=
[]
for
_
in
range
(
num_dims
-
2
):
position_broadcast_shape
.
append
(
1
)
position_broadcast_shape
.
extend
([
seq_length
,
width
])
position_embeddings
=
tf
.
reshape
(
position_embeddings
,
position_broadcast_shape
)
output
+=
position_embeddings
output
=
layer_norm_and_dropout
(
output
,
dropout_prob
)
return
output
def
create_attention_mask_from_input_mask
(
from_tensor
,
to_mask
):
def
create_attention_mask_from_input_mask
(
from_tensor
,
to_mask
):
"""Create 3D attention mask from a 2D tensor mask.
"""Create 3D attention mask from a 2D tensor mask.
Args:
Args:
from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
to_mask: int32 Tensor of shape [batch_size, to_seq_length].
Returns:
Returns:
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
float Tensor of shape [batch_size, from_seq_length, to_seq_length].
"""
"""
from_shape
=
get_shape_list
(
from_tensor
,
expected_rank
=
[
2
,
3
])
from_shape
=
get_shape_list
(
from_tensor
,
expected_rank
=
[
2
,
3
])
batch_size
=
from_shape
[
0
]
batch_size
=
from_shape
[
0
]
from_seq_length
=
from_shape
[
1
]
from_seq_length
=
from_shape
[
1
]
to_shape
=
get_shape_list
(
to_mask
,
expected_rank
=
2
)
to_shape
=
get_shape_list
(
to_mask
,
expected_rank
=
2
)
to_seq_length
=
to_shape
[
1
]
to_seq_length
=
to_shape
[
1
]
to_mask
=
tf
.
cast
(
to_mask
=
tf
.
cast
(
tf
.
reshape
(
to_mask
,
[
batch_size
,
1
,
to_seq_length
]),
tf
.
float32
)
tf
.
reshape
(
to_mask
,
[
batch_size
,
1
,
to_seq_length
]),
tf
.
float32
)
# We don't assume that `from_tensor` is a mask (although it could be). We
# We don't assume that `from_tensor` is a mask (although it could be). We
# don't actually care if we attend *from* padding tokens (only *to* padding)
# don't actually care if we attend *from* padding tokens (only *to* padding)
# tokens so we create a tensor of all ones.
# tokens so we create a tensor of all ones.
#
#
# `broadcast_ones` = [batch_size, from_seq_length, 1]
# `broadcast_ones` = [batch_size, from_seq_length, 1]
broadcast_ones
=
tf
.
ones
(
broadcast_ones
=
tf
.
ones
(
shape
=
[
batch_size
,
from_seq_length
,
1
],
dtype
=
tf
.
float32
)
shape
=
[
batch_size
,
from_seq_length
,
1
],
dtype
=
tf
.
float32
)
# Here we broadcast along two dimensions to create the mask.
# Here we broadcast along two dimensions to create the mask.
mask
=
broadcast_ones
*
to_mask
mask
=
broadcast_ones
*
to_mask
return
mask
return
mask
def
attention_layer
(
from_tensor
,
def
attention_layer
(
from_tensor
,
...
@@ -578,185 +578,185 @@ def attention_layer(from_tensor,
...
@@ -578,185 +578,185 @@ def attention_layer(from_tensor,
batch_size
=
None
,
batch_size
=
None
,
from_seq_length
=
None
,
from_seq_length
=
None
,
to_seq_length
=
None
):
to_seq_length
=
None
):
"""Performs multi-headed attention from `from_tensor` to `to_tensor`.
"""Performs multi-headed attention from `from_tensor` to `to_tensor`.
This is an implementation of multi-headed attention based on "Attention
is all you Need". If `from_tensor` and `to_tensor` are the same, then
this is self-attention. Each timestep in `from_tensor` attends to the
corresponding sequence in `to_tensor`, and returns a fixed-with vector.
This function first projects `from_tensor` into a "query" tensor and
`to_tensor` into "key" and "value" tensors. These are (effectively) a list
of tensors of length `num_attention_heads`, where each tensor is of shape
[batch_size, seq_length, size_per_head].
Then, the query and key tensors are dot-producted and scaled. These are
softmaxed to obtain attention probabilities. The value tensors are then
interpolated by these probabilities, then concatenated back to a single
tensor and returned.
In practice, the multi-headed attention are done with transposes and
reshapes rather than actual separate tensors.
Args:
from_tensor: float Tensor of shape [batch_size, from_seq_length,
from_width].
to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
attention_mask: (optional) int32 Tensor of shape [batch_size,
from_seq_length, to_seq_length]. The values should be 1 or 0. The
attention scores will effectively be set to -infinity for any positions in
the mask that are 0, and will be unchaged for positions that are 1.
num_attention_heads: int. Number of attention heads.
size_per_head: int. Size of each attention head.
query_act: (optional) Activation function for the query transform.
key_act: (optional) Activation function for the key transform.
value_act: (optional) Activation function for the value transform.
attention_probs_dropout_prob:
initializer_range: float. Range of the weight initializer.
do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
* from_seq_length, num_attention_heads * size_per_head]. If False, the
output will be of shape [batch_size, from_seq_length, num_attention_heads
* size_per_head].
batch_size: (Optional) int. If the input is 2D, this might be the batch size
of the 3D version of the `from_tensor` and `to_tensor`.
from_seq_length: (Optional) If the input is 2D, this might be the seq length
of the 3D version of the `from_tensor`.
to_seq_length: (Optional) If the input is 2D, this might be the seq length
of the 3D version of the `to_tensor`.
Returns:
float Tensor of shape [batch_size, from_seq_length,
num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
true, this will be of shape [batch_size * from_seq_length,
num_attention_heads * size_per_head]).
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
def
transpose_for_scores
(
input_tensor
,
batch_size
,
num_attention_heads
,
seq_length
,
width
):
output_tensor
=
tf
.
reshape
(
input_tensor
,
[
batch_size
,
seq_length
,
num_attention_heads
,
width
])
output_tensor
=
tf
.
transpose
(
output_tensor
,
[
0
,
2
,
1
,
3
])
return
output_tensor
from_shape
=
get_shape_list
(
from_tensor
,
expected_rank
=
[
2
,
3
])
This is an implementation of multi-headed attention based on "Attention
to_shape
=
get_shape_list
(
to_tensor
,
expected_rank
=
[
2
,
3
])
is all you Need". If `from_tensor` and `to_tensor` are the same, then
this is self-attention. Each timestep in `from_tensor` attends to the
corresponding sequence in `to_tensor`, and returns a fixed-with vector.
if
len
(
from_shape
)
!=
len
(
to_shape
):
This function first projects `from_tensor` into a "query" tensor and
raise
ValueError
(
`to_tensor` into "key" and "value" tensors. These are (effectively) a list
"The rank of `from_tensor` must match the rank of `to_tensor`."
)
of tensors of length `num_attention_heads`, where each tensor is of shape
[batch_size, seq_length, size_per_head].
if
len
(
from_shape
)
==
3
:
Then, the query and key tensors are dot-producted and scaled. These are
batch_size
=
from_shape
[
0
]
softmaxed to obtain attention probabilities. The value tensors are then
from_seq_length
=
from_shape
[
1
]
interpolated by these probabilities, then concatenated back to a single
to_seq_length
=
to_shape
[
1
]
tensor and returned.
elif
len
(
from_shape
)
==
2
:
if
(
batch_size
is
None
or
from_seq_length
is
None
or
to_seq_length
is
None
):
In practice, the multi-headed attention are done with transposes and
raise
ValueError
(
reshapes rather than actual separate tensors.
"When passing in rank 2 tensors to attention_layer, the values "
"for `batch_size`, `from_seq_length`, and `to_seq_length` "
Args:
"must all be specified."
)
from_tensor: float Tensor of shape [batch_size, from_seq_length,
from_width].
# Scalar dimensions referenced here:
to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width].
# B = batch size (number of sequences)
attention_mask: (optional) int32 Tensor of shape [batch_size,
# F = `from_tensor` sequence length
from_seq_length, to_seq_length]. The values should be 1 or 0. The
# T = `to_tensor` sequence length
attention scores will effectively be set to -infinity for any positions in
# N = `num_attention_heads`
the mask that are 0, and will be unchaged for positions that are 1.
# H = `size_per_head`
num_attention_heads: int. Number of attention heads.
size_per_head: int. Size of each attention head.
from_tensor_2d
=
reshape_to_matrix
(
from_tensor
)
query_act: (optional) Activation function for the query transform.
to_tensor_2d
=
reshape_to_matrix
(
to_tensor
)
key_act: (optional) Activation function for the key transform.
value_act: (optional) Activation function for the value transform.
# `query_layer` = [B*F, N*H]
attention_probs_dropout_prob:
query_layer
=
tf
.
layers
.
dense
(
initializer_range: float. Range of the weight initializer.
from_tensor_2d
,
do_return_2d_tensor: bool. If True, the output will be of shape [batch_size
num_attention_heads
*
size_per_head
,
* from_seq_length, num_attention_heads * size_per_head]. If False, the
activation
=
query_act
,
output will be of shape [batch_size, from_seq_length, num_attention_heads
name
=
"query"
,
* size_per_head].
kernel_initializer
=
create_initializer
(
initializer_range
))
batch_size: (Optional) int. If the input is 2D, this might be the batch size
of the 3D version of the `from_tensor` and `to_tensor`.
# `key_layer` = [B*T, N*H]
from_seq_length: (Optional) If the input is 2D, this might be the seq length
key_layer
=
tf
.
layers
.
dense
(
of the 3D version of the `from_tensor`.
to_tensor_2d
,
to_seq_length: (Optional) If the input is 2D, this might be the seq length
num_attention_heads
*
size_per_head
,
of the 3D version of the `to_tensor`.
activation
=
key_act
,
name
=
"key"
,
Returns:
kernel_initializer
=
create_initializer
(
initializer_range
))
float Tensor of shape [batch_size, from_seq_length,
num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is
# `value_layer` = [B*T, N*H]
true, this will be of shape [batch_size * from_seq_length,
value_layer
=
tf
.
layers
.
dense
(
num_attention_heads * size_per_head]).
to_tensor_2d
,
num_attention_heads
*
size_per_head
,
Raises:
activation
=
value_act
,
ValueError: Any of the arguments or tensor shapes are invalid.
name
=
"value"
,
"""
kernel_initializer
=
create_initializer
(
initializer_range
))
def
transpose_for_scores
(
input_tensor
,
batch_size
,
num_attention_heads
,
# `query_layer` = [B, N, F, H]
seq_length
,
width
):
query_layer
=
transpose_for_scores
(
query_layer
,
batch_size
,
output_tensor
=
tf
.
reshape
(
num_attention_heads
,
from_seq_length
,
input_tensor
,
[
batch_size
,
seq_length
,
num_attention_heads
,
width
])
size_per_head
)
output_tensor
=
tf
.
transpose
(
output_tensor
,
[
0
,
2
,
1
,
3
])
# `key_layer` = [B, N, T, H]
return
output_tensor
key_layer
=
transpose_for_scores
(
key_layer
,
batch_size
,
num_attention_heads
,
to_seq_length
,
size_per_head
)
from_shape
=
get_shape_list
(
from_tensor
,
expected_rank
=
[
2
,
3
])
to_shape
=
get_shape_list
(
to_tensor
,
expected_rank
=
[
2
,
3
])
# Take the dot product between "query" and "key" to get the raw
# attention scores.
if
len
(
from_shape
)
!=
len
(
to_shape
):
# `attention_scores` = [B, N, F, T]
raise
ValueError
(
attention_scores
=
tf
.
matmul
(
query_layer
,
key_layer
,
transpose_b
=
True
)
"The rank of `from_tensor` must match the rank of `to_tensor`."
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
size_per_head
)))
if
len
(
from_shape
)
==
3
:
batch_size
=
from_shape
[
0
]
if
attention_mask
is
not
None
:
from_seq_length
=
from_shape
[
1
]
# `attention_mask` = [B, 1, F, T]
to_seq_length
=
to_shape
[
1
]
attention_mask
=
tf
.
expand_dims
(
attention_mask
,
axis
=
[
1
])
elif
len
(
from_shape
)
==
2
:
if
(
batch_size
is
None
or
from_seq_length
is
None
or
to_seq_length
is
None
):
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
raise
ValueError
(
# masked positions, this operation will create a tensor which is 0.0 for
"When passing in rank 2 tensors to attention_layer, the values "
# positions we want to attend and -10000.0 for masked positions.
"for `batch_size`, `from_seq_length`, and `to_seq_length` "
adder
=
(
1.0
-
tf
.
cast
(
attention_mask
,
tf
.
float32
))
*
-
10000.0
"must all be specified."
)
# Since we are adding it to the raw scores before the softmax, this is
# Scalar dimensions referenced here:
# effectively the same as removing these entirely.
# B = batch size (number of sequences)
attention_scores
+=
adder
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# Normalize the attention scores to probabilities.
# N = `num_attention_heads`
# `attention_probs` = [B, N, F, T]
# H = `size_per_head`
attention_probs
=
tf
.
nn
.
softmax
(
attention_scores
)
from_tensor_2d
=
reshape_to_matrix
(
from_tensor
)
# This is actually dropping out entire tokens to attend to, which might
to_tensor_2d
=
reshape_to_matrix
(
to_tensor
)
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs
=
dropout
(
attention_probs
,
attention_probs_dropout_prob
)
# `query_layer` = [B*F, N*H]
query_layer
=
tf
.
layers
.
dense
(
# `value_layer` = [B, T, N, H]
from_tensor_2d
,
value_layer
=
tf
.
reshape
(
num_attention_heads
*
size_per_head
,
value_layer
,
activation
=
query_act
,
[
batch_size
,
to_seq_length
,
num_attention_heads
,
size_per_head
])
name
=
"query"
,
kernel_initializer
=
create_initializer
(
initializer_range
))
# `value_layer` = [B, N, T, H]
value_layer
=
tf
.
transpose
(
value_layer
,
[
0
,
2
,
1
,
3
])
# `key_layer` = [B*T, N*H]
key_layer
=
tf
.
layers
.
dense
(
# `context_layer` = [B, N, F, H]
to_tensor_2d
,
context_layer
=
tf
.
matmul
(
attention_probs
,
value_layer
)
num_attention_heads
*
size_per_head
,
activation
=
key_act
,
# `context_layer` = [B, F, N, H]
name
=
"key"
,
context_layer
=
tf
.
transpose
(
context_layer
,
[
0
,
2
,
1
,
3
])
kernel_initializer
=
create_initializer
(
initializer_range
))
if
do_return_2d_tensor
:
# `value_layer` = [B*T, N*H]
# `context_layer` = [B*F, N*V]
value_layer
=
tf
.
layers
.
dense
(
context_layer
=
tf
.
reshape
(
to_tensor_2d
,
context_layer
,
num_attention_heads
*
size_per_head
,
[
batch_size
*
from_seq_length
,
num_attention_heads
*
size_per_head
])
activation
=
value_act
,
else
:
name
=
"value"
,
# `context_layer` = [B, F, N*V]
kernel_initializer
=
create_initializer
(
initializer_range
))
context_layer
=
tf
.
reshape
(
context_layer
,
# `query_layer` = [B, N, F, H]
[
batch_size
,
from_seq_length
,
num_attention_heads
*
size_per_head
])
query_layer
=
transpose_for_scores
(
query_layer
,
batch_size
,
num_attention_heads
,
from_seq_length
,
return
context_layer
size_per_head
)
# `key_layer` = [B, N, T, H]
key_layer
=
transpose_for_scores
(
key_layer
,
batch_size
,
num_attention_heads
,
to_seq_length
,
size_per_head
)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
# `attention_scores` = [B, N, F, T]
attention_scores
=
tf
.
matmul
(
query_layer
,
key_layer
,
transpose_b
=
True
)
attention_scores
=
tf
.
multiply
(
attention_scores
,
1.0
/
math
.
sqrt
(
float
(
size_per_head
)))
if
attention_mask
is
not
None
:
# `attention_mask` = [B, 1, F, T]
attention_mask
=
tf
.
expand_dims
(
attention_mask
,
axis
=
[
1
])
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
adder
=
(
1.0
-
tf
.
cast
(
attention_mask
,
tf
.
float32
))
*
-
10000.0
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_scores
+=
adder
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs
=
tf
.
nn
.
softmax
(
attention_scores
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs
=
dropout
(
attention_probs
,
attention_probs_dropout_prob
)
# `value_layer` = [B, T, N, H]
value_layer
=
tf
.
reshape
(
value_layer
,
[
batch_size
,
to_seq_length
,
num_attention_heads
,
size_per_head
])
# `value_layer` = [B, N, T, H]
value_layer
=
tf
.
transpose
(
value_layer
,
[
0
,
2
,
1
,
3
])
# `context_layer` = [B, N, F, H]
context_layer
=
tf
.
matmul
(
attention_probs
,
value_layer
)
# `context_layer` = [B, F, N, H]
context_layer
=
tf
.
transpose
(
context_layer
,
[
0
,
2
,
1
,
3
])
if
do_return_2d_tensor
:
# `context_layer` = [B*F, N*V]
context_layer
=
tf
.
reshape
(
context_layer
,
[
batch_size
*
from_seq_length
,
num_attention_heads
*
size_per_head
])
else
:
# `context_layer` = [B, F, N*V]
context_layer
=
tf
.
reshape
(
context_layer
,
[
batch_size
,
from_seq_length
,
num_attention_heads
*
size_per_head
])
return
context_layer
def
transformer_model
(
input_tensor
,
def
transformer_model
(
input_tensor
,
...
@@ -770,225 +770,225 @@ def transformer_model(input_tensor,
...
@@ -770,225 +770,225 @@ def transformer_model(input_tensor,
attention_probs_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
do_return_all_layers
=
False
):
do_return_all_layers
=
False
):
"""Multi-headed, multi-layer Transformer from "Attention is All You Need".
"""Multi-headed, multi-layer Transformer from "Attention is All You Need".
This is almost an exact implementation of the original Transformer encoder.
This is almost an exact implementation of the original Transformer encoder.
See the original paper:
See the original paper:
https://arxiv.org/abs/1706.03762
https://arxiv.org/abs/1706.03762
Also see:
Also see:
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py
Args:
Args:
input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size].
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length,
seq_length], with 1 for positions that can be attended to and 0 in
seq_length], with 1 for positions that can be attended to and 0 in
positions that should not be.
positions that should not be.
hidden_size: int. Hidden size of the Transformer.
hidden_size: int. Hidden size of the Transformer.
num_hidden_layers: int. Number of layers (blocks) in the Transformer.
num_hidden_layers: int. Number of layers (blocks) in the Transformer.
num_attention_heads: int. Number of attention heads in the Transformer.
num_attention_heads: int. Number of attention heads in the Transformer.
intermediate_size: int. The size of the "intermediate" (a.k.a., feed
intermediate_size: int. The size of the "intermediate" (a.k.a., feed
forward) layer.
forward) layer.
intermediate_act_fn: function. The non-linear activation function to apply
intermediate_act_fn: function. The non-linear activation function to apply
to the output of the intermediate/feed-forward layer.
to the output of the intermediate/feed-forward layer.
hidden_dropout_prob: float. Dropout probability for the hidden layers.
hidden_dropout_prob: float. Dropout probability for the hidden layers.
attention_probs_dropout_prob: float. Dropout probability of the attention
attention_probs_dropout_prob: float. Dropout probability of the attention
probabilities.
probabilities.
initializer_range: float. Range of the initializer (stddev of truncated
initializer_range: float. Range of the initializer (stddev of truncated
normal).
normal).
do_return_all_layers: Whether to also return all layers or just the final
do_return_all_layers: Whether to also return all layers or just the final
layer.
layer.
Returns:
Returns:
float Tensor of shape [batch_size, seq_length, hidden_size], the final
float Tensor of shape [batch_size, seq_length, hidden_size], the final
hidden layer of the Transformer.
hidden layer of the Transformer.
Raises:
Raises:
ValueError: A Tensor shape or parameter is invalid.
ValueError: A Tensor shape or parameter is invalid.
"""
"""
if
hidden_size
%
num_attention_heads
!=
0
:
if
hidden_size
%
num_attention_heads
!=
0
:
raise
ValueError
(
raise
ValueError
(
"The hidden size (%d) is not a multiple of the number of attention "
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
num_attention_heads
))
"heads (%d)"
%
(
hidden_size
,
num_attention_heads
))
attention_head_size
=
int
(
hidden_size
/
num_attention_heads
)
attention_head_size
=
int
(
hidden_size
/
num_attention_heads
)
input_shape
=
get_shape_list
(
input_tensor
,
expected_rank
=
3
)
input_shape
=
get_shape_list
(
input_tensor
,
expected_rank
=
3
)
batch_size
=
input_shape
[
0
]
batch_size
=
input_shape
[
0
]
seq_length
=
input_shape
[
1
]
seq_length
=
input_shape
[
1
]
input_width
=
input_shape
[
2
]
input_width
=
input_shape
[
2
]
# The Transformer performs sum residuals on all layers so the input needs
# The Transformer performs sum residuals on all layers so the input needs
# to be the same as the hidden size.
# to be the same as the hidden size.
if
input_width
!=
hidden_size
:
if
input_width
!=
hidden_size
:
raise
ValueError
(
"The width of the input tensor (%d) != hidden size (%d)"
%
raise
ValueError
(
"The width of the input tensor (%d) != hidden size (%d)"
%
(
input_width
,
hidden_size
))
(
input_width
,
hidden_size
))
# We keep the representation as a 2D tensor to avoid re-shaping it back and
# We keep the representation as a 2D tensor to avoid re-shaping it back and
# forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
# forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on
# the GPU/CPU but may not be free on the TPU, so we want to minimize them to
# the GPU/CPU but may not be free on the TPU, so we want to minimize them to
# help the optimizer.
# help the optimizer.
prev_output
=
reshape_to_matrix
(
input_tensor
)
prev_output
=
reshape_to_matrix
(
input_tensor
)
all_layer_outputs
=
[]
all_layer_outputs
=
[]
for
layer_idx
in
range
(
num_hidden_layers
):
for
layer_idx
in
range
(
num_hidden_layers
):
with
tf
.
variable_scope
(
"layer_%d"
%
layer_idx
):
with
tf
.
variable_scope
(
"layer_%d"
%
layer_idx
):
layer_input
=
prev_output
layer_input
=
prev_output
with
tf
.
variable_scope
(
"attention"
):
with
tf
.
variable_scope
(
"attention"
):
attention_heads
=
[]
attention_heads
=
[]
with
tf
.
variable_scope
(
"self"
):
with
tf
.
variable_scope
(
"self"
):
attention_head
=
attention_layer
(
attention_head
=
attention_layer
(
from_tensor
=
layer_input
,
from_tensor
=
layer_input
,
to_tensor
=
layer_input
,
to_tensor
=
layer_input
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
num_attention_heads
=
num_attention_heads
,
num_attention_heads
=
num_attention_heads
,
size_per_head
=
attention_head_size
,
size_per_head
=
attention_head_size
,
attention_probs_dropout_prob
=
attention_probs_dropout_prob
,
attention_probs_dropout_prob
=
attention_probs_dropout_prob
,
initializer_range
=
initializer_range
,
initializer_range
=
initializer_range
,
do_return_2d_tensor
=
True
,
do_return_2d_tensor
=
True
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
from_seq_length
=
seq_length
,
from_seq_length
=
seq_length
,
to_seq_length
=
seq_length
)
to_seq_length
=
seq_length
)
attention_heads
.
append
(
attention_head
)
attention_heads
.
append
(
attention_head
)
attention_output
=
None
attention_output
=
None
if
len
(
attention_heads
)
==
1
:
if
len
(
attention_heads
)
==
1
:
attention_output
=
attention_heads
[
0
]
attention_output
=
attention_heads
[
0
]
else
:
else
:
# In the case where we have other sequences, we just concatenate
# In the case where we have other sequences, we just concatenate
# them to the self-attention head before the projection.
# them to the self-attention head before the projection.
attention_output
=
tf
.
concat
(
attention_heads
,
axis
=-
1
)
attention_output
=
tf
.
concat
(
attention_heads
,
axis
=-
1
)
# Run a linear projection of `hidden_size` then add a residual
# Run a linear projection of `hidden_size` then add a residual
# with `layer_input`.
# with `layer_input`.
with
tf
.
variable_scope
(
"output"
):
with
tf
.
variable_scope
(
"output"
):
attention_output
=
tf
.
layers
.
dense
(
attention_output
=
tf
.
layers
.
dense
(
attention_output
,
attention_output
,
hidden_size
,
hidden_size
,
kernel_initializer
=
create_initializer
(
initializer_range
))
kernel_initializer
=
create_initializer
(
initializer_range
))
attention_output
=
dropout
(
attention_output
,
hidden_dropout_prob
)
attention_output
=
dropout
(
attention_output
,
hidden_dropout_prob
)
attention_output
=
layer_norm
(
attention_output
+
layer_input
)
attention_output
=
layer_norm
(
attention_output
+
layer_input
)
# The activation is only applied to the "intermediate" hidden layer.
# The activation is only applied to the "intermediate" hidden layer.
with
tf
.
variable_scope
(
"intermediate"
):
with
tf
.
variable_scope
(
"intermediate"
):
intermediate_output
=
tf
.
layers
.
dense
(
intermediate_output
=
tf
.
layers
.
dense
(
attention_output
,
attention_output
,
intermediate_size
,
intermediate_size
,
activation
=
intermediate_act_fn
,
activation
=
intermediate_act_fn
,
kernel_initializer
=
create_initializer
(
initializer_range
))
kernel_initializer
=
create_initializer
(
initializer_range
))
# Down-project back to `hidden_size` then add the residual.
# Down-project back to `hidden_size` then add the residual.
with
tf
.
variable_scope
(
"output"
):
with
tf
.
variable_scope
(
"output"
):
layer_output
=
tf
.
layers
.
dense
(
layer_output
=
tf
.
layers
.
dense
(
intermediate_output
,
intermediate_output
,
hidden_size
,
hidden_size
,
kernel_initializer
=
create_initializer
(
initializer_range
))
kernel_initializer
=
create_initializer
(
initializer_range
))
layer_output
=
dropout
(
layer_output
,
hidden_dropout_prob
)
layer_output
=
dropout
(
layer_output
,
hidden_dropout_prob
)
layer_output
=
layer_norm
(
layer_output
+
attention_output
)
layer_output
=
layer_norm
(
layer_output
+
attention_output
)
prev_output
=
layer_output
prev_output
=
layer_output
all_layer_outputs
.
append
(
layer_output
)
all_layer_outputs
.
append
(
layer_output
)
if
do_return_all_layers
:
if
do_return_all_layers
:
final_outputs
=
[]
final_outputs
=
[]
for
layer_output
in
all_layer_outputs
:
for
layer_output
in
all_layer_outputs
:
final_output
=
reshape_from_matrix
(
layer_output
,
input_shape
)
final_output
=
reshape_from_matrix
(
layer_output
,
input_shape
)
final_outputs
.
append
(
final_output
)
final_outputs
.
append
(
final_output
)
return
final_outputs
return
final_outputs
else
:
else
:
final_output
=
reshape_from_matrix
(
prev_output
,
input_shape
)
final_output
=
reshape_from_matrix
(
prev_output
,
input_shape
)
return
final_output
return
final_output
def
get_shape_list
(
tensor
,
expected_rank
=
None
,
name
=
None
):
def
get_shape_list
(
tensor
,
expected_rank
=
None
,
name
=
None
):
"""Returns a list of the shape of tensor, preferring static dimensions.
"""Returns a list of the shape of tensor, preferring static dimensions.
Args:
tensor: A tf.Tensor object to find the shape of.
expected_rank: (optional) int. The expected rank of `tensor`. If this is
specified and the `tensor` has a different rank, and exception will be
thrown.
name: Optional name of the tensor for the error message.
Returns:
A list of dimensions of the shape of tensor. All static dimensions will
be returned as python integers, and dynamic dimensions will be returned
as tf.Tensor scalars.
"""
if
name
is
None
:
name
=
tensor
.
name
if
expected_rank
is
not
None
:
assert_rank
(
tensor
,
expected_rank
,
name
)
shape
=
tensor
.
shape
.
as_list
()
non_static_indexes
=
[]
for
(
index
,
dim
)
in
enumerate
(
shape
):
if
dim
is
None
:
non_static_indexes
.
append
(
index
)
if
not
non_static_indexes
:
return
shape
dyn_shape
=
tf
.
shape
(
tensor
)
Args:
for
index
in
non_static_indexes
:
tensor: A tf.Tensor object to find the shape of.
shape
[
index
]
=
dyn_shape
[
index
]
expected_rank: (optional) int. The expected rank of `tensor`. If this is
return
shape
specified and the `tensor` has a different rank, and exception will be
thrown.
name: Optional name of the tensor for the error message.
Returns:
A list of dimensions of the shape of tensor. All static dimensions will
be returned as python integers, and dynamic dimensions will be returned
as tf.Tensor scalars.
"""
if
name
is
None
:
name
=
tensor
.
name
def
reshape_to_matrix
(
input_tensor
):
if
expected_rank
is
not
None
:
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
assert_rank
(
tensor
,
expected_rank
,
name
)
ndims
=
input_tensor
.
shape
.
ndims
if
ndims
<
2
:
raise
ValueError
(
"Input tensor must have at least rank 2. Shape = %s"
%
(
input_tensor
.
shape
))
if
ndims
==
2
:
return
input_tensor
width
=
input_tensor
.
shape
[
-
1
]
shape
=
tensor
.
shape
.
as_list
()
output_tensor
=
tf
.
reshape
(
input_tensor
,
[
-
1
,
width
])
return
output_tensor
non_static_indexes
=
[]
for
(
index
,
dim
)
in
enumerate
(
shape
):
if
dim
is
None
:
non_static_indexes
.
append
(
index
)
def
reshape_from_matrix
(
output_tensor
,
orig_shape_list
):
if
not
non_static_indexes
:
"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
return
shape
if
len
(
orig_shape_list
)
==
2
:
dyn_shape
=
tf
.
shape
(
tensor
)
for
index
in
non_static_indexes
:
shape
[
index
]
=
dyn_shape
[
index
]
return
shape
def
reshape_to_matrix
(
input_tensor
):
"""Reshapes a >= rank 2 tensor to a rank 2 tensor (i.e., a matrix)."""
ndims
=
input_tensor
.
shape
.
ndims
if
ndims
<
2
:
raise
ValueError
(
"Input tensor must have at least rank 2. Shape = %s"
%
(
input_tensor
.
shape
))
if
ndims
==
2
:
return
input_tensor
width
=
input_tensor
.
shape
[
-
1
]
output_tensor
=
tf
.
reshape
(
input_tensor
,
[
-
1
,
width
])
return
output_tensor
return
output_tensor
output_shape
=
get_shape_list
(
output_tensor
)
orig_dims
=
orig_shape_list
[
0
:
-
1
]
def
reshape_from_matrix
(
output_tensor
,
orig_shape_list
):
width
=
output_shape
[
-
1
]
"""Reshapes a rank 2 tensor back to its original rank >= 2 tensor."""
if
len
(
orig_shape_list
)
==
2
:
return
output_tensor
output_shape
=
get_shape_list
(
output_tensor
)
orig_dims
=
orig_shape_list
[
0
:
-
1
]
width
=
output_shape
[
-
1
]
return
tf
.
reshape
(
output_tensor
,
orig_dims
+
[
width
])
return
tf
.
reshape
(
output_tensor
,
orig_dims
+
[
width
])
def
assert_rank
(
tensor
,
expected_rank
,
name
=
None
):
def
assert_rank
(
tensor
,
expected_rank
,
name
=
None
):
"""Raises an exception if the tensor rank is not of the expected rank.
"""Raises an exception if the tensor rank is not of the expected rank.
Args:
Args:
tensor: A tf.Tensor to check the rank of.
tensor: A tf.Tensor to check the rank of.
expected_rank: Python integer or list of integers, expected rank.
expected_rank: Python integer or list of integers, expected rank.
name: Optional name of the tensor for the error message.
name: Optional name of the tensor for the error message.
Raises:
Raises:
ValueError: If the expected shape doesn"t match the actual shape.
ValueError: If the expected shape doesn"t match the actual shape.
"""
"""
if
name
is
None
:
if
name
is
None
:
name
=
tensor
.
name
name
=
tensor
.
name
expected_rank_dict
=
{}
expected_rank_dict
=
{}
if
isinstance
(
expected_rank
,
six
.
integer_types
):
if
isinstance
(
expected_rank
,
six
.
integer_types
):
expected_rank_dict
[
expected_rank
]
=
True
expected_rank_dict
[
expected_rank
]
=
True
else
:
else
:
for
x
in
expected_rank
:
for
x
in
expected_rank
:
expected_rank_dict
[
x
]
=
True
expected_rank_dict
[
x
]
=
True
actual_rank
=
tensor
.
shape
.
ndims
actual_rank
=
tensor
.
shape
.
ndims
if
actual_rank
not
in
expected_rank_dict
:
if
actual_rank
not
in
expected_rank_dict
:
scope_name
=
tf
.
get_variable_scope
().
name
scope_name
=
tf
.
get_variable_scope
().
name
raise
ValueError
(
raise
ValueError
(
"For the tensor `%s` in scope `%s`, the actual rank "
"For the tensor `%s` in scope `%s`, the actual rank "
"`%d` (shape = %s) is not equal to the expected rank `%s`"
%
"`%d` (shape = %s) is not equal to the expected rank `%s`"
%
(
name
,
scope_name
,
actual_rank
,
str
(
tensor
.
shape
),
str
(
expected_rank
)))
(
name
,
scope_name
,
actual_rank
,
str
(
tensor
.
shape
),
str
(
expected_rank
)))
modeling_test.py
View file @
8163baab
...
@@ -27,250 +27,249 @@ import tensorflow as tf
...
@@ -27,250 +27,249 @@ import tensorflow as tf
class
BertModelTest
(
tf
.
test
.
TestCase
):
class
BertModelTest
(
tf
.
test
.
TestCase
):
class
BertModelTester
(
object
):
class
BertModelTester
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
parent
,
parent
,
batch_size
=
13
,
batch_size
=
13
,
seq_length
=
7
,
seq_length
=
7
,
is_training
=
True
,
is_training
=
True
,
use_input_mask
=
True
,
use_input_mask
=
True
,
use_token_type_ids
=
True
,
use_token_type_ids
=
True
,
vocab_size
=
99
,
vocab_size
=
99
,
hidden_size
=
32
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
num_attention_heads
=
4
,
intermediate_size
=
37
,
intermediate_size
=
37
,
hidden_act
=
"gelu"
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_vocab_size
=
16
,
initializer_range
=
0.02
,
initializer_range
=
0.02
,
scope
=
None
):
scope
=
None
):
self
.
parent
=
parent
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
is_training
=
is_training
self
.
use_input_mask
=
use_input_mask
self
.
use_input_mask
=
use_input_mask
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_token_type_ids
=
use_token_type_ids
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
intermediate_size
=
intermediate_size
self
.
hidden_act
=
hidden_act
self
.
hidden_act
=
hidden_act
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_vocab_size
=
type_vocab_size
self
.
initializer_range
=
initializer_range
self
.
initializer_range
=
initializer_range
self
.
scope
=
scope
self
.
scope
=
scope
def
create_model
(
self
):
def
create_model
(
self
):
input_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
input_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
self
.
vocab_size
)
input_mask
=
None
input_mask
=
None
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
BertModelTest
.
ids_tensor
(
input_mask
=
BertModelTest
.
ids_tensor
(
[
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
[
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
token_type_ids
=
None
token_type_ids
=
None
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
token_type_ids
=
BertModelTest
.
ids_tensor
(
token_type_ids
=
BertModelTest
.
ids_tensor
(
[
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
[
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
config
=
modeling
.
BertConfig
(
config
=
modeling
.
BertConfig
(
vocab_size
=
self
.
vocab_size
,
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
num_hidden_layers
=
self
.
num_hidden_layers
,
num_attention_heads
=
self
.
num_attention_heads
,
num_attention_heads
=
self
.
num_attention_heads
,
intermediate_size
=
self
.
intermediate_size
,
intermediate_size
=
self
.
intermediate_size
,
hidden_act
=
self
.
hidden_act
,
hidden_act
=
self
.
hidden_act
,
hidden_dropout_prob
=
self
.
hidden_dropout_prob
,
hidden_dropout_prob
=
self
.
hidden_dropout_prob
,
attention_probs_dropout_prob
=
self
.
attention_probs_dropout_prob
,
attention_probs_dropout_prob
=
self
.
attention_probs_dropout_prob
,
max_position_embeddings
=
self
.
max_position_embeddings
,
max_position_embeddings
=
self
.
max_position_embeddings
,
type_vocab_size
=
self
.
type_vocab_size
,
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
)
initializer_range
=
self
.
initializer_range
)
model
=
modeling
.
BertModel
(
model
=
modeling
.
BertModel
(
config
=
config
,
config
=
config
,
is_training
=
self
.
is_training
,
is_training
=
self
.
is_training
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
token_type_ids
=
token_type_ids
,
scope
=
self
.
scope
)
scope
=
self
.
scope
)
outputs
=
{
outputs
=
{
"embedding_output"
:
model
.
get_embedding_output
(),
"embedding_output"
:
model
.
get_embedding_output
(),
"sequence_output"
:
model
.
get_sequence_output
(),
"sequence_output"
:
model
.
get_sequence_output
(),
"pooled_output"
:
model
.
get_pooled_output
(),
"pooled_output"
:
model
.
get_pooled_output
(),
"all_encoder_layers"
:
model
.
get_all_encoder_layers
(),
"all_encoder_layers"
:
model
.
get_all_encoder_layers
(),
}
}
return
outputs
return
outputs
def
check_output
(
self
,
result
):
def
check_output
(
self
,
result
):
self
.
parent
.
assertAllEqual
(
self
.
parent
.
assertAllEqual
(
result
[
"embedding_output"
].
shape
,
result
[
"embedding_output"
].
shape
,
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertAllEqual
(
self
.
parent
.
assertAllEqual
(
result
[
"sequence_output"
].
shape
,
result
[
"sequence_output"
].
shape
,
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertAllEqual
(
result
[
"pooled_output"
].
shape
,
self
.
parent
.
assertAllEqual
(
result
[
"pooled_output"
].
shape
,
[
self
.
batch_size
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
hidden_size
])
def
test_default
(
self
):
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
def
test_config_to_json_string
(
self
):
def
test_config_to_json_string
(
self
):
config
=
modeling
.
BertConfig
(
vocab_size
=
99
,
hidden_size
=
37
)
config
=
modeling
.
BertConfig
(
vocab_size
=
99
,
hidden_size
=
37
)
obj
=
json
.
loads
(
config
.
to_json_string
())
obj
=
json
.
loads
(
config
.
to_json_string
())
self
.
assertEqual
(
obj
[
"vocab_size"
],
99
)
self
.
assertEqual
(
obj
[
"vocab_size"
],
99
)
self
.
assertEqual
(
obj
[
"hidden_size"
],
37
)
self
.
assertEqual
(
obj
[
"hidden_size"
],
37
)
def
run_tester
(
self
,
tester
):
def
run_tester
(
self
,
tester
):
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
ops
=
tester
.
create_model
()
ops
=
tester
.
create_model
()
init_op
=
tf
.
group
(
tf
.
global_variables_initializer
(),
init_op
=
tf
.
group
(
tf
.
global_variables_initializer
(),
tf
.
local_variables_initializer
())
tf
.
local_variables_initializer
())
sess
.
run
(
init_op
)
sess
.
run
(
init_op
)
output_result
=
sess
.
run
(
ops
)
output_result
=
sess
.
run
(
ops
)
tester
.
check_output
(
output_result
)
tester
.
check_output
(
output_result
)
self
.
assert_all_tensors_reachable
(
sess
,
[
init_op
,
ops
])
self
.
assert_all_tensors_reachable
(
sess
,
[
init_op
,
ops
])
@
classmethod
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
"""Creates a random int32 tensor of the shape within the vocab size."""
if
rng
is
None
:
if
rng
is
None
:
rng
=
random
.
Random
()
rng
=
random
.
Random
()
total_dims
=
1
total_dims
=
1
for
dim
in
shape
:
for
dim
in
shape
:
total_dims
*=
dim
total_dims
*=
dim
values
=
[]
values
=
[]
for
_
in
range
(
total_dims
):
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
tf
.
constant
(
value
=
values
,
dtype
=
tf
.
int32
,
shape
=
shape
,
name
=
name
)
return
tf
.
constant
(
value
=
values
,
dtype
=
tf
.
int32
,
shape
=
shape
,
name
=
name
)
def
assert_all_tensors_reachable
(
self
,
sess
,
outputs
):
def
assert_all_tensors_reachable
(
self
,
sess
,
outputs
):
"""Checks that all the tensors in the graph are reachable from outputs."""
"""Checks that all the tensors in the graph are reachable from outputs."""
graph
=
sess
.
graph
graph
=
sess
.
graph
ignore_strings
=
[
ignore_strings
=
[
"^.*/dilation_rate$"
,
"^.*/dilation_rate$"
,
"^.*/Tensordot/concat$"
,
"^.*/Tensordot/concat$"
,
"^.*/Tensordot/concat/axis$"
,
"^.*/Tensordot/concat/axis$"
,
"^testing/.*$"
,
"^testing/.*$"
,
]
]
ignore_regexes
=
[
re
.
compile
(
x
)
for
x
in
ignore_strings
]
ignore_regexes
=
[
re
.
compile
(
x
)
for
x
in
ignore_strings
]
unreachable
=
self
.
get_unreachable_ops
(
graph
,
outputs
)
unreachable
=
self
.
get_unreachable_ops
(
graph
,
outputs
)
filtered_unreachable
=
[]
filtered_unreachable
=
[]
for
x
in
unreachable
:
for
x
in
unreachable
:
do_ignore
=
False
do_ignore
=
False
for
r
in
ignore_regexes
:
for
r
in
ignore_regexes
:
m
=
r
.
match
(
x
.
name
)
m
=
r
.
match
(
x
.
name
)
if
m
is
not
None
:
if
m
is
not
None
:
do_ignore
=
True
do_ignore
=
True
if
do_ignore
:
if
do_ignore
:
continue
continue
filtered_unreachable
.
append
(
x
)
filtered_unreachable
.
append
(
x
)
unreachable
=
filtered_unreachable
unreachable
=
filtered_unreachable
self
.
assertEqual
(
self
.
assertEqual
(
len
(
unreachable
),
0
,
"The following ops are unreachable: %s"
%
len
(
unreachable
),
0
,
"The following ops are unreachable: %s"
%
(
" "
.
join
([
x
.
name
for
x
in
unreachable
])))
(
" "
.
join
([
x
.
name
for
x
in
unreachable
])))
@
classmethod
@
classmethod
def
get_unreachable_ops
(
cls
,
graph
,
outputs
):
def
get_unreachable_ops
(
cls
,
graph
,
outputs
):
"""Finds all of the tensors in graph that are unreachable from outputs."""
"""Finds all of the tensors in graph that are unreachable from outputs."""
outputs
=
cls
.
flatten_recursive
(
outputs
)
outputs
=
cls
.
flatten_recursive
(
outputs
)
output_to_op
=
collections
.
defaultdict
(
list
)
output_to_op
=
collections
.
defaultdict
(
list
)
op_to_all
=
collections
.
defaultdict
(
list
)
op_to_all
=
collections
.
defaultdict
(
list
)
assign_out_to_in
=
collections
.
defaultdict
(
list
)
assign_out_to_in
=
collections
.
defaultdict
(
list
)
for
op
in
graph
.
get_operations
():
for
op
in
graph
.
get_operations
():
for
x
in
op
.
inputs
:
for
x
in
op
.
inputs
:
op_to_all
[
op
.
name
].
append
(
x
.
name
)
op_to_all
[
op
.
name
].
append
(
x
.
name
)
for
y
in
op
.
outputs
:
for
y
in
op
.
outputs
:
output_to_op
[
y
.
name
].
append
(
op
.
name
)
output_to_op
[
y
.
name
].
append
(
op
.
name
)
op_to_all
[
op
.
name
].
append
(
y
.
name
)
op_to_all
[
op
.
name
].
append
(
y
.
name
)
if
str
(
op
.
type
)
==
"Assign"
:
if
str
(
op
.
type
)
==
"Assign"
:
for
y
in
op
.
outputs
:
for
y
in
op
.
outputs
:
for
x
in
op
.
inputs
:
for
x
in
op
.
inputs
:
assign_out_to_in
[
y
.
name
].
append
(
x
.
name
)
assign_out_to_in
[
y
.
name
].
append
(
x
.
name
)
assign_groups
=
collections
.
defaultdict
(
list
)
assign_groups
=
collections
.
defaultdict
(
list
)
for
out_name
in
assign_out_to_in
.
keys
():
for
out_name
in
assign_out_to_in
.
keys
():
name_group
=
assign_out_to_in
[
out_name
]
name_group
=
assign_out_to_in
[
out_name
]
for
n1
in
name_group
:
for
n1
in
name_group
:
assign_groups
[
n1
].
append
(
out_name
)
assign_groups
[
n1
].
append
(
out_name
)
for
n2
in
name_group
:
for
n2
in
name_group
:
if
n1
!=
n2
:
if
n1
!=
n2
:
assign_groups
[
n1
].
append
(
n2
)
assign_groups
[
n1
].
append
(
n2
)
seen_tensors
=
{}
seen_tensors
=
{}
stack
=
[
x
.
name
for
x
in
outputs
]
stack
=
[
x
.
name
for
x
in
outputs
]
while
stack
:
while
stack
:
name
=
stack
.
pop
()
name
=
stack
.
pop
()
if
name
in
seen_tensors
:
if
name
in
seen_tensors
:
continue
continue
seen_tensors
[
name
]
=
True
seen_tensors
[
name
]
=
True
if
name
in
output_to_op
:
if
name
in
output_to_op
:
for
op_name
in
output_to_op
[
name
]:
for
op_name
in
output_to_op
[
name
]:
if
op_name
in
op_to_all
:
if
op_name
in
op_to_all
:
for
input_name
in
op_to_all
[
op_name
]:
for
input_name
in
op_to_all
[
op_name
]:
if
input_name
not
in
stack
:
if
input_name
not
in
stack
:
stack
.
append
(
input_name
)
stack
.
append
(
input_name
)
expanded_names
=
[]
expanded_names
=
[]
if
name
in
assign_groups
:
if
name
in
assign_groups
:
for
assign_name
in
assign_groups
[
name
]:
for
assign_name
in
assign_groups
[
name
]:
expanded_names
.
append
(
assign_name
)
expanded_names
.
append
(
assign_name
)
for
expanded_name
in
expanded_names
:
for
expanded_name
in
expanded_names
:
if
expanded_name
not
in
stack
:
if
expanded_name
not
in
stack
:
stack
.
append
(
expanded_name
)
stack
.
append
(
expanded_name
)
unreachable_ops
=
[]
unreachable_ops
=
[]
for
op
in
graph
.
get_operations
():
for
op
in
graph
.
get_operations
():
is_unreachable
=
False
is_unreachable
=
False
all_names
=
[
x
.
name
for
x
in
op
.
inputs
]
+
[
x
.
name
for
x
in
op
.
outputs
]
all_names
=
[
x
.
name
for
x
in
op
.
inputs
]
+
[
x
.
name
for
x
in
op
.
outputs
]
for
name
in
all_names
:
for
name
in
all_names
:
if
name
not
in
seen_tensors
:
if
name
not
in
seen_tensors
:
is_unreachable
=
True
is_unreachable
=
True
if
is_unreachable
:
if
is_unreachable
:
unreachable_ops
.
append
(
op
)
unreachable_ops
.
append
(
op
)
return
unreachable_ops
return
unreachable_ops
@
classmethod
@
classmethod
def
flatten_recursive
(
cls
,
item
):
def
flatten_recursive
(
cls
,
item
):
"""Flattens (potentially nested) a tuple/dictionary/list to a list."""
"""Flattens (potentially nested) a tuple/dictionary/list to a list."""
output
=
[]
output
=
[]
if
isinstance
(
item
,
list
):
if
isinstance
(
item
,
list
):
output
.
extend
(
item
)
output
.
extend
(
item
)
elif
isinstance
(
item
,
tuple
):
elif
isinstance
(
item
,
tuple
):
output
.
extend
(
list
(
item
))
output
.
extend
(
list
(
item
))
elif
isinstance
(
item
,
dict
):
elif
isinstance
(
item
,
dict
):
for
(
_
,
v
)
in
six
.
iteritems
(
item
):
for
(
_
,
v
)
in
six
.
iteritems
(
item
):
output
.
append
(
v
)
output
.
append
(
v
)
else
:
else
:
return
[
item
]
return
[
item
]
flat_output
=
[]
flat_output
=
[]
for
x
in
output
:
for
x
in
output
:
flat_output
.
extend
(
cls
.
flatten_recursive
(
x
))
flat_output
.
extend
(
cls
.
flatten_recursive
(
x
))
return
flat_output
return
flat_output
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
optimization.py
View file @
8163baab
...
@@ -23,149 +23,149 @@ import tensorflow as tf
...
@@ -23,149 +23,149 @@ import tensorflow as tf
def
create_optimizer
(
loss
,
init_lr
,
num_train_steps
,
num_warmup_steps
,
use_tpu
):
def
create_optimizer
(
loss
,
init_lr
,
num_train_steps
,
num_warmup_steps
,
use_tpu
):
"""Creates an optimizer training op."""
"""Creates an optimizer training op."""
global_step
=
tf
.
train
.
get_or_create_global_step
()
global_step
=
tf
.
train
.
get_or_create_global_step
()
learning_rate
=
tf
.
constant
(
value
=
init_lr
,
shape
=
[],
dtype
=
tf
.
float32
)
learning_rate
=
tf
.
constant
(
value
=
init_lr
,
shape
=
[],
dtype
=
tf
.
float32
)
# Implements linear decay of the learning rate.
# Implements linear decay of the learning rate.
learning_rate
=
tf
.
train
.
polynomial_decay
(
learning_rate
=
tf
.
train
.
polynomial_decay
(
learning_rate
,
learning_rate
,
global_step
,
global_step
,
num_train_steps
,
num_train_steps
,
end_learning_rate
=
0.0
,
end_learning_rate
=
0.0
,
power
=
1.0
,
power
=
1.0
,
cycle
=
False
)
cycle
=
False
)
# Implements linear warmup. I.e., if global_step < num_warmup_steps, the
# Implements linear warmup. I.e., if global_step < num_warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
# learning rate will be `global_step/num_warmup_steps * init_lr`.
if
num_warmup_steps
:
if
num_warmup_steps
:
global_steps_int
=
tf
.
cast
(
global_step
,
tf
.
int32
)
global_steps_int
=
tf
.
cast
(
global_step
,
tf
.
int32
)
warmup_steps_int
=
tf
.
constant
(
num_warmup_steps
,
dtype
=
tf
.
int32
)
warmup_steps_int
=
tf
.
constant
(
num_warmup_steps
,
dtype
=
tf
.
int32
)
global_steps_float
=
tf
.
cast
(
global_steps_int
,
tf
.
float32
)
global_steps_float
=
tf
.
cast
(
global_steps_int
,
tf
.
float32
)
warmup_steps_float
=
tf
.
cast
(
warmup_steps_int
,
tf
.
float32
)
warmup_steps_float
=
tf
.
cast
(
warmup_steps_int
,
tf
.
float32
)
warmup_percent_done
=
global_steps_float
/
warmup_steps_float
warmup_percent_done
=
global_steps_float
/
warmup_steps_float
warmup_learning_rate
=
init_lr
*
warmup_percent_done
warmup_learning_rate
=
init_lr
*
warmup_percent_done
is_warmup
=
tf
.
cast
(
global_steps_int
<
warmup_steps_int
,
tf
.
float32
)
is_warmup
=
tf
.
cast
(
global_steps_int
<
warmup_steps_int
,
tf
.
float32
)
learning_rate
=
(
learning_rate
=
(
(
1.0
-
is_warmup
)
*
learning_rate
+
is_warmup
*
warmup_learning_rate
)
(
1.0
-
is_warmup
)
*
learning_rate
+
is_warmup
*
warmup_learning_rate
)
# It is recommended that you use this optimizer for fine tuning, since this
# It is recommended that you use this optimizer for fine tuning, since this
# is how the model was trained (note that the Adam m/v variables are NOT
# is how the model was trained (note that the Adam m/v variables are NOT
# loaded from init_checkpoint.)
# loaded from init_checkpoint.)
optimizer
=
AdamWeightDecayOptimizer
(
optimizer
=
AdamWeightDecayOptimizer
(
learning_rate
=
learning_rate
,
learning_rate
=
learning_rate
,
weight_decay_rate
=
0.01
,
weight_decay_rate
=
0.01
,
beta_1
=
0.9
,
beta_1
=
0.9
,
beta_2
=
0.999
,
beta_2
=
0.999
,
epsilon
=
1e-6
,
epsilon
=
1e-6
,
exclude_from_weight_decay
=
[
"LayerNorm"
,
"layer_norm"
,
"bias"
])
exclude_from_weight_decay
=
[
"LayerNorm"
,
"layer_norm"
,
"bias"
])
if
use_tpu
:
if
use_tpu
:
optimizer
=
tf
.
contrib
.
tpu
.
CrossShardOptimizer
(
optimizer
)
optimizer
=
tf
.
contrib
.
tpu
.
CrossShardOptimizer
(
optimizer
)
tvars
=
tf
.
trainable_variables
()
tvars
=
tf
.
trainable_variables
()
grads
=
tf
.
gradients
(
loss
,
tvars
)
grads
=
tf
.
gradients
(
loss
,
tvars
)
# This is how the model was pre-trained.
# This is how the model was pre-trained.
(
grads
,
_
)
=
tf
.
clip_by_global_norm
(
grads
,
clip_norm
=
1.0
)
(
grads
,
_
)
=
tf
.
clip_by_global_norm
(
grads
,
clip_norm
=
1.0
)
train_op
=
optimizer
.
apply_gradients
(
train_op
=
optimizer
.
apply_gradients
(
zip
(
grads
,
tvars
),
global_step
=
global_step
)
zip
(
grads
,
tvars
),
global_step
=
global_step
)
new_global_step
=
global_step
+
1
new_global_step
=
global_step
+
1
train_op
=
tf
.
group
(
train_op
,
[
global_step
.
assign
(
new_global_step
)])
train_op
=
tf
.
group
(
train_op
,
[
global_step
.
assign
(
new_global_step
)])
return
train_op
return
train_op
class
AdamWeightDecayOptimizer
(
tf
.
train
.
Optimizer
):
class
AdamWeightDecayOptimizer
(
tf
.
train
.
Optimizer
):
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
"""A basic Adam optimizer that includes "correct" L2 weight decay."""
def
__init__
(
self
,
def
__init__
(
self
,
learning_rate
,
learning_rate
,
weight_decay_rate
=
0.0
,
weight_decay_rate
=
0.0
,
beta_1
=
0.9
,
beta_1
=
0.9
,
beta_2
=
0.999
,
beta_2
=
0.999
,
epsilon
=
1e-6
,
epsilon
=
1e-6
,
exclude_from_weight_decay
=
None
,
exclude_from_weight_decay
=
None
,
name
=
"AdamWeightDecayOptimizer"
):
name
=
"AdamWeightDecayOptimizer"
):
"""Constructs a AdamWeightDecayOptimizer."""
"""Constructs a AdamWeightDecayOptimizer."""
super
(
AdamWeightDecayOptimizer
,
self
).
__init__
(
False
,
name
)
super
(
AdamWeightDecayOptimizer
,
self
).
__init__
(
False
,
name
)
self
.
learning_rate
=
learning_rate
self
.
learning_rate
=
learning_rate
self
.
weight_decay_rate
=
weight_decay_rate
self
.
weight_decay_rate
=
weight_decay_rate
self
.
beta_1
=
beta_1
self
.
beta_1
=
beta_1
self
.
beta_2
=
beta_2
self
.
beta_2
=
beta_2
self
.
epsilon
=
epsilon
self
.
epsilon
=
epsilon
self
.
exclude_from_weight_decay
=
exclude_from_weight_decay
self
.
exclude_from_weight_decay
=
exclude_from_weight_decay
def
apply_gradients
(
self
,
grads_and_vars
,
global_step
=
None
,
name
=
None
):
def
apply_gradients
(
self
,
grads_and_vars
,
global_step
=
None
,
name
=
None
):
"""See base class."""
"""See base class."""
assignments
=
[]
assignments
=
[]
for
(
grad
,
param
)
in
grads_and_vars
:
for
(
grad
,
param
)
in
grads_and_vars
:
if
grad
is
None
or
param
is
None
:
if
grad
is
None
or
param
is
None
:
continue
continue
param_name
=
self
.
_get_variable_name
(
param
.
name
)
param_name
=
self
.
_get_variable_name
(
param
.
name
)
m
=
tf
.
get_variable
(
m
=
tf
.
get_variable
(
name
=
param_name
+
"/adam_m"
,
name
=
param_name
+
"/adam_m"
,
shape
=
param
.
shape
.
as_list
(),
shape
=
param
.
shape
.
as_list
(),
dtype
=
tf
.
float32
,
dtype
=
tf
.
float32
,
trainable
=
False
,
trainable
=
False
,
initializer
=
tf
.
zeros_initializer
())
initializer
=
tf
.
zeros_initializer
())
v
=
tf
.
get_variable
(
v
=
tf
.
get_variable
(
name
=
param_name
+
"/adam_v"
,
name
=
param_name
+
"/adam_v"
,
shape
=
param
.
shape
.
as_list
(),
shape
=
param
.
shape
.
as_list
(),
dtype
=
tf
.
float32
,
dtype
=
tf
.
float32
,
trainable
=
False
,
trainable
=
False
,
initializer
=
tf
.
zeros_initializer
())
initializer
=
tf
.
zeros_initializer
())
# Standard Adam update.
# Standard Adam update.
next_m
=
(
next_m
=
(
tf
.
multiply
(
self
.
beta_1
,
m
)
+
tf
.
multiply
(
1.0
-
self
.
beta_1
,
grad
))
tf
.
multiply
(
self
.
beta_1
,
m
)
+
tf
.
multiply
(
1.0
-
self
.
beta_1
,
grad
))
next_v
=
(
next_v
=
(
tf
.
multiply
(
self
.
beta_2
,
v
)
+
tf
.
multiply
(
1.0
-
self
.
beta_2
,
tf
.
multiply
(
self
.
beta_2
,
v
)
+
tf
.
multiply
(
1.0
-
self
.
beta_2
,
tf
.
square
(
grad
)))
tf
.
square
(
grad
)))
update
=
next_m
/
(
tf
.
sqrt
(
next_v
)
+
self
.
epsilon
)
update
=
next_m
/
(
tf
.
sqrt
(
next_v
)
+
self
.
epsilon
)
# Just adding the square of the weights to the loss function is *not*
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
# since that will interact with the m and v parameters in strange ways.
#
#
# Instead we want ot decay the weights in a manner that doesn't interact
# Instead we want ot decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# of the weights to the loss with plain (non-momentum) SGD.
if
self
.
_do_use_weight_decay
(
param_name
):
if
self
.
_do_use_weight_decay
(
param_name
):
update
+=
self
.
weight_decay_rate
*
param
update
+=
self
.
weight_decay_rate
*
param
update_with_lr
=
self
.
learning_rate
*
update
update_with_lr
=
self
.
learning_rate
*
update
next_param
=
param
-
update_with_lr
next_param
=
param
-
update_with_lr
assignments
.
extend
(
assignments
.
extend
(
[
param
.
assign
(
next_param
),
[
param
.
assign
(
next_param
),
m
.
assign
(
next_m
),
m
.
assign
(
next_m
),
v
.
assign
(
next_v
)])
v
.
assign
(
next_v
)])
return
tf
.
group
(
*
assignments
,
name
=
name
)
return
tf
.
group
(
*
assignments
,
name
=
name
)
def
_do_use_weight_decay
(
self
,
param_name
):
def
_do_use_weight_decay
(
self
,
param_name
):
"""Whether to use L2 weight decay for `param_name`."""
"""Whether to use L2 weight decay for `param_name`."""
if
not
self
.
weight_decay_rate
:
if
not
self
.
weight_decay_rate
:
return
False
return
False
if
self
.
exclude_from_weight_decay
:
if
self
.
exclude_from_weight_decay
:
for
r
in
self
.
exclude_from_weight_decay
:
for
r
in
self
.
exclude_from_weight_decay
:
if
re
.
search
(
r
,
param_name
)
is
not
None
:
if
re
.
search
(
r
,
param_name
)
is
not
None
:
return
False
return
False
return
True
return
True
def
_get_variable_name
(
self
,
param_name
):
def
_get_variable_name
(
self
,
param_name
):
"""Get the variable name from the tensor name."""
"""Get the variable name from the tensor name."""
m
=
re
.
match
(
"^(.*):
\\
d+$"
,
param_name
)
m
=
re
.
match
(
"^(.*):
\\
d+$"
,
param_name
)
if
m
is
not
None
:
if
m
is
not
None
:
param_name
=
m
.
group
(
1
)
param_name
=
m
.
group
(
1
)
return
param_name
return
param_name
optimization_test.py
View file @
8163baab
...
@@ -22,27 +22,27 @@ import tensorflow as tf
...
@@ -22,27 +22,27 @@ import tensorflow as tf
class
OptimizationTest
(
tf
.
test
.
TestCase
):
class
OptimizationTest
(
tf
.
test
.
TestCase
):
def
test_adam
(
self
):
def
test_adam
(
self
):
with
self
.
test_session
()
as
sess
:
with
self
.
test_session
()
as
sess
:
w
=
tf
.
get_variable
(
w
=
tf
.
get_variable
(
"w"
,
"w"
,
shape
=
[
3
],
shape
=
[
3
],
initializer
=
tf
.
constant_initializer
([
0.1
,
-
0.2
,
-
0.1
]))
initializer
=
tf
.
constant_initializer
([
0.1
,
-
0.2
,
-
0.1
]))
x
=
tf
.
constant
([
0.4
,
0.2
,
-
0.5
])
x
=
tf
.
constant
([
0.4
,
0.2
,
-
0.5
])
loss
=
tf
.
reduce_mean
(
tf
.
square
(
x
-
w
))
loss
=
tf
.
reduce_mean
(
tf
.
square
(
x
-
w
))
tvars
=
tf
.
trainable_variables
()
tvars
=
tf
.
trainable_variables
()
grads
=
tf
.
gradients
(
loss
,
tvars
)
grads
=
tf
.
gradients
(
loss
,
tvars
)
global_step
=
tf
.
train
.
get_or_create_global_step
()
global_step
=
tf
.
train
.
get_or_create_global_step
()
optimizer
=
optimization
.
AdamWeightDecayOptimizer
(
learning_rate
=
0.2
)
optimizer
=
optimization
.
AdamWeightDecayOptimizer
(
learning_rate
=
0.2
)
train_op
=
optimizer
.
apply_gradients
(
zip
(
grads
,
tvars
),
global_step
)
train_op
=
optimizer
.
apply_gradients
(
zip
(
grads
,
tvars
),
global_step
)
init_op
=
tf
.
group
(
tf
.
global_variables_initializer
(),
init_op
=
tf
.
group
(
tf
.
global_variables_initializer
(),
tf
.
local_variables_initializer
())
tf
.
local_variables_initializer
())
sess
.
run
(
init_op
)
sess
.
run
(
init_op
)
for
_
in
range
(
100
):
for
_
in
range
(
100
):
sess
.
run
(
train_op
)
sess
.
run
(
train_op
)
w_np
=
sess
.
run
(
w
)
w_np
=
sess
.
run
(
w
)
self
.
assertAllClose
(
w_np
.
flat
,
[
0.4
,
0.2
,
-
0.5
],
rtol
=
1e-2
,
atol
=
1e-2
)
self
.
assertAllClose
(
w_np
.
flat
,
[
0.4
,
0.2
,
-
0.5
],
rtol
=
1e-2
,
atol
=
1e-2
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
run_classifier.py
View file @
8163baab
...
@@ -118,583 +118,583 @@ flags.DEFINE_integer(
...
@@ -118,583 +118,583 @@ flags.DEFINE_integer(
class
InputExample
(
object
):
class
InputExample
(
object
):
"""A single training/test example for simple sequence classification."""
"""A single training/test example for simple sequence classification."""
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
):
def
__init__
(
self
,
guid
,
text_a
,
text_b
=
None
,
label
=
None
):
"""Constructs a InputExample.
"""Constructs a InputExample.
Args:
Args:
guid: Unique id for the example.
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
specified for train and dev examples, but not for test examples.
"""
"""
self
.
guid
=
guid
self
.
guid
=
guid
self
.
text_a
=
text_a
self
.
text_a
=
text_a
self
.
text_b
=
text_b
self
.
text_b
=
text_b
self
.
label
=
label
self
.
label
=
label
class
InputFeatures
(
object
):
class
InputFeatures
(
object
):
"""A single set of features of data."""
"""A single set of features of data."""
def
__init__
(
self
,
input_ids
,
input_mask
,
segment_ids
,
label_id
):
def
__init__
(
self
,
input_ids
,
input_mask
,
segment_ids
,
label_id
):
self
.
input_ids
=
input_ids
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
segment_ids
=
segment_ids
self
.
label_id
=
label_id
self
.
label_id
=
label_id
class
DataProcessor
(
object
):
class
DataProcessor
(
object
):
"""Base class for data converters for sequence classification data sets."""
"""Base class for data converters for sequence classification data sets."""
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
"""Gets a collection of `InputExample`s for the train set."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the dev set."""
"""Gets a collection of `InputExample`s for the dev set."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_labels
(
self
):
def
get_labels
(
self
):
"""Gets the list of labels for this data set."""
"""Gets the list of labels for this data set."""
raise
NotImplementedError
()
raise
NotImplementedError
()
@
classmethod
@
classmethod
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
def
_read_tsv
(
cls
,
input_file
,
quotechar
=
None
):
"""Reads a tab separated value file."""
"""Reads a tab separated value file."""
with
tf
.
gfile
.
Open
(
input_file
,
"r"
)
as
f
:
with
tf
.
gfile
.
Open
(
input_file
,
"r"
)
as
f
:
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
reader
=
csv
.
reader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
quotechar
)
lines
=
[]
lines
=
[]
for
line
in
reader
:
for
line
in
reader
:
lines
.
append
(
line
)
lines
.
append
(
line
)
return
lines
return
lines
class
MnliProcessor
(
DataProcessor
):
class
MnliProcessor
(
DataProcessor
):
"""Processor for the MultiNLI data set (GLUE version)."""
"""Processor for the MultiNLI data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev_matched.tsv"
)),
"dev_matched"
)
"dev_matched"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
return
[
"contradiction"
,
"entailment"
,
"neutral"
]
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training and dev sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
tokenization
.
convert_to_unicode
(
line
[
0
]))
guid
=
"%s-%s"
%
(
set_type
,
tokenization
.
convert_to_unicode
(
line
[
0
]))
text_a
=
tokenization
.
convert_to_unicode
(
line
[
8
])
text_a
=
tokenization
.
convert_to_unicode
(
line
[
8
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
9
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
9
])
label
=
tokenization
.
convert_to_unicode
(
line
[
-
1
])
label
=
tokenization
.
convert_to_unicode
(
line
[
-
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
class
MrpcProcessor
(
DataProcessor
):
class
MrpcProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
"""Processor for the MRPC data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
print
(
"LOOKING AT {}"
.
format
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)))
print
(
"LOOKING AT {}"
.
format
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)))
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
return
[
"0"
,
"1"
]
return
[
"0"
,
"1"
]
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training and dev sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
(
i
,
line
)
in
enumerate
(
lines
):
if
i
==
0
:
if
i
==
0
:
continue
continue
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
tokenization
.
convert_to_unicode
(
line
[
3
])
text_a
=
tokenization
.
convert_to_unicode
(
line
[
3
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
4
])
text_b
=
tokenization
.
convert_to_unicode
(
line
[
4
])
label
=
tokenization
.
convert_to_unicode
(
line
[
0
])
label
=
tokenization
.
convert_to_unicode
(
line
[
0
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
text_b
,
label
=
label
))
return
examples
return
examples
class
ColaProcessor
(
DataProcessor
):
class
ColaProcessor
(
DataProcessor
):
"""Processor for the CoLA data set (GLUE version)."""
"""Processor for the CoLA data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"train.tsv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
"""See base class."""
return
self
.
_create_examples
(
return
self
.
_create_examples
(
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
self
.
_read_tsv
(
os
.
path
.
join
(
data_dir
,
"dev.tsv"
)),
"dev"
)
def
get_labels
(
self
):
def
get_labels
(
self
):
"""See base class."""
"""See base class."""
return
[
"0"
,
"1"
]
return
[
"0"
,
"1"
]
def
_create_examples
(
self
,
lines
,
set_type
):
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
"""Creates examples for the training and dev sets."""
examples
=
[]
examples
=
[]
for
(
i
,
line
)
in
enumerate
(
lines
):
for
(
i
,
line
)
in
enumerate
(
lines
):
guid
=
"%s-%s"
%
(
set_type
,
i
)
guid
=
"%s-%s"
%
(
set_type
,
i
)
text_a
=
tokenization
.
convert_to_unicode
(
line
[
3
])
text_a
=
tokenization
.
convert_to_unicode
(
line
[
3
])
label
=
tokenization
.
convert_to_unicode
(
line
[
1
])
label
=
tokenization
.
convert_to_unicode
(
line
[
1
])
examples
.
append
(
examples
.
append
(
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
InputExample
(
guid
=
guid
,
text_a
=
text_a
,
text_b
=
None
,
label
=
label
))
return
examples
return
examples
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
):
tokenizer
):
"""Loads a data file into a list of `InputBatch`s."""
"""Loads a data file into a list of `InputBatch`s."""
label_map
=
{}
label_map
=
{}
for
(
i
,
label
)
in
enumerate
(
label_list
):
for
(
i
,
label
)
in
enumerate
(
label_list
):
label_map
[
label
]
=
i
label_map
[
label
]
=
i
features
=
[]
features
=
[]
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
for
(
ex_index
,
example
)
in
enumerate
(
examples
):
tokens_a
=
tokenizer
.
tokenize
(
example
.
text_a
)
tokens_a
=
tokenizer
.
tokenize
(
example
.
text_a
)
tokens_b
=
None
tokens_b
=
None
if
example
.
text_b
:
if
example
.
text_b
:
tokens_b
=
tokenizer
.
tokenize
(
example
.
text_b
)
tokens_b
=
tokenizer
.
tokenize
(
example
.
text_b
)
if
tokens_b
:
if
tokens_b
:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_seq_length
-
3
)
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_seq_length
-
3
)
else
:
else
:
# Account for [CLS] and [SEP] with "- 2"
# Account for [CLS] and [SEP] with "- 2"
if
len
(
tokens_a
)
>
max_seq_length
-
2
:
if
len
(
tokens_a
)
>
max_seq_length
-
2
:
tokens_a
=
tokens_a
[
0
:(
max_seq_length
-
2
)]
tokens_a
=
tokens_a
[
0
:(
max_seq_length
-
2
)]
# The convention in BERT is:
# The convention in BERT is:
# (a) For sequence pairs:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
# type_ids: 0 0 0 0 0 0 0
#
#
# Where "type_ids" are used to indicate whether this is the first
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambigiously separates the sequences, but it makes
# since the [SEP] token unambigiously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
# it easier for the model to learn the concept of sequences.
#
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
# the entire model is fine-tuned.
tokens
=
[]
tokens
=
[]
segment_ids
=
[]
segment_ids
=
[]
tokens
.
append
(
"[CLS]"
)
tokens
.
append
(
"[CLS]"
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
0
)
for
token
in
tokens_a
:
for
token
in
tokens_a
:
tokens
.
append
(
token
)
tokens
.
append
(
token
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
0
)
tokens
.
append
(
"[SEP]"
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
0
)
if
tokens_b
:
if
tokens_b
:
for
token
in
tokens_b
:
for
token
in
tokens_b
:
tokens
.
append
(
token
)
tokens
.
append
(
token
)
segment_ids
.
append
(
1
)
segment_ids
.
append
(
1
)
tokens
.
append
(
"[SEP]"
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
1
)
segment_ids
.
append
(
1
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
# tokens are attended to.
input_mask
=
[
1
]
*
len
(
input_ids
)
input_mask
=
[
1
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
# Zero-pad up to the sequence length.
while
len
(
input_ids
)
<
max_seq_length
:
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
input_mask
.
append
(
0
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
0
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
label_id
=
label_map
[
example
.
label
]
label_id
=
label_map
[
example
.
label
]
if
ex_index
<
5
:
if
ex_index
<
5
:
tf
.
logging
.
info
(
"*** Example ***"
)
tf
.
logging
.
info
(
"*** Example ***"
)
tf
.
logging
.
info
(
"guid: %s"
%
(
example
.
guid
))
tf
.
logging
.
info
(
"guid: %s"
%
(
example
.
guid
))
tf
.
logging
.
info
(
"tokens: %s"
%
" "
.
join
(
tf
.
logging
.
info
(
"tokens: %s"
%
" "
.
join
(
[
tokenization
.
printable_text
(
x
)
for
x
in
tokens
]))
[
tokenization
.
printable_text
(
x
)
for
x
in
tokens
]))
tf
.
logging
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
tf
.
logging
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
tf
.
logging
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
tf
.
logging
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
tf
.
logging
.
info
(
tf
.
logging
.
info
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
tf
.
logging
.
info
(
"label: %s (id = %d)"
%
(
example
.
label
,
label_id
))
tf
.
logging
.
info
(
"label: %s (id = %d)"
%
(
example
.
label
,
label_id
))
features
.
append
(
features
.
append
(
InputFeatures
(
InputFeatures
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
segment_ids
=
segment_ids
,
label_id
=
label_id
))
label_id
=
label_id
))
return
features
return
features
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
"""Truncates a sequence pair in place to the maximum length."""
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
# that's truncated likely contains more information than a longer sequence.
while
True
:
while
True
:
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
if
total_length
<=
max_length
:
if
total_length
<=
max_length
:
break
break
if
len
(
tokens_a
)
>
len
(
tokens_b
):
if
len
(
tokens_a
)
>
len
(
tokens_b
):
tokens_a
.
pop
()
tokens_a
.
pop
()
else
:
else
:
tokens_b
.
pop
()
tokens_b
.
pop
()
def
create_model
(
bert_config
,
is_training
,
input_ids
,
input_mask
,
segment_ids
,
def
create_model
(
bert_config
,
is_training
,
input_ids
,
input_mask
,
segment_ids
,
labels
,
num_labels
,
use_one_hot_embeddings
):
labels
,
num_labels
,
use_one_hot_embeddings
):
"""Creates a classification model."""
"""Creates a classification model."""
model
=
modeling
.
BertModel
(
model
=
modeling
.
BertModel
(
config
=
bert_config
,
config
=
bert_config
,
is_training
=
is_training
,
is_training
=
is_training
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
token_type_ids
=
segment_ids
,
token_type_ids
=
segment_ids
,
use_one_hot_embeddings
=
use_one_hot_embeddings
)
use_one_hot_embeddings
=
use_one_hot_embeddings
)
# In the demo, we are doing a simple classification task on the entire
# In the demo, we are doing a simple classification task on the entire
# segment.
# segment.
#
#
# If you want to use the token-level output, use model.get_sequence_output()
# If you want to use the token-level output, use model.get_sequence_output()
# instead.
# instead.
output_layer
=
model
.
get_pooled_output
()
output_layer
=
model
.
get_pooled_output
()
hidden_size
=
output_layer
.
shape
[
-
1
].
value
hidden_size
=
output_layer
.
shape
[
-
1
].
value
output_weights
=
tf
.
get_variable
(
output_weights
=
tf
.
get_variable
(
"output_weights"
,
[
num_labels
,
hidden_size
],
"output_weights"
,
[
num_labels
,
hidden_size
],
initializer
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
))
initializer
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
))
output_bias
=
tf
.
get_variable
(
output_bias
=
tf
.
get_variable
(
"output_bias"
,
[
num_labels
],
initializer
=
tf
.
zeros_initializer
())
"output_bias"
,
[
num_labels
],
initializer
=
tf
.
zeros_initializer
())
with
tf
.
variable_scope
(
"loss"
):
with
tf
.
variable_scope
(
"loss"
):
if
is_training
:
if
is_training
:
# I.e., 0.1 dropout
# I.e., 0.1 dropout
output_layer
=
tf
.
nn
.
dropout
(
output_layer
,
keep_prob
=
0.9
)
output_layer
=
tf
.
nn
.
dropout
(
output_layer
,
keep_prob
=
0.9
)
logits
=
tf
.
matmul
(
output_layer
,
output_weights
,
transpose_b
=
True
)
logits
=
tf
.
matmul
(
output_layer
,
output_weights
,
transpose_b
=
True
)
logits
=
tf
.
nn
.
bias_add
(
logits
,
output_bias
)
logits
=
tf
.
nn
.
bias_add
(
logits
,
output_bias
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
,
axis
=-
1
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
,
axis
=-
1
)
one_hot_labels
=
tf
.
one_hot
(
labels
,
depth
=
num_labels
,
dtype
=
tf
.
float32
)
one_hot_labels
=
tf
.
one_hot
(
labels
,
depth
=
num_labels
,
dtype
=
tf
.
float32
)
per_example_loss
=
-
tf
.
reduce_sum
(
one_hot_labels
*
log_probs
,
axis
=-
1
)
per_example_loss
=
-
tf
.
reduce_sum
(
one_hot_labels
*
log_probs
,
axis
=-
1
)
loss
=
tf
.
reduce_mean
(
per_example_loss
)
loss
=
tf
.
reduce_mean
(
per_example_loss
)
return
(
loss
,
per_example_loss
,
logits
)
return
(
loss
,
per_example_loss
,
logits
)
def
model_fn_builder
(
bert_config
,
num_labels
,
init_checkpoint
,
learning_rate
,
def
model_fn_builder
(
bert_config
,
num_labels
,
init_checkpoint
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
,
num_train_steps
,
num_warmup_steps
,
use_tpu
,
use_one_hot_embeddings
):
use_one_hot_embeddings
):
"""Returns `model_fn` closure for TPUEstimator."""
"""Returns `model_fn` closure for TPUEstimator."""
def
model_fn
(
features
,
labels
,
mode
,
params
):
# pylint: disable=unused-argument
def
model_fn
(
features
,
labels
,
mode
,
params
):
# pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
"""The `model_fn` for TPUEstimator."""
tf
.
logging
.
info
(
"*** Features ***"
)
tf
.
logging
.
info
(
"*** Features ***"
)
for
name
in
sorted
(
features
.
keys
()):
for
name
in
sorted
(
features
.
keys
()):
tf
.
logging
.
info
(
" name = %s, shape = %s"
%
(
name
,
features
[
name
].
shape
))
tf
.
logging
.
info
(
" name = %s, shape = %s"
%
(
name
,
features
[
name
].
shape
))
input_ids
=
features
[
"input_ids"
]
input_ids
=
features
[
"input_ids"
]
input_mask
=
features
[
"input_mask"
]
input_mask
=
features
[
"input_mask"
]
segment_ids
=
features
[
"segment_ids"
]
segment_ids
=
features
[
"segment_ids"
]
label_ids
=
features
[
"label_ids"
]
label_ids
=
features
[
"label_ids"
]
is_training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
is_training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
(
total_loss
,
per_example_loss
,
logits
)
=
create_model
(
(
total_loss
,
per_example_loss
,
logits
)
=
create_model
(
bert_config
,
is_training
,
input_ids
,
input_mask
,
segment_ids
,
label_ids
,
bert_config
,
is_training
,
input_ids
,
input_mask
,
segment_ids
,
label_ids
,
num_labels
,
use_one_hot_embeddings
)
num_labels
,
use_one_hot_embeddings
)
tvars
=
tf
.
trainable_variables
()
tvars
=
tf
.
trainable_variables
()
scaffold_fn
=
None
scaffold_fn
=
None
if
init_checkpoint
:
if
init_checkpoint
:
(
assignment_map
,
(
assignment_map
,
initialized_variable_names
)
=
modeling
.
get_assigment_map_from_checkpoint
(
initialized_variable_names
)
=
modeling
.
get_assigment_map_from_checkpoint
(
tvars
,
init_checkpoint
)
tvars
,
init_checkpoint
)
if
use_tpu
:
if
use_tpu
:
def
tpu_scaffold
():
def
tpu_scaffold
():
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
return
tf
.
train
.
Scaffold
()
return
tf
.
train
.
Scaffold
()
scaffold_fn
=
tpu_scaffold
scaffold_fn
=
tpu_scaffold
else
:
else
:
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
tf
.
logging
.
info
(
"**** Trainable Variables ****"
)
tf
.
logging
.
info
(
"**** Trainable Variables ****"
)
for
var
in
tvars
:
for
var
in
tvars
:
init_string
=
""
init_string
=
""
if
var
.
name
in
initialized_variable_names
:
if
var
.
name
in
initialized_variable_names
:
init_string
=
", *INIT_FROM_CKPT*"
init_string
=
", *INIT_FROM_CKPT*"
tf
.
logging
.
info
(
" name = %s, shape = %s%s"
,
var
.
name
,
var
.
shape
,
tf
.
logging
.
info
(
" name = %s, shape = %s%s"
,
var
.
name
,
var
.
shape
,
init_string
)
init_string
)
output_spec
=
None
output_spec
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
train_op
=
optimization
.
create_optimizer
(
train_op
=
optimization
.
create_optimizer
(
total_loss
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
)
total_loss
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
)
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
mode
=
mode
,
loss
=
total_loss
,
loss
=
total_loss
,
train_op
=
train_op
,
train_op
=
train_op
,
scaffold_fn
=
scaffold_fn
)
scaffold_fn
=
scaffold_fn
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
elif
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
def
metric_fn
(
per_example_loss
,
label_ids
,
logits
):
def
metric_fn
(
per_example_loss
,
label_ids
,
logits
):
predictions
=
tf
.
argmax
(
logits
,
axis
=-
1
,
output_type
=
tf
.
int32
)
predictions
=
tf
.
argmax
(
logits
,
axis
=-
1
,
output_type
=
tf
.
int32
)
accuracy
=
tf
.
metrics
.
accuracy
(
label_ids
,
predictions
)
accuracy
=
tf
.
metrics
.
accuracy
(
label_ids
,
predictions
)
loss
=
tf
.
metrics
.
mean
(
per_example_loss
)
loss
=
tf
.
metrics
.
mean
(
per_example_loss
)
return
{
return
{
"eval_accuracy"
:
accuracy
,
"eval_accuracy"
:
accuracy
,
"eval_loss"
:
loss
,
"eval_loss"
:
loss
,
}
}
eval_metrics
=
(
metric_fn
,
[
per_example_loss
,
label_ids
,
logits
])
eval_metrics
=
(
metric_fn
,
[
per_example_loss
,
label_ids
,
logits
])
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
mode
=
mode
,
loss
=
total_loss
,
loss
=
total_loss
,
eval_metrics
=
eval_metrics
,
eval_metrics
=
eval_metrics
,
scaffold_fn
=
scaffold_fn
)
scaffold_fn
=
scaffold_fn
)
else
:
else
:
raise
ValueError
(
"Only TRAIN and EVAL modes are supported: %s"
%
(
mode
))
raise
ValueError
(
"Only TRAIN and EVAL modes are supported: %s"
%
(
mode
))
return
output_spec
return
output_spec
return
model_fn
return
model_fn
def
input_fn_builder
(
features
,
seq_length
,
is_training
,
drop_remainder
):
def
input_fn_builder
(
features
,
seq_length
,
is_training
,
drop_remainder
):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
all_input_ids
=
[]
all_input_ids
=
[]
all_input_mask
=
[]
all_input_mask
=
[]
all_segment_ids
=
[]
all_segment_ids
=
[]
all_label_ids
=
[]
all_label_ids
=
[]
for
feature
in
features
:
for
feature
in
features
:
all_input_ids
.
append
(
feature
.
input_ids
)
all_input_ids
.
append
(
feature
.
input_ids
)
all_input_mask
.
append
(
feature
.
input_mask
)
all_input_mask
.
append
(
feature
.
input_mask
)
all_segment_ids
.
append
(
feature
.
segment_ids
)
all_segment_ids
.
append
(
feature
.
segment_ids
)
all_label_ids
.
append
(
feature
.
label_id
)
all_label_ids
.
append
(
feature
.
label_id
)
def
input_fn
(
params
):
def
input_fn
(
params
):
"""The actual input function."""
"""The actual input function."""
batch_size
=
params
[
"batch_size"
]
batch_size
=
params
[
"batch_size"
]
num_examples
=
len
(
features
)
num_examples
=
len
(
features
)
# This is for demo purposes and does NOT scale to large data sets. We do
# This is for demo purposes and does NOT scale to large data sets. We do
# not use Dataset.from_generator() because that uses tf.py_func which is
# not use Dataset.from_generator() because that uses tf.py_func which is
# not TPU compatible. The right way to load data is with TFRecordReader.
# not TPU compatible. The right way to load data is with TFRecordReader.
d
=
tf
.
data
.
Dataset
.
from_tensor_slices
({
d
=
tf
.
data
.
Dataset
.
from_tensor_slices
({
"input_ids"
:
"input_ids"
:
tf
.
constant
(
tf
.
constant
(
all_input_ids
,
shape
=
[
num_examples
,
seq_length
],
all_input_ids
,
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
"input_mask"
:
"input_mask"
:
tf
.
constant
(
tf
.
constant
(
all_input_mask
,
all_input_mask
,
shape
=
[
num_examples
,
seq_length
],
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
"segment_ids"
:
"segment_ids"
:
tf
.
constant
(
tf
.
constant
(
all_segment_ids
,
all_segment_ids
,
shape
=
[
num_examples
,
seq_length
],
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
dtype
=
tf
.
int32
),
"label_ids"
:
"label_ids"
:
tf
.
constant
(
all_label_ids
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
),
tf
.
constant
(
all_label_ids
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
),
})
})
if
is_training
:
if
is_training
:
d
=
d
.
repeat
()
d
=
d
.
repeat
()
d
=
d
.
shuffle
(
buffer_size
=
100
)
d
=
d
.
shuffle
(
buffer_size
=
100
)
d
=
d
.
batch
(
batch_size
=
batch_size
,
drop_remainder
=
drop_remainder
)
d
=
d
.
batch
(
batch_size
=
batch_size
,
drop_remainder
=
drop_remainder
)
return
d
return
d
return
input_fn
return
input_fn
def
main
(
_
):
def
main
(
_
):
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
processors
=
{
processors
=
{
"cola"
:
ColaProcessor
,
"cola"
:
ColaProcessor
,
"mnli"
:
MnliProcessor
,
"mnli"
:
MnliProcessor
,
"mrpc"
:
MrpcProcessor
,
"mrpc"
:
MrpcProcessor
,
}
}
if
not
FLAGS
.
do_train
and
not
FLAGS
.
do_eval
:
if
not
FLAGS
.
do_train
and
not
FLAGS
.
do_eval
:
raise
ValueError
(
"At least one of `do_train` or `do_eval` must be True."
)
raise
ValueError
(
"At least one of `do_train` or `do_eval` must be True."
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
max_seq_length
>
bert_config
.
max_position_embeddings
:
if
FLAGS
.
max_seq_length
>
bert_config
.
max_position_embeddings
:
raise
ValueError
(
raise
ValueError
(
"Cannot use sequence length %d because the BERT model "
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d"
%
"was only trained up to sequence length %d"
%
(
FLAGS
.
max_seq_length
,
bert_config
.
max_position_embeddings
))
(
FLAGS
.
max_seq_length
,
bert_config
.
max_position_embeddings
))
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
task_name
=
FLAGS
.
task_name
.
lower
()
task_name
=
FLAGS
.
task_name
.
lower
()
if
task_name
not
in
processors
:
if
task_name
not
in
processors
:
raise
ValueError
(
"Task not found: %s"
%
(
task_name
))
raise
ValueError
(
"Task not found: %s"
%
(
task_name
))
processor
=
processors
[
task_name
]()
processor
=
processors
[
task_name
]()
label_list
=
processor
.
get_labels
()
label_list
=
processor
.
get_labels
()
tokenizer
=
tokenization
.
FullTokenizer
(
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
tpu_cluster_resolver
=
None
tpu_cluster_resolver
=
None
if
FLAGS
.
use_tpu
and
FLAGS
.
tpu_name
:
if
FLAGS
.
use_tpu
and
FLAGS
.
tpu_name
:
tpu_cluster_resolver
=
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
tpu_cluster_resolver
=
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
FLAGS
.
tpu_name
,
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
)
FLAGS
.
tpu_name
,
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
)
is_per_host
=
tf
.
contrib
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
is_per_host
=
tf
.
contrib
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
cluster
=
tpu_cluster_resolver
,
cluster
=
tpu_cluster_resolver
,
master
=
FLAGS
.
master
,
master
=
FLAGS
.
master
,
model_dir
=
FLAGS
.
output_dir
,
model_dir
=
FLAGS
.
output_dir
,
save_checkpoints_steps
=
FLAGS
.
save_checkpoints_steps
,
save_checkpoints_steps
=
FLAGS
.
save_checkpoints_steps
,
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
iterations_per_loop
=
FLAGS
.
iterations_per_loop
,
iterations_per_loop
=
FLAGS
.
iterations_per_loop
,
num_shards
=
FLAGS
.
num_tpu_cores
,
num_shards
=
FLAGS
.
num_tpu_cores
,
per_host_input_for_training
=
is_per_host
))
per_host_input_for_training
=
is_per_host
))
train_examples
=
None
train_examples
=
None
num_train_steps
=
None
num_train_steps
=
None
num_warmup_steps
=
None
num_warmup_steps
=
None
if
FLAGS
.
do_train
:
if
FLAGS
.
do_train
:
train_examples
=
processor
.
get_train_examples
(
FLAGS
.
data_dir
)
train_examples
=
processor
.
get_train_examples
(
FLAGS
.
data_dir
)
num_train_steps
=
int
(
num_train_steps
=
int
(
len
(
train_examples
)
/
FLAGS
.
train_batch_size
*
FLAGS
.
num_train_epochs
)
len
(
train_examples
)
/
FLAGS
.
train_batch_size
*
FLAGS
.
num_train_epochs
)
num_warmup_steps
=
int
(
num_train_steps
*
FLAGS
.
warmup_proportion
)
num_warmup_steps
=
int
(
num_train_steps
*
FLAGS
.
warmup_proportion
)
model_fn
=
model_fn_builder
(
model_fn
=
model_fn_builder
(
bert_config
=
bert_config
,
bert_config
=
bert_config
,
num_labels
=
len
(
label_list
),
num_labels
=
len
(
label_list
),
init_checkpoint
=
FLAGS
.
init_checkpoint
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
learning_rate
=
FLAGS
.
learning_rate
,
learning_rate
=
FLAGS
.
learning_rate
,
num_train_steps
=
num_train_steps
,
num_train_steps
=
num_train_steps
,
num_warmup_steps
=
num_warmup_steps
,
num_warmup_steps
=
num_warmup_steps
,
use_tpu
=
FLAGS
.
use_tpu
,
use_tpu
=
FLAGS
.
use_tpu
,
use_one_hot_embeddings
=
FLAGS
.
use_tpu
)
use_one_hot_embeddings
=
FLAGS
.
use_tpu
)
# If TPU is not available, this will fall back to normal Estimator on CPU
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
# or GPU.
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
use_tpu
=
FLAGS
.
use_tpu
,
use_tpu
=
FLAGS
.
use_tpu
,
model_fn
=
model_fn
,
model_fn
=
model_fn
,
config
=
run_config
,
config
=
run_config
,
train_batch_size
=
FLAGS
.
train_batch_size
,
train_batch_size
=
FLAGS
.
train_batch_size
,
eval_batch_size
=
FLAGS
.
eval_batch_size
)
eval_batch_size
=
FLAGS
.
eval_batch_size
)
if
FLAGS
.
do_train
:
if
FLAGS
.
do_train
:
train_features
=
convert_examples_to_features
(
train_features
=
convert_examples_to_features
(
train_examples
,
label_list
,
FLAGS
.
max_seq_length
,
tokenizer
)
train_examples
,
label_list
,
FLAGS
.
max_seq_length
,
tokenizer
)
tf
.
logging
.
info
(
"***** Running training *****"
)
tf
.
logging
.
info
(
"***** Running training *****"
)
tf
.
logging
.
info
(
" Num examples = %d"
,
len
(
train_examples
))
tf
.
logging
.
info
(
" Num examples = %d"
,
len
(
train_examples
))
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
train_batch_size
)
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
train_batch_size
)
tf
.
logging
.
info
(
" Num steps = %d"
,
num_train_steps
)
tf
.
logging
.
info
(
" Num steps = %d"
,
num_train_steps
)
train_input_fn
=
input_fn_builder
(
train_input_fn
=
input_fn_builder
(
features
=
train_features
,
features
=
train_features
,
seq_length
=
FLAGS
.
max_seq_length
,
seq_length
=
FLAGS
.
max_seq_length
,
is_training
=
True
,
is_training
=
True
,
drop_remainder
=
True
)
drop_remainder
=
True
)
estimator
.
train
(
input_fn
=
train_input_fn
,
max_steps
=
num_train_steps
)
estimator
.
train
(
input_fn
=
train_input_fn
,
max_steps
=
num_train_steps
)
if
FLAGS
.
do_eval
:
if
FLAGS
.
do_eval
:
eval_examples
=
processor
.
get_dev_examples
(
FLAGS
.
data_dir
)
eval_examples
=
processor
.
get_dev_examples
(
FLAGS
.
data_dir
)
eval_features
=
convert_examples_to_features
(
eval_features
=
convert_examples_to_features
(
eval_examples
,
label_list
,
FLAGS
.
max_seq_length
,
tokenizer
)
eval_examples
,
label_list
,
FLAGS
.
max_seq_length
,
tokenizer
)
tf
.
logging
.
info
(
"***** Running evaluation *****"
)
tf
.
logging
.
info
(
"***** Running evaluation *****"
)
tf
.
logging
.
info
(
" Num examples = %d"
,
len
(
eval_examples
))
tf
.
logging
.
info
(
" Num examples = %d"
,
len
(
eval_examples
))
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
eval_batch_size
)
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
eval_batch_size
)
# This tells the estimator to run through the entire set.
# This tells the estimator to run through the entire set.
eval_steps
=
None
eval_steps
=
None
# However, if running eval on the TPU, you will need to specify the
# However, if running eval on the TPU, you will need to specify the
# number of steps.
# number of steps.
if
FLAGS
.
use_tpu
:
if
FLAGS
.
use_tpu
:
# Eval will be slightly WRONG on the TPU because it will truncate
# Eval will be slightly WRONG on the TPU because it will truncate
# the last batch.
# the last batch.
eval_steps
=
int
(
len
(
eval_examples
)
/
FLAGS
.
eval_batch_size
)
eval_steps
=
int
(
len
(
eval_examples
)
/
FLAGS
.
eval_batch_size
)
eval_drop_remainder
=
True
if
FLAGS
.
use_tpu
else
False
eval_drop_remainder
=
True
if
FLAGS
.
use_tpu
else
False
eval_input_fn
=
input_fn_builder
(
eval_input_fn
=
input_fn_builder
(
features
=
eval_features
,
features
=
eval_features
,
seq_length
=
FLAGS
.
max_seq_length
,
seq_length
=
FLAGS
.
max_seq_length
,
is_training
=
False
,
is_training
=
False
,
drop_remainder
=
eval_drop_remainder
)
drop_remainder
=
eval_drop_remainder
)
result
=
estimator
.
evaluate
(
input_fn
=
eval_input_fn
,
steps
=
eval_steps
)
result
=
estimator
.
evaluate
(
input_fn
=
eval_input_fn
,
steps
=
eval_steps
)
output_eval_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"eval_results.txt"
)
output_eval_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"eval_results.txt"
)
with
tf
.
gfile
.
GFile
(
output_eval_file
,
"w"
)
as
writer
:
with
tf
.
gfile
.
GFile
(
output_eval_file
,
"w"
)
as
writer
:
tf
.
logging
.
info
(
"***** Eval results *****"
)
tf
.
logging
.
info
(
"***** Eval results *****"
)
for
key
in
sorted
(
result
.
keys
()):
for
key
in
sorted
(
result
.
keys
()):
tf
.
logging
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
tf
.
logging
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
flags
.
mark_flag_as_required
(
"data_dir"
)
flags
.
mark_flag_as_required
(
"data_dir"
)
flags
.
mark_flag_as_required
(
"task_name"
)
flags
.
mark_flag_as_required
(
"task_name"
)
flags
.
mark_flag_as_required
(
"vocab_file"
)
flags
.
mark_flag_as_required
(
"vocab_file"
)
flags
.
mark_flag_as_required
(
"bert_config_file"
)
flags
.
mark_flag_as_required
(
"bert_config_file"
)
flags
.
mark_flag_as_required
(
"output_dir"
)
flags
.
mark_flag_as_required
(
"output_dir"
)
tf
.
app
.
run
()
tf
.
app
.
run
()
run_pretraining.py
View file @
8163baab
...
@@ -109,217 +109,217 @@ flags.DEFINE_integer(
...
@@ -109,217 +109,217 @@ flags.DEFINE_integer(
def
model_fn_builder
(
bert_config
,
init_checkpoint
,
learning_rate
,
def
model_fn_builder
(
bert_config
,
init_checkpoint
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
,
num_train_steps
,
num_warmup_steps
,
use_tpu
,
use_one_hot_embeddings
):
use_one_hot_embeddings
):
"""Returns `model_fn` closure for TPUEstimator."""
"""Returns `model_fn` closure for TPUEstimator."""
def
model_fn
(
features
,
labels
,
mode
,
params
):
# pylint: disable=unused-argument
def
model_fn
(
features
,
labels
,
mode
,
params
):
# pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
"""The `model_fn` for TPUEstimator."""
tf
.
logging
.
info
(
"*** Features ***"
)
tf
.
logging
.
info
(
"*** Features ***"
)
for
name
in
sorted
(
features
.
keys
()):
for
name
in
sorted
(
features
.
keys
()):
tf
.
logging
.
info
(
" name = %s, shape = %s"
%
(
name
,
features
[
name
].
shape
))
tf
.
logging
.
info
(
" name = %s, shape = %s"
%
(
name
,
features
[
name
].
shape
))
input_ids
=
features
[
"input_ids"
]
input_ids
=
features
[
"input_ids"
]
input_mask
=
features
[
"input_mask"
]
input_mask
=
features
[
"input_mask"
]
segment_ids
=
features
[
"segment_ids"
]
segment_ids
=
features
[
"segment_ids"
]
masked_lm_positions
=
features
[
"masked_lm_positions"
]
masked_lm_positions
=
features
[
"masked_lm_positions"
]
masked_lm_ids
=
features
[
"masked_lm_ids"
]
masked_lm_ids
=
features
[
"masked_lm_ids"
]
masked_lm_weights
=
features
[
"masked_lm_weights"
]
masked_lm_weights
=
features
[
"masked_lm_weights"
]
next_sentence_labels
=
features
[
"next_sentence_labels"
]
next_sentence_labels
=
features
[
"next_sentence_labels"
]
is_training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
is_training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
model
=
modeling
.
BertModel
(
model
=
modeling
.
BertModel
(
config
=
bert_config
,
config
=
bert_config
,
is_training
=
is_training
,
is_training
=
is_training
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
token_type_ids
=
segment_ids
,
token_type_ids
=
segment_ids
,
use_one_hot_embeddings
=
use_one_hot_embeddings
)
use_one_hot_embeddings
=
use_one_hot_embeddings
)
(
masked_lm_loss
,
(
masked_lm_loss
,
masked_lm_example_loss
,
masked_lm_log_probs
)
=
get_masked_lm_output
(
masked_lm_example_loss
,
masked_lm_log_probs
)
=
get_masked_lm_output
(
bert_config
,
model
.
get_sequence_output
(),
model
.
get_embedding_table
(),
bert_config
,
model
.
get_sequence_output
(),
model
.
get_embedding_table
(),
masked_lm_positions
,
masked_lm_ids
,
masked_lm_weights
)
masked_lm_positions
,
masked_lm_ids
,
masked_lm_weights
)
(
next_sentence_loss
,
next_sentence_example_loss
,
(
next_sentence_loss
,
next_sentence_example_loss
,
next_sentence_log_probs
)
=
get_next_sentence_output
(
next_sentence_log_probs
)
=
get_next_sentence_output
(
bert_config
,
model
.
get_pooled_output
(),
next_sentence_labels
)
bert_config
,
model
.
get_pooled_output
(),
next_sentence_labels
)
total_loss
=
masked_lm_loss
+
next_sentence_loss
total_loss
=
masked_lm_loss
+
next_sentence_loss
tvars
=
tf
.
trainable_variables
()
tvars
=
tf
.
trainable_variables
()
initialized_variable_names
=
{}
initialized_variable_names
=
{}
scaffold_fn
=
None
scaffold_fn
=
None
if
init_checkpoint
:
if
init_checkpoint
:
(
assignment_map
,
(
assignment_map
,
initialized_variable_names
)
=
modeling
.
get_assigment_map_from_checkpoint
(
initialized_variable_names
)
=
modeling
.
get_assigment_map_from_checkpoint
(
tvars
,
init_checkpoint
)
tvars
,
init_checkpoint
)
if
use_tpu
:
if
use_tpu
:
def
tpu_scaffold
():
def
tpu_scaffold
():
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
return
tf
.
train
.
Scaffold
()
return
tf
.
train
.
Scaffold
()
scaffold_fn
=
tpu_scaffold
scaffold_fn
=
tpu_scaffold
else
:
else
:
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
tf
.
logging
.
info
(
"**** Trainable Variables ****"
)
tf
.
logging
.
info
(
"**** Trainable Variables ****"
)
for
var
in
tvars
:
for
var
in
tvars
:
init_string
=
""
init_string
=
""
if
var
.
name
in
initialized_variable_names
:
if
var
.
name
in
initialized_variable_names
:
init_string
=
", *INIT_FROM_CKPT*"
init_string
=
", *INIT_FROM_CKPT*"
tf
.
logging
.
info
(
" name = %s, shape = %s%s"
,
var
.
name
,
var
.
shape
,
tf
.
logging
.
info
(
" name = %s, shape = %s%s"
,
var
.
name
,
var
.
shape
,
init_string
)
init_string
)
output_spec
=
None
output_spec
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
train_op
=
optimization
.
create_optimizer
(
train_op
=
optimization
.
create_optimizer
(
total_loss
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
)
total_loss
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
)
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
mode
=
mode
,
loss
=
total_loss
,
loss
=
total_loss
,
train_op
=
train_op
,
train_op
=
train_op
,
scaffold_fn
=
scaffold_fn
)
scaffold_fn
=
scaffold_fn
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
elif
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
def
metric_fn
(
masked_lm_example_loss
,
masked_lm_log_probs
,
masked_lm_ids
,
def
metric_fn
(
masked_lm_example_loss
,
masked_lm_log_probs
,
masked_lm_ids
,
masked_lm_weights
,
next_sentence_example_loss
,
masked_lm_weights
,
next_sentence_example_loss
,
next_sentence_log_probs
,
next_sentence_labels
):
next_sentence_log_probs
,
next_sentence_labels
):
"""Computes the loss and accuracy of the model."""
"""Computes the loss and accuracy of the model."""
masked_lm_log_probs
=
tf
.
reshape
(
masked_lm_log_probs
,
masked_lm_log_probs
=
tf
.
reshape
(
masked_lm_log_probs
,
[
-
1
,
masked_lm_log_probs
.
shape
[
-
1
]])
[
-
1
,
masked_lm_log_probs
.
shape
[
-
1
]])
masked_lm_predictions
=
tf
.
argmax
(
masked_lm_predictions
=
tf
.
argmax
(
masked_lm_log_probs
,
axis
=-
1
,
output_type
=
tf
.
int32
)
masked_lm_log_probs
,
axis
=-
1
,
output_type
=
tf
.
int32
)
masked_lm_example_loss
=
tf
.
reshape
(
masked_lm_example_loss
,
[
-
1
])
masked_lm_example_loss
=
tf
.
reshape
(
masked_lm_example_loss
,
[
-
1
])
masked_lm_ids
=
tf
.
reshape
(
masked_lm_ids
,
[
-
1
])
masked_lm_ids
=
tf
.
reshape
(
masked_lm_ids
,
[
-
1
])
masked_lm_weights
=
tf
.
reshape
(
masked_lm_weights
,
[
-
1
])
masked_lm_weights
=
tf
.
reshape
(
masked_lm_weights
,
[
-
1
])
masked_lm_accuracy
=
tf
.
metrics
.
accuracy
(
masked_lm_accuracy
=
tf
.
metrics
.
accuracy
(
labels
=
masked_lm_ids
,
labels
=
masked_lm_ids
,
predictions
=
masked_lm_predictions
,
predictions
=
masked_lm_predictions
,
weights
=
masked_lm_weights
)
weights
=
masked_lm_weights
)
masked_lm_mean_loss
=
tf
.
metrics
.
mean
(
masked_lm_mean_loss
=
tf
.
metrics
.
mean
(
values
=
masked_lm_example_loss
,
weights
=
masked_lm_weights
)
values
=
masked_lm_example_loss
,
weights
=
masked_lm_weights
)
next_sentence_log_probs
=
tf
.
reshape
(
next_sentence_log_probs
=
tf
.
reshape
(
next_sentence_log_probs
,
[
-
1
,
next_sentence_log_probs
.
shape
[
-
1
]])
next_sentence_log_probs
,
[
-
1
,
next_sentence_log_probs
.
shape
[
-
1
]])
next_sentence_predictions
=
tf
.
argmax
(
next_sentence_predictions
=
tf
.
argmax
(
next_sentence_log_probs
,
axis
=-
1
,
output_type
=
tf
.
int32
)
next_sentence_log_probs
,
axis
=-
1
,
output_type
=
tf
.
int32
)
next_sentence_labels
=
tf
.
reshape
(
next_sentence_labels
,
[
-
1
])
next_sentence_labels
=
tf
.
reshape
(
next_sentence_labels
,
[
-
1
])
next_sentence_accuracy
=
tf
.
metrics
.
accuracy
(
next_sentence_accuracy
=
tf
.
metrics
.
accuracy
(
labels
=
next_sentence_labels
,
predictions
=
next_sentence_predictions
)
labels
=
next_sentence_labels
,
predictions
=
next_sentence_predictions
)
next_sentence_mean_loss
=
tf
.
metrics
.
mean
(
next_sentence_mean_loss
=
tf
.
metrics
.
mean
(
values
=
next_sentence_example_loss
)
values
=
next_sentence_example_loss
)
return
{
return
{
"masked_lm_accuracy"
:
masked_lm_accuracy
,
"masked_lm_accuracy"
:
masked_lm_accuracy
,
"masked_lm_loss"
:
masked_lm_mean_loss
,
"masked_lm_loss"
:
masked_lm_mean_loss
,
"next_sentence_accuracy"
:
next_sentence_accuracy
,
"next_sentence_accuracy"
:
next_sentence_accuracy
,
"next_sentence_loss"
:
next_sentence_mean_loss
,
"next_sentence_loss"
:
next_sentence_mean_loss
,
}
}
eval_metrics
=
(
metric_fn
,
[
eval_metrics
=
(
metric_fn
,
[
masked_lm_example_loss
,
masked_lm_log_probs
,
masked_lm_ids
,
masked_lm_example_loss
,
masked_lm_log_probs
,
masked_lm_ids
,
masked_lm_weights
,
next_sentence_example_loss
,
masked_lm_weights
,
next_sentence_example_loss
,
next_sentence_log_probs
,
next_sentence_labels
next_sentence_log_probs
,
next_sentence_labels
])
])
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
mode
=
mode
,
loss
=
total_loss
,
loss
=
total_loss
,
eval_metrics
=
eval_metrics
,
eval_metrics
=
eval_metrics
,
scaffold_fn
=
scaffold_fn
)
scaffold_fn
=
scaffold_fn
)
else
:
else
:
raise
ValueError
(
"Only TRAIN and EVAL modes are supported: %s"
%
(
mode
))
raise
ValueError
(
"Only TRAIN and EVAL modes are supported: %s"
%
(
mode
))
return
output_spec
return
output_spec
return
model_fn
return
model_fn
def
get_masked_lm_output
(
bert_config
,
input_tensor
,
output_weights
,
positions
,
def
get_masked_lm_output
(
bert_config
,
input_tensor
,
output_weights
,
positions
,
label_ids
,
label_weights
):
label_ids
,
label_weights
):
"""Get loss and log probs for the masked LM."""
"""Get loss and log probs for the masked LM."""
input_tensor
=
gather_indexes
(
input_tensor
,
positions
)
input_tensor
=
gather_indexes
(
input_tensor
,
positions
)
with
tf
.
variable_scope
(
"cls/predictions"
):
with
tf
.
variable_scope
(
"cls/predictions"
):
# We apply one more non-linear transformation before the output layer.
# We apply one more non-linear transformation before the output layer.
# This matrix is not used after pre-training.
# This matrix is not used after pre-training.
with
tf
.
variable_scope
(
"transform"
):
with
tf
.
variable_scope
(
"transform"
):
input_tensor
=
tf
.
layers
.
dense
(
input_tensor
=
tf
.
layers
.
dense
(
input_tensor
,
input_tensor
,
units
=
bert_config
.
hidden_size
,
units
=
bert_config
.
hidden_size
,
activation
=
modeling
.
get_activation
(
bert_config
.
hidden_act
),
activation
=
modeling
.
get_activation
(
bert_config
.
hidden_act
),
kernel_initializer
=
modeling
.
create_initializer
(
kernel_initializer
=
modeling
.
create_initializer
(
bert_config
.
initializer_range
))
bert_config
.
initializer_range
))
input_tensor
=
modeling
.
layer_norm
(
input_tensor
)
input_tensor
=
modeling
.
layer_norm
(
input_tensor
)
# The output weights are the same as the input embeddings, but there is
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
# an output-only bias for each token.
output_bias
=
tf
.
get_variable
(
output_bias
=
tf
.
get_variable
(
"output_bias"
,
"output_bias"
,
shape
=
[
bert_config
.
vocab_size
],
shape
=
[
bert_config
.
vocab_size
],
initializer
=
tf
.
zeros_initializer
())
initializer
=
tf
.
zeros_initializer
())
logits
=
tf
.
matmul
(
input_tensor
,
output_weights
,
transpose_b
=
True
)
logits
=
tf
.
matmul
(
input_tensor
,
output_weights
,
transpose_b
=
True
)
logits
=
tf
.
nn
.
bias_add
(
logits
,
output_bias
)
logits
=
tf
.
nn
.
bias_add
(
logits
,
output_bias
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
,
axis
=-
1
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
,
axis
=-
1
)
label_ids
=
tf
.
reshape
(
label_ids
,
[
-
1
])
label_ids
=
tf
.
reshape
(
label_ids
,
[
-
1
])
label_weights
=
tf
.
reshape
(
label_weights
,
[
-
1
])
label_weights
=
tf
.
reshape
(
label_weights
,
[
-
1
])
one_hot_labels
=
tf
.
one_hot
(
one_hot_labels
=
tf
.
one_hot
(
label_ids
,
depth
=
bert_config
.
vocab_size
,
dtype
=
tf
.
float32
)
label_ids
,
depth
=
bert_config
.
vocab_size
,
dtype
=
tf
.
float32
)
# The `positions` tensor might be zero-padded (if the sequence is too
# The `positions` tensor might be zero-padded (if the sequence is too
# short to have the maximum number of predictions). The `label_weights`
# short to have the maximum number of predictions). The `label_weights`
# tensor has a value of 1.0 for every real prediction and 0.0 for the
# tensor has a value of 1.0 for every real prediction and 0.0 for the
# padding predictions.
# padding predictions.
per_example_loss
=
-
tf
.
reduce_sum
(
log_probs
*
one_hot_labels
,
axis
=
[
-
1
])
per_example_loss
=
-
tf
.
reduce_sum
(
log_probs
*
one_hot_labels
,
axis
=
[
-
1
])
numerator
=
tf
.
reduce_sum
(
label_weights
*
per_example_loss
)
numerator
=
tf
.
reduce_sum
(
label_weights
*
per_example_loss
)
denominator
=
tf
.
reduce_sum
(
label_weights
)
+
1e-5
denominator
=
tf
.
reduce_sum
(
label_weights
)
+
1e-5
loss
=
numerator
/
denominator
loss
=
numerator
/
denominator
return
(
loss
,
per_example_loss
,
log_probs
)
return
(
loss
,
per_example_loss
,
log_probs
)
def
get_next_sentence_output
(
bert_config
,
input_tensor
,
labels
):
def
get_next_sentence_output
(
bert_config
,
input_tensor
,
labels
):
"""Get loss and log probs for the next sentence prediction."""
"""Get loss and log probs for the next sentence prediction."""
# Simple binary classification. Note that 0 is "next sentence" and 1 is
# Simple binary classification. Note that 0 is "next sentence" and 1 is
# "random sentence". This weight matrix is not used after pre-training.
# "random sentence". This weight matrix is not used after pre-training.
with
tf
.
variable_scope
(
"cls/seq_relationship"
):
with
tf
.
variable_scope
(
"cls/seq_relationship"
):
output_weights
=
tf
.
get_variable
(
output_weights
=
tf
.
get_variable
(
"output_weights"
,
"output_weights"
,
shape
=
[
2
,
bert_config
.
hidden_size
],
shape
=
[
2
,
bert_config
.
hidden_size
],
initializer
=
modeling
.
create_initializer
(
bert_config
.
initializer_range
))
initializer
=
modeling
.
create_initializer
(
bert_config
.
initializer_range
))
output_bias
=
tf
.
get_variable
(
output_bias
=
tf
.
get_variable
(
"output_bias"
,
shape
=
[
2
],
initializer
=
tf
.
zeros_initializer
())
"output_bias"
,
shape
=
[
2
],
initializer
=
tf
.
zeros_initializer
())
logits
=
tf
.
matmul
(
input_tensor
,
output_weights
,
transpose_b
=
True
)
logits
=
tf
.
matmul
(
input_tensor
,
output_weights
,
transpose_b
=
True
)
logits
=
tf
.
nn
.
bias_add
(
logits
,
output_bias
)
logits
=
tf
.
nn
.
bias_add
(
logits
,
output_bias
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
,
axis
=-
1
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
,
axis
=-
1
)
labels
=
tf
.
reshape
(
labels
,
[
-
1
])
labels
=
tf
.
reshape
(
labels
,
[
-
1
])
one_hot_labels
=
tf
.
one_hot
(
labels
,
depth
=
2
,
dtype
=
tf
.
float32
)
one_hot_labels
=
tf
.
one_hot
(
labels
,
depth
=
2
,
dtype
=
tf
.
float32
)
per_example_loss
=
-
tf
.
reduce_sum
(
one_hot_labels
*
log_probs
,
axis
=-
1
)
per_example_loss
=
-
tf
.
reduce_sum
(
one_hot_labels
*
log_probs
,
axis
=-
1
)
loss
=
tf
.
reduce_mean
(
per_example_loss
)
loss
=
tf
.
reduce_mean
(
per_example_loss
)
return
(
loss
,
per_example_loss
,
log_probs
)
return
(
loss
,
per_example_loss
,
log_probs
)
def
gather_indexes
(
sequence_tensor
,
positions
):
def
gather_indexes
(
sequence_tensor
,
positions
):
"""Gathers the vectors at the specific positions over a minibatch."""
"""Gathers the vectors at the specific positions over a minibatch."""
sequence_shape
=
modeling
.
get_shape_list
(
sequence_tensor
,
expected_rank
=
3
)
sequence_shape
=
modeling
.
get_shape_list
(
sequence_tensor
,
expected_rank
=
3
)
batch_size
=
sequence_shape
[
0
]
batch_size
=
sequence_shape
[
0
]
seq_length
=
sequence_shape
[
1
]
seq_length
=
sequence_shape
[
1
]
width
=
sequence_shape
[
2
]
width
=
sequence_shape
[
2
]
flat_offsets
=
tf
.
reshape
(
flat_offsets
=
tf
.
reshape
(
tf
.
range
(
0
,
batch_size
,
dtype
=
tf
.
int32
)
*
seq_length
,
[
-
1
,
1
])
tf
.
range
(
0
,
batch_size
,
dtype
=
tf
.
int32
)
*
seq_length
,
[
-
1
,
1
])
flat_positions
=
tf
.
reshape
(
positions
+
flat_offsets
,
[
-
1
])
flat_positions
=
tf
.
reshape
(
positions
+
flat_offsets
,
[
-
1
])
flat_sequence_tensor
=
tf
.
reshape
(
sequence_tensor
,
flat_sequence_tensor
=
tf
.
reshape
(
sequence_tensor
,
[
batch_size
*
seq_length
,
width
])
[
batch_size
*
seq_length
,
width
])
output_tensor
=
tf
.
gather
(
flat_sequence_tensor
,
flat_positions
)
output_tensor
=
tf
.
gather
(
flat_sequence_tensor
,
flat_positions
)
return
output_tensor
return
output_tensor
def
input_fn_builder
(
input_files
,
def
input_fn_builder
(
input_files
,
...
@@ -327,168 +327,168 @@ def input_fn_builder(input_files,
...
@@ -327,168 +327,168 @@ def input_fn_builder(input_files,
max_predictions_per_seq
,
max_predictions_per_seq
,
is_training
,
is_training
,
num_cpu_threads
=
4
):
num_cpu_threads
=
4
):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
def
input_fn
(
params
):
def
input_fn
(
params
):
"""The actual input function."""
"""The actual input function."""
batch_size
=
params
[
"batch_size"
]
batch_size
=
params
[
"batch_size"
]
name_to_features
=
{
name_to_features
=
{
"input_ids"
:
"input_ids"
:
tf
.
FixedLenFeature
([
max_seq_length
],
tf
.
int64
),
tf
.
FixedLenFeature
([
max_seq_length
],
tf
.
int64
),
"input_mask"
:
"input_mask"
:
tf
.
FixedLenFeature
([
max_seq_length
],
tf
.
int64
),
tf
.
FixedLenFeature
([
max_seq_length
],
tf
.
int64
),
"segment_ids"
:
"segment_ids"
:
tf
.
FixedLenFeature
([
max_seq_length
],
tf
.
int64
),
tf
.
FixedLenFeature
([
max_seq_length
],
tf
.
int64
),
"masked_lm_positions"
:
"masked_lm_positions"
:
tf
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
int64
),
tf
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
int64
),
"masked_lm_ids"
:
"masked_lm_ids"
:
tf
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
int64
),
tf
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
int64
),
"masked_lm_weights"
:
"masked_lm_weights"
:
tf
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
float32
),
tf
.
FixedLenFeature
([
max_predictions_per_seq
],
tf
.
float32
),
"next_sentence_labels"
:
"next_sentence_labels"
:
tf
.
FixedLenFeature
([
1
],
tf
.
int64
),
tf
.
FixedLenFeature
([
1
],
tf
.
int64
),
}
}
# For training, we want a lot of parallel reading and shuffling.
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
# For eval, we want no shuffling and parallel reading doesn't matter.
if
is_training
:
if
is_training
:
d
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
tf
.
constant
(
input_files
))
d
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
tf
.
constant
(
input_files
))
d
=
d
.
repeat
()
d
=
d
.
repeat
()
d
=
d
.
shuffle
(
buffer_size
=
len
(
input_files
))
d
=
d
.
shuffle
(
buffer_size
=
len
(
input_files
))
# `cycle_length` is the number of parallel files that get read.
# `cycle_length` is the number of parallel files that get read.
cycle_length
=
min
(
num_cpu_threads
,
len
(
input_files
))
cycle_length
=
min
(
num_cpu_threads
,
len
(
input_files
))
# `sloppy` mode means that the interleaving is not exact. This adds
# `sloppy` mode means that the interleaving is not exact. This adds
# even more randomness to the training pipeline.
# even more randomness to the training pipeline.
d
=
d
.
apply
(
d
=
d
.
apply
(
tf
.
contrib
.
data
.
parallel_interleave
(
tf
.
contrib
.
data
.
parallel_interleave
(
tf
.
data
.
TFRecordDataset
,
tf
.
data
.
TFRecordDataset
,
sloppy
=
is_training
,
sloppy
=
is_training
,
cycle_length
=
cycle_length
))
cycle_length
=
cycle_length
))
d
=
d
.
shuffle
(
buffer_size
=
100
)
d
=
d
.
shuffle
(
buffer_size
=
100
)
else
:
else
:
d
=
tf
.
data
.
TFRecordDataset
(
input_files
)
d
=
tf
.
data
.
TFRecordDataset
(
input_files
)
# Since we evaluate for a fixed number of steps we don't want to encounter
# Since we evaluate for a fixed number of steps we don't want to encounter
# out-of-range exceptions.
# out-of-range exceptions.
d
=
d
.
repeat
()
d
=
d
.
repeat
()
# We must `drop_remainder` on training because the TPU requires fixed
# We must `drop_remainder` on training because the TPU requires fixed
# size dimensions. For eval, we assume we are evaling on the CPU or GPU
# size dimensions. For eval, we assume we are evaling on the CPU or GPU
# and we *don"t* want to drop the remainder, otherwise we wont cover
# and we *don"t* want to drop the remainder, otherwise we wont cover
# every sample.
# every sample.
d
=
d
.
apply
(
d
=
d
.
apply
(
tf
.
contrib
.
data
.
map_and_batch
(
tf
.
contrib
.
data
.
map_and_batch
(
lambda
record
:
_decode_record
(
record
,
name_to_features
),
lambda
record
:
_decode_record
(
record
,
name_to_features
),
batch_size
=
batch_size
,
batch_size
=
batch_size
,
num_parallel_batches
=
num_cpu_threads
,
num_parallel_batches
=
num_cpu_threads
,
drop_remainder
=
True
))
drop_remainder
=
True
))
return
d
return
d
return
input_fn
return
input_fn
def
_decode_record
(
record
,
name_to_features
):
def
_decode_record
(
record
,
name_to_features
):
"""Decodes a record to a TensorFlow example."""
"""Decodes a record to a TensorFlow example."""
example
=
tf
.
parse_single_example
(
record
,
name_to_features
)
example
=
tf
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
# So cast all int64 to int32.
for
name
in
list
(
example
.
keys
()):
for
name
in
list
(
example
.
keys
()):
t
=
example
[
name
]
t
=
example
[
name
]
if
t
.
dtype
==
tf
.
int64
:
if
t
.
dtype
==
tf
.
int64
:
t
=
tf
.
to_int32
(
t
)
t
=
tf
.
to_int32
(
t
)
example
[
name
]
=
t
example
[
name
]
=
t
return
example
return
example
def
main
(
_
):
def
main
(
_
):
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
if
not
FLAGS
.
do_train
and
not
FLAGS
.
do_eval
:
if
not
FLAGS
.
do_train
and
not
FLAGS
.
do_eval
:
raise
ValueError
(
"At least one of `do_train` or `do_eval` must be True."
)
raise
ValueError
(
"At least one of `do_train` or `do_eval` must be True."
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
input_files
=
[]
input_files
=
[]
for
input_pattern
in
FLAGS
.
input_file
.
split
(
","
):
for
input_pattern
in
FLAGS
.
input_file
.
split
(
","
):
input_files
.
extend
(
tf
.
gfile
.
Glob
(
input_pattern
))
input_files
.
extend
(
tf
.
gfile
.
Glob
(
input_pattern
))
tf
.
logging
.
info
(
"*** Input Files ***"
)
tf
.
logging
.
info
(
"*** Input Files ***"
)
for
input_file
in
input_files
:
for
input_file
in
input_files
:
tf
.
logging
.
info
(
" %s"
%
input_file
)
tf
.
logging
.
info
(
" %s"
%
input_file
)
tpu_cluster_resolver
=
None
tpu_cluster_resolver
=
None
if
FLAGS
.
use_tpu
and
FLAGS
.
tpu_name
:
if
FLAGS
.
use_tpu
and
FLAGS
.
tpu_name
:
tpu_cluster_resolver
=
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
tpu_cluster_resolver
=
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
FLAGS
.
tpu_name
,
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
)
FLAGS
.
tpu_name
,
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
)
is_per_host
=
tf
.
contrib
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
is_per_host
=
tf
.
contrib
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
cluster
=
tpu_cluster_resolver
,
cluster
=
tpu_cluster_resolver
,
master
=
FLAGS
.
master
,
master
=
FLAGS
.
master
,
model_dir
=
FLAGS
.
output_dir
,
model_dir
=
FLAGS
.
output_dir
,
save_checkpoints_steps
=
FLAGS
.
save_checkpoints_steps
,
save_checkpoints_steps
=
FLAGS
.
save_checkpoints_steps
,
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
iterations_per_loop
=
FLAGS
.
iterations_per_loop
,
iterations_per_loop
=
FLAGS
.
iterations_per_loop
,
num_shards
=
FLAGS
.
num_tpu_cores
,
num_shards
=
FLAGS
.
num_tpu_cores
,
per_host_input_for_training
=
is_per_host
))
per_host_input_for_training
=
is_per_host
))
model_fn
=
model_fn_builder
(
model_fn
=
model_fn_builder
(
bert_config
=
bert_config
,
bert_config
=
bert_config
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
learning_rate
=
FLAGS
.
learning_rate
,
learning_rate
=
FLAGS
.
learning_rate
,
num_train_steps
=
FLAGS
.
num_train_steps
,
num_train_steps
=
FLAGS
.
num_train_steps
,
num_warmup_steps
=
FLAGS
.
num_warmup_steps
,
num_warmup_steps
=
FLAGS
.
num_warmup_steps
,
use_tpu
=
FLAGS
.
use_tpu
,
use_tpu
=
FLAGS
.
use_tpu
,
use_one_hot_embeddings
=
FLAGS
.
use_tpu
)
use_one_hot_embeddings
=
FLAGS
.
use_tpu
)
# If TPU is not available, this will fall back to normal Estimator on CPU
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
# or GPU.
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
use_tpu
=
FLAGS
.
use_tpu
,
use_tpu
=
FLAGS
.
use_tpu
,
model_fn
=
model_fn
,
model_fn
=
model_fn
,
config
=
run_config
,
config
=
run_config
,
train_batch_size
=
FLAGS
.
train_batch_size
,
train_batch_size
=
FLAGS
.
train_batch_size
,
eval_batch_size
=
FLAGS
.
eval_batch_size
)
eval_batch_size
=
FLAGS
.
eval_batch_size
)
if
FLAGS
.
do_train
:
if
FLAGS
.
do_train
:
tf
.
logging
.
info
(
"***** Running training *****"
)
tf
.
logging
.
info
(
"***** Running training *****"
)
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
train_batch_size
)
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
train_batch_size
)
train_input_fn
=
input_fn_builder
(
train_input_fn
=
input_fn_builder
(
input_files
=
input_files
,
input_files
=
input_files
,
max_seq_length
=
FLAGS
.
max_seq_length
,
max_seq_length
=
FLAGS
.
max_seq_length
,
max_predictions_per_seq
=
FLAGS
.
max_predictions_per_seq
,
max_predictions_per_seq
=
FLAGS
.
max_predictions_per_seq
,
is_training
=
True
)
is_training
=
True
)
estimator
.
train
(
input_fn
=
train_input_fn
,
max_steps
=
FLAGS
.
num_train_steps
)
estimator
.
train
(
input_fn
=
train_input_fn
,
max_steps
=
FLAGS
.
num_train_steps
)
if
FLAGS
.
do_eval
:
if
FLAGS
.
do_eval
:
tf
.
logging
.
info
(
"***** Running evaluation *****"
)
tf
.
logging
.
info
(
"***** Running evaluation *****"
)
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
eval_batch_size
)
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
eval_batch_size
)
eval_input_fn
=
input_fn_builder
(
eval_input_fn
=
input_fn_builder
(
input_files
=
input_files
,
input_files
=
input_files
,
max_seq_length
=
FLAGS
.
max_seq_length
,
max_seq_length
=
FLAGS
.
max_seq_length
,
max_predictions_per_seq
=
FLAGS
.
max_predictions_per_seq
,
max_predictions_per_seq
=
FLAGS
.
max_predictions_per_seq
,
is_training
=
False
)
is_training
=
False
)
result
=
estimator
.
evaluate
(
result
=
estimator
.
evaluate
(
input_fn
=
eval_input_fn
,
steps
=
FLAGS
.
max_eval_steps
)
input_fn
=
eval_input_fn
,
steps
=
FLAGS
.
max_eval_steps
)
output_eval_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"eval_results.txt"
)
output_eval_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"eval_results.txt"
)
with
tf
.
gfile
.
GFile
(
output_eval_file
,
"w"
)
as
writer
:
with
tf
.
gfile
.
GFile
(
output_eval_file
,
"w"
)
as
writer
:
tf
.
logging
.
info
(
"***** Eval results *****"
)
tf
.
logging
.
info
(
"***** Eval results *****"
)
for
key
in
sorted
(
result
.
keys
()):
for
key
in
sorted
(
result
.
keys
()):
tf
.
logging
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
tf
.
logging
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
flags
.
mark_flag_as_required
(
"input_file"
)
flags
.
mark_flag_as_required
(
"input_file"
)
flags
.
mark_flag_as_required
(
"bert_config_file"
)
flags
.
mark_flag_as_required
(
"bert_config_file"
)
flags
.
mark_flag_as_required
(
"output_dir"
)
flags
.
mark_flag_as_required
(
"output_dir"
)
tf
.
app
.
run
()
tf
.
app
.
run
()
run_squad.py
View file @
8163baab
...
@@ -146,562 +146,562 @@ flags.DEFINE_bool(
...
@@ -146,562 +146,562 @@ flags.DEFINE_bool(
class
SquadExample
(
object
):
class
SquadExample
(
object
):
"""A single training/test example for simple sequence classification."""
"""A single training/test example for simple sequence classification."""
def
__init__
(
self
,
def
__init__
(
self
,
qas_id
,
qas_id
,
question_text
,
question_text
,
doc_tokens
,
doc_tokens
,
orig_answer_text
=
None
,
orig_answer_text
=
None
,
start_position
=
None
,
start_position
=
None
,
end_position
=
None
):
end_position
=
None
):
self
.
qas_id
=
qas_id
self
.
qas_id
=
qas_id
self
.
question_text
=
question_text
self
.
question_text
=
question_text
self
.
doc_tokens
=
doc_tokens
self
.
doc_tokens
=
doc_tokens
self
.
orig_answer_text
=
orig_answer_text
self
.
orig_answer_text
=
orig_answer_text
self
.
start_position
=
start_position
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
end_position
=
end_position
def
__str__
(
self
):
def
__str__
(
self
):
return
self
.
__repr__
()
return
self
.
__repr__
()
def
__repr__
(
self
):
def
__repr__
(
self
):
s
=
""
s
=
""
s
+=
"qas_id: %s"
%
(
tokenization
.
printable_text
(
self
.
qas_id
))
s
+=
"qas_id: %s"
%
(
tokenization
.
printable_text
(
self
.
qas_id
))
s
+=
", question_text: %s"
%
(
s
+=
", question_text: %s"
%
(
tokenization
.
printable_text
(
self
.
question_text
))
tokenization
.
printable_text
(
self
.
question_text
))
s
+=
", doc_tokens: [%s]"
%
(
" "
.
join
(
self
.
doc_tokens
))
s
+=
", doc_tokens: [%s]"
%
(
" "
.
join
(
self
.
doc_tokens
))
if
self
.
start_position
:
if
self
.
start_position
:
s
+=
", start_position: %d"
%
(
self
.
start_position
)
s
+=
", start_position: %d"
%
(
self
.
start_position
)
if
self
.
start_position
:
if
self
.
start_position
:
s
+=
", end_position: %d"
%
(
self
.
end_position
)
s
+=
", end_position: %d"
%
(
self
.
end_position
)
return
s
return
s
class
InputFeatures
(
object
):
class
InputFeatures
(
object
):
"""A single set of features of data."""
"""A single set of features of data."""
def
__init__
(
self
,
def
__init__
(
self
,
unique_id
,
unique_id
,
example_index
,
example_index
,
doc_span_index
,
doc_span_index
,
tokens
,
tokens
,
token_to_orig_map
,
token_to_orig_map
,
token_is_max_context
,
token_is_max_context
,
input_ids
,
input_ids
,
input_mask
,
input_mask
,
segment_ids
,
segment_ids
,
start_position
=
None
,
start_position
=
None
,
end_position
=
None
):
end_position
=
None
):
self
.
unique_id
=
unique_id
self
.
unique_id
=
unique_id
self
.
example_index
=
example_index
self
.
example_index
=
example_index
self
.
doc_span_index
=
doc_span_index
self
.
doc_span_index
=
doc_span_index
self
.
tokens
=
tokens
self
.
tokens
=
tokens
self
.
token_to_orig_map
=
token_to_orig_map
self
.
token_to_orig_map
=
token_to_orig_map
self
.
token_is_max_context
=
token_is_max_context
self
.
token_is_max_context
=
token_is_max_context
self
.
input_ids
=
input_ids
self
.
input_ids
=
input_ids
self
.
input_mask
=
input_mask
self
.
input_mask
=
input_mask
self
.
segment_ids
=
segment_ids
self
.
segment_ids
=
segment_ids
self
.
start_position
=
start_position
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
end_position
=
end_position
def
read_squad_examples
(
input_file
,
is_training
):
def
read_squad_examples
(
input_file
,
is_training
):
"""Read a SQuAD json file into a list of SquadExample."""
"""Read a SQuAD json file into a list of SquadExample."""
with
tf
.
gfile
.
Open
(
input_file
,
"r"
)
as
reader
:
with
tf
.
gfile
.
Open
(
input_file
,
"r"
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
input_data
=
json
.
load
(
reader
)[
"data"
]
def
is_whitespace
(
c
):
def
is_whitespace
(
c
):
if
c
==
" "
or
c
==
"
\t
"
or
c
==
"
\r
"
or
c
==
"
\n
"
or
ord
(
c
)
==
0x202F
:
if
c
==
" "
or
c
==
"
\t
"
or
c
==
"
\r
"
or
c
==
"
\n
"
or
ord
(
c
)
==
0x202F
:
return
True
return
True
return
False
return
False
examples
=
[]
examples
=
[]
for
entry
in
input_data
:
for
entry
in
input_data
:
for
paragraph
in
entry
[
"paragraphs"
]:
for
paragraph
in
entry
[
"paragraphs"
]:
paragraph_text
=
paragraph
[
"context"
]
paragraph_text
=
paragraph
[
"context"
]
doc_tokens
=
[]
doc_tokens
=
[]
char_to_word_offset
=
[]
char_to_word_offset
=
[]
prev_is_whitespace
=
True
prev_is_whitespace
=
True
for
c
in
paragraph_text
:
for
c
in
paragraph_text
:
if
is_whitespace
(
c
):
if
is_whitespace
(
c
):
prev_is_whitespace
=
True
prev_is_whitespace
=
True
else
:
else
:
if
prev_is_whitespace
:
if
prev_is_whitespace
:
doc_tokens
.
append
(
c
)
doc_tokens
.
append
(
c
)
else
:
else
:
doc_tokens
[
-
1
]
+=
c
doc_tokens
[
-
1
]
+=
c
prev_is_whitespace
=
False
prev_is_whitespace
=
False
char_to_word_offset
.
append
(
len
(
doc_tokens
)
-
1
)
char_to_word_offset
.
append
(
len
(
doc_tokens
)
-
1
)
for
qa
in
paragraph
[
"qas"
]:
for
qa
in
paragraph
[
"qas"
]:
qas_id
=
qa
[
"id"
]
qas_id
=
qa
[
"id"
]
question_text
=
qa
[
"question"
]
question_text
=
qa
[
"question"
]
start_position
=
None
start_position
=
None
end_position
=
None
end_position
=
None
orig_answer_text
=
None
orig_answer_text
=
None
if
is_training
:
if
is_training
:
if
len
(
qa
[
"answers"
])
!=
1
:
if
len
(
qa
[
"answers"
])
!=
1
:
raise
ValueError
(
raise
ValueError
(
"For training, each question should have exactly 1 answer."
)
"For training, each question should have exactly 1 answer."
)
answer
=
qa
[
"answers"
][
0
]
answer
=
qa
[
"answers"
][
0
]
orig_answer_text
=
answer
[
"text"
]
orig_answer_text
=
answer
[
"text"
]
answer_offset
=
answer
[
"answer_start"
]
answer_offset
=
answer
[
"answer_start"
]
answer_length
=
len
(
orig_answer_text
)
answer_length
=
len
(
orig_answer_text
)
start_position
=
char_to_word_offset
[
answer_offset
]
start_position
=
char_to_word_offset
[
answer_offset
]
end_position
=
char_to_word_offset
[
answer_offset
+
answer_length
-
1
]
end_position
=
char_to_word_offset
[
answer_offset
+
answer_length
-
1
]
# Only add answers where the text can be exactly recovered from the
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
# stuff so we will just skip the example.
#
#
# Note that this means for training mode, every example is NOT
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
# guaranteed to be preserved.
actual_text
=
" "
.
join
(
doc_tokens
[
start_position
:(
end_position
+
1
)])
actual_text
=
" "
.
join
(
doc_tokens
[
start_position
:(
end_position
+
1
)])
cleaned_answer_text
=
" "
.
join
(
cleaned_answer_text
=
" "
.
join
(
tokenization
.
whitespace_tokenize
(
orig_answer_text
))
tokenization
.
whitespace_tokenize
(
orig_answer_text
))
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
tf
.
logging
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
tf
.
logging
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
actual_text
,
cleaned_answer_text
)
continue
continue
example
=
SquadExample
(
example
=
SquadExample
(
qas_id
=
qas_id
,
qas_id
=
qas_id
,
question_text
=
question_text
,
question_text
=
question_text
,
doc_tokens
=
doc_tokens
,
doc_tokens
=
doc_tokens
,
orig_answer_text
=
orig_answer_text
,
orig_answer_text
=
orig_answer_text
,
start_position
=
start_position
,
start_position
=
start_position
,
end_position
=
end_position
)
end_position
=
end_position
)
examples
.
append
(
example
)
examples
.
append
(
example
)
return
examples
return
examples
def
convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
def
convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
):
doc_stride
,
max_query_length
,
is_training
):
"""Loads a data file into a list of `InputBatch`s."""
"""Loads a data file into a list of `InputBatch`s."""
unique_id
=
1000000000
unique_id
=
1000000000
features
=
[]
features
=
[]
for
(
example_index
,
example
)
in
enumerate
(
examples
):
for
(
example_index
,
example
)
in
enumerate
(
examples
):
query_tokens
=
tokenizer
.
tokenize
(
example
.
question_text
)
query_tokens
=
tokenizer
.
tokenize
(
example
.
question_text
)
if
len
(
query_tokens
)
>
max_query_length
:
if
len
(
query_tokens
)
>
max_query_length
:
query_tokens
=
query_tokens
[
0
:
max_query_length
]
query_tokens
=
query_tokens
[
0
:
max_query_length
]
tok_to_orig_index
=
[]
tok_to_orig_index
=
[]
orig_to_tok_index
=
[]
orig_to_tok_index
=
[]
all_doc_tokens
=
[]
all_doc_tokens
=
[]
for
(
i
,
token
)
in
enumerate
(
example
.
doc_tokens
):
for
(
i
,
token
)
in
enumerate
(
example
.
doc_tokens
):
orig_to_tok_index
.
append
(
len
(
all_doc_tokens
))
orig_to_tok_index
.
append
(
len
(
all_doc_tokens
))
sub_tokens
=
tokenizer
.
tokenize
(
token
)
sub_tokens
=
tokenizer
.
tokenize
(
token
)
for
sub_token
in
sub_tokens
:
for
sub_token
in
sub_tokens
:
tok_to_orig_index
.
append
(
i
)
tok_to_orig_index
.
append
(
i
)
all_doc_tokens
.
append
(
sub_token
)
all_doc_tokens
.
append
(
sub_token
)
tok_start_position
=
None
tok_start_position
=
None
tok_end_position
=
None
tok_end_position
=
None
if
is_training
:
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
tok_end_position
=
orig_to_tok_index
[
example
.
end_position
+
1
]
-
1
else
:
tok_end_position
=
len
(
all_doc_tokens
)
-
1
(
tok_start_position
,
tok_end_position
)
=
_improve_answer_span
(
all_doc_tokens
,
tok_start_position
,
tok_end_position
,
tokenizer
,
example
.
orig_answer_text
)
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc
=
max_seq_length
-
len
(
query_tokens
)
-
3
# We can have documents that are longer than the maximum sequence length.
# To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`.
_DocSpan
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"DocSpan"
,
[
"start"
,
"length"
])
doc_spans
=
[]
start_offset
=
0
while
start_offset
<
len
(
all_doc_tokens
):
length
=
len
(
all_doc_tokens
)
-
start_offset
if
length
>
max_tokens_for_doc
:
length
=
max_tokens_for_doc
doc_spans
.
append
(
_DocSpan
(
start
=
start_offset
,
length
=
length
))
if
start_offset
+
length
==
len
(
all_doc_tokens
):
break
start_offset
+=
min
(
length
,
doc_stride
)
for
(
doc_span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
tokens
=
[]
token_to_orig_map
=
{}
token_is_max_context
=
{}
segment_ids
=
[]
tokens
.
append
(
"[CLS]"
)
segment_ids
.
append
(
0
)
for
token
in
query_tokens
:
tokens
.
append
(
token
)
segment_ids
.
append
(
0
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
0
)
for
i
in
range
(
doc_span
.
length
):
split_token_index
=
doc_span
.
start
+
i
token_to_orig_map
[
len
(
tokens
)]
=
tok_to_orig_index
[
split_token_index
]
is_max_context
=
_check_is_max_context
(
doc_spans
,
doc_span_index
,
split_token_index
)
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
segment_ids
.
append
(
1
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
1
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask
=
[
1
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
segment_ids
.
append
(
0
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
start_position
=
None
end_position
=
None
if
is_training
:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start
=
doc_span
.
start
doc_end
=
doc_span
.
start
+
doc_span
.
length
-
1
if
(
example
.
start_position
<
doc_start
or
example
.
end_position
<
doc_start
or
example
.
start_position
>
doc_end
or
example
.
end_position
>
doc_end
):
continue
doc_offset
=
len
(
query_tokens
)
+
2
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
if
example_index
<
20
:
tf
.
logging
.
info
(
"*** Example ***"
)
tf
.
logging
.
info
(
"unique_id: %s"
%
(
unique_id
))
tf
.
logging
.
info
(
"example_index: %s"
%
(
example_index
))
tf
.
logging
.
info
(
"doc_span_index: %s"
%
(
doc_span_index
))
tf
.
logging
.
info
(
"tokens: %s"
%
" "
.
join
(
[
tokenization
.
printable_text
(
x
)
for
x
in
tokens
]))
tf
.
logging
.
info
(
"token_to_orig_map: %s"
%
" "
.
join
(
[
"%d:%d"
%
(
x
,
y
)
for
(
x
,
y
)
in
six
.
iteritems
(
token_to_orig_map
)]))
tf
.
logging
.
info
(
"token_is_max_context: %s"
%
" "
.
join
([
"%d:%s"
%
(
x
,
y
)
for
(
x
,
y
)
in
six
.
iteritems
(
token_is_max_context
)
]))
tf
.
logging
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
tf
.
logging
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
tf
.
logging
.
info
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
if
is_training
:
if
is_training
:
answer_text
=
" "
.
join
(
tokens
[
start_position
:(
end_position
+
1
)])
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
tf
.
logging
.
info
(
"start_position: %d"
%
(
start_position
))
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
tf
.
logging
.
info
(
"end_position: %d"
%
(
end_position
))
tok_end_position
=
orig_to_tok_index
[
example
.
end_position
+
1
]
-
1
tf
.
logging
.
info
(
else
:
"answer: %s"
%
(
tokenization
.
printable_text
(
answer_text
)))
tok_end_position
=
len
(
all_doc_tokens
)
-
1
(
tok_start_position
,
tok_end_position
)
=
_improve_answer_span
(
features
.
append
(
all_doc_tokens
,
tok_start_position
,
tok_end_position
,
tokenizer
,
InputFeatures
(
example
.
orig_answer_text
)
unique_id
=
unique_id
,
example_index
=
example_index
,
# The -3 accounts for [CLS], [SEP] and [SEP]
doc_span_index
=
doc_span_index
,
max_tokens_for_doc
=
max_seq_length
-
len
(
query_tokens
)
-
3
tokens
=
tokens
,
token_to_orig_map
=
token_to_orig_map
,
# We can have documents that are longer than the maximum sequence length.
token_is_max_context
=
token_is_max_context
,
# To deal with this we do a sliding window approach, where we take chunks
input_ids
=
input_ids
,
# of the up to our max length with a stride of `doc_stride`.
input_mask
=
input_mask
,
_DocSpan
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
segment_ids
=
segment_ids
,
"DocSpan"
,
[
"start"
,
"length"
])
start_position
=
start_position
,
doc_spans
=
[]
end_position
=
end_position
))
start_offset
=
0
unique_id
+=
1
while
start_offset
<
len
(
all_doc_tokens
):
length
=
len
(
all_doc_tokens
)
-
start_offset
return
features
if
length
>
max_tokens_for_doc
:
length
=
max_tokens_for_doc
doc_spans
.
append
(
_DocSpan
(
start
=
start_offset
,
length
=
length
))
if
start_offset
+
length
==
len
(
all_doc_tokens
):
break
start_offset
+=
min
(
length
,
doc_stride
)
for
(
doc_span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
tokens
=
[]
token_to_orig_map
=
{}
token_is_max_context
=
{}
segment_ids
=
[]
tokens
.
append
(
"[CLS]"
)
segment_ids
.
append
(
0
)
for
token
in
query_tokens
:
tokens
.
append
(
token
)
segment_ids
.
append
(
0
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
0
)
for
i
in
range
(
doc_span
.
length
):
split_token_index
=
doc_span
.
start
+
i
token_to_orig_map
[
len
(
tokens
)]
=
tok_to_orig_index
[
split_token_index
]
is_max_context
=
_check_is_max_context
(
doc_spans
,
doc_span_index
,
split_token_index
)
token_is_max_context
[
len
(
tokens
)]
=
is_max_context
tokens
.
append
(
all_doc_tokens
[
split_token_index
])
segment_ids
.
append
(
1
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
1
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask
=
[
1
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
while
len
(
input_ids
)
<
max_seq_length
:
input_ids
.
append
(
0
)
input_mask
.
append
(
0
)
segment_ids
.
append
(
0
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
start_position
=
None
end_position
=
None
if
is_training
:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start
=
doc_span
.
start
doc_end
=
doc_span
.
start
+
doc_span
.
length
-
1
if
(
example
.
start_position
<
doc_start
or
example
.
end_position
<
doc_start
or
example
.
start_position
>
doc_end
or
example
.
end_position
>
doc_end
):
continue
doc_offset
=
len
(
query_tokens
)
+
2
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
if
example_index
<
20
:
tf
.
logging
.
info
(
"*** Example ***"
)
tf
.
logging
.
info
(
"unique_id: %s"
%
(
unique_id
))
tf
.
logging
.
info
(
"example_index: %s"
%
(
example_index
))
tf
.
logging
.
info
(
"doc_span_index: %s"
%
(
doc_span_index
))
tf
.
logging
.
info
(
"tokens: %s"
%
" "
.
join
(
[
tokenization
.
printable_text
(
x
)
for
x
in
tokens
]))
tf
.
logging
.
info
(
"token_to_orig_map: %s"
%
" "
.
join
(
[
"%d:%d"
%
(
x
,
y
)
for
(
x
,
y
)
in
six
.
iteritems
(
token_to_orig_map
)]))
tf
.
logging
.
info
(
"token_is_max_context: %s"
%
" "
.
join
([
"%d:%s"
%
(
x
,
y
)
for
(
x
,
y
)
in
six
.
iteritems
(
token_is_max_context
)
]))
tf
.
logging
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
tf
.
logging
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
tf
.
logging
.
info
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
if
is_training
:
answer_text
=
" "
.
join
(
tokens
[
start_position
:(
end_position
+
1
)])
tf
.
logging
.
info
(
"start_position: %d"
%
(
start_position
))
tf
.
logging
.
info
(
"end_position: %d"
%
(
end_position
))
tf
.
logging
.
info
(
"answer: %s"
%
(
tokenization
.
printable_text
(
answer_text
)))
features
.
append
(
InputFeatures
(
unique_id
=
unique_id
,
example_index
=
example_index
,
doc_span_index
=
doc_span_index
,
tokens
=
tokens
,
token_to_orig_map
=
token_to_orig_map
,
token_is_max_context
=
token_is_max_context
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
start_position
=
start_position
,
end_position
=
end_position
))
unique_id
+=
1
return
features
def
_improve_answer_span
(
doc_tokens
,
input_start
,
input_end
,
tokenizer
,
def
_improve_answer_span
(
doc_tokens
,
input_start
,
input_end
,
tokenizer
,
orig_answer_text
):
orig_answer_text
):
"""Returns tokenized answer spans that better match the annotated answer."""
"""Returns tokenized answer spans that better match the annotated answer."""
# The SQuAD annotations are character based. We first project them to
# The SQuAD annotations are character based. We first project them to
# whitespace-tokenized words. But then after WordPiece tokenization, we can
# whitespace-tokenized words. But then after WordPiece tokenization, we can
# often find a "better match". For example:
# often find a "better match". For example:
#
#
# Question: What year was John Smith born?
# Question: What year was John Smith born?
# Context: The leader was John Smith (1895-1943).
# Context: The leader was John Smith (1895-1943).
# Answer: 1895
# Answer: 1895
#
#
# The original whitespace-tokenized answer will be "(1895-1943).". However
# The original whitespace-tokenized answer will be "(1895-1943).". However
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
# the exact answer, 1895.
# the exact answer, 1895.
#
#
# However, this is not always possible. Consider the following:
# However, this is not always possible. Consider the following:
#
#
# Question: What country is the top exporter of electornics?
# Question: What country is the top exporter of electornics?
# Context: The Japanese electronics industry is the lagest in the world.
# Context: The Japanese electronics industry is the lagest in the world.
# Answer: Japan
# Answer: Japan
#
#
# In this case, the annotator chose "Japan" as a character sub-span of
# In this case, the annotator chose "Japan" as a character sub-span of
# the word "Japanese". Since our WordPiece tokenizer does not split
# the word "Japanese". Since our WordPiece tokenizer does not split
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
# in SQuAD, but does happen.
# in SQuAD, but does happen.
tok_answer_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_answer_text
))
tok_answer_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_answer_text
))
for
new_start
in
range
(
input_start
,
input_end
+
1
):
for
new_start
in
range
(
input_start
,
input_end
+
1
):
for
new_end
in
range
(
input_end
,
new_start
-
1
,
-
1
):
for
new_end
in
range
(
input_end
,
new_start
-
1
,
-
1
):
text_span
=
" "
.
join
(
doc_tokens
[
new_start
:(
new_end
+
1
)])
text_span
=
" "
.
join
(
doc_tokens
[
new_start
:(
new_end
+
1
)])
if
text_span
==
tok_answer_text
:
if
text_span
==
tok_answer_text
:
return
(
new_start
,
new_end
)
return
(
new_start
,
new_end
)
return
(
input_start
,
input_end
)
return
(
input_start
,
input_end
)
def
_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
def
_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
"""Check if this is the 'max context' doc span for the token."""
"""Check if this is the 'max context' doc span for the token."""
# Because of the sliding window approach taken to scoring documents, a single
# Because of the sliding window approach taken to scoring documents, a single
# token can appear in multiple documents. E.g.
# token can appear in multiple documents. E.g.
# Doc: the man went to the store and bought a gallon of milk
# Doc: the man went to the store and bought a gallon of milk
# Span A: the man went to the
# Span A: the man went to the
# Span B: to the store and bought
# Span B: to the store and bought
# Span C: and bought a gallon of
# Span C: and bought a gallon of
# ...
# ...
#
#
# Now the word 'bought' will have two scores from spans B and C. We only
# Now the word 'bought' will have two scores from spans B and C. We only
# want to consider the score with "maximum context", which we define as
# want to consider the score with "maximum context", which we define as
# the *minimum* of its left and right context (the *sum* of left and
# the *minimum* of its left and right context (the *sum* of left and
# right context will always be the same, of course).
# right context will always be the same, of course).
#
#
# In the example the maximum context for 'bought' would be span C since
# In the example the maximum context for 'bought' would be span C since
# it has 1 left context and 3 right context, while span B has 4 left context
# it has 1 left context and 3 right context, while span B has 4 left context
# and 0 right context.
# and 0 right context.
best_score
=
None
best_score
=
None
best_span_index
=
None
best_span_index
=
None
for
(
span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
for
(
span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
end
=
doc_span
.
start
+
doc_span
.
length
-
1
end
=
doc_span
.
start
+
doc_span
.
length
-
1
if
position
<
doc_span
.
start
:
if
position
<
doc_span
.
start
:
continue
continue
if
position
>
end
:
if
position
>
end
:
continue
continue
num_left_context
=
position
-
doc_span
.
start
num_left_context
=
position
-
doc_span
.
start
num_right_context
=
end
-
position
num_right_context
=
end
-
position
score
=
min
(
num_left_context
,
num_right_context
)
+
0.01
*
doc_span
.
length
score
=
min
(
num_left_context
,
num_right_context
)
+
0.01
*
doc_span
.
length
if
best_score
is
None
or
score
>
best_score
:
if
best_score
is
None
or
score
>
best_score
:
best_score
=
score
best_score
=
score
best_span_index
=
span_index
best_span_index
=
span_index
return
cur_span_index
==
best_span_index
return
cur_span_index
==
best_span_index
def
create_model
(
bert_config
,
is_training
,
input_ids
,
input_mask
,
segment_ids
,
def
create_model
(
bert_config
,
is_training
,
input_ids
,
input_mask
,
segment_ids
,
use_one_hot_embeddings
):
use_one_hot_embeddings
):
"""Creates a classification model."""
"""Creates a classification model."""
model
=
modeling
.
BertModel
(
model
=
modeling
.
BertModel
(
config
=
bert_config
,
config
=
bert_config
,
is_training
=
is_training
,
is_training
=
is_training
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
input_mask
=
input_mask
,
token_type_ids
=
segment_ids
,
token_type_ids
=
segment_ids
,
use_one_hot_embeddings
=
use_one_hot_embeddings
)
use_one_hot_embeddings
=
use_one_hot_embeddings
)
final_hidden
=
model
.
get_sequence_output
()
final_hidden
=
model
.
get_sequence_output
()
final_hidden_shape
=
modeling
.
get_shape_list
(
final_hidden
,
expected_rank
=
3
)
final_hidden_shape
=
modeling
.
get_shape_list
(
final_hidden
,
expected_rank
=
3
)
batch_size
=
final_hidden_shape
[
0
]
batch_size
=
final_hidden_shape
[
0
]
seq_length
=
final_hidden_shape
[
1
]
seq_length
=
final_hidden_shape
[
1
]
hidden_size
=
final_hidden_shape
[
2
]
hidden_size
=
final_hidden_shape
[
2
]
output_weights
=
tf
.
get_variable
(
output_weights
=
tf
.
get_variable
(
"cls/squad/output_weights"
,
[
2
,
hidden_size
],
"cls/squad/output_weights"
,
[
2
,
hidden_size
],
initializer
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
))
initializer
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
))
output_bias
=
tf
.
get_variable
(
output_bias
=
tf
.
get_variable
(
"cls/squad/output_bias"
,
[
2
],
initializer
=
tf
.
zeros_initializer
())
"cls/squad/output_bias"
,
[
2
],
initializer
=
tf
.
zeros_initializer
())
final_hidden_matrix
=
tf
.
reshape
(
final_hidden
,
final_hidden_matrix
=
tf
.
reshape
(
final_hidden
,
[
batch_size
*
seq_length
,
hidden_size
])
[
batch_size
*
seq_length
,
hidden_size
])
logits
=
tf
.
matmul
(
final_hidden_matrix
,
output_weights
,
transpose_b
=
True
)
logits
=
tf
.
matmul
(
final_hidden_matrix
,
output_weights
,
transpose_b
=
True
)
logits
=
tf
.
nn
.
bias_add
(
logits
,
output_bias
)
logits
=
tf
.
nn
.
bias_add
(
logits
,
output_bias
)
logits
=
tf
.
reshape
(
logits
,
[
batch_size
,
seq_length
,
2
])
logits
=
tf
.
reshape
(
logits
,
[
batch_size
,
seq_length
,
2
])
logits
=
tf
.
transpose
(
logits
,
[
2
,
0
,
1
])
logits
=
tf
.
transpose
(
logits
,
[
2
,
0
,
1
])
unstacked_logits
=
tf
.
unstack
(
logits
,
axis
=
0
)
unstacked_logits
=
tf
.
unstack
(
logits
,
axis
=
0
)
(
start_logits
,
end_logits
)
=
(
unstacked_logits
[
0
],
unstacked_logits
[
1
])
(
start_logits
,
end_logits
)
=
(
unstacked_logits
[
0
],
unstacked_logits
[
1
])
return
(
start_logits
,
end_logits
)
return
(
start_logits
,
end_logits
)
def
model_fn_builder
(
bert_config
,
init_checkpoint
,
learning_rate
,
def
model_fn_builder
(
bert_config
,
init_checkpoint
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
,
num_train_steps
,
num_warmup_steps
,
use_tpu
,
use_one_hot_embeddings
):
use_one_hot_embeddings
):
"""Returns `model_fn` closure for TPUEstimator."""
"""Returns `model_fn` closure for TPUEstimator."""
def
model_fn
(
features
,
labels
,
mode
,
params
):
# pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
tf
.
logging
.
info
(
"*** Features ***"
)
for
name
in
sorted
(
features
.
keys
()):
tf
.
logging
.
info
(
" name = %s, shape = %s"
%
(
name
,
features
[
name
].
shape
))
unique_ids
=
features
[
"unique_ids"
]
input_ids
=
features
[
"input_ids"
]
input_mask
=
features
[
"input_mask"
]
segment_ids
=
features
[
"segment_ids"
]
is_training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
(
start_logits
,
end_logits
)
=
create_model
(
bert_config
=
bert_config
,
is_training
=
is_training
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
use_one_hot_embeddings
=
use_one_hot_embeddings
)
tvars
=
tf
.
trainable_variables
()
initialized_variable_names
=
{}
scaffold_fn
=
None
if
init_checkpoint
:
(
assignment_map
,
initialized_variable_names
)
=
modeling
.
get_assigment_map_from_checkpoint
(
tvars
,
init_checkpoint
)
if
use_tpu
:
def
tpu_scaffold
():
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
return
tf
.
train
.
Scaffold
()
scaffold_fn
=
tpu_scaffold
else
:
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
tf
.
logging
.
info
(
"**** Trainable Variables ****"
)
for
var
in
tvars
:
init_string
=
""
if
var
.
name
in
initialized_variable_names
:
init_string
=
", *INIT_FROM_CKPT*"
tf
.
logging
.
info
(
" name = %s, shape = %s%s"
,
var
.
name
,
var
.
shape
,
init_string
)
output_spec
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
seq_length
=
modeling
.
get_shape_list
(
input_ids
)[
1
]
def
compute_loss
(
logits
,
positions
):
one_hot_positions
=
tf
.
one_hot
(
positions
,
depth
=
seq_length
,
dtype
=
tf
.
float32
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
,
axis
=-
1
)
loss
=
-
tf
.
reduce_mean
(
tf
.
reduce_sum
(
one_hot_positions
*
log_probs
,
axis
=-
1
))
return
loss
start_positions
=
features
[
"start_positions"
]
end_positions
=
features
[
"end_positions"
]
start_loss
=
compute_loss
(
start_logits
,
start_positions
)
end_loss
=
compute_loss
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2.0
train_op
=
optimization
.
create_optimizer
(
total_loss
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
)
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
total_loss
,
train_op
=
train_op
,
scaffold_fn
=
scaffold_fn
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
predictions
=
{
"unique_ids"
:
unique_ids
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
,
scaffold_fn
=
scaffold_fn
)
else
:
raise
ValueError
(
"Only TRAIN and PREDICT modes are supported: %s"
%
(
mode
))
def
model_fn
(
features
,
labels
,
mode
,
params
):
# pylint: disable=unused-argument
return
output_spec
"""The `model_fn` for TPUEstimator."""
tf
.
logging
.
info
(
"*** Features ***"
)
return
model_fn
for
name
in
sorted
(
features
.
keys
()):
tf
.
logging
.
info
(
" name = %s, shape = %s"
%
(
name
,
features
[
name
].
shape
))
unique_ids
=
features
[
"unique_ids"
]
input_ids
=
features
[
"input_ids"
]
input_mask
=
features
[
"input_mask"
]
segment_ids
=
features
[
"segment_ids"
]
is_training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
def
input_fn_builder
(
features
,
seq_length
,
is_training
,
drop_remainder
):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
all_unique_ids
=
[]
all_input_ids
=
[]
all_input_mask
=
[]
all_segment_ids
=
[]
all_start_positions
=
[]
all_end_positions
=
[]
for
feature
in
features
:
all_unique_ids
.
append
(
feature
.
unique_id
)
all_input_ids
.
append
(
feature
.
input_ids
)
all_input_mask
.
append
(
feature
.
input_mask
)
all_segment_ids
.
append
(
feature
.
segment_ids
)
if
is_training
:
all_start_positions
.
append
(
feature
.
start_position
)
all_end_positions
.
append
(
feature
.
end_position
)
def
input_fn
(
params
):
"""The actual input function."""
batch_size
=
params
[
"batch_size"
]
num_examples
=
len
(
features
)
# This is for demo purposes and does NOT scale to large data sets. We do
# not use Dataset.from_generator() because that uses tf.py_func which is
# not TPU compatible. The right way to load data is with TFRecordReader.
feature_map
=
{
"unique_ids"
:
tf
.
constant
(
all_unique_ids
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
),
"input_ids"
:
tf
.
constant
(
all_input_ids
,
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
"input_mask"
:
tf
.
constant
(
all_input_mask
,
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
"segment_ids"
:
tf
.
constant
(
all_segment_ids
,
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
}
if
is_training
:
feature_map
[
"start_positions"
]
=
tf
.
constant
(
all_start_positions
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
)
feature_map
[
"end_positions"
]
=
tf
.
constant
(
all_end_positions
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
)
(
start_logits
,
end_logits
)
=
create_model
(
d
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
feature_map
)
bert_config
=
bert_config
,
is_training
=
is_training
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
use_one_hot_embeddings
=
use_one_hot_embeddings
)
tvars
=
tf
.
trainable_variables
()
if
is_training
:
d
=
d
.
repeat
()
initialized_variable_names
=
{}
d
=
d
.
shuffle
(
buffer_size
=
100
)
scaffold_fn
=
None
if
init_checkpoint
:
(
assignment_map
,
initialized_variable_names
)
=
modeling
.
get_assigment_map_from_checkpoint
(
tvars
,
init_checkpoint
)
if
use_tpu
:
def
tpu_scaffold
():
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
return
tf
.
train
.
Scaffold
()
scaffold_fn
=
tpu_scaffold
else
:
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
tf
.
logging
.
info
(
"**** Trainable Variables ****"
)
for
var
in
tvars
:
init_string
=
""
if
var
.
name
in
initialized_variable_names
:
init_string
=
", *INIT_FROM_CKPT*"
tf
.
logging
.
info
(
" name = %s, shape = %s%s"
,
var
.
name
,
var
.
shape
,
init_string
)
output_spec
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
seq_length
=
modeling
.
get_shape_list
(
input_ids
)[
1
]
def
compute_loss
(
logits
,
positions
):
one_hot_positions
=
tf
.
one_hot
(
positions
,
depth
=
seq_length
,
dtype
=
tf
.
float32
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
,
axis
=-
1
)
loss
=
-
tf
.
reduce_mean
(
tf
.
reduce_sum
(
one_hot_positions
*
log_probs
,
axis
=-
1
))
return
loss
start_positions
=
features
[
"start_positions"
]
end_positions
=
features
[
"end_positions"
]
start_loss
=
compute_loss
(
start_logits
,
start_positions
)
end_loss
=
compute_loss
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2.0
train_op
=
optimization
.
create_optimizer
(
total_loss
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
)
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
total_loss
,
train_op
=
train_op
,
scaffold_fn
=
scaffold_fn
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
predictions
=
{
"unique_ids"
:
unique_ids
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
,
scaffold_fn
=
scaffold_fn
)
else
:
raise
ValueError
(
"Only TRAIN and PREDICT modes are supported: %s"
%
(
mode
))
return
output_spec
return
model_fn
d
=
d
.
batch
(
batch_size
=
batch_size
,
drop_remainder
=
drop_remainder
)
return
d
def
input_fn_builder
(
features
,
seq_length
,
is_training
,
drop_remainder
):
return
input_fn
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
all_unique_ids
=
[]
all_input_ids
=
[]
all_input_mask
=
[]
all_segment_ids
=
[]
all_start_positions
=
[]
all_end_positions
=
[]
for
feature
in
features
:
all_unique_ids
.
append
(
feature
.
unique_id
)
all_input_ids
.
append
(
feature
.
input_ids
)
all_input_mask
.
append
(
feature
.
input_mask
)
all_segment_ids
.
append
(
feature
.
segment_ids
)
if
is_training
:
all_start_positions
.
append
(
feature
.
start_position
)
all_end_positions
.
append
(
feature
.
end_position
)
def
input_fn
(
params
):
"""The actual input function."""
batch_size
=
params
[
"batch_size"
]
num_examples
=
len
(
features
)
# This is for demo purposes and does NOT scale to large data sets. We do
# not use Dataset.from_generator() because that uses tf.py_func which is
# not TPU compatible. The right way to load data is with TFRecordReader.
feature_map
=
{
"unique_ids"
:
tf
.
constant
(
all_unique_ids
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
),
"input_ids"
:
tf
.
constant
(
all_input_ids
,
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
"input_mask"
:
tf
.
constant
(
all_input_mask
,
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
"segment_ids"
:
tf
.
constant
(
all_segment_ids
,
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
}
if
is_training
:
feature_map
[
"start_positions"
]
=
tf
.
constant
(
all_start_positions
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
)
feature_map
[
"end_positions"
]
=
tf
.
constant
(
all_end_positions
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
)
d
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
feature_map
)
if
is_training
:
d
=
d
.
repeat
()
d
=
d
.
shuffle
(
buffer_size
=
100
)
d
=
d
.
batch
(
batch_size
=
batch_size
,
drop_remainder
=
drop_remainder
)
return
d
return
input_fn
RawResult
=
collections
.
namedtuple
(
"RawResult"
,
RawResult
=
collections
.
namedtuple
(
"RawResult"
,
...
@@ -711,410 +711,410 @@ RawResult = collections.namedtuple("RawResult",
...
@@ -711,410 +711,410 @@ RawResult = collections.namedtuple("RawResult",
def
write_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
def
write_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
output_nbest_file
):
output_nbest_file
):
"""Write final predictions to the json file."""
"""Write final predictions to the json file."""
tf
.
logging
.
info
(
"Writing predictions to: %s"
%
(
output_prediction_file
))
tf
.
logging
.
info
(
"Writing predictions to: %s"
%
(
output_prediction_file
))
tf
.
logging
.
info
(
"Writing nbest to: %s"
%
(
output_nbest_file
))
tf
.
logging
.
info
(
"Writing nbest to: %s"
%
(
output_nbest_file
))
example_index_to_features
=
collections
.
defaultdict
(
list
)
example_index_to_features
=
collections
.
defaultdict
(
list
)
for
feature
in
all_features
:
for
feature
in
all_features
:
example_index_to_features
[
feature
.
example_index
].
append
(
feature
)
example_index_to_features
[
feature
.
example_index
].
append
(
feature
)
unique_id_to_result
=
{}
unique_id_to_result
=
{}
for
result
in
all_results
:
for
result
in
all_results
:
unique_id_to_result
[
result
.
unique_id
]
=
result
unique_id_to_result
[
result
.
unique_id
]
=
result
_PrelimPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
_PrelimPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"PrelimPrediction"
,
"PrelimPrediction"
,
[
"feature_index"
,
"start_index"
,
"end_index"
,
"start_logit"
,
"end_logit"
])
[
"feature_index"
,
"start_index"
,
"end_index"
,
"start_logit"
,
"end_logit"
])
all_predictions
=
collections
.
OrderedDict
()
all_predictions
=
collections
.
OrderedDict
()
all_nbest_json
=
collections
.
OrderedDict
()
all_nbest_json
=
collections
.
OrderedDict
()
for
(
example_index
,
example
)
in
enumerate
(
all_examples
):
for
(
example_index
,
example
)
in
enumerate
(
all_examples
):
features
=
example_index_to_features
[
example_index
]
features
=
example_index_to_features
[
example_index
]
prelim_predictions
=
[]
prelim_predictions
=
[]
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
for
(
feature_index
,
feature
)
in
enumerate
(
features
):
result
=
unique_id_to_result
[
feature
.
unique_id
]
result
=
unique_id_to_result
[
feature
.
unique_id
]
start_indexes
=
_get_best_indexes
(
result
.
start_logits
,
n_best_size
)
start_indexes
=
_get_best_indexes
(
result
.
start_logits
,
n_best_size
)
end_indexes
=
_get_best_indexes
(
result
.
end_logits
,
n_best_size
)
end_indexes
=
_get_best_indexes
(
result
.
end_logits
,
n_best_size
)
for
start_index
in
start_indexes
:
for
start_index
in
start_indexes
:
for
end_index
in
end_indexes
:
for
end_index
in
end_indexes
:
# We could hypothetically create invalid predictions, e.g., predict
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# that the start of the span is in the question. We throw out all
# invalid predictions.
# invalid predictions.
if
start_index
>=
len
(
feature
.
tokens
):
if
start_index
>=
len
(
feature
.
tokens
):
continue
continue
if
end_index
>=
len
(
feature
.
tokens
):
if
end_index
>=
len
(
feature
.
tokens
):
continue
continue
if
start_index
not
in
feature
.
token_to_orig_map
:
if
start_index
not
in
feature
.
token_to_orig_map
:
continue
continue
if
end_index
not
in
feature
.
token_to_orig_map
:
if
end_index
not
in
feature
.
token_to_orig_map
:
continue
continue
if
not
feature
.
token_is_max_context
.
get
(
start_index
,
False
):
if
not
feature
.
token_is_max_context
.
get
(
start_index
,
False
):
continue
continue
if
end_index
<
start_index
:
if
end_index
<
start_index
:
continue
continue
length
=
end_index
-
start_index
+
1
length
=
end_index
-
start_index
+
1
if
length
>
max_answer_length
:
if
length
>
max_answer_length
:
continue
continue
prelim_predictions
.
append
(
prelim_predictions
.
append
(
_PrelimPrediction
(
_PrelimPrediction
(
feature_index
=
feature_index
,
feature_index
=
feature_index
,
start_index
=
start_index
,
start_index
=
start_index
,
end_index
=
end_index
,
end_index
=
end_index
,
start_logit
=
result
.
start_logits
[
start_index
],
start_logit
=
result
.
start_logits
[
start_index
],
end_logit
=
result
.
end_logits
[
end_index
]))
end_logit
=
result
.
end_logits
[
end_index
]))
prelim_predictions
=
sorted
(
prelim_predictions
=
sorted
(
prelim_predictions
,
prelim_predictions
,
key
=
lambda
x
:
(
x
.
start_logit
+
x
.
end_logit
),
key
=
lambda
x
:
(
x
.
start_logit
+
x
.
end_logit
),
reverse
=
True
)
reverse
=
True
)
_NbestPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
_NbestPrediction
=
collections
.
namedtuple
(
# pylint: disable=invalid-name
"NbestPrediction"
,
[
"text"
,
"start_logit"
,
"end_logit"
])
"NbestPrediction"
,
[
"text"
,
"start_logit"
,
"end_logit"
])
seen_predictions
=
{}
seen_predictions
=
{}
nbest
=
[]
nbest
=
[]
for
pred
in
prelim_predictions
:
for
pred
in
prelim_predictions
:
if
len
(
nbest
)
>=
n_best_size
:
if
len
(
nbest
)
>=
n_best_size
:
break
break
feature
=
features
[
pred
.
feature_index
]
feature
=
features
[
pred
.
feature_index
]
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:(
pred
.
end_index
+
1
)]
tok_tokens
=
feature
.
tokens
[
pred
.
start_index
:(
pred
.
end_index
+
1
)]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_start
=
feature
.
token_to_orig_map
[
pred
.
start_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
orig_doc_end
=
feature
.
token_to_orig_map
[
pred
.
end_index
]
orig_tokens
=
example
.
doc_tokens
[
orig_doc_start
:(
orig_doc_end
+
1
)]
orig_tokens
=
example
.
doc_tokens
[
orig_doc_start
:(
orig_doc_end
+
1
)]
tok_text
=
" "
.
join
(
tok_tokens
)
tok_text
=
" "
.
join
(
tok_tokens
)
# De-tokenize WordPieces that have been split off.
# De-tokenize WordPieces that have been split off.
tok_text
=
tok_text
.
replace
(
" ##"
,
""
)
tok_text
=
tok_text
.
replace
(
" ##"
,
""
)
tok_text
=
tok_text
.
replace
(
"##"
,
""
)
tok_text
=
tok_text
.
replace
(
"##"
,
""
)
# Clean whitespace
# Clean whitespace
tok_text
=
tok_text
.
strip
()
tok_text
=
tok_text
.
strip
()
tok_text
=
" "
.
join
(
tok_text
.
split
())
tok_text
=
" "
.
join
(
tok_text
.
split
())
orig_text
=
" "
.
join
(
orig_tokens
)
orig_text
=
" "
.
join
(
orig_tokens
)
final_text
=
get_final_text
(
tok_text
,
orig_text
,
do_lower_case
)
final_text
=
get_final_text
(
tok_text
,
orig_text
,
do_lower_case
)
if
final_text
in
seen_predictions
:
if
final_text
in
seen_predictions
:
continue
continue
seen_predictions
[
final_text
]
=
True
seen_predictions
[
final_text
]
=
True
nbest
.
append
(
nbest
.
append
(
_NbestPrediction
(
_NbestPrediction
(
text
=
final_text
,
text
=
final_text
,
start_logit
=
pred
.
start_logit
,
start_logit
=
pred
.
start_logit
,
end_logit
=
pred
.
end_logit
))
end_logit
=
pred
.
end_logit
))
# In very rare edge cases we could have no valid predictions. So we
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
# just create a nonce prediction in this case to avoid failure.
if
not
nbest
:
if
not
nbest
:
nbest
.
append
(
nbest
.
append
(
_NbestPrediction
(
text
=
"empty"
,
start_logit
=
0.0
,
end_logit
=
0.0
))
_NbestPrediction
(
text
=
"empty"
,
start_logit
=
0.0
,
end_logit
=
0.0
))
assert
len
(
nbest
)
>=
1
assert
len
(
nbest
)
>=
1
total_scores
=
[]
total_scores
=
[]
for
entry
in
nbest
:
for
entry
in
nbest
:
total_scores
.
append
(
entry
.
start_logit
+
entry
.
end_logit
)
total_scores
.
append
(
entry
.
start_logit
+
entry
.
end_logit
)
probs
=
_compute_softmax
(
total_scores
)
probs
=
_compute_softmax
(
total_scores
)
nbest_json
=
[]
nbest_json
=
[]
for
(
i
,
entry
)
in
enumerate
(
nbest
):
for
(
i
,
entry
)
in
enumerate
(
nbest
):
output
=
collections
.
OrderedDict
()
output
=
collections
.
OrderedDict
()
output
[
"text"
]
=
entry
.
text
output
[
"text"
]
=
entry
.
text
output
[
"probability"
]
=
probs
[
i
]
output
[
"probability"
]
=
probs
[
i
]
output
[
"start_logit"
]
=
entry
.
start_logit
output
[
"start_logit"
]
=
entry
.
start_logit
output
[
"end_logit"
]
=
entry
.
end_logit
output
[
"end_logit"
]
=
entry
.
end_logit
nbest_json
.
append
(
output
)
nbest_json
.
append
(
output
)
assert
len
(
nbest_json
)
>=
1
assert
len
(
nbest_json
)
>=
1
all_predictions
[
example
.
qas_id
]
=
nbest_json
[
0
][
"text"
]
all_predictions
[
example
.
qas_id
]
=
nbest_json
[
0
][
"text"
]
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
with
tf
.
gfile
.
GFile
(
output_prediction_file
,
"w"
)
as
writer
:
with
tf
.
gfile
.
GFile
(
output_prediction_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_predictions
,
indent
=
4
)
+
"
\n
"
)
writer
.
write
(
json
.
dumps
(
all_predictions
,
indent
=
4
)
+
"
\n
"
)
with
tf
.
gfile
.
GFile
(
output_nbest_file
,
"w"
)
as
writer
:
with
tf
.
gfile
.
GFile
(
output_nbest_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
):
def
get_final_text
(
pred_text
,
orig_text
,
do_lower_case
):
"""Project the tokenized prediction back to the original text."""
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding to the
# now `orig_text` contains the span of our original text corresponding to the
# span that we predicted.
# span that we predicted.
#
#
# However, `orig_text` may contain extra characters that we don't want in
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
# our prediction.
#
#
# For example, let's say:
# For example, let's say:
# pred_text = steve smith
# pred_text = steve smith
# orig_text = Steve Smith's
# orig_text = Steve Smith's
#
#
# We don't want to return `orig_text` because it contains the extra "'s".
# We don't want to return `orig_text` because it contains the extra "'s".
#
#
# We don't want to return `pred_text` because it's already been normalized
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# our tokenizer does additional normalization like stripping accent
# characters).
# characters).
#
#
# What we really want to return is "Steve Smith".
# What we really want to return is "Steve Smith".
#
#
# Therefore, we have to apply a semi-complicated alignment heruistic between
# Therefore, we have to apply a semi-complicated alignment heruistic between
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
# can fail in certain cases in which case we just return `orig_text`.
# can fail in certain cases in which case we just return `orig_text`.
def
_strip_spaces
(
text
):
def
_strip_spaces
(
text
):
ns_chars
=
[]
ns_chars
=
[]
ns_to_s_map
=
collections
.
OrderedDict
()
ns_to_s_map
=
collections
.
OrderedDict
()
for
(
i
,
c
)
in
enumerate
(
text
):
for
(
i
,
c
)
in
enumerate
(
text
):
if
c
==
" "
:
if
c
==
" "
:
continue
continue
ns_to_s_map
[
len
(
ns_chars
)]
=
i
ns_to_s_map
[
len
(
ns_chars
)]
=
i
ns_chars
.
append
(
c
)
ns_chars
.
append
(
c
)
ns_text
=
""
.
join
(
ns_chars
)
ns_text
=
""
.
join
(
ns_chars
)
return
(
ns_text
,
ns_to_s_map
)
return
(
ns_text
,
ns_to_s_map
)
# We first tokenize `orig_text`, strip whitespace from the result
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
# length, we assume the characters are one-to-one aligned.
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
tok_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_text
))
tok_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_text
))
start_position
=
tok_text
.
find
(
pred_text
)
start_position
=
tok_text
.
find
(
pred_text
)
if
start_position
==
-
1
:
if
start_position
==
-
1
:
if
FLAGS
.
verbose_logging
:
if
FLAGS
.
verbose_logging
:
tf
.
logging
.
info
(
tf
.
logging
.
info
(
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
return
orig_text
return
orig_text
end_position
=
start_position
+
len
(
pred_text
)
-
1
end_position
=
start_position
+
len
(
pred_text
)
-
1
(
orig_ns_text
,
orig_ns_to_s_map
)
=
_strip_spaces
(
orig_text
)
(
orig_ns_text
,
orig_ns_to_s_map
)
=
_strip_spaces
(
orig_text
)
(
tok_ns_text
,
tok_ns_to_s_map
)
=
_strip_spaces
(
tok_text
)
(
tok_ns_text
,
tok_ns_to_s_map
)
=
_strip_spaces
(
tok_text
)
if
len
(
orig_ns_text
)
!=
len
(
tok_ns_text
):
if
len
(
orig_ns_text
)
!=
len
(
tok_ns_text
):
if
FLAGS
.
verbose_logging
:
if
FLAGS
.
verbose_logging
:
tf
.
logging
.
info
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
tf
.
logging
.
info
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
orig_ns_text
,
tok_ns_text
)
orig_ns_text
,
tok_ns_text
)
return
orig_text
return
orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
# the character-to-character alignment.
tok_s_to_ns_map
=
{}
tok_s_to_ns_map
=
{}
for
(
i
,
tok_index
)
in
six
.
iteritems
(
tok_ns_to_s_map
):
for
(
i
,
tok_index
)
in
six
.
iteritems
(
tok_ns_to_s_map
):
tok_s_to_ns_map
[
tok_index
]
=
i
tok_s_to_ns_map
[
tok_index
]
=
i
orig_start_position
=
None
orig_start_position
=
None
if
start_position
in
tok_s_to_ns_map
:
if
start_position
in
tok_s_to_ns_map
:
ns_start_position
=
tok_s_to_ns_map
[
start_position
]
ns_start_position
=
tok_s_to_ns_map
[
start_position
]
if
ns_start_position
in
orig_ns_to_s_map
:
if
ns_start_position
in
orig_ns_to_s_map
:
orig_start_position
=
orig_ns_to_s_map
[
ns_start_position
]
orig_start_position
=
orig_ns_to_s_map
[
ns_start_position
]
if
orig_start_position
is
None
:
if
orig_start_position
is
None
:
if
FLAGS
.
verbose_logging
:
if
FLAGS
.
verbose_logging
:
tf
.
logging
.
info
(
"Couldn't map start position"
)
tf
.
logging
.
info
(
"Couldn't map start position"
)
return
orig_text
return
orig_text
orig_end_position
=
None
orig_end_position
=
None
if
end_position
in
tok_s_to_ns_map
:
if
end_position
in
tok_s_to_ns_map
:
ns_end_position
=
tok_s_to_ns_map
[
end_position
]
ns_end_position
=
tok_s_to_ns_map
[
end_position
]
if
ns_end_position
in
orig_ns_to_s_map
:
if
ns_end_position
in
orig_ns_to_s_map
:
orig_end_position
=
orig_ns_to_s_map
[
ns_end_position
]
orig_end_position
=
orig_ns_to_s_map
[
ns_end_position
]
if
orig_end_position
is
None
:
if
orig_end_position
is
None
:
if
FLAGS
.
verbose_logging
:
if
FLAGS
.
verbose_logging
:
tf
.
logging
.
info
(
"Couldn't map end position"
)
tf
.
logging
.
info
(
"Couldn't map end position"
)
return
orig_text
return
orig_text
output_text
=
orig_text
[
orig_start_position
:(
orig_end_position
+
1
)]
output_text
=
orig_text
[
orig_start_position
:(
orig_end_position
+
1
)]
return
output_text
return
output_text
def
_get_best_indexes
(
logits
,
n_best_size
):
def
_get_best_indexes
(
logits
,
n_best_size
):
"""Get the n-best logits from a list."""
"""Get the n-best logits from a list."""
index_and_score
=
sorted
(
enumerate
(
logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
index_and_score
=
sorted
(
enumerate
(
logits
),
key
=
lambda
x
:
x
[
1
],
reverse
=
True
)
best_indexes
=
[]
best_indexes
=
[]
for
i
in
range
(
len
(
index_and_score
)):
for
i
in
range
(
len
(
index_and_score
)):
if
i
>=
n_best_size
:
if
i
>=
n_best_size
:
break
break
best_indexes
.
append
(
index_and_score
[
i
][
0
])
best_indexes
.
append
(
index_and_score
[
i
][
0
])
return
best_indexes
return
best_indexes
def
_compute_softmax
(
scores
):
def
_compute_softmax
(
scores
):
"""Compute softmax probability over raw logits."""
"""Compute softmax probability over raw logits."""
if
not
scores
:
if
not
scores
:
return
[]
return
[]
max_score
=
None
max_score
=
None
for
score
in
scores
:
for
score
in
scores
:
if
max_score
is
None
or
score
>
max_score
:
if
max_score
is
None
or
score
>
max_score
:
max_score
=
score
max_score
=
score
exp_scores
=
[]
exp_scores
=
[]
total_sum
=
0.0
total_sum
=
0.0
for
score
in
scores
:
for
score
in
scores
:
x
=
math
.
exp
(
score
-
max_score
)
x
=
math
.
exp
(
score
-
max_score
)
exp_scores
.
append
(
x
)
exp_scores
.
append
(
x
)
total_sum
+=
x
total_sum
+=
x
probs
=
[]
probs
=
[]
for
score
in
exp_scores
:
for
score
in
exp_scores
:
probs
.
append
(
score
/
total_sum
)
probs
.
append
(
score
/
total_sum
)
return
probs
return
probs
def
main
(
_
):
def
main
(
_
):
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
if
not
FLAGS
.
do_train
and
not
FLAGS
.
do_predict
:
if
not
FLAGS
.
do_train
and
not
FLAGS
.
do_predict
:
raise
ValueError
(
"At least one of `do_train` or `do_predict` must be True."
)
raise
ValueError
(
"At least one of `do_train` or `do_predict` must be True."
)
if
FLAGS
.
do_train
:
if
FLAGS
.
do_train
:
if
not
FLAGS
.
train_file
:
if
not
FLAGS
.
train_file
:
raise
ValueError
(
raise
ValueError
(
"If `do_train` is True, then `train_file` must be specified."
)
"If `do_train` is True, then `train_file` must be specified."
)
if
FLAGS
.
do_predict
:
if
FLAGS
.
do_predict
:
if
not
FLAGS
.
predict_file
:
if
not
FLAGS
.
predict_file
:
raise
ValueError
(
raise
ValueError
(
"If `do_predict` is True, then `predict_file` must be specified."
)
"If `do_predict` is True, then `predict_file` must be specified."
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
bert_config
=
modeling
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
max_seq_length
>
bert_config
.
max_position_embeddings
:
if
FLAGS
.
max_seq_length
>
bert_config
.
max_position_embeddings
:
raise
ValueError
(
raise
ValueError
(
"Cannot use sequence length %d because the BERT model "
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d"
%
"was only trained up to sequence length %d"
%
(
FLAGS
.
max_seq_length
,
bert_config
.
max_position_embeddings
))
(
FLAGS
.
max_seq_length
,
bert_config
.
max_position_embeddings
))
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
tf
.
gfile
.
MakeDirs
(
FLAGS
.
output_dir
)
tokenizer
=
tokenization
.
FullTokenizer
(
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
vocab_file
=
FLAGS
.
vocab_file
,
do_lower_case
=
FLAGS
.
do_lower_case
)
tpu_cluster_resolver
=
None
tpu_cluster_resolver
=
None
if
FLAGS
.
use_tpu
and
FLAGS
.
tpu_name
:
if
FLAGS
.
use_tpu
and
FLAGS
.
tpu_name
:
tpu_cluster_resolver
=
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
tpu_cluster_resolver
=
tf
.
contrib
.
cluster_resolver
.
TPUClusterResolver
(
FLAGS
.
tpu_name
,
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
)
FLAGS
.
tpu_name
,
zone
=
FLAGS
.
tpu_zone
,
project
=
FLAGS
.
gcp_project
)
is_per_host
=
tf
.
contrib
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
is_per_host
=
tf
.
contrib
.
tpu
.
InputPipelineConfig
.
PER_HOST_V2
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
run_config
=
tf
.
contrib
.
tpu
.
RunConfig
(
cluster
=
tpu_cluster_resolver
,
cluster
=
tpu_cluster_resolver
,
master
=
FLAGS
.
master
,
master
=
FLAGS
.
master
,
model_dir
=
FLAGS
.
output_dir
,
model_dir
=
FLAGS
.
output_dir
,
save_checkpoints_steps
=
FLAGS
.
save_checkpoints_steps
,
save_checkpoints_steps
=
FLAGS
.
save_checkpoints_steps
,
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
tpu_config
=
tf
.
contrib
.
tpu
.
TPUConfig
(
iterations_per_loop
=
FLAGS
.
iterations_per_loop
,
iterations_per_loop
=
FLAGS
.
iterations_per_loop
,
num_shards
=
FLAGS
.
num_tpu_cores
,
num_shards
=
FLAGS
.
num_tpu_cores
,
per_host_input_for_training
=
is_per_host
))
per_host_input_for_training
=
is_per_host
))
train_examples
=
None
train_examples
=
None
num_train_steps
=
None
num_train_steps
=
None
num_warmup_steps
=
None
num_warmup_steps
=
None
if
FLAGS
.
do_train
:
if
FLAGS
.
do_train
:
train_examples
=
read_squad_examples
(
train_examples
=
read_squad_examples
(
input_file
=
FLAGS
.
train_file
,
is_training
=
True
)
input_file
=
FLAGS
.
train_file
,
is_training
=
True
)
num_train_steps
=
int
(
num_train_steps
=
int
(
len
(
train_examples
)
/
FLAGS
.
train_batch_size
*
FLAGS
.
num_train_epochs
)
len
(
train_examples
)
/
FLAGS
.
train_batch_size
*
FLAGS
.
num_train_epochs
)
num_warmup_steps
=
int
(
num_train_steps
*
FLAGS
.
warmup_proportion
)
num_warmup_steps
=
int
(
num_train_steps
*
FLAGS
.
warmup_proportion
)
model_fn
=
model_fn_builder
(
model_fn
=
model_fn_builder
(
bert_config
=
bert_config
,
bert_config
=
bert_config
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
init_checkpoint
=
FLAGS
.
init_checkpoint
,
learning_rate
=
FLAGS
.
learning_rate
,
learning_rate
=
FLAGS
.
learning_rate
,
num_train_steps
=
num_train_steps
,
num_train_steps
=
num_train_steps
,
num_warmup_steps
=
num_warmup_steps
,
num_warmup_steps
=
num_warmup_steps
,
use_tpu
=
FLAGS
.
use_tpu
,
use_tpu
=
FLAGS
.
use_tpu
,
use_one_hot_embeddings
=
FLAGS
.
use_tpu
)
use_one_hot_embeddings
=
FLAGS
.
use_tpu
)
# If TPU is not available, this will fall back to normal Estimator on CPU
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
# or GPU.
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
estimator
=
tf
.
contrib
.
tpu
.
TPUEstimator
(
use_tpu
=
FLAGS
.
use_tpu
,
use_tpu
=
FLAGS
.
use_tpu
,
model_fn
=
model_fn
,
model_fn
=
model_fn
,
config
=
run_config
,
config
=
run_config
,
train_batch_size
=
FLAGS
.
train_batch_size
,
train_batch_size
=
FLAGS
.
train_batch_size
,
predict_batch_size
=
FLAGS
.
predict_batch_size
)
predict_batch_size
=
FLAGS
.
predict_batch_size
)
if
FLAGS
.
do_train
:
if
FLAGS
.
do_train
:
train_features
=
convert_examples_to_features
(
train_features
=
convert_examples_to_features
(
examples
=
train_examples
,
examples
=
train_examples
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
max_seq_length
=
FLAGS
.
max_seq_length
,
max_seq_length
=
FLAGS
.
max_seq_length
,
doc_stride
=
FLAGS
.
doc_stride
,
doc_stride
=
FLAGS
.
doc_stride
,
max_query_length
=
FLAGS
.
max_query_length
,
max_query_length
=
FLAGS
.
max_query_length
,
is_training
=
True
)
is_training
=
True
)
tf
.
logging
.
info
(
"***** Running training *****"
)
tf
.
logging
.
info
(
"***** Running training *****"
)
tf
.
logging
.
info
(
" Num orig examples = %d"
,
len
(
train_examples
))
tf
.
logging
.
info
(
" Num orig examples = %d"
,
len
(
train_examples
))
tf
.
logging
.
info
(
" Num split examples = %d"
,
len
(
train_features
))
tf
.
logging
.
info
(
" Num split examples = %d"
,
len
(
train_features
))
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
train_batch_size
)
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
train_batch_size
)
tf
.
logging
.
info
(
" Num steps = %d"
,
num_train_steps
)
tf
.
logging
.
info
(
" Num steps = %d"
,
num_train_steps
)
train_input_fn
=
input_fn_builder
(
train_input_fn
=
input_fn_builder
(
features
=
train_features
,
features
=
train_features
,
seq_length
=
FLAGS
.
max_seq_length
,
seq_length
=
FLAGS
.
max_seq_length
,
is_training
=
True
,
is_training
=
True
,
drop_remainder
=
True
)
drop_remainder
=
True
)
estimator
.
train
(
input_fn
=
train_input_fn
,
max_steps
=
num_train_steps
)
estimator
.
train
(
input_fn
=
train_input_fn
,
max_steps
=
num_train_steps
)
if
FLAGS
.
do_predict
:
if
FLAGS
.
do_predict
:
eval_examples
=
read_squad_examples
(
eval_examples
=
read_squad_examples
(
input_file
=
FLAGS
.
predict_file
,
is_training
=
False
)
input_file
=
FLAGS
.
predict_file
,
is_training
=
False
)
eval_features
=
convert_examples_to_features
(
eval_features
=
convert_examples_to_features
(
examples
=
eval_examples
,
examples
=
eval_examples
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
max_seq_length
=
FLAGS
.
max_seq_length
,
max_seq_length
=
FLAGS
.
max_seq_length
,
doc_stride
=
FLAGS
.
doc_stride
,
doc_stride
=
FLAGS
.
doc_stride
,
max_query_length
=
FLAGS
.
max_query_length
,
max_query_length
=
FLAGS
.
max_query_length
,
is_training
=
False
)
is_training
=
False
)
tf
.
logging
.
info
(
"***** Running predictions *****"
)
tf
.
logging
.
info
(
"***** Running predictions *****"
)
tf
.
logging
.
info
(
" Num orig examples = %d"
,
len
(
eval_examples
))
tf
.
logging
.
info
(
" Num orig examples = %d"
,
len
(
eval_examples
))
tf
.
logging
.
info
(
" Num split examples = %d"
,
len
(
eval_features
))
tf
.
logging
.
info
(
" Num split examples = %d"
,
len
(
eval_features
))
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
predict_batch_size
)
tf
.
logging
.
info
(
" Batch size = %d"
,
FLAGS
.
predict_batch_size
)
all_results
=
[]
all_results
=
[]
predict_input_fn
=
input_fn_builder
(
predict_input_fn
=
input_fn_builder
(
features
=
eval_features
,
features
=
eval_features
,
seq_length
=
FLAGS
.
max_seq_length
,
seq_length
=
FLAGS
.
max_seq_length
,
is_training
=
False
,
is_training
=
False
,
drop_remainder
=
False
)
drop_remainder
=
False
)
# If running eval on the TPU, you will need to specify the number of
# If running eval on the TPU, you will need to specify the number of
# steps.
# steps.
all_results
=
[]
all_results
=
[]
for
result
in
estimator
.
predict
(
for
result
in
estimator
.
predict
(
predict_input_fn
,
yield_single_examples
=
True
):
predict_input_fn
,
yield_single_examples
=
True
):
if
len
(
all_results
)
%
1000
==
0
:
if
len
(
all_results
)
%
1000
==
0
:
tf
.
logging
.
info
(
"Processing example: %d"
%
(
len
(
all_results
)))
tf
.
logging
.
info
(
"Processing example: %d"
%
(
len
(
all_results
)))
unique_id
=
int
(
result
[
"unique_ids"
])
unique_id
=
int
(
result
[
"unique_ids"
])
start_logits
=
[
float
(
x
)
for
x
in
result
[
"start_logits"
].
flat
]
start_logits
=
[
float
(
x
)
for
x
in
result
[
"start_logits"
].
flat
]
end_logits
=
[
float
(
x
)
for
x
in
result
[
"end_logits"
].
flat
]
end_logits
=
[
float
(
x
)
for
x
in
result
[
"end_logits"
].
flat
]
all_results
.
append
(
all_results
.
append
(
RawResult
(
RawResult
(
unique_id
=
unique_id
,
unique_id
=
unique_id
,
start_logits
=
start_logits
,
start_logits
=
start_logits
,
end_logits
=
end_logits
))
end_logits
=
end_logits
))
output_prediction_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"predictions.json"
)
output_prediction_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"predictions.json"
)
output_nbest_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"nbest_predictions.json"
)
output_nbest_file
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
"nbest_predictions.json"
)
write_predictions
(
eval_examples
,
eval_features
,
all_results
,
write_predictions
(
eval_examples
,
eval_features
,
all_results
,
FLAGS
.
n_best_size
,
FLAGS
.
max_answer_length
,
FLAGS
.
n_best_size
,
FLAGS
.
max_answer_length
,
FLAGS
.
do_lower_case
,
output_prediction_file
,
FLAGS
.
do_lower_case
,
output_prediction_file
,
output_nbest_file
)
output_nbest_file
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
flags
.
mark_flag_as_required
(
"vocab_file"
)
flags
.
mark_flag_as_required
(
"vocab_file"
)
flags
.
mark_flag_as_required
(
"bert_config_file"
)
flags
.
mark_flag_as_required
(
"bert_config_file"
)
flags
.
mark_flag_as_required
(
"output_dir"
)
flags
.
mark_flag_as_required
(
"output_dir"
)
tf
.
app
.
run
()
tf
.
app
.
run
()
tokenization.py
View file @
8163baab
...
@@ -25,268 +25,268 @@ import tensorflow as tf
...
@@ -25,268 +25,268 @@ import tensorflow as tf
def
convert_to_unicode
(
text
):
def
convert_to_unicode
(
text
):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if
six
.
PY3
:
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
if
isinstance
(
text
,
str
):
return
text
return
text
elif
isinstance
(
text
,
bytes
):
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
elif
isinstance
(
text
,
unicode
):
return
text
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
elif
isinstance
(
text
,
unicode
):
return
text
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
printable_text
(
text
):
def
printable_text
(
text
):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
# it's a Unicode string and in the other it's a byte string.
if
six
.
PY3
:
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
if
isinstance
(
text
,
str
):
return
text
return
text
elif
isinstance
(
text
,
bytes
):
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
if
isinstance
(
text
,
str
):
return
text
return
text
elif
isinstance
(
text
,
unicode
):
elif
isinstance
(
text
,
unicode
):
return
text
.
encode
(
"utf-8"
)
return
text
.
encode
(
"utf-8"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
load_vocab
(
vocab_file
):
def
load_vocab
(
vocab_file
):
"""Loads a vocabulary file into a dictionary."""
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
vocab
=
collections
.
OrderedDict
()
index
=
0
index
=
0
with
tf
.
gfile
.
GFile
(
vocab_file
,
"r"
)
as
reader
:
with
tf
.
gfile
.
GFile
(
vocab_file
,
"r"
)
as
reader
:
while
True
:
while
True
:
token
=
convert_to_unicode
(
reader
.
readline
())
token
=
convert_to_unicode
(
reader
.
readline
())
if
not
token
:
if
not
token
:
break
break
token
=
token
.
strip
()
token
=
token
.
strip
()
vocab
[
token
]
=
index
vocab
[
token
]
=
index
index
+=
1
index
+=
1
return
vocab
return
vocab
def
convert_tokens_to_ids
(
vocab
,
tokens
):
def
convert_tokens_to_ids
(
vocab
,
tokens
):
"""Converts a sequence of tokens into ids using the vocab."""
"""Converts a sequence of tokens into ids using the vocab."""
ids
=
[]
ids
=
[]
for
token
in
tokens
:
for
token
in
tokens
:
ids
.
append
(
vocab
[
token
])
ids
.
append
(
vocab
[
token
])
return
ids
return
ids
def
whitespace_tokenize
(
text
):
def
whitespace_tokenize
(
text
):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text
=
text
.
strip
()
text
=
text
.
strip
()
if
not
text
:
if
not
text
:
return
[]
return
[]
tokens
=
text
.
split
()
tokens
=
text
.
split
()
return
tokens
return
tokens
class
FullTokenizer
(
object
):
class
FullTokenizer
(
object
):
"""Runs end-to-end tokenziation."""
"""Runs end-to-end tokenziation."""
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
):
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
):
self
.
vocab
=
load_vocab
(
vocab_file
)
self
.
vocab
=
load_vocab
(
vocab_file
)
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
)
def
tokenize
(
self
,
text
):
def
tokenize
(
self
,
text
):
split_tokens
=
[]
split_tokens
=
[]
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
):
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
split_tokens
.
append
(
sub_token
)
split_tokens
.
append
(
sub_token
)
return
split_tokens
return
split_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
def
convert_tokens_to_ids
(
self
,
tokens
):
return
convert_tokens_to_ids
(
self
.
vocab
,
tokens
)
return
convert_tokens_to_ids
(
self
.
vocab
,
tokens
)
class
BasicTokenizer
(
object
):
class
BasicTokenizer
(
object
):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
self
,
do_lower_case
=
True
):
def
__init__
(
self
,
do_lower_case
=
True
):
"""Constructs a BasicTokenizer.
"""Constructs a BasicTokenizer.
Args:
Args:
do_lower_case: Whether to lower case the input.
do_lower_case: Whether to lower case the input.
"""
"""
self
.
do_lower_case
=
do_lower_case
self
.
do_lower_case
=
do_lower_case
def
tokenize
(
self
,
text
):
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text."""
"""Tokenizes a piece of text."""
text
=
convert_to_unicode
(
text
)
text
=
convert_to_unicode
(
text
)
text
=
self
.
_clean_text
(
text
)
text
=
self
.
_clean_text
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
split_tokens
=
[]
split_tokens
=
[]
for
token
in
orig_tokens
:
for
token
in
orig_tokens
:
if
self
.
do_lower_case
:
if
self
.
do_lower_case
:
token
=
token
.
lower
()
token
=
token
.
lower
()
token
=
self
.
_run_strip_accents
(
token
)
token
=
self
.
_run_strip_accents
(
token
)
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
output_tokens
=
whitespace_tokenize
(
" "
.
join
(
split_tokens
))
output_tokens
=
whitespace_tokenize
(
" "
.
join
(
split_tokens
))
return
output_tokens
return
output_tokens
def
_run_strip_accents
(
self
,
text
):
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
output
=
[]
for
char
in
text
:
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
if
cat
==
"Mn"
:
continue
continue
output
.
append
(
char
)
output
.
append
(
char
)
return
""
.
join
(
output
)
return
""
.
join
(
output
)
def
_run_split_on_punc
(
self
,
text
):
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
"""Splits punctuation on a piece of text."""
chars
=
list
(
text
)
chars
=
list
(
text
)
i
=
0
i
=
0
start_new_word
=
True
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
start_new_word
=
True
start_new_word
=
True
else
:
output
=
[]
if
start_new_word
:
while
i
<
len
(
chars
):
output
.
append
([])
char
=
chars
[
i
]
start_new_word
=
False
if
_is_punctuation
(
char
):
output
[
-
1
].
append
(
char
)
output
.
append
([
char
])
i
+=
1
start_new_word
=
True
else
:
return
[
""
.
join
(
x
)
for
x
in
output
]
if
start_new_word
:
output
.
append
([])
def
_clean_text
(
self
,
text
):
start_new_word
=
False
"""Performs invalid character removal and whitespace cleanup on text."""
output
[
-
1
].
append
(
char
)
output
=
[]
i
+=
1
for
char
in
text
:
cp
=
ord
(
char
)
return
[
""
.
join
(
x
)
for
x
in
output
]
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
def
_clean_text
(
self
,
text
):
if
_is_whitespace
(
char
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
.
append
(
" "
)
output
=
[]
else
:
for
char
in
text
:
output
.
append
(
char
)
cp
=
ord
(
char
)
return
""
.
join
(
output
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
class
WordpieceTokenizer
(
object
):
class
WordpieceTokenizer
(
object
):
"""Runs WordPiece tokenziation."""
"""Runs WordPiece tokenziation."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
100
):
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
100
):
self
.
vocab
=
vocab
self
.
vocab
=
vocab
self
.
unk_token
=
unk_token
self
.
unk_token
=
unk_token
self
.
max_input_chars_per_word
=
max_input_chars_per_word
self
.
max_input_chars_per_word
=
max_input_chars_per_word
def
tokenize
(
self
,
text
):
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text into its word pieces.
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
using the given vocabulary.
For example:
For example:
input = "unaffable"
input = "unaffable"
output = ["un", "##aff", "##able"]
output = ["un", "##aff", "##able"]
Args:
Args:
text: A single token or whitespace separated tokens. This should have
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
already been passed through `BasicTokenizer.
Returns:
Returns:
A list of wordpiece tokens.
A list of wordpiece tokens.
"""
"""
text
=
convert_to_unicode
(
text
)
text
=
convert_to_unicode
(
text
)
output_tokens
=
[]
output_tokens
=
[]
for
token
in
whitespace_tokenize
(
text
):
for
token
in
whitespace_tokenize
(
text
):
chars
=
list
(
token
)
chars
=
list
(
token
)
if
len
(
chars
)
>
self
.
max_input_chars_per_word
:
if
len
(
chars
)
>
self
.
max_input_chars_per_word
:
output_tokens
.
append
(
self
.
unk_token
)
output_tokens
.
append
(
self
.
unk_token
)
continue
continue
is_bad
=
False
is_bad
=
False
start
=
0
start
=
0
sub_tokens
=
[]
sub_tokens
=
[]
while
start
<
len
(
chars
):
while
start
<
len
(
chars
):
end
=
len
(
chars
)
end
=
len
(
chars
)
cur_substr
=
None
cur_substr
=
None
while
start
<
end
:
while
start
<
end
:
substr
=
""
.
join
(
chars
[
start
:
end
])
substr
=
""
.
join
(
chars
[
start
:
end
])
if
start
>
0
:
if
start
>
0
:
substr
=
"##"
+
substr
substr
=
"##"
+
substr
if
substr
in
self
.
vocab
:
if
substr
in
self
.
vocab
:
cur_substr
=
substr
cur_substr
=
substr
break
break
end
-=
1
end
-=
1
if
cur_substr
is
None
:
if
cur_substr
is
None
:
is_bad
=
True
is_bad
=
True
break
break
sub_tokens
.
append
(
cur_substr
)
sub_tokens
.
append
(
cur_substr
)
start
=
end
start
=
end
if
is_bad
:
if
is_bad
:
output_tokens
.
append
(
self
.
unk_token
)
output_tokens
.
append
(
self
.
unk_token
)
else
:
else
:
output_tokens
.
extend
(
sub_tokens
)
output_tokens
.
extend
(
sub_tokens
)
return
output_tokens
return
output_tokens
def
_is_whitespace
(
char
):
def
_is_whitespace
(
char
):
"""Checks whether `chars` is a whitespace character."""
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
return
True
cat
=
unicodedata
.
category
(
char
)
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
if
cat
==
"Zs"
:
return
True
return
True
return
False
return
False
def
_is_control
(
char
):
def
_is_control
(
char
):
"""Checks whether `chars` is a control character."""
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# These are technically control characters but we count them as whitespace
# characters.
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"C"
):
return
True
return
False
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"C"
):
return
True
return
False
def
_is_punctuation
(
char
):
def
_is_punctuation
(
char
):
"""Checks whether `chars` is a punctuation character."""
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
return
True
cat
=
unicodedata
.
category
(
char
)
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
if
cat
.
startswith
(
"P"
):
return
True
return
True
return
False
return
False
tokenization_test.py
View file @
8163baab
...
@@ -25,101 +25,101 @@ import tensorflow as tf
...
@@ -25,101 +25,101 @@ import tensorflow as tf
class
TokenizationTest
(
tf
.
test
.
TestCase
):
class
TokenizationTest
(
tf
.
test
.
TestCase
):
def
test_full_tokenizer
(
self
):
def
test_full_tokenizer
(
self
):
vocab_tokens
=
[
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
"##ing"
,
","
]
]
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
vocab_writer
:
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
vocab_writer
.
name
vocab_file
=
vocab_writer
.
name
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
)
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
)
os
.
unlink
(
vocab_file
)
os
.
unlink
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
self
.
assertAllEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertAllEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertAllEqual
(
self
.
assertAllEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
def
test_basic_tokenizer_lower
(
self
):
def
test_basic_tokenizer_lower
(
self
):
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
True
)
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
True
)
self
.
assertAllEqual
(
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"hello"
,
"!"
,
"how"
,
"are"
,
"you"
,
"?"
])
[
"hello"
,
"!"
,
"how"
,
"are"
,
"you"
,
"?"
])
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
"hello"
])
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
"hello"
])
def
test_basic_tokenizer_no_lower
(
self
):
def
test_basic_tokenizer_no_lower
(
self
):
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
False
)
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
False
)
self
.
assertAllEqual
(
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
def
test_wordpiece_tokenizer
(
self
):
def
test_wordpiece_tokenizer
(
self
):
vocab_tokens
=
[
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
"##ing"
]
]
vocab
=
{}
vocab
=
{}
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
vocab
[
token
]
=
i
vocab
[
token
]
=
i
tokenizer
=
tokenization
.
WordpieceTokenizer
(
vocab
=
vocab
)
tokenizer
=
tokenization
.
WordpieceTokenizer
(
vocab
=
vocab
)
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
""
),
[])
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
""
),
[])
self
.
assertAllEqual
(
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
"unwanted running"
),
tokenizer
.
tokenize
(
"unwanted running"
),
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
])
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
])
self
.
assertAllEqual
(
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
"unwantedX running"
),
[
"[UNK]"
,
"runn"
,
"##ing"
])
tokenizer
.
tokenize
(
"unwantedX running"
),
[
"[UNK]"
,
"runn"
,
"##ing"
])
def
test_convert_tokens_to_ids
(
self
):
def
test_convert_tokens_to_ids
(
self
):
vocab_tokens
=
[
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
"##ing"
]
]
vocab
=
{}
vocab
=
{}
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
vocab
[
token
]
=
i
vocab
[
token
]
=
i
self
.
assertAllEqual
(
self
.
assertAllEqual
(
tokenization
.
convert_tokens_to_ids
(
tokenization
.
convert_tokens_to_ids
(
vocab
,
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
]),
[
7
,
4
,
5
,
8
,
9
])
vocab
,
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
]),
[
7
,
4
,
5
,
8
,
9
])
def
test_is_whitespace
(
self
):
def
test_is_whitespace
(
self
):
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
" "
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
" "
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\t
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\t
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\r
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\r
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\n
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\n
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\u00A0
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\u00A0
"
))
self
.
assertFalse
(
tokenization
.
_is_whitespace
(
u
"A"
))
self
.
assertFalse
(
tokenization
.
_is_whitespace
(
u
"A"
))
self
.
assertFalse
(
tokenization
.
_is_whitespace
(
u
"-"
))
self
.
assertFalse
(
tokenization
.
_is_whitespace
(
u
"-"
))
def
test_is_control
(
self
):
def
test_is_control
(
self
):
self
.
assertTrue
(
tokenization
.
_is_control
(
u
"
\u0005
"
))
self
.
assertTrue
(
tokenization
.
_is_control
(
u
"
\u0005
"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"A"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"A"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
" "
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
" "
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"
\t
"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"
\t
"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"
\r
"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"
\r
"
))
def
test_is_punctuation
(
self
):
def
test_is_punctuation
(
self
):
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"-"
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"-"
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"$"
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"$"
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"`"
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"`"
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"."
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"."
))
self
.
assertFalse
(
tokenization
.
_is_punctuation
(
u
"A"
))
self
.
assertFalse
(
tokenization
.
_is_punctuation
(
u
"A"
))
self
.
assertFalse
(
tokenization
.
_is_punctuation
(
u
" "
))
self
.
assertFalse
(
tokenization
.
_is_punctuation
(
u
" "
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment