"docs/vscode:/vscode.git/clone" did not exist on "ea9336c0c169497a22c6f6f3ae26e67039b08b6d"
Commit 83fdbd60 authored by Grégory Châtel's avatar Grégory Châtel
Browse files

Adding read_swag_examples to load the dataset.

parent 7183cded
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
"""BERT finetuning runner.""" """BERT finetuning runner."""
import pandas as pd
class SwagExample(object): class SwagExample(object):
"""A single training/test example for the SWAG dataset.""" """A single training/test example for the SWAG dataset."""
def __init__(self, def __init__(self,
...@@ -53,26 +56,32 @@ class SwagExample(object): ...@@ -53,26 +56,32 @@ class SwagExample(object):
return ', '.join(l) return ', '.join(l)
if __name__ == "__main__": def read_swag_examples(input_file, is_training):
e = SwagExample( input_df = pd.read_csv(input_file)
3416,
'Members of the procession walk down the street holding small horn brass instruments.',
'A drum line',
'passes by walking down the street playing their instruments.',
'has heard approaching them.',
"arrives and they're outside dancing and asleep.",
'turns the lead singer watches the performance.',
)
print(e)
e = SwagExample( if is_training and 'label' not in input_df.columns:
3416, raise ValueError(
'Members of the procession walk down the street holding small horn brass instruments.', "For training, the input file must contain a label column.")
'A drum line',
'passes by walking down the street playing their instruments.', examples = [
'has heard approaching them.', SwagExample(
"arrives and they're outside dancing and asleep.", swag_id = row['fold-ind'],
'turns the lead singer watches the performance.', context_sentence = row['sent1'],
0 start_ending = row['sent2'],
) ending_0 = row['ending0'],
print(e) ending_1 = row['ending1'],
ending_2 = row['ending2'],
ending_3 = row['ending3'],
label = row['label'] if is_training else None
) for _, row in input_df.iterrows()
]
return examples
if __name__ == "__main__":
examples = read_swag_examples('data/train.csv', True)
print(len(examples))
for example in examples[:5]:
print('###########################')
print(example)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment