Commit 965b2565 authored by thomwolf's avatar thomwolf
Browse files

add distributed training

parent 1ceac85e
...@@ -449,14 +449,15 @@ def main(): ...@@ -449,14 +449,15 @@ def main():
else: else:
device = torch.device("cuda", args.local_rank) device = torch.device("cuda", args.local_rank)
n_gpu = 1 n_gpu = 1
# print("Initializing the distributed backend: NCCL") # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
print("device", device, "n_gpu", n_gpu) torch.distributed.init_process_group(backend='nccl')
print("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
if args.accumulate_gradients < 1: if args.accumulate_gradients < 1:
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format( raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
args.accumulate_gradients)) args.accumulate_gradients))
args.batch_size = args.batch_size / args.accumulate_gradients args.train_batch_size = args.train_batch_size / args.accumulate_gradients
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
...@@ -502,7 +503,10 @@ def main(): ...@@ -502,7 +503,10 @@ def main():
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device) model.to(device)
if n_gpu > 1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
no_decay = ['bias', 'gamma', 'beta'] no_decay = ['bias', 'gamma', 'beta']
......
...@@ -748,13 +748,21 @@ def main(): ...@@ -748,13 +748,21 @@ def main():
else: else:
device = torch.device("cuda", args.local_rank) device = torch.device("cuda", args.local_rank)
n_gpu = 1 n_gpu = 1
# print("Initializing the distributed backend: NCCL") # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
print("device", device, "n_gpu", n_gpu) torch.distributed.init_process_group(backend='nccl')
print("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
if args.accumulate_gradients < 1:
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
args.accumulate_gradients))
args.train_batch_size = args.train_batch_size / args.accumulate_gradients
random.seed(args.seed) random.seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
if n_gpu>0: torch.cuda.manual_seed_all(args.seed) if n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_predict: if not args.do_train and not args.do_predict:
raise ValueError("At least one of `do_train` or `do_predict` must be True.") raise ValueError("At least one of `do_train` or `do_predict` must be True.")
...@@ -795,8 +803,11 @@ def main(): ...@@ -795,8 +803,11 @@ def main():
if args.init_checkpoint is not None: if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device) model.to(device)
if n_gpu > 1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
no_decay = ['bias', 'gamma', 'beta'] no_decay = ['bias', 'gamma', 'beta']
...@@ -831,7 +842,8 @@ def main(): ...@@ -831,7 +842,8 @@ def main():
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long) all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
all_start_positions, all_end_positions)
if args.local_rank == -1: if args.local_rank == -1:
train_sampler = RandomSampler(train_data) train_sampler = RandomSampler(train_data)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment