"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "3bd95b0fafbfc00135daca6bdd547a7c6d610665"
Commit 83fdbd60 authored by Grégory Châtel's avatar Grégory Châtel
Browse files

Adding read_swag_examples to load the dataset.

parent 7183cded
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
"""BERT finetuning runner.""" """BERT finetuning runner."""
import pandas as pd
class SwagExample(object): class SwagExample(object):
"""A single training/test example for the SWAG dataset.""" """A single training/test example for the SWAG dataset."""
def __init__(self, def __init__(self,
...@@ -53,26 +56,32 @@ class SwagExample(object): ...@@ -53,26 +56,32 @@ class SwagExample(object):
return ', '.join(l) return ', '.join(l)
if __name__ == "__main__": def read_swag_examples(input_file, is_training):
e = SwagExample( input_df = pd.read_csv(input_file)
3416,
'Members of the procession walk down the street holding small horn brass instruments.', if is_training and 'label' not in input_df.columns:
'A drum line', raise ValueError(
'passes by walking down the street playing their instruments.', "For training, the input file must contain a label column.")
'has heard approaching them.',
"arrives and they're outside dancing and asleep.",
'turns the lead singer watches the performance.',
)
print(e)
e = SwagExample( examples = [
3416, SwagExample(
'Members of the procession walk down the street holding small horn brass instruments.', swag_id = row['fold-ind'],
'A drum line', context_sentence = row['sent1'],
'passes by walking down the street playing their instruments.', start_ending = row['sent2'],
'has heard approaching them.', ending_0 = row['ending0'],
"arrives and they're outside dancing and asleep.", ending_1 = row['ending1'],
'turns the lead singer watches the performance.', ending_2 = row['ending2'],
0 ending_3 = row['ending3'],
) label = row['label'] if is_training else None
print(e) ) for _, row in input_df.iterrows()
]
return examples
if __name__ == "__main__":
examples = read_swag_examples('data/train.csv', True)
print(len(examples))
for example in examples[:5]:
print('###########################')
print(example)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment