Commit 49a77ac1 authored by samuel.broscheit's avatar samuel.broscheit
Browse files

Clean up a little bit

parent 3bf3f959
...@@ -736,9 +736,28 @@ def main(): ...@@ -736,9 +736,28 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None # Prepare model
num_train_optimization_steps = None cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=cache_dir,
num_labels=num_labels)
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
if args.do_train: if args.do_train:
# Prepare data loader
train_examples = processor.get_train_examples(args.data_dir) train_examples = processor.get_train_examples(args.data_dir)
train_features = convert_examples_to_features( train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer, output_mode) train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
...@@ -762,26 +781,8 @@ def main(): ...@@ -762,26 +781,8 @@ def main():
if args.local_rank != -1: if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=cache_dir,
num_labels=num_labels)
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
...@@ -815,7 +816,7 @@ def main(): ...@@ -815,7 +816,7 @@ def main():
global_step = 0 global_step = 0
nb_tr_steps = 0 nb_tr_steps = 0
tr_loss = 0 tr_loss = 0
if args.do_train:
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples)) logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
......
...@@ -894,14 +894,31 @@ def main(): ...@@ -894,14 +894,31 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None # Prepare model
num_train_optimization_steps = None model = BertForQuestionAnswering.from_pretrained(args.bert_model,
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)))
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
if args.do_train: if args.do_train:
# Prepare data loader
train_examples = read_squad_examples( train_examples = read_squad_examples(
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative) input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format( cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length)) list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
train_features = None
try: try:
with open(cached_train_features_file, "rb") as reader: with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader) train_features = pickle.load(reader)
...@@ -933,25 +950,8 @@ def main(): ...@@ -933,25 +950,8 @@ def main():
if args.local_rank != -1: if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)))
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
# hack to remove pooler, which is not used # hack to remove pooler, which is not used
...@@ -988,7 +988,7 @@ def main(): ...@@ -988,7 +988,7 @@ def main():
t_total=num_train_optimization_steps) t_total=num_train_optimization_steps)
global_step = 0 global_step = 0
if args.do_train:
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num orig examples = %d", len(train_examples)) logger.info(" Num orig examples = %d", len(train_examples))
logger.info(" Num split examples = %d", len(train_features)) logger.info(" Num split examples = %d", len(train_features))
......
...@@ -358,9 +358,27 @@ def main(): ...@@ -358,9 +358,27 @@ def main():
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
train_examples = None # Prepare model
num_train_optimization_steps = None model = BertForMultipleChoice.from_pretrained(args.bert_model,
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)),
num_choices=4)
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
if args.do_train: if args.do_train:
# Prepare data loader
train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True) train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
train_features = convert_examples_to_features( train_features = convert_examples_to_features(
train_examples, tokenizer, args.max_seq_length, True) train_examples, tokenizer, args.max_seq_length, True)
...@@ -379,25 +397,8 @@ def main(): ...@@ -379,25 +397,8 @@ def main():
if args.local_rank != -1: if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model
model = BertForMultipleChoice.from_pretrained(args.bert_model,
cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)),
num_choices=4)
if args.fp16:
model.half()
model.to(device)
if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
# Prepare optimizer # Prepare optimizer
if args.do_train:
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
# hack to remove pooler, which is not used # hack to remove pooler, which is not used
...@@ -433,7 +434,7 @@ def main(): ...@@ -433,7 +434,7 @@ def main():
t_total=num_train_optimization_steps) t_total=num_train_optimization_steps)
global_step = 0 global_step = 0
if args.do_train:
logger.info("***** Running training *****") logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples)) logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment