Commit 0cf88ff0 authored by thomwolf's avatar thomwolf
Browse files

make examples work without apex

parent 52c53f39
...@@ -36,13 +36,6 @@ from pytorch_pretrained_bert.modeling import BertForSequenceClassification ...@@ -36,13 +36,6 @@ from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.")
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO) level = logging.INFO)
...@@ -467,6 +460,11 @@ def main(): ...@@ -467,6 +460,11 @@ def main():
model.half() model.half()
model.to(device) model.to(device)
if args.local_rank != -1: if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model) model = DDP(model)
elif n_gpu > 1: elif n_gpu > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
...@@ -482,6 +480,12 @@ def main(): ...@@ -482,6 +480,12 @@ def main():
if args.local_rank != -1: if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size() t_total = t_total // torch.distributed.get_world_size()
if args.fp16: if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters, optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
bias_correction=False, bias_correction=False,
......
...@@ -39,13 +39,6 @@ from pytorch_pretrained_bert.modeling import BertForQuestionAnswering ...@@ -39,13 +39,6 @@ from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this.")
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO) level = logging.INFO)
...@@ -813,6 +806,11 @@ def main(): ...@@ -813,6 +806,11 @@ def main():
model.half() model.half()
model.to(device) model.to(device)
if args.local_rank != -1: if args.local_rank != -1:
try:
from apex.parallel import DistributedDataParallel as DDP
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = DDP(model) model = DDP(model)
elif n_gpu > 1: elif n_gpu > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
...@@ -834,6 +832,12 @@ def main(): ...@@ -834,6 +832,12 @@ def main():
if args.local_rank != -1: if args.local_rank != -1:
t_total = t_total // torch.distributed.get_world_size() t_total = t_total // torch.distributed.get_world_size()
if args.fp16: if args.fp16:
try:
from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
optimizer = FusedAdam(optimizer_grouped_parameters, optimizer = FusedAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
bias_correction=False, bias_correction=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment