Commit 7de5c6aa authored by Matthew Carrigan's avatar Matthew Carrigan
Browse files

PEP8 and formatting cleanups

parent 1798e98e
...@@ -9,7 +9,7 @@ from collections import namedtuple ...@@ -9,7 +9,7 @@ from collections import namedtuple
from torch.utils.data import DataLoader, Dataset, RandomSampler from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm
from pytorch_pretrained_bert.modeling import BertForPreTraining from pytorch_pretrained_bert.modeling import BertForPreTraining
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
...@@ -149,7 +149,8 @@ def main(): ...@@ -149,7 +149,8 @@ def main():
help="random seed for initialization") help="random seed for initialization")
args = parser.parse_args() args = parser.parse_args()
assert args.pregenerated_data.is_dir(), "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!" assert args.pregenerated_data.is_dir(), \
"--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"
samples_per_epoch = [] samples_per_epoch = []
for i in range(args.epochs): for i in range(args.epochs):
...@@ -237,7 +238,8 @@ def main(): ...@@ -237,7 +238,8 @@ def main():
from apex.optimizers import FP16_Optimizer from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam from apex.optimizers import FusedAdam
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters, optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
...@@ -293,7 +295,8 @@ def main(): ...@@ -293,7 +295,8 @@ def main():
if args.fp16: if args.fp16:
# modify learning rate with special warm up BERT uses # modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically # if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion) lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps,
args.warmup_proportion)
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step param_group['lr'] = lr_this_step
optimizer.step() optimizer.step()
......
...@@ -269,6 +269,5 @@ def main(): ...@@ -269,6 +269,5 @@ def main():
metrics_file.write(json.dumps(metrics)) metrics_file.write(json.dumps(metrics))
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment