Commit 05ea0cca authored by Neel Kant's avatar Neel Kant
Browse files

Change sync variable to gloo backend

parent a670b6c9
......@@ -16,7 +16,7 @@ from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever
from megatron.global_vars import set_global_variables
from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group, get_data_parallel_group
from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group, get_data_parallel_group, get_gloo_comm_group
from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group, init_realm_groups
from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard
from megatron.training import get_model
......@@ -176,10 +176,10 @@ class AsyncIndexBuilder(object):
INDEX_READY = 1 - INDEX_READY
print("Switched INDEX_READY", flush=True)
torch.cuda.synchronize()
send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, async_op=True)
send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, group=get_gloo_comm_group(), async_op=True)
torch.distributed.barrier(get_data_parallel_group())
recv_handle = dist.broadcast(INDEX_READY, 0)
dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
class BasicIndexBuilder(object):
......@@ -287,12 +287,14 @@ def get_ict_dataset(use_titles=True):
return dataset
def get_one_epoch_dataloader(dataset):
def get_one_epoch_dataloader(dataset, batch_size=None):
args = get_args()
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size
if batch_size is None:
batch_size = args.batch_size
global_batch_size = batch_size * world_size
num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset)
......
......@@ -15,6 +15,7 @@
"""Megatron initialization."""
import datetime
import random
import os
......
......@@ -16,6 +16,7 @@
"""Model and data parallel groups."""
import datetime
import torch
from .utils import ensure_divisibility
......@@ -26,6 +27,7 @@ _MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_GLOO_COMM_GROUP = None
_TRAIN_GROUP = None
_INDEX_GROUP = None
_INDEX_READY = None
......@@ -177,12 +179,22 @@ def destroy_model_parallel():
def init_realm_groups(max_training_rank, world_size):
global _GLOO_COMM_GROUP
_GLOO_COMM_GROUP = torch.distributed.new_group(list(range(world_size)),
backend="gloo",
timeout=datetime.timedelta(0, 7200))
global _TRAIN_GROUP
_TRAIN_GROUP = torch.distributed.new_group(list(range(max_training_rank)))
global _INDEX_GROUP
_INDEX_GROUP = torch.distributed.new_group(list(range(max_training_rank, world_size)))
global _INDEX_READY
_INDEX_READY = torch.zeros(1).cuda()
_INDEX_READY = torch.zeros(1)
def get_gloo_comm_group():
global _GLOO_COMM_GROUP
assert _GLOO_COMM_GROUP is not None
return _GLOO_COMM_GROUP
def get_train_group():
......
......@@ -36,7 +36,7 @@ from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.mpu.initialize import get_index_ready, get_train_group, get_data_parallel_group
from megatron.mpu.initialize import get_index_ready, get_train_group, get_data_parallel_group, get_gloo_comm_group
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader
from megatron.utils import report_memory
......@@ -374,14 +374,20 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start()
report_memory_flag = True
global INDEX_READY
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
# start off by posting a receive call which will be answered.
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
while iteration < args.train_iters:
if args.max_training_rank is not None and INDEX_READY == 1:
# this only applies for realm right here
if args.max_training_rank is not None and recv_handle.is_completed():
# should add check that INDEX_READY == 1 but what else could be happening
true_model = model
if hasattr(true_model, 'module'):
true_model = true_model.module
if hasattr(true_model, 'module'):
true_model = true_model.module
print(">>>>>>> starting to reload index", flush=True)
true_model.retriever.reload_index()
save_checkpoint(iteration, model, optimizer, lr_scheduler)
......@@ -390,10 +396,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
INDEX_READY = 1 - INDEX_READY
print(">>> Switched index ready", flush=True)
torch.cuda.synchronize()
send_handle = torch.distributed.broadcast(INDEX_READY, 0)
send_handle = torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
torch.distributed.barrier(get_data_parallel_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
else:
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
elif iteration < 100:
print("moving right along", flush=True)
loss_dict, skipped_iter = train_step(forward_step_func,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment