Unverified Commit 8e36da7a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #347 from jplehmann/feature/sst2-processor

Processor for SST-2 task
parents 21c88a07 0f96d4b1
...@@ -196,6 +196,37 @@ class ColaProcessor(DataProcessor): ...@@ -196,6 +196,37 @@ class ColaProcessor(DataProcessor):
return examples return examples
class Sst2Processor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[0]
label = line[1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
...@@ -401,10 +432,12 @@ def main(): ...@@ -401,10 +432,12 @@ def main():
"cola": ColaProcessor, "cola": ColaProcessor,
"mnli": MnliProcessor, "mnli": MnliProcessor,
"mrpc": MrpcProcessor, "mrpc": MrpcProcessor,
"sst-2": Sst2Processor,
} }
num_labels_task = { num_labels_task = {
"cola": 2, "cola": 2,
"sst-2": 2,
"mnli": 3, "mnli": 3,
"mrpc": 2, "mrpc": 2,
} }
...@@ -597,7 +630,7 @@ def main(): ...@@ -597,7 +630,7 @@ def main():
model.eval() model.eval()
eval_loss, eval_accuracy = 0, 0 eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0 nb_eval_steps, nb_eval_examples = 0, 0
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"): for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.to(device) input_mask = input_mask.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment