Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e748d785
Commit
e748d785
authored
May 21, 2021
by
stephenwu
Browse files
fixed style issues
parent
790e49e5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
23 deletions
+21
-23
official/nlp/finetuning/binary_helper.py
official/nlp/finetuning/binary_helper.py
+10
-11
official/nlp/finetuning/superglue/flags.py
official/nlp/finetuning/superglue/flags.py
+3
-3
official/nlp/finetuning/superglue/run_superglue.py
official/nlp/finetuning/superglue/run_superglue.py
+8
-9
No files found.
official/nlp/finetuning/binary_helper.py
View file @
e748d785
...
@@ -309,16 +309,16 @@ def write_glue_classification(task,
...
@@ -309,16 +309,16 @@ def write_glue_classification(task,
# Classification.
# Classification.
writer
.
write
(
'%d
\t
%s
\n
'
%
(
index
,
class_names
[
prediction
]))
writer
.
write
(
'%d
\t
%s
\n
'
%
(
index
,
class_names
[
prediction
]))
def
write_superglue_classification
(
task
,
def
write_superglue_classification
(
task
,
model
,
model
,
input_file
,
input_file
,
output_file
,
output_file
,
predict_batch_size
,
predict_batch_size
,
seq_length
,
seq_length
,
class_names
,
class_names
,
label_type
=
'int'
,
label_type
=
'int'
,
min_float_value
=
None
,
min_float_value
=
None
,
max_float_value
=
None
):
max_float_value
=
None
):
"""Makes classification predictions for glue and writes to output file.
"""Makes classification predictions for
super
glue and writes to output file.
Args:
Args:
task: `Task` instance.
task: `Task` instance.
...
@@ -350,7 +350,6 @@ def write_superglue_classification(task,
...
@@ -350,7 +350,6 @@ def write_superglue_classification(task,
include_example_id
=
True
)
include_example_id
=
True
)
predictions
=
sentence_prediction
.
predict
(
task
,
data_config
,
model
)
predictions
=
sentence_prediction
.
predict
(
task
,
data_config
,
model
)
with
tf
.
io
.
gfile
.
GFile
(
output_file
,
'w'
)
as
writer
:
with
tf
.
io
.
gfile
.
GFile
(
output_file
,
'w'
)
as
writer
:
for
index
,
prediction
in
enumerate
(
predictions
):
for
index
,
prediction
in
enumerate
(
predictions
):
if
label_type
==
'int'
:
if
label_type
==
'int'
:
...
...
official/nlp/finetuning/superglue/flags.py
View file @
e748d785
...
@@ -36,8 +36,8 @@ def define_flags():
...
@@ -36,8 +36,8 @@ def define_flags():
'run prediction using the model in `model_dir`.'
)
'run prediction using the model in `model_dir`.'
)
flags
.
DEFINE_enum
(
'task_name'
,
None
,
[
flags
.
DEFINE_enum
(
'task_name'
,
None
,
[
'AX-b'
,
'CB'
,
'COPA'
,
'MULTIRC'
,
'RTE'
,
'WiC'
,
'WSC'
,
'AX-b'
,
'CB'
,
'COPA'
,
'MULTIRC'
,
'RTE'
,
'WiC'
,
'WSC'
,
'BoolQ'
,
'ReCoRD'
,
'AX-g'
,
'BoolQ'
,
'ReCoRD'
,
'AX-g'
,
],
'The type of GLUE task.'
)
],
'The type of GLUE task.'
)
flags
.
DEFINE_string
(
'train_input_path'
,
None
,
flags
.
DEFINE_string
(
'train_input_path'
,
None
,
...
@@ -160,4 +160,4 @@ def validate_flags(flags_obj: flags.FlagValues,
...
@@ -160,4 +160,4 @@ def validate_flags(flags_obj: flags.FlagValues,
_validate_path
(
flags_obj
.
model_config_file
,
'model_config_file'
)
_validate_path
(
flags_obj
.
model_config_file
,
'model_config_file'
)
logging
.
info
(
logging
.
info
(
'Using the pretrained checkpoint from %s and model_config_file from '
'Using the pretrained checkpoint from %s and model_config_file from '
'%s.'
,
flags_obj
.
init_checkpoint
,
flags_obj
.
model_config_file
)
'%s.'
,
flags_obj
.
init_checkpoint
,
flags_obj
.
model_config_file
)
\ No newline at end of file
official/nlp/finetuning/superglue/run_superglue.py
View file @
e748d785
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Runs prediction to generate submission files for GLUE tasks."""
"""Runs prediction to generate submission files for
Super
GLUE tasks."""
import
functools
import
functools
import
json
import
json
import
os
import
os
...
@@ -27,14 +27,13 @@ import tensorflow as tf
...
@@ -27,14 +27,13 @@ import tensorflow as tf
from
official.common
import
distribute_utils
from
official.common
import
distribute_utils
# Imports registered experiment configs.
# Imports registered experiment configs.
from
official.common
import
registry_imports
# pylint: disable=unused-import
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.hyperparams
import
params_dict
from
official.nlp.finetuning
import
binary_helper
from
official.nlp.finetuning
import
binary_helper
from
official.nlp.finetuning.superglue
import
flags
as
glue_flags
from
official.nlp.finetuning.superglue
import
flags
as
super
glue_flags
# Device configs.
# Device configs.
...
@@ -81,13 +80,13 @@ def _override_exp_config_by_file(exp_config, exp_config_files):
...
@@ -81,13 +80,13 @@ def _override_exp_config_by_file(exp_config, exp_config_files):
def
_override_exp_config_by_flags
(
exp_config
,
input_meta_data
):
def
_override_exp_config_by_flags
(
exp_config
,
input_meta_data
):
"""Overrides an `ExperimentConfig` object by flags."""
"""Overrides an `ExperimentConfig` object by flags."""
if
FLAGS
.
task_name
in
(
'AX-b'
)
:
if
FLAGS
.
task_name
in
'AX-b'
:
override_task_cfg_fn
=
functools
.
partial
(
override_task_cfg_fn
=
functools
.
partial
(
binary_helper
.
override_sentence_prediction_task_config
,
binary_helper
.
override_sentence_prediction_task_config
,
num_classes
=
input_meta_data
[
'num_labels'
],
num_classes
=
input_meta_data
[
'num_labels'
],
metric_type
=
'matthews_corrcoef'
)
metric_type
=
'matthews_corrcoef'
)
elif
FLAGS
.
task_name
in
(
'CB'
,
'COPA'
,
'RTE'
,
'WiC'
,
'WSC'
,
'BoolQ'
,
elif
FLAGS
.
task_name
in
(
'CB'
,
'COPA'
,
'RTE'
,
'WiC'
,
'WSC'
,
'BoolQ'
,
'ReCoRD'
,
'AX-g'
):
'ReCoRD'
,
'AX-g'
):
override_task_cfg_fn
=
functools
.
partial
(
override_task_cfg_fn
=
functools
.
partial
(
binary_helper
.
override_sentence_prediction_task_config
,
binary_helper
.
override_sentence_prediction_task_config
,
num_classes
=
input_meta_data
[
'num_labels'
])
num_classes
=
input_meta_data
[
'num_labels'
])
...
@@ -152,7 +151,7 @@ def _write_submission_file(task, seq_length):
...
@@ -152,7 +151,7 @@ def _write_submission_file(task, seq_length):
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint
.
read
(
ckpt_file
).
expect_partial
()
checkpoint
.
read
(
ckpt_file
).
expect_partial
()
write_fn
=
binary_helper
.
write_glue_classification
write_fn
=
binary_helper
.
write_
super
glue_classification
write_fn_map
=
{
write_fn_map
=
{
'RTE'
:
'RTE'
:
functools
.
partial
(
functools
.
partial
(
...
@@ -176,7 +175,7 @@ def main(argv):
...
@@ -176,7 +175,7 @@ def main(argv):
if
len
(
argv
)
>
1
:
if
len
(
argv
)
>
1
:
raise
app
.
UsageError
(
'Too many command-line arguments.'
)
raise
app
.
UsageError
(
'Too many command-line arguments.'
)
glue_flags
.
validate_flags
(
FLAGS
,
file_exists_fn
=
tf
.
io
.
gfile
.
exists
)
super
glue_flags
.
validate_flags
(
FLAGS
,
file_exists_fn
=
tf
.
io
.
gfile
.
exists
)
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
...
@@ -218,7 +217,7 @@ def main(argv):
...
@@ -218,7 +217,7 @@ def main(argv):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
glue_flags
.
define_flags
()
super
glue_flags
.
define_flags
()
flags
.
mark_flag_as_required
(
'mode'
)
flags
.
mark_flag_as_required
(
'mode'
)
flags
.
mark_flag_as_required
(
'task_name'
)
flags
.
mark_flag_as_required
(
'task_name'
)
app
.
run
(
main
)
app
.
run
(
main
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment