Commit e748d785 authored by stephenwu's avatar stephenwu
Browse files

fixed style issues

parent 790e49e5
...@@ -309,16 +309,16 @@ def write_glue_classification(task, ...@@ -309,16 +309,16 @@ def write_glue_classification(task,
# Classification. # Classification.
writer.write('%d\t%s\n' % (index, class_names[prediction])) writer.write('%d\t%s\n' % (index, class_names[prediction]))
def write_superglue_classification(task, def write_superglue_classification(task,
model, model,
input_file, input_file,
output_file, output_file,
predict_batch_size, predict_batch_size,
seq_length, seq_length,
class_names, class_names,
label_type='int', label_type='int',
min_float_value=None, min_float_value=None,
max_float_value=None): max_float_value=None):
"""Makes classification predictions for glue and writes to output file. """Makes classification predictions for superglue and writes to output file.
Args: Args:
task: `Task` instance. task: `Task` instance.
...@@ -350,7 +350,6 @@ def write_superglue_classification(task, ...@@ -350,7 +350,6 @@ def write_superglue_classification(task,
include_example_id=True) include_example_id=True)
predictions = sentence_prediction.predict(task, data_config, model) predictions = sentence_prediction.predict(task, data_config, model)
with tf.io.gfile.GFile(output_file, 'w') as writer: with tf.io.gfile.GFile(output_file, 'w') as writer:
for index, prediction in enumerate(predictions): for index, prediction in enumerate(predictions):
if label_type == 'int': if label_type == 'int':
......
...@@ -36,8 +36,8 @@ def define_flags(): ...@@ -36,8 +36,8 @@ def define_flags():
'run prediction using the model in `model_dir`.') 'run prediction using the model in `model_dir`.')
flags.DEFINE_enum('task_name', None, [ flags.DEFINE_enum('task_name', None, [
'AX-b', 'CB', 'COPA', 'MULTIRC', 'RTE', 'WiC', 'WSC', 'AX-b', 'CB', 'COPA', 'MULTIRC', 'RTE', 'WiC', 'WSC',
'BoolQ', 'ReCoRD', 'AX-g', 'BoolQ', 'ReCoRD', 'AX-g',
], 'The type of GLUE task.') ], 'The type of GLUE task.')
flags.DEFINE_string('train_input_path', None, flags.DEFINE_string('train_input_path', None,
...@@ -160,4 +160,4 @@ def validate_flags(flags_obj: flags.FlagValues, ...@@ -160,4 +160,4 @@ def validate_flags(flags_obj: flags.FlagValues,
_validate_path(flags_obj.model_config_file, 'model_config_file') _validate_path(flags_obj.model_config_file, 'model_config_file')
logging.info( logging.info(
'Using the pretrained checkpoint from %s and model_config_file from ' 'Using the pretrained checkpoint from %s and model_config_file from '
'%s.', flags_obj.init_checkpoint, flags_obj.model_config_file) '%s.', flags_obj.init_checkpoint, flags_obj.model_config_file)
\ No newline at end of file
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Runs prediction to generate submission files for GLUE tasks.""" """Runs prediction to generate submission files for SuperGLUE tasks."""
import functools import functools
import json import json
import os import os
...@@ -27,14 +27,13 @@ import tensorflow as tf ...@@ -27,14 +27,13 @@ import tensorflow as tf
from official.common import distribute_utils from official.common import distribute_utils
# Imports registered experiment configs. # Imports registered experiment configs.
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory from official.core import exp_factory
from official.core import task_factory from official.core import task_factory
from official.core import train_lib from official.core import train_lib
from official.core import train_utils from official.core import train_utils
from official.modeling.hyperparams import params_dict from official.modeling.hyperparams import params_dict
from official.nlp.finetuning import binary_helper from official.nlp.finetuning import binary_helper
from official.nlp.finetuning.superglue import flags as glue_flags from official.nlp.finetuning.superglue import flags as superglue_flags
# Device configs. # Device configs.
...@@ -81,13 +80,13 @@ def _override_exp_config_by_file(exp_config, exp_config_files): ...@@ -81,13 +80,13 @@ def _override_exp_config_by_file(exp_config, exp_config_files):
def _override_exp_config_by_flags(exp_config, input_meta_data): def _override_exp_config_by_flags(exp_config, input_meta_data):
"""Overrides an `ExperimentConfig` object by flags.""" """Overrides an `ExperimentConfig` object by flags."""
if FLAGS.task_name in ('AX-b'): if FLAGS.task_name in 'AX-b':
override_task_cfg_fn = functools.partial( override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config, binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels'], num_classes=input_meta_data['num_labels'],
metric_type='matthews_corrcoef') metric_type='matthews_corrcoef')
elif FLAGS.task_name in ('CB', 'COPA', 'RTE', 'WiC', 'WSC', 'BoolQ', elif FLAGS.task_name in ('CB', 'COPA', 'RTE', 'WiC', 'WSC', 'BoolQ',
'ReCoRD','AX-g'): 'ReCoRD', 'AX-g'):
override_task_cfg_fn = functools.partial( override_task_cfg_fn = functools.partial(
binary_helper.override_sentence_prediction_task_config, binary_helper.override_sentence_prediction_task_config,
num_classes=input_meta_data['num_labels']) num_classes=input_meta_data['num_labels'])
...@@ -152,7 +151,7 @@ def _write_submission_file(task, seq_length): ...@@ -152,7 +151,7 @@ def _write_submission_file(task, seq_length):
checkpoint = tf.train.Checkpoint(model=model) checkpoint = tf.train.Checkpoint(model=model)
checkpoint.read(ckpt_file).expect_partial() checkpoint.read(ckpt_file).expect_partial()
write_fn = binary_helper.write_glue_classification write_fn = binary_helper.write_superglue_classification
write_fn_map = { write_fn_map = {
'RTE': 'RTE':
functools.partial( functools.partial(
...@@ -176,7 +175,7 @@ def main(argv): ...@@ -176,7 +175,7 @@ def main(argv):
if len(argv) > 1: if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.') raise app.UsageError('Too many command-line arguments.')
glue_flags.validate_flags(FLAGS, file_exists_fn=tf.io.gfile.exists) superglue_flags.validate_flags(FLAGS, file_exists_fn=tf.io.gfile.exists)
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
distribution_strategy = distribute_utils.get_distribution_strategy( distribution_strategy = distribute_utils.get_distribution_strategy(
...@@ -218,7 +217,7 @@ def main(argv): ...@@ -218,7 +217,7 @@ def main(argv):
if __name__ == '__main__': if __name__ == '__main__':
glue_flags.define_flags() superglue_flags.define_flags()
flags.mark_flag_as_required('mode') flags.mark_flag_as_required('mode')
flags.mark_flag_as_required('task_name') flags.mark_flag_as_required('task_name')
app.run(main) app.run(main)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment