Commit c45d8ac5 authored by Grégory Châtel's avatar Grégory Châtel
Browse files

Storing the feature of each choice as a dict for readability.

parent 0812aee2
...@@ -73,9 +73,17 @@ class InputFeatures(object): ...@@ -73,9 +73,17 @@ class InputFeatures(object):
example_id, example_id,
choices_features, choices_features,
label label
): ):
self.example_id = example_id self.example_id = example_id
self.choices_features = choices_features self.choices_features = [
{
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids
}
for _, input_ids, input_mask, segment_ids in choices_features
]
self.label = label self.label = label
def read_swag_examples(input_file, is_training): def read_swag_examples(input_file, is_training):
...@@ -181,6 +189,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -181,6 +189,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
return features return features
def _truncate_seq_pair(tokens_a, tokens_b, max_length): def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length.""" """Truncates a sequence pair in place to the maximum length."""
...@@ -207,4 +216,11 @@ if __name__ == "__main__": ...@@ -207,4 +216,11 @@ if __name__ == "__main__":
print("###########################") print("###########################")
print(example) print(example)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
features = convert_examples_to_features(examples, tokenizer, max_seq_length, is_training) features = convert_examples_to_features(examples[:500], tokenizer, max_seq_length, is_training)
for i in range(10):
choice_feature_list = features[i].choices_features
for choice_idx, choice_feature in enumerate(choice_feature_list):
print(f'choice_idx: {choice_idx}')
print(f'input_ids: {" ".join(map(str, choice_feature["input_ids"]))}')
print(f'input_mask: {" ".join(map(str, choice_feature["input_mask"]))}')
print(f'segment_ids: {" ".join(map(str, choice_feature["segment_ids"]))}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment