Commit 93cdbaf5 authored by stephenwu's avatar stephenwu
Browse files

fixed minor changes

parent e748d785
...@@ -315,9 +315,7 @@ def write_superglue_classification(task, ...@@ -315,9 +315,7 @@ def write_superglue_classification(task,
predict_batch_size, predict_batch_size,
seq_length, seq_length,
class_names, class_names,
label_type='int', label_type='int'):
min_float_value=None,
max_float_value=None):
"""Makes classification predictions for superglue and writes to output file. """Makes classification predictions for superglue and writes to output file.
Args: Args:
...@@ -329,12 +327,7 @@ def write_superglue_classification(task, ...@@ -329,12 +327,7 @@ def write_superglue_classification(task,
seq_length: Input sequence length. seq_length: Input sequence length.
class_names: List of string class names. class_names: List of string class names.
label_type: String denoting label type ('int', 'float'), defaults to 'int'. label_type: String denoting label type ('int', 'float'), defaults to 'int'.
min_float_value: If set, predictions will be min-clipped to this value (only
for regression when `label_type` is set to 'float'). Defaults to `None`
(no clipping).
max_float_value: If set, predictions will be max-clipped to this value (only
for regression when `label_type` is set to 'float'). Defaults to `None`
(no clipping).
""" """
if label_type not in ('int'): if label_type not in ('int'):
raise ValueError('Unsupported `label_type`. Given: %s, expected `int` or ' raise ValueError('Unsupported `label_type`. Given: %s, expected `int` or '
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Common flags for GLUE finetuning binary.""" """Common flags for SuperGLUE finetuning binary."""
from typing import Callable from typing import Callable
from absl import flags from absl import flags
...@@ -23,7 +23,7 @@ def define_flags(): ...@@ -23,7 +23,7 @@ def define_flags():
"""Defines flags.""" """Defines flags."""
# =========================================================================== # ===========================================================================
# Glue binary flags. # SuperGlue binary flags.
# =========================================================================== # ===========================================================================
flags.DEFINE_enum( flags.DEFINE_enum(
'mode', 'train_eval_and_predict', 'mode', 'train_eval_and_predict',
...@@ -38,7 +38,7 @@ def define_flags(): ...@@ -38,7 +38,7 @@ def define_flags():
flags.DEFINE_enum('task_name', None, [ flags.DEFINE_enum('task_name', None, [
'AX-b', 'CB', 'COPA', 'MULTIRC', 'RTE', 'WiC', 'WSC', 'AX-b', 'CB', 'COPA', 'MULTIRC', 'RTE', 'WiC', 'WSC',
'BoolQ', 'ReCoRD', 'AX-g', 'BoolQ', 'ReCoRD', 'AX-g',
], 'The type of GLUE task.') ], 'The type of SuperGLUE task.')
flags.DEFINE_string('train_input_path', None, flags.DEFINE_string('train_input_path', None,
'The file path to the training data.') 'The file path to the training data.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment