"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "a8409aba767c4c3737518d99de35813b0a194880"
Unverified Commit 8e36da7a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #347 from jplehmann/feature/sst2-processor

Processor for SST-2 task
parents 21c88a07 0f96d4b1
......@@ -196,6 +196,37 @@ class ColaProcessor(DataProcessor):
return examples
class Sst2Processor(DataProcessor):
"""Processor for the SST-2 data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = line[0]
label = line[1]
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
"""Loads a data file into a list of `InputBatch`s."""
......@@ -401,10 +432,12 @@ def main():
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"sst-2": Sst2Processor,
}
num_labels_task = {
"cola": 2,
"sst-2": 2,
"mnli": 3,
"mrpc": 2,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment