Commit e748d785 authored by stephenwu's avatar stephenwu
Browse files

fixed style issues

parent 790e49e5
......@@ -309,16 +309,16 @@ def write_glue_classification(task,
# Classification.
writer.write('%d\t%s\n' % (index, class_names[prediction]))
def write_superglue_classification(task,
model,
input_file,
output_file,
predict_batch_size,
seq_length,
class_names,
label_type='int',
min_float_value=None,
max_float_value=None):
"""Makes classification predictions for glue and writes to output file.
model,
input_file,
output_file,
predict_batch_size,
seq_length,
class_names,
label_type='int',
min_float_value=None,
max_float_value=None):
"""Makes classification predictions for superglue and writes to output file.
Args:
task: `Task` instance.
......@@ -350,7 +350,6 @@ def write_superglue_classification(task,
include_example_id=True)
predictions = sentence_prediction.predict(task, data_config, model)
with tf.io.gfile.GFile(output_file, 'w') as writer:
for index, prediction in enumerate(predictions):
if label_type == 'int':
......
......@@ -36,8 +36,8 @@ def define_flags():
'run prediction using the model in `model_dir`.')
flags.DEFINE_enum('task_name', None, [
'AX-b', 'CB', 'COPA', 'MULTIRC', 'RTE', 'WiC', 'WSC',
'BoolQ', 'ReCoRD', 'AX-g',
'AX-b', 'CB', 'COPA', 'MULTIRC', 'RTE', 'WiC', 'WSC',
'BoolQ', 'ReCoRD', 'AX-g',
], 'The type of GLUE task.')
flags.DEFINE_string('train_input_path', None,
......@@ -160,4 +160,4 @@ def validate_flags(flags_obj: flags.FlagValues,
_validate_path(flags_obj.model_config_file, 'model_config_file')
logging.info(
'Using the pretrained checkpoint from %s and model_config_file from '
'%s.', flags_obj.init_checkpoint, flags_obj.model_config_file)
\ No newline at end of file
'%s.', flags_obj.init_checkpoint, flags_obj.model_config_file)
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runs prediction to generate submission files for GLUE tasks."""
"""Runs prediction to generate submission files for SuperGLUE tasks."""
import functools
import json
import os
......@@ -27,14 +27,13 @@ import tensorflow as tf
from official.common import distribute_utils
# Imports registered experiment configs.
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling.hyperparams import params_dict
from official.nlp.finetuning import binary_helper
from official.nlp.finetuning.superglue import flags as glue_flags
from official.nlp.finetuning.superglue import flags as superglue_flags
# Device configs.
......@@ -81,13 +80,13 @@ def _override_exp_config_by_file(exp_config, exp_config_files):
def _override_exp_config_by_flags(exp_config, input_meta_data):
"""Overrides an `ExperimentConfig` object by flags."""
if FLAGS.task_name in ('AX-b'):
if FLAGS.task_name in 'AX-b':
override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels'],
metric_type='matthews_corrcoef')
elif FLAGS.task_name in ('CB', 'COPA', 'RTE', 'WiC', 'WSC', 'BoolQ',
'ReCoRD','AX-g'):
'ReCoRD', 'AX-g'):
override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels'])
......@@ -152,7 +151,7 @@ def _write_submission_file(task, seq_length):
checkpoint = tf.train.Checkpoint(model=model)
checkpoint.read(ckpt_file).expect_partial()
write_fn = binary_helper.write_glue_classification
write_fn = binary_helper.write_superglue_classification
write_fn_map = {
'RTE':
functools.partial(
......@@ -176,7 +175,7 @@ def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
glue_flags.validate_flags(FLAGS, file_exists_fn=tf.io.gfile.exists)
superglue_flags.validate_flags(FLAGS, file_exists_fn=tf.io.gfile.exists)
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
distribution_strategy = distribute_utils.get_distribution_strategy(
......@@ -218,7 +217,7 @@ def main(argv):
if __name__ == '__main__':
glue_flags.define_flags()
superglue_flags.define_flags()
flags.mark_flag_as_required('mode')
flags.mark_flag_as_required('task_name')
app.run(main)
\ No newline at end of file
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment