Commit 1c933358 authored by thomwolf's avatar thomwolf
Browse files

formating

parent e25b6fe3
...@@ -18,8 +18,9 @@ ...@@ -18,8 +18,9 @@
import logging import logging
import os import os
from utils_hans import DataProcessor, InputExample, InputFeatures
from transformers.file_utils import is_tf_available from transformers.file_utils import is_tf_available
from utils_hans import DataProcessor, InputExample, InputFeatures
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
...@@ -27,7 +28,9 @@ if is_tf_available(): ...@@ -27,7 +28,9 @@ if is_tf_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def hans_convert_examples_to_features(examples, tokenizer, def hans_convert_examples_to_features(
examples,
tokenizer,
max_length=512, max_length=512,
task=None, task=None,
label_list=None, label_list=None,
...@@ -35,7 +38,8 @@ def hans_convert_examples_to_features(examples, tokenizer, ...@@ -35,7 +38,8 @@ def hans_convert_examples_to_features(examples, tokenizer,
pad_on_left=False, pad_on_left=False,
pad_token=0, pad_token=0,
pad_token_segment_id=0, pad_token_segment_id=0,
mask_padding_with_zero=True): mask_padding_with_zero=True,
):
""" """
Loads a data file into a list of ``InputFeatures`` Loads a data file into a list of ``InputFeatures``
...@@ -82,12 +86,7 @@ def hans_convert_examples_to_features(examples, tokenizer, ...@@ -82,12 +86,7 @@ def hans_convert_examples_to_features(examples, tokenizer,
example = processor.get_example_from_tensor_dict(example) example = processor.get_example_from_tensor_dict(example)
example = processor.tfds_map(example) example = processor.tfds_map(example)
inputs = tokenizer.encode_plus( inputs = tokenizer.encode_plus(example.text_a, example.text_b, add_special_tokens=True, max_length=max_length,)
example.text_a,
example.text_b,
add_special_tokens=True,
max_length=max_length,
)
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
# The mask has 1 for real tokens and 0 for padding tokens. Only real # The mask has 1 for real tokens and 0 for padding tokens. Only real
...@@ -106,8 +105,12 @@ def hans_convert_examples_to_features(examples, tokenizer, ...@@ -106,8 +105,12 @@ def hans_convert_examples_to_features(examples, tokenizer,
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length) assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask), max_length) assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids), max_length) len(attention_mask), max_length
)
assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(
len(token_type_ids), max_length
)
if output_mode == "classification": if output_mode == "classification":
label = label_map[example.label] if example.label in label_map else 0 label = label_map[example.label] if example.label in label_map else 0
...@@ -128,28 +131,40 @@ def hans_convert_examples_to_features(examples, tokenizer, ...@@ -128,28 +131,40 @@ def hans_convert_examples_to_features(examples, tokenizer,
logger.info("label: %s (id = %d)" % (example.label, label)) logger.info("label: %s (id = %d)" % (example.label, label))
features.append( features.append(
InputFeatures(input_ids=input_ids, InputFeatures(
input_ids=input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
label=label, pairID=pairID)) label=label,
pairID=pairID,
)
)
if is_tf_available() and is_tf_dataset: if is_tf_available() and is_tf_dataset:
def gen(): def gen():
for ex in features: for ex in features:
yield ({'input_ids': ex.input_ids, yield (
'attention_mask': ex.attention_mask, {
'token_type_ids': ex.token_type_ids}, "input_ids": ex.input_ids,
ex.label) "attention_mask": ex.attention_mask,
"token_type_ids": ex.token_type_ids,
return tf.data.Dataset.from_generator(gen, },
({'input_ids': tf.int32, ex.label,
'attention_mask': tf.int32, )
'token_type_ids': tf.int32},
tf.int64), return tf.data.Dataset.from_generator(
({'input_ids': tf.TensorShape([None]), gen,
'attention_mask': tf.TensorShape([None]), ({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
'token_type_ids': tf.TensorShape([None])}, (
tf.TensorShape([]))) {
"input_ids": tf.TensorShape([None]),
"attention_mask": tf.TensorShape([None]),
"token_type_ids": tf.TensorShape([None]),
},
tf.TensorShape([]),
),
)
return features return features
...@@ -159,21 +174,20 @@ class HansProcessor(DataProcessor): ...@@ -159,21 +174,20 @@ class HansProcessor(DataProcessor):
def get_example_from_tensor_dict(self, tensor_dict): def get_example_from_tensor_dict(self, tensor_dict):
"""See base class.""" """See base class."""
return InputExample(tensor_dict['idx'].numpy(), return InputExample(
tensor_dict['premise'].numpy().decode('utf-8'), tensor_dict["idx"].numpy(),
tensor_dict['hypothesis'].numpy().decode('utf-8'), tensor_dict["premise"].numpy().decode("utf-8"),
str(tensor_dict['label'].numpy())) tensor_dict["hypothesis"].numpy().decode("utf-8"),
str(tensor_dict["label"].numpy()),
)
def get_train_examples(self, data_dir): def get_train_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "heuristics_train_set.txt")), "train")
self._read_tsv(os.path.join(data_dir, "heuristics_train_set.txt")), "train")
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
"""See base class.""" """See base class."""
return self._create_examples( return self._create_examples(self._read_tsv(os.path.join(data_dir, "heuristics_evaluation_set.txt")), "dev")
self._read_tsv(os.path.join(data_dir, "heuristics_evaluation_set.txt")),
"dev")
def get_labels(self): def get_labels(self):
"""See base class.""" """See base class."""
...@@ -188,14 +202,12 @@ class HansProcessor(DataProcessor): ...@@ -188,14 +202,12 @@ class HansProcessor(DataProcessor):
guid = "%s-%s" % (set_type, line[0]) guid = "%s-%s" % (set_type, line[0])
text_a = line[5] text_a = line[5]
text_b = line[6] text_b = line[6]
pairID = line[7][2:] if line[7].startswith('ex') else line[7] pairID = line[7][2:] if line[7].startswith("ex") else line[7]
label = line[-1] label = line[-1]
examples.append( examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, pairID=pairID))
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, pairID=pairID))
return examples return examples
glue_tasks_num_labels = { glue_tasks_num_labels = {
"hans": 3, "hans": 3,
} }
...@@ -207,4 +219,3 @@ glue_processors = { ...@@ -207,4 +219,3 @@ glue_processors = {
glue_output_modes = { glue_output_modes = {
"hans": "classification", "hans": "classification",
} }
This diff is collapsed.
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import csv
import sys
import copy import copy
import csv
import json import json
import sys
class InputExample(object): class InputExample(object):
""" """
...@@ -32,6 +33,7 @@ class InputExample(object): ...@@ -32,6 +33,7 @@ class InputExample(object):
label: (Optional) string. The label of the example. This should be label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples. specified for train and dev examples, but not for test examples.
""" """
def __init__(self, guid, text_a, text_b=None, label=None, pairID=None): def __init__(self, guid, text_a, text_b=None, label=None, pairID=None):
self.guid = guid self.guid = guid
self.text_a = text_a self.text_a = text_a
...@@ -117,6 +119,6 @@ class DataProcessor(object): ...@@ -117,6 +119,6 @@ class DataProcessor(object):
lines = [] lines = []
for line in reader: for line in reader:
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line) line = list(unicode(cell, "utf-8") for cell in line)
lines.append(line) lines.append(line)
return lines return lines
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment