Commit e748d785 authored by stephenwu's avatar stephenwu
Browse files

fixed style issues

parent 790e49e5
......@@ -318,7 +318,7 @@ def write_superglue_classification(task,
label_type='int',
min_float_value=None,
max_float_value=None):
"""Makes classification predictions for glue and writes to output file.
"""Makes classification predictions for superglue and writes to output file.
Args:
task: `Task` instance.
......@@ -350,7 +350,6 @@ def write_superglue_classification(task,
include_example_id=True)
predictions = sentence_prediction.predict(task, data_config, model)
with tf.io.gfile.GFile(output_file, 'w') as writer:
for index, prediction in enumerate(predictions):
if label_type == 'int':
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs prediction to generate submission files for GLUE tasks."""
"""Runs prediction to generate submission files for SuperGLUE tasks."""
import functools
import json
import os
......@@ -27,14 +27,13 @@ import tensorflow as tf
from official.common import distribute_utils
# Imports registered experiment configs.
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling.hyperparams import params_dict
from official.nlp.finetuning import binary_helper
from official.nlp.finetuning.superglue import flags as glue_flags
from official.nlp.finetuning.superglue import flags as superglue_flags
# Device configs.
......@@ -81,13 +80,13 @@ def _override_exp_config_by_file(exp_config, exp_config_files):
def _override_exp_config_by_flags(exp_config, input_meta_data):
"""Overrides an `ExperimentConfig` object by flags."""
if FLAGS.task_name in ('AX-b'):
if FLAGS.task_name in 'AX-b':
override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels'],
metric_type='matthews_corrcoef')
elif FLAGS.task_name in ('CB', 'COPA', 'RTE', 'WiC', 'WSC', 'BoolQ',
'ReCoRD','AX-g'):
'ReCoRD', 'AX-g'):
override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels'])
......@@ -152,7 +151,7 @@ def _write_submission_file(task, seq_length):
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.read(ckpt_file).expect_partial()
write_fn = binary_helper.write_glue_classification
write_fn = binary_helper.write_superglue_classification
write_fn_map = {
'RTE':
functools.partial(
......@@ -176,7 +175,7 @@ def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
glue_flags.validate_flags(FLAGS, file_exists_fn=tf.io.gfile.exists)
superglue_flags.validate_flags(FLAGS, file_exists_fn=tf.io.gfile.exists)
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
distribution_strategy = distribute_utils.get_distribution_strategy(
......@@ -218,7 +217,7 @@ def main(argv):
if __name__ == '__main__':
glue_flags.define_flags()
superglue_flags.define_flags()
flags.mark_flag_as_required('mode')
flags.mark_flag_as_required('task_name')
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment