Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
12deafa2
Commit
12deafa2
authored
Jul 10, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 320664492
parent
9d8b1543
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
4 deletions
+47
-4
official/nlp/albert/run_classifier.py
official/nlp/albert/run_classifier.py
+47
-4
No files found.
official/nlp/albert/run_classifier.py
View file @
12deafa2
...
...
@@ -14,23 +14,61 @@
# ==============================================================================
"""ALBERT classification finetuning runner in tf2.x."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
json
import
os
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.nlp.albert
import
configs
as
albert_configs
from
official.nlp.bert
import
bert_models
from
official.nlp.bert
import
run_classifier
as
run_classifier_bert
from
official.utils.misc
import
distribution_utils
FLAGS
=
flags
.
FLAGS
def
predict
(
strategy
,
albert_config
,
input_meta_data
,
predict_input_fn
):
"""Function outputs both the ground truth predictions as .tsv files."""
with
strategy
.
scope
():
classifier_model
=
bert_models
.
classifier_model
(
albert_config
,
input_meta_data
[
'num_labels'
])[
0
]
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
classifier_model
)
latest_checkpoint_file
=
(
FLAGS
.
predict_checkpoint_path
or
tf
.
train
.
latest_checkpoint
(
FLAGS
.
model_dir
))
assert
latest_checkpoint_file
logging
.
info
(
'Checkpoint file %s found and restoring from '
'checkpoint'
,
latest_checkpoint_file
)
checkpoint
.
restore
(
latest_checkpoint_file
).
assert_existing_objects_matched
()
preds
,
ground_truth
=
run_classifier_bert
.
get_predictions_and_labels
(
strategy
,
classifier_model
,
predict_input_fn
,
return_probs
=
True
)
output_predict_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'test_results.tsv'
)
with
tf
.
io
.
gfile
.
GFile
(
output_predict_file
,
'w'
)
as
writer
:
logging
.
info
(
'***** Predict results *****'
)
for
probabilities
in
preds
:
output_line
=
'
\t
'
.
join
(
str
(
class_probability
)
for
class_probability
in
probabilities
)
+
'
\n
'
writer
.
write
(
output_line
)
ground_truth_labels_file
=
os
.
path
.
join
(
FLAGS
.
model_dir
,
'output_labels.tsv'
)
with
tf
.
io
.
gfile
.
GFile
(
ground_truth_labels_file
,
'w'
)
as
writer
:
logging
.
info
(
'***** Ground truth results *****'
)
for
label
in
ground_truth
:
output_line
=
'
\t
'
.
join
(
str
(
label
))
+
'
\n
'
writer
.
write
(
output_line
)
return
def
main
(
_
):
with
tf
.
io
.
gfile
.
GFile
(
FLAGS
.
input_meta_data_path
,
'rb'
)
as
reader
:
input_meta_data
=
json
.
loads
(
reader
.
read
().
decode
(
'utf-8'
))
...
...
@@ -56,9 +94,14 @@ def main(_):
albert_config
=
albert_configs
.
AlbertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
if
FLAGS
.
mode
==
'train_and_eval'
:
run_classifier_bert
.
run_bert
(
strategy
,
input_meta_data
,
albert_config
,
train_input_fn
,
eval_input_fn
)
elif
FLAGS
.
mode
==
'predict'
:
predict
(
strategy
,
albert_config
,
input_meta_data
,
eval_input_fn
)
else
:
raise
ValueError
(
'Unsupported mode is specified: %s'
%
FLAGS
.
mode
)
return
if
__name__
==
'__main__'
:
flags
.
mark_flag_as_required
(
'bert_config_file'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment