Commit df34f228 authored by Grégory Châtel's avatar Grégory Châtel
Browse files

Removing the dependency to pandas and using the csv module to load data.

parent 0876b77f
......@@ -14,13 +14,12 @@
# limitations under the License.
"""BERT finetuning runner."""
import pandas as pd
import logging
import os
import argparse
import random
from tqdm import tqdm, trange
import csv
import numpy as np
import torch
......@@ -100,25 +99,28 @@ class InputFeatures(object):
def read_swag_examples(input_file, is_training):
input_df = pd.read_csv(input_file)
with open(input_file, 'r') as f:
reader = csv.reader(f)
lines = list(reader)
if is_training and 'label' not in input_df.columns:
if is_training and lines[0][-1] != 'label':
raise ValueError(
"For training, the input file must contain a label column.")
"For training, the input file must contain a label column."
)
examples = [
SwagExample(
swag_id = row['fold-ind'],
context_sentence = row['sent1'],
start_ending = row['sent2'], # in the swag dataset, the
swag_id = line[2],
context_sentence = line[4],
start_ending = line[5], # in the swag dataset, the
# common beginning of each
# choice is stored in "sent2".
ending_0 = row['ending0'],
ending_1 = row['ending1'],
ending_2 = row['ending2'],
ending_3 = row['ending3'],
label = row['label'] if is_training else None
) for _, row in input_df.iterrows()
ending_0 = line[7],
ending_1 = line[8],
ending_2 = line[9],
ending_3 = line[10],
label = int(line[11]) if is_training else None
) for line in lines[1:] # we skip the line with the column names
]
return examples
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment