Commit 6b0da96b authored by thomwolf's avatar thomwolf
Browse files

clean up

parent 834b485b
...@@ -69,7 +69,7 @@ class InputFeatures(object): ...@@ -69,7 +69,7 @@ class InputFeatures(object):
self.input_mask = input_mask self.input_mask = input_mask
self.segment_ids = segment_ids self.segment_ids = segment_ids
self.label_id = label_id self.label_id = label_id
class DataProcessor(object): class DataProcessor(object):
"""Base class for data converters for sequence classification data sets.""" """Base class for data converters for sequence classification data sets."""
...@@ -95,8 +95,8 @@ class DataProcessor(object): ...@@ -95,8 +95,8 @@ class DataProcessor(object):
for line in reader: for line in reader:
lines.append(line) lines.append(line)
return lines return lines
class MrpcProcessor(DataProcessor): class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version).""" """Processor for the MRPC data set (GLUE version)."""
...@@ -190,10 +190,9 @@ class ColaProcessor(DataProcessor): ...@@ -190,10 +190,9 @@ class ColaProcessor(DataProcessor):
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
def convert_examples_to_features(examples, label_list, max_seq_length, def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
tokenizer):
"""Loads a data file into a list of `InputBatch`s.""" """Loads a data file into a list of `InputBatch`s."""
label_map = {} label_map = {}
...@@ -380,7 +379,7 @@ def main(): ...@@ -380,7 +379,7 @@ def main():
parser.add_argument("--do_lower_case", parser.add_argument("--do_lower_case",
default=False, default=False,
action='store_true', action='store_true',
help="Whether to lower case the input text. Should be True for uncased models and False for cased models.") help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--max_seq_length", parser.add_argument("--max_seq_length",
default=128, default=128,
type=int, type=int,
...@@ -424,6 +423,10 @@ def main(): ...@@ -424,6 +423,10 @@ def main():
default=False, default=False,
action='store_true', action='store_true',
help="Whether not to use CUDA when available") help="Whether not to use CUDA when available")
parser.add_argument("--accumulate_gradients",
type=int,
default=1,
help="Number of steps to accumulate gradient on (divide the single step batch_size)")
parser.add_argument("--local_rank", parser.add_argument("--local_rank",
type=int, type=int,
default=-1, default=-1,
...@@ -448,12 +451,12 @@ def main(): ...@@ -448,12 +451,12 @@ def main():
n_gpu = 1 n_gpu = 1
# print("Initializing the distributed backend: NCCL") # print("Initializing the distributed backend: NCCL")
print("device", device, "n_gpu", n_gpu) print("device", device, "n_gpu", n_gpu)
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
if n_gpu>0: torch.cuda.manual_seed_all(args.seed) if n_gpu>0: torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_eval: if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.") raise ValueError("At least one of `do_train` or `do_eval` must be True.")
......
...@@ -18,15 +18,15 @@ from __future__ import absolute_import ...@@ -18,15 +18,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import six
import argparse import argparse
import collections import collections
import logging import logging
import json import json
import math import math
import os import os
from tqdm import tqdm, trange import six
import random import random
from tqdm import tqdm, trange
import numpy as np import numpy as np
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment