Commit 6b0da96b authored by thomwolf's avatar thomwolf
Browse files

clean up

parent 834b485b
......@@ -69,7 +69,7 @@ class InputFeatures(object):
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
......@@ -95,8 +95,8 @@ class DataProcessor(object):
for line in reader:
lines.append(line)
return lines
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
......@@ -190,10 +190,9 @@ class ColaProcessor(DataProcessor):
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer):
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
"""Loads a data file into a list of `InputBatch`s."""
label_map = {}
......@@ -380,7 +379,7 @@ def main():
parser.add_argument("--do_lower_case",
default=False,
action='store_true',
help="Whether to lower case the input text. Should be True for uncased models and False for cased models.")
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--max_seq_length",
default=128,
type=int,
......@@ -424,6 +423,10 @@ def main():
default=False,
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--accumulate_gradients",
type=int,
default=1,
help="Number of steps to accumulate gradient on (divide the single step batch_size)")
parser.add_argument("--local_rank",
type=int,
default=-1,
......@@ -448,12 +451,12 @@ def main():
n_gpu = 1
# print("Initializing the distributed backend: NCCL")
print("device", device, "n_gpu", n_gpu)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu>0: torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
......
......@@ -18,15 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import argparse
import collections
import logging
import json
import math
import os
from tqdm import tqdm, trange
import six
import random
from tqdm import tqdm, trange
import numpy as np
import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment