"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1485bd9c027bced10df2c93fa87bcb7918e52a08"
Commit 0812aee2 authored by Grégory Châtel's avatar Grégory Châtel
Browse files

Fixing problems in convert_examples_to_features.

parent f2b873e9
...@@ -70,20 +70,13 @@ class SwagExample(object): ...@@ -70,20 +70,13 @@ class SwagExample(object):
class InputFeatures(object): class InputFeatures(object):
def __init__(self, def __init__(self,
unique_id,
example_id, example_id,
input_ids, choices_features,
input_mask, label
segment_ids,
label_id
): ):
self.unique_id = unique_id
self.example_id = example_id self.example_id = example_id
self.input_ids = input_ids self.choices_features = choices_features
self.input_mask = input_mask self.label = label
self.segment_ids = segment_ids
self.label_id = label_id
def read_swag_examples(input_file, is_training): def read_swag_examples(input_file, is_training):
input_df = pd.read_csv(input_file) input_df = pd.read_csv(input_file)
...@@ -145,7 +138,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -145,7 +138,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
# place so that the total length is less than the # place so that the total length is less than the
# specified length. Account for [CLS], [SEP], [SEP] with # specified length. Account for [CLS], [SEP], [SEP] with
# "- 3" # "- 3"
_truncate_seq_pair(context_tokens, ending_tokens, max_seq_length - 3) _truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3)
tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"] tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"]
segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1) segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1)
...@@ -178,7 +171,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -178,7 +171,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
if is_training: if is_training:
logger.info(f"label: {label}") logger.info(f"label: {label}")
features.append(
InputFeatures(
example_id = example.swag_id,
choices_features = choices_features,
label = label
)
)
return features
def _truncate_seq_pair(tokens_a, tokens_b, max_length): def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length.""" """Truncates a sequence pair in place to the maximum length."""
...@@ -206,4 +207,4 @@ if __name__ == "__main__": ...@@ -206,4 +207,4 @@ if __name__ == "__main__":
print("###########################") print("###########################")
print(example) print(example)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
convert_examples_to_features(examples, tokenizer, max_seq_length, is_training) features = convert_examples_to_features(examples, tokenizer, max_seq_length, is_training)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment