Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
a32ffa95
Commit
a32ffa95
authored
Feb 03, 2023
by
qianyj
Browse files
update TensorFlow2x test method
parent
e286da17
Changes
268
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1400 additions
and
0 deletions
+1400
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/bert/squad_evaluate_v2_0.py
...on/models-master/official/nlp/bert/squad_evaluate_v2_0.py
+249
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/bert/tf1_checkpoint_converter_lib.py
...-master/official/nlp/bert/tf1_checkpoint_converter_lib.py
+201
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/bert/tf2_encoder_checkpoint_converter.py
...ter/official/nlp/bert/tf2_encoder_checkpoint_converter.py
+160
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/bert/tokenization.py
...ification/models-master/official/nlp/bert/tokenization.py
+541
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/bert/tokenization_test.py
...tion/models-master/official/nlp/bert/tokenization_test.py
+156
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/configs/__init__.py
...sification/models-master/official/nlp/configs/__init__.py
+14
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/configs/bert.py
...Classification/models-master/official/nlp/configs/bert.py
+43
-0
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/configs/electra.py
...ssification/models-master/official/nlp/configs/electra.py
+36
-0
No files found.
Too many changes to show.
To preserve performance only
268 of 268+
files are displayed.
Plain diff
Email patch
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/bert/squad_evaluate_v2_0.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation script for SQuAD version 2.0.
The functions are copied and modified from
https://raw.githubusercontent.com/white127/SQUAD-2.0-bidaf/master/evaluate-v2.0.py
In addition to basic functionality, we also compute additional statistics and
plot precision-recall curves if an additional na_prob.json file is provided.
This file is expected to map question ID's to the model's predicted probability
that a question is unanswerable.
"""
import
collections
import
re
import
string
from
absl
import
logging
def
_make_qid_to_has_ans
(
dataset
):
qid_to_has_ans
=
{}
for
article
in
dataset
:
for
p
in
article
[
'paragraphs'
]:
for
qa
in
p
[
'qas'
]:
qid_to_has_ans
[
qa
[
'id'
]]
=
bool
(
qa
[
'answers'
])
return
qid_to_has_ans
def
_normalize_answer
(
s
):
"""Lower text and remove punctuation, articles and extra whitespace."""
def
remove_articles
(
text
):
regex
=
re
.
compile
(
r
'\b(a|an|the)\b'
,
re
.
UNICODE
)
return
re
.
sub
(
regex
,
' '
,
text
)
def
white_space_fix
(
text
):
return
' '
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
''
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
_get_tokens
(
s
):
if
not
s
:
return
[]
return
_normalize_answer
(
s
).
split
()
def
_compute_exact
(
a_gold
,
a_pred
):
return
int
(
_normalize_answer
(
a_gold
)
==
_normalize_answer
(
a_pred
))
def
_compute_f1
(
a_gold
,
a_pred
):
"""Compute F1-score."""
gold_toks
=
_get_tokens
(
a_gold
)
pred_toks
=
_get_tokens
(
a_pred
)
common
=
collections
.
Counter
(
gold_toks
)
&
collections
.
Counter
(
pred_toks
)
num_same
=
sum
(
common
.
values
())
if
not
gold_toks
or
not
pred_toks
:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return
int
(
gold_toks
==
pred_toks
)
if
num_same
==
0
:
return
0
precision
=
1.0
*
num_same
/
len
(
pred_toks
)
recall
=
1.0
*
num_same
/
len
(
gold_toks
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
def
_get_raw_scores
(
dataset
,
predictions
):
"""Compute raw scores."""
exact_scores
=
{}
f1_scores
=
{}
for
article
in
dataset
:
for
p
in
article
[
'paragraphs'
]:
for
qa
in
p
[
'qas'
]:
qid
=
qa
[
'id'
]
gold_answers
=
[
a
[
'text'
]
for
a
in
qa
[
'answers'
]
if
_normalize_answer
(
a
[
'text'
])]
if
not
gold_answers
:
# For unanswerable questions, only correct answer is empty string
gold_answers
=
[
''
]
if
qid
not
in
predictions
:
logging
.
error
(
'Missing prediction for %s'
,
qid
)
continue
a_pred
=
predictions
[
qid
]
# Take max over all gold answers
exact_scores
[
qid
]
=
max
(
_compute_exact
(
a
,
a_pred
)
for
a
in
gold_answers
)
f1_scores
[
qid
]
=
max
(
_compute_f1
(
a
,
a_pred
)
for
a
in
gold_answers
)
return
exact_scores
,
f1_scores
def
_apply_no_ans_threshold
(
scores
,
na_probs
,
qid_to_has_ans
,
na_prob_thresh
=
1.0
):
new_scores
=
{}
for
qid
,
s
in
scores
.
items
():
pred_na
=
na_probs
[
qid
]
>
na_prob_thresh
if
pred_na
:
new_scores
[
qid
]
=
float
(
not
qid_to_has_ans
[
qid
])
else
:
new_scores
[
qid
]
=
s
return
new_scores
def
_make_eval_dict
(
exact_scores
,
f1_scores
,
qid_list
=
None
):
"""Make evaluation result dictionary."""
if
not
qid_list
:
total
=
len
(
exact_scores
)
return
collections
.
OrderedDict
([
(
'exact'
,
100.0
*
sum
(
exact_scores
.
values
())
/
total
),
(
'f1'
,
100.0
*
sum
(
f1_scores
.
values
())
/
total
),
(
'total'
,
total
),
])
else
:
total
=
len
(
qid_list
)
return
collections
.
OrderedDict
([
(
'exact'
,
100.0
*
sum
(
exact_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
'f1'
,
100.0
*
sum
(
f1_scores
[
k
]
for
k
in
qid_list
)
/
total
),
(
'total'
,
total
),
])
def
_merge_eval
(
main_eval
,
new_eval
,
prefix
):
for
k
in
new_eval
:
main_eval
[
'%s_%s'
%
(
prefix
,
k
)]
=
new_eval
[
k
]
def
_make_precision_recall_eval
(
scores
,
na_probs
,
num_true_pos
,
qid_to_has_ans
):
"""Make evaluation dictionary containing average recision recall."""
qid_list
=
sorted
(
na_probs
,
key
=
lambda
k
:
na_probs
[
k
])
true_pos
=
0.0
cur_p
=
1.0
cur_r
=
0.0
precisions
=
[
1.0
]
recalls
=
[
0.0
]
avg_prec
=
0.0
for
i
,
qid
in
enumerate
(
qid_list
):
if
qid_to_has_ans
[
qid
]:
true_pos
+=
scores
[
qid
]
cur_p
=
true_pos
/
float
(
i
+
1
)
cur_r
=
true_pos
/
float
(
num_true_pos
)
if
i
==
len
(
qid_list
)
-
1
or
na_probs
[
qid
]
!=
na_probs
[
qid_list
[
i
+
1
]]:
# i.e., if we can put a threshold after this point
avg_prec
+=
cur_p
*
(
cur_r
-
recalls
[
-
1
])
precisions
.
append
(
cur_p
)
recalls
.
append
(
cur_r
)
return
{
'ap'
:
100.0
*
avg_prec
}
def
_run_precision_recall_analysis
(
main_eval
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
):
"""Run precision recall analysis and return result dictionary."""
num_true_pos
=
sum
(
1
for
v
in
qid_to_has_ans
.
values
()
if
v
)
if
num_true_pos
==
0
:
return
pr_exact
=
_make_precision_recall_eval
(
exact_raw
,
na_probs
,
num_true_pos
,
qid_to_has_ans
)
pr_f1
=
_make_precision_recall_eval
(
f1_raw
,
na_probs
,
num_true_pos
,
qid_to_has_ans
)
oracle_scores
=
{
k
:
float
(
v
)
for
k
,
v
in
qid_to_has_ans
.
items
()}
pr_oracle
=
_make_precision_recall_eval
(
oracle_scores
,
na_probs
,
num_true_pos
,
qid_to_has_ans
)
_merge_eval
(
main_eval
,
pr_exact
,
'pr_exact'
)
_merge_eval
(
main_eval
,
pr_f1
,
'pr_f1'
)
_merge_eval
(
main_eval
,
pr_oracle
,
'pr_oracle'
)
def
_find_best_thresh
(
predictions
,
scores
,
na_probs
,
qid_to_has_ans
):
"""Find the best threshold for no answer probability."""
num_no_ans
=
sum
(
1
for
k
in
qid_to_has_ans
if
not
qid_to_has_ans
[
k
])
cur_score
=
num_no_ans
best_score
=
cur_score
best_thresh
=
0.0
qid_list
=
sorted
(
na_probs
,
key
=
lambda
k
:
na_probs
[
k
])
for
qid
in
qid_list
:
if
qid
not
in
scores
:
continue
if
qid_to_has_ans
[
qid
]:
diff
=
scores
[
qid
]
else
:
if
predictions
[
qid
]:
diff
=
-
1
else
:
diff
=
0
cur_score
+=
diff
if
cur_score
>
best_score
:
best_score
=
cur_score
best_thresh
=
na_probs
[
qid
]
return
100.0
*
best_score
/
len
(
scores
),
best_thresh
def
_find_all_best_thresh
(
main_eval
,
predictions
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
):
best_exact
,
exact_thresh
=
_find_best_thresh
(
predictions
,
exact_raw
,
na_probs
,
qid_to_has_ans
)
best_f1
,
f1_thresh
=
_find_best_thresh
(
predictions
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
main_eval
[
'final_exact'
]
=
best_exact
main_eval
[
'final_exact_thresh'
]
=
exact_thresh
main_eval
[
'final_f1'
]
=
best_f1
main_eval
[
'final_f1_thresh'
]
=
f1_thresh
def
evaluate
(
dataset
,
predictions
,
na_probs
=
None
):
"""Evaluate prediction results."""
new_orig_data
=
[]
for
article
in
dataset
:
for
p
in
article
[
'paragraphs'
]:
for
qa
in
p
[
'qas'
]:
if
qa
[
'id'
]
in
predictions
:
new_para
=
{
'qas'
:
[
qa
]}
new_article
=
{
'paragraphs'
:
[
new_para
]}
new_orig_data
.
append
(
new_article
)
dataset
=
new_orig_data
if
na_probs
is
None
:
na_probs
=
{
k
:
0.0
for
k
in
predictions
}
qid_to_has_ans
=
_make_qid_to_has_ans
(
dataset
)
# maps qid to True/False
has_ans_qids
=
[
k
for
k
,
v
in
qid_to_has_ans
.
items
()
if
v
]
no_ans_qids
=
[
k
for
k
,
v
in
qid_to_has_ans
.
items
()
if
not
v
]
exact_raw
,
f1_raw
=
_get_raw_scores
(
dataset
,
predictions
)
exact_thresh
=
_apply_no_ans_threshold
(
exact_raw
,
na_probs
,
qid_to_has_ans
)
f1_thresh
=
_apply_no_ans_threshold
(
f1_raw
,
na_probs
,
qid_to_has_ans
)
out_eval
=
_make_eval_dict
(
exact_thresh
,
f1_thresh
)
if
has_ans_qids
:
has_ans_eval
=
_make_eval_dict
(
exact_thresh
,
f1_thresh
,
qid_list
=
has_ans_qids
)
_merge_eval
(
out_eval
,
has_ans_eval
,
'HasAns'
)
if
no_ans_qids
:
no_ans_eval
=
_make_eval_dict
(
exact_thresh
,
f1_thresh
,
qid_list
=
no_ans_qids
)
_merge_eval
(
out_eval
,
no_ans_eval
,
'NoAns'
)
_find_all_best_thresh
(
out_eval
,
predictions
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
_run_precision_recall_analysis
(
out_eval
,
exact_raw
,
f1_raw
,
na_probs
,
qid_to_has_ans
)
return
out_eval
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/bert/tf1_checkpoint_converter_lib.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""Convert checkpoints created by Estimator (tf1) to be Keras compatible."""
import
numpy
as
np
import
tensorflow.compat.v1
as
tf
# TF 1.x
# Mapping between old <=> new names. The source pattern in original variable
# name will be replaced by destination pattern.
BERT_NAME_REPLACEMENTS
=
(
(
"bert"
,
"bert_model"
),
(
"embeddings/word_embeddings"
,
"word_embeddings/embeddings"
),
(
"embeddings/token_type_embeddings"
,
"embedding_postprocessor/type_embeddings"
),
(
"embeddings/position_embeddings"
,
"embedding_postprocessor/position_embeddings"
),
(
"embeddings/LayerNorm"
,
"embedding_postprocessor/layer_norm"
),
(
"attention/self"
,
"self_attention"
),
(
"attention/output/dense"
,
"self_attention_output"
),
(
"attention/output/LayerNorm"
,
"self_attention_layer_norm"
),
(
"intermediate/dense"
,
"intermediate"
),
(
"output/dense"
,
"output"
),
(
"output/LayerNorm"
,
"output_layer_norm"
),
(
"pooler/dense"
,
"pooler_transform"
),
)
BERT_V2_NAME_REPLACEMENTS
=
(
(
"bert/"
,
""
),
(
"encoder"
,
"transformer"
),
(
"embeddings/word_embeddings"
,
"word_embeddings/embeddings"
),
(
"embeddings/token_type_embeddings"
,
"type_embeddings/embeddings"
),
(
"embeddings/position_embeddings"
,
"position_embedding/embeddings"
),
(
"embeddings/LayerNorm"
,
"embeddings/layer_norm"
),
(
"attention/self"
,
"self_attention"
),
(
"attention/output/dense"
,
"self_attention/attention_output"
),
(
"attention/output/LayerNorm"
,
"self_attention_layer_norm"
),
(
"intermediate/dense"
,
"intermediate"
),
(
"output/dense"
,
"output"
),
(
"output/LayerNorm"
,
"output_layer_norm"
),
(
"pooler/dense"
,
"pooler_transform"
),
(
"cls/predictions"
,
"bert/cls/predictions"
),
(
"cls/predictions/output_bias"
,
"cls/predictions/output_bias/bias"
),
(
"cls/seq_relationship/output_bias"
,
"predictions/transform/logits/bias"
),
(
"cls/seq_relationship/output_weights"
,
"predictions/transform/logits/kernel"
),
)
BERT_PERMUTATIONS
=
()
BERT_V2_PERMUTATIONS
=
((
"cls/seq_relationship/output_weights"
,
(
1
,
0
)),)
def
_bert_name_replacement
(
var_name
,
name_replacements
):
"""Gets the variable name replacement."""
for
src_pattern
,
tgt_pattern
in
name_replacements
:
if
src_pattern
in
var_name
:
old_var_name
=
var_name
var_name
=
var_name
.
replace
(
src_pattern
,
tgt_pattern
)
tf
.
logging
.
info
(
"Converted: %s --> %s"
,
old_var_name
,
var_name
)
return
var_name
def
_has_exclude_patterns
(
name
,
exclude_patterns
):
"""Checks if a string contains substrings that match patterns to exclude."""
for
p
in
exclude_patterns
:
if
p
in
name
:
return
True
return
False
def
_get_permutation
(
name
,
permutations
):
"""Checks whether a variable requires transposition by pattern matching."""
for
src_pattern
,
permutation
in
permutations
:
if
src_pattern
in
name
:
tf
.
logging
.
info
(
"Permuted: %s --> %s"
,
name
,
permutation
)
return
permutation
return
None
def
_get_new_shape
(
name
,
shape
,
num_heads
):
"""Checks whether a variable requires reshape by pattern matching."""
if
"self_attention/attention_output/kernel"
in
name
:
return
tuple
([
num_heads
,
shape
[
0
]
//
num_heads
,
shape
[
1
]])
if
"self_attention/attention_output/bias"
in
name
:
return
shape
patterns
=
[
"self_attention/query"
,
"self_attention/value"
,
"self_attention/key"
]
for
pattern
in
patterns
:
if
pattern
in
name
:
if
"kernel"
in
name
:
return
tuple
([
shape
[
0
],
num_heads
,
shape
[
1
]
//
num_heads
])
if
"bias"
in
name
:
return
tuple
([
num_heads
,
shape
[
0
]
//
num_heads
])
return
None
def
create_v2_checkpoint
(
model
,
src_checkpoint
,
output_path
,
checkpoint_model_name
=
"model"
):
"""Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
# Uses streaming-restore in eager model to read V1 name-based checkpoints.
model
.
load_weights
(
src_checkpoint
).
assert_existing_objects_matched
()
if
hasattr
(
model
,
"checkpoint_items"
):
checkpoint_items
=
model
.
checkpoint_items
else
:
checkpoint_items
=
{}
checkpoint_items
[
checkpoint_model_name
]
=
model
checkpoint
=
tf
.
train
.
Checkpoint
(
**
checkpoint_items
)
checkpoint
.
save
(
output_path
)
def
convert
(
checkpoint_from_path
,
checkpoint_to_path
,
num_heads
,
name_replacements
,
permutations
,
exclude_patterns
=
None
):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
num_heads: The number of heads of the model.
name_replacements: A list of tuples of the form (match_str, replace_str)
describing variable names to adjust.
permutations: A list of tuples of the form (match_str, permutation)
describing permutations to apply to given variables. Note that match_str
should match the original variable name, not the replaced one.
exclude_patterns: A list of string patterns to exclude variables from
checkpoint conversion.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with
tf
.
Graph
().
as_default
():
tf
.
logging
.
info
(
"Reading checkpoint_from_path %s"
,
checkpoint_from_path
)
reader
=
tf
.
train
.
NewCheckpointReader
(
checkpoint_from_path
)
name_shape_map
=
reader
.
get_variable_to_shape_map
()
new_variable_map
=
{}
conversion_map
=
{}
for
var_name
in
name_shape_map
:
if
exclude_patterns
and
_has_exclude_patterns
(
var_name
,
exclude_patterns
):
continue
# Get the original tensor data.
tensor
=
reader
.
get_tensor
(
var_name
)
# Look up the new variable name, if any.
new_var_name
=
_bert_name_replacement
(
var_name
,
name_replacements
)
# See if we need to reshape the underlying tensor.
new_shape
=
None
if
num_heads
>
0
:
new_shape
=
_get_new_shape
(
new_var_name
,
tensor
.
shape
,
num_heads
)
if
new_shape
:
tf
.
logging
.
info
(
"Veriable %s has a shape change from %s to %s"
,
var_name
,
tensor
.
shape
,
new_shape
)
tensor
=
np
.
reshape
(
tensor
,
new_shape
)
# See if we need to permute the underlying tensor.
permutation
=
_get_permutation
(
var_name
,
permutations
)
if
permutation
:
tensor
=
np
.
transpose
(
tensor
,
permutation
)
# Create a new variable with the possibly-reshaped or transposed tensor.
var
=
tf
.
Variable
(
tensor
,
name
=
var_name
)
# Save the variable into the new variable map.
new_variable_map
[
new_var_name
]
=
var
# Keep a list of converter variables for sanity checking.
if
new_var_name
!=
var_name
:
conversion_map
[
var_name
]
=
new_var_name
saver
=
tf
.
train
.
Saver
(
new_variable_map
)
with
tf
.
Session
()
as
sess
:
sess
.
run
(
tf
.
global_variables_initializer
())
tf
.
logging
.
info
(
"Writing checkpoint_to_path %s"
,
checkpoint_to_path
)
saver
.
save
(
sess
,
checkpoint_to_path
,
write_meta_graph
=
False
)
tf
.
logging
.
info
(
"Summary:"
)
tf
.
logging
.
info
(
" Converted %d variable name(s)."
,
len
(
new_variable_map
))
tf
.
logging
.
info
(
" Converted: %s"
,
str
(
conversion_map
))
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/bert/tf2_encoder_checkpoint_converter.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A converter from a V1 BERT encoder checkpoint to a V2 encoder checkpoint.
The conversion will yield an object-oriented checkpoint that can be used
to restore a BertEncoder or BertPretrainerV2 object (see the `converted_model`
FLAG below).
"""
import
os
from
absl
import
app
from
absl
import
flags
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
from
official.nlp.bert
import
configs
from
official.nlp.bert
import
tf1_checkpoint_converter_lib
from
official.nlp.modeling
import
models
from
official.nlp.modeling
import
networks
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"bert_config_file"
,
None
,
"Bert configuration file to define core bert layers."
)
flags
.
DEFINE_string
(
"checkpoint_to_convert"
,
None
,
"Initial checkpoint from a pretrained BERT model core (that is, only the "
"BertModel, with no task heads.)"
)
flags
.
DEFINE_string
(
"converted_checkpoint_path"
,
None
,
"Name for the created object-based V2 checkpoint."
)
flags
.
DEFINE_string
(
"checkpoint_model_name"
,
"encoder"
,
"The name of the model when saving the checkpoint, i.e., "
"the checkpoint will be saved using: "
"tf.train.Checkpoint(FLAGS.checkpoint_model_name=model)."
)
flags
.
DEFINE_enum
(
"converted_model"
,
"encoder"
,
[
"encoder"
,
"pretrainer"
],
"Whether to convert the checkpoint to a `BertEncoder` model or a "
"`BertPretrainerV2` model (with mlm but without classification heads)."
)
def
_create_bert_model
(
cfg
):
"""Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertEncoder network.
"""
bert_encoder
=
networks
.
BertEncoder
(
vocab_size
=
cfg
.
vocab_size
,
hidden_size
=
cfg
.
hidden_size
,
num_layers
=
cfg
.
num_hidden_layers
,
num_attention_heads
=
cfg
.
num_attention_heads
,
intermediate_size
=
cfg
.
intermediate_size
,
activation
=
tf_utils
.
get_activation
(
cfg
.
hidden_act
),
dropout_rate
=
cfg
.
hidden_dropout_prob
,
attention_dropout_rate
=
cfg
.
attention_probs_dropout_prob
,
max_sequence_length
=
cfg
.
max_position_embeddings
,
type_vocab_size
=
cfg
.
type_vocab_size
,
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
cfg
.
initializer_range
),
embedding_width
=
cfg
.
embedding_size
)
return
bert_encoder
def
_create_bert_pretrainer_model
(
cfg
):
"""Creates a BERT keras core model from BERT configuration.
Args:
cfg: A `BertConfig` to create the core model.
Returns:
A BertPretrainerV2 model.
"""
bert_encoder
=
_create_bert_model
(
cfg
)
pretrainer
=
models
.
BertPretrainerV2
(
encoder_network
=
bert_encoder
,
mlm_activation
=
tf_utils
.
get_activation
(
cfg
.
hidden_act
),
mlm_initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
cfg
.
initializer_range
))
# Makes sure the pretrainer variables are created.
_
=
pretrainer
(
pretrainer
.
inputs
)
return
pretrainer
def
convert_checkpoint
(
bert_config
,
output_path
,
v1_checkpoint
,
checkpoint_model_name
=
"model"
,
converted_model
=
"encoder"
):
"""Converts a V1 checkpoint into an OO V2 checkpoint."""
output_dir
,
_
=
os
.
path
.
split
(
output_path
)
tf
.
io
.
gfile
.
makedirs
(
output_dir
)
# Create a temporary V1 name-converted checkpoint in the output directory.
temporary_checkpoint_dir
=
os
.
path
.
join
(
output_dir
,
"temp_v1"
)
temporary_checkpoint
=
os
.
path
.
join
(
temporary_checkpoint_dir
,
"ckpt"
)
tf1_checkpoint_converter_lib
.
convert
(
checkpoint_from_path
=
v1_checkpoint
,
checkpoint_to_path
=
temporary_checkpoint
,
num_heads
=
bert_config
.
num_attention_heads
,
name_replacements
=
tf1_checkpoint_converter_lib
.
BERT_V2_NAME_REPLACEMENTS
,
permutations
=
tf1_checkpoint_converter_lib
.
BERT_V2_PERMUTATIONS
,
exclude_patterns
=
[
"adam"
,
"Adam"
])
if
converted_model
==
"encoder"
:
model
=
_create_bert_model
(
bert_config
)
elif
converted_model
==
"pretrainer"
:
model
=
_create_bert_pretrainer_model
(
bert_config
)
else
:
raise
ValueError
(
"Unsupported converted_model: %s"
%
converted_model
)
# Create a V2 checkpoint from the temporary checkpoint.
tf1_checkpoint_converter_lib
.
create_v2_checkpoint
(
model
,
temporary_checkpoint
,
output_path
,
checkpoint_model_name
)
# Clean up the temporary checkpoint, if it exists.
try
:
tf
.
io
.
gfile
.
rmtree
(
temporary_checkpoint_dir
)
except
tf
.
errors
.
OpError
:
# If it doesn't exist, we don't need to clean it up; continue.
pass
def
main
(
argv
):
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
"Too many command-line arguments."
)
output_path
=
FLAGS
.
converted_checkpoint_path
v1_checkpoint
=
FLAGS
.
checkpoint_to_convert
checkpoint_model_name
=
FLAGS
.
checkpoint_model_name
converted_model
=
FLAGS
.
converted_model
bert_config
=
configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
convert_checkpoint
(
bert_config
=
bert_config
,
output_path
=
output_path
,
v1_checkpoint
=
v1_checkpoint
,
checkpoint_model_name
=
checkpoint_model_name
,
converted_model
=
converted_model
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/bert/tokenization.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# coding=utf-8
"""Tokenization classes implementation.
The file is forked from:
https://github.com/google-research/bert/blob/master/tokenization.py.
"""
import
collections
import
re
import
unicodedata
import
six
import
tensorflow
as
tf
import
sentencepiece
as
spm
SPIECE_UNDERLINE
=
"▁"
def
validate_case_matches_checkpoint
(
do_lower_case
,
init_checkpoint
):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if
not
init_checkpoint
:
return
m
=
re
.
match
(
"^.*?([A-Za-z0-9_-]+)/bert_model.ckpt"
,
init_checkpoint
)
if
m
is
None
:
return
model_name
=
m
.
group
(
1
)
lower_models
=
[
"uncased_L-24_H-1024_A-16"
,
"uncased_L-12_H-768_A-12"
,
"multilingual_L-12_H-768_A-12"
,
"chinese_L-12_H-768_A-12"
]
cased_models
=
[
"cased_L-12_H-768_A-12"
,
"cased_L-24_H-1024_A-16"
,
"multi_cased_L-12_H-768_A-12"
]
is_bad_config
=
False
if
model_name
in
lower_models
and
not
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"False"
case_name
=
"lowercased"
opposite_flag
=
"True"
if
model_name
in
cased_models
and
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"True"
case_name
=
"cased"
opposite_flag
=
"False"
if
is_bad_config
:
raise
ValueError
(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check."
%
(
actual_flag
,
init_checkpoint
,
model_name
,
case_name
,
opposite_flag
))
def
convert_to_unicode
(
text
):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
elif
isinstance
(
text
,
unicode
):
return
text
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
printable_text
(
text
):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
unicode
):
return
text
.
encode
(
"utf-8"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
load_vocab
(
vocab_file
):
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
index
=
0
with
tf
.
io
.
gfile
.
GFile
(
vocab_file
,
"r"
)
as
reader
:
while
True
:
token
=
convert_to_unicode
(
reader
.
readline
())
if
not
token
:
break
token
=
token
.
strip
()
vocab
[
token
]
=
index
index
+=
1
return
vocab
def
convert_by_vocab
(
vocab
,
items
):
"""Converts a sequence of [tokens|ids] using the vocab."""
output
=
[]
for
item
in
items
:
output
.
append
(
vocab
[
item
])
return
output
def
convert_tokens_to_ids
(
vocab
,
tokens
):
return
convert_by_vocab
(
vocab
,
tokens
)
def
convert_ids_to_tokens
(
inv_vocab
,
ids
):
return
convert_by_vocab
(
inv_vocab
,
ids
)
def
whitespace_tokenize
(
text
):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
tokens
=
text
.
split
()
return
tokens
class
FullTokenizer
(
object
):
"""Runs end-to-end tokenziation."""
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
,
split_on_punc
=
True
):
self
.
vocab
=
load_vocab
(
vocab_file
)
self
.
inv_vocab
=
{
v
:
k
for
k
,
v
in
self
.
vocab
.
items
()}
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
,
split_on_punc
=
split_on_punc
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
)
def
tokenize
(
self
,
text
):
split_tokens
=
[]
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
split_tokens
.
append
(
sub_token
)
return
split_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
return
convert_by_vocab
(
self
.
vocab
,
tokens
)
def
convert_ids_to_tokens
(
self
,
ids
):
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
class
BasicTokenizer
(
object
):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
self
,
do_lower_case
=
True
,
split_on_punc
=
True
):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
split_on_punc: Whether to apply split on punctuations. By default BERT
starts a new token for punctuations. This makes detokenization difficult
for tasks like seq2seq decoding.
"""
self
.
do_lower_case
=
do_lower_case
self
.
split_on_punc
=
split_on_punc
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text."""
text
=
convert_to_unicode
(
text
)
text
=
self
.
_clean_text
(
text
)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text
=
self
.
_tokenize_chinese_chars
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
split_tokens
=
[]
for
token
in
orig_tokens
:
if
self
.
do_lower_case
:
token
=
token
.
lower
()
token
=
self
.
_run_strip_accents
(
token
)
if
self
.
split_on_punc
:
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
else
:
split_tokens
.
append
(
token
)
output_tokens
=
whitespace_tokenize
(
" "
.
join
(
split_tokens
))
return
output_tokens
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_tokenize_chinese_chars
(
self
,
text
):
"""Adds whitespace around any CJK character."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
self
.
_is_chinese_char
(
cp
):
output
.
append
(
" "
)
output
.
append
(
char
)
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_is_chinese_char
(
self
,
cp
):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if
((
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
or
#
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
or
#
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
or
#
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
or
#
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
or
#
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
#
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)):
#
return
True
return
False
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
class
WordpieceTokenizer
(
object
):
"""Runs WordPiece tokenziation."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
400
):
self
.
vocab
=
vocab
self
.
unk_token
=
unk_token
self
.
max_input_chars_per_word
=
max_input_chars_per_word
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text
=
convert_to_unicode
(
text
)
output_tokens
=
[]
for
token
in
whitespace_tokenize
(
text
):
chars
=
list
(
token
)
if
len
(
chars
)
>
self
.
max_input_chars_per_word
:
output_tokens
.
append
(
self
.
unk_token
)
continue
is_bad
=
False
start
=
0
sub_tokens
=
[]
while
start
<
len
(
chars
):
end
=
len
(
chars
)
cur_substr
=
None
while
start
<
end
:
substr
=
""
.
join
(
chars
[
start
:
end
])
if
start
>
0
:
substr
=
"##"
+
substr
if
substr
in
self
.
vocab
:
cur_substr
=
substr
break
end
-=
1
if
cur_substr
is
None
:
is_bad
=
True
break
sub_tokens
.
append
(
cur_substr
)
start
=
end
if
is_bad
:
output_tokens
.
append
(
self
.
unk_token
)
else
:
output_tokens
.
extend
(
sub_tokens
)
return
output_tokens
def
_is_whitespace
(
char
):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically control characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
return
True
return
False
def
_is_control
(
char
):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
in
(
"Cc"
,
"Cf"
):
return
True
return
False
def
_is_punctuation
(
char
):
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
return
True
return
False
def
preprocess_text
(
inputs
,
remove_space
=
True
,
lower
=
False
):
"""Preprocesses data by removing extra space and normalize data.
This method is used together with sentence piece tokenizer and is forked from:
https://github.com/google-research/google-research/blob/e1f6fa00/albert/tokenization.py
Args:
inputs: The input text.
remove_space: Whether to remove the extra space.
lower: Whether to lowercase the text.
Returns:
The preprocessed text.
"""
outputs
=
inputs
if
remove_space
:
outputs
=
" "
.
join
(
inputs
.
strip
().
split
())
if
six
.
PY2
and
isinstance
(
outputs
,
str
):
try
:
outputs
=
six
.
ensure_text
(
outputs
,
"utf-8"
)
except
UnicodeDecodeError
:
outputs
=
six
.
ensure_text
(
outputs
,
"latin-1"
)
outputs
=
unicodedata
.
normalize
(
"NFKD"
,
outputs
)
outputs
=
""
.
join
([
c
for
c
in
outputs
if
not
unicodedata
.
combining
(
c
)])
if
lower
:
outputs
=
outputs
.
lower
()
return
outputs
def
encode_pieces
(
sp_model
,
text
,
sample
=
False
):
"""Segements text into pieces.
This method is used together with sentence piece tokenizer and is forked from:
https://github.com/google-research/google-research/blob/e1f6fa00/albert/tokenization.py
Args:
sp_model: A spm.SentencePieceProcessor object.
text: The input text to be segemented.
sample: Whether to randomly sample a segmentation output or return a
deterministic one.
Returns:
A list of token pieces.
"""
if
six
.
PY2
and
isinstance
(
text
,
six
.
text_type
):
text
=
six
.
ensure_binary
(
text
,
"utf-8"
)
if
not
sample
:
pieces
=
sp_model
.
EncodeAsPieces
(
text
)
else
:
pieces
=
sp_model
.
SampleEncodeAsPieces
(
text
,
64
,
0.1
)
new_pieces
=
[]
for
piece
in
pieces
:
piece
=
printable_text
(
piece
)
if
len
(
piece
)
>
1
and
piece
[
-
1
]
==
","
and
piece
[
-
2
].
isdigit
():
cur_pieces
=
sp_model
.
EncodeAsPieces
(
piece
[:
-
1
].
replace
(
SPIECE_UNDERLINE
,
""
))
if
piece
[
0
]
!=
SPIECE_UNDERLINE
and
cur_pieces
[
0
][
0
]
==
SPIECE_UNDERLINE
:
if
len
(
cur_pieces
[
0
])
==
1
:
cur_pieces
=
cur_pieces
[
1
:]
else
:
cur_pieces
[
0
]
=
cur_pieces
[
0
][
1
:]
cur_pieces
.
append
(
piece
[
-
1
])
new_pieces
.
extend
(
cur_pieces
)
else
:
new_pieces
.
append
(
piece
)
return
new_pieces
def
encode_ids
(
sp_model
,
text
,
sample
=
False
):
"""Segments text and return token ids.
This method is used together with sentence piece tokenizer and is forked from:
https://github.com/google-research/google-research/blob/e1f6fa00/albert/tokenization.py
Args:
sp_model: A spm.SentencePieceProcessor object.
text: The input text to be segemented.
sample: Whether to randomly sample a segmentation output or return a
deterministic one.
Returns:
A list of token ids.
"""
pieces
=
encode_pieces
(
sp_model
,
text
,
sample
=
sample
)
ids
=
[
sp_model
.
PieceToId
(
piece
)
for
piece
in
pieces
]
return
ids
class
FullSentencePieceTokenizer
(
object
):
"""Runs end-to-end sentence piece tokenization.
The interface of this class is intended to keep the same as above
`FullTokenizer` class for easier usage.
"""
def
__init__
(
self
,
sp_model_file
):
"""Inits FullSentencePieceTokenizer.
Args:
sp_model_file: The path to the sentence piece model file.
"""
self
.
sp_model
=
spm
.
SentencePieceProcessor
()
self
.
sp_model
.
Load
(
sp_model_file
)
self
.
vocab
=
{
self
.
sp_model
.
IdToPiece
(
i
):
i
for
i
in
six
.
moves
.
range
(
self
.
sp_model
.
GetPieceSize
())
}
def
tokenize
(
self
,
text
):
"""Tokenizes text into pieces."""
return
encode_pieces
(
self
.
sp_model
,
text
)
def
convert_tokens_to_ids
(
self
,
tokens
):
"""Converts a list of tokens to a list of ids."""
return
[
self
.
sp_model
.
PieceToId
(
printable_text
(
token
))
for
token
in
tokens
]
def
convert_ids_to_tokens
(
self
,
ids
):
"""Converts a list of ids ot a list of tokens."""
return
[
self
.
sp_model
.
IdToPiece
(
id_
)
for
id_
in
ids
]
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/bert/tokenization_test.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
tempfile
import
six
import
tensorflow
as
tf
from
official.nlp.bert
import
tokenization
class
TokenizationTest
(
tf
.
test
.
TestCase
):
"""Tokenization test.
The implementation is forked from
https://github.com/google-research/bert/blob/master/tokenization_test.py."
"""
def
test_full_tokenizer
(
self
):
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
]
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
vocab_writer
:
if
six
.
PY2
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
else
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]).
encode
(
"utf-8"
))
vocab_file
=
vocab_writer
.
name
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
)
os
.
unlink
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
self
.
assertAllEqual
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assertAllEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
def
test_chinese
(
self
):
tokenizer
=
tokenization
.
BasicTokenizer
()
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
u
"ah
\u535A\u63A8
zz"
),
[
u
"ah"
,
u
"
\u535A
"
,
u
"
\u63A8
"
,
u
"zz"
])
def
test_basic_tokenizer_lower
(
self
):
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
True
)
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"hello"
,
"!"
,
"how"
,
"are"
,
"you"
,
"?"
])
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
"hello"
])
def
test_basic_tokenizer_no_lower
(
self
):
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
False
)
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
def
test_basic_tokenizer_no_split_on_punc
(
self
):
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
True
,
split_on_punc
=
False
)
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"hello!how"
,
"are"
,
"you?"
])
def
test_wordpiece_tokenizer
(
self
):
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
"##!"
,
"!"
]
vocab
=
{}
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
vocab
[
token
]
=
i
tokenizer
=
tokenization
.
WordpieceTokenizer
(
vocab
=
vocab
)
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
""
),
[])
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
"unwanted running"
),
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
])
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
"unwanted running !"
),
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
,
"!"
])
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
"unwanted running!"
),
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
,
"##!"
])
self
.
assertAllEqual
(
tokenizer
.
tokenize
(
"unwantedX running"
),
[
"[UNK]"
,
"runn"
,
"##ing"
])
def
test_convert_tokens_to_ids
(
self
):
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
]
vocab
=
{}
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
vocab
[
token
]
=
i
self
.
assertAllEqual
(
tokenization
.
convert_tokens_to_ids
(
vocab
,
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
]),
[
7
,
4
,
5
,
8
,
9
])
def
test_is_whitespace
(
self
):
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
" "
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\t
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\r
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\n
"
))
self
.
assertTrue
(
tokenization
.
_is_whitespace
(
u
"
\u00A0
"
))
self
.
assertFalse
(
tokenization
.
_is_whitespace
(
u
"A"
))
self
.
assertFalse
(
tokenization
.
_is_whitespace
(
u
"-"
))
def
test_is_control
(
self
):
self
.
assertTrue
(
tokenization
.
_is_control
(
u
"
\u0005
"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"A"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
" "
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"
\t
"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"
\r
"
))
self
.
assertFalse
(
tokenization
.
_is_control
(
u
"
\U0001F4A9
"
))
def
test_is_punctuation
(
self
):
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"-"
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"$"
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"`"
))
self
.
assertTrue
(
tokenization
.
_is_punctuation
(
u
"."
))
self
.
assertFalse
(
tokenization
.
_is_punctuation
(
u
"A"
))
self
.
assertFalse
(
tokenization
.
_is_punctuation
(
u
" "
))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/configs/__init__.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/configs/bert.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-head BERT encoder network with classification heads.
Includes configurations and instantiation methods.
"""
from
typing
import
List
,
Optional
,
Text
import
dataclasses
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
encoders
@
dataclasses
.
dataclass
class
ClsHeadConfig
(
base_config
.
Config
):
inner_dim
:
int
=
0
num_classes
:
int
=
2
activation
:
Optional
[
Text
]
=
"tanh"
dropout_rate
:
float
=
0.0
cls_token_idx
:
int
=
0
name
:
Optional
[
Text
]
=
None
@
dataclasses
.
dataclass
class
PretrainerConfig
(
base_config
.
Config
):
"""Pretrainer configuration."""
encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
cls_heads
:
List
[
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
mlm_activation
:
str
=
"gelu"
mlm_initializer_range
:
float
=
0.02
TensorFlow2x/ComputeVision/Classification/models-master/official/nlp/configs/electra.py
0 → 100644
View file @
a32ffa95
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ELECTRA model configurations and instantiation methods."""
from
typing
import
List
import
dataclasses
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
@
dataclasses
.
dataclass
class
ElectraPretrainerConfig
(
base_config
.
Config
):
"""ELECTRA pretrainer configuration."""
num_masked_tokens
:
int
=
76
sequence_length
:
int
=
512
num_classes
:
int
=
2
discriminator_loss_weight
:
float
=
50.0
tie_embeddings
:
bool
=
True
disallow_correct
:
bool
=
False
generator_encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
discriminator_encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
cls_heads
:
List
[
bert
.
ClsHeadConfig
]
=
dataclasses
.
field
(
default_factory
=
list
)
Prev
1
…
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment