Commit 05ea0cca authored by Neel Kant's avatar Neel Kant
Browse files

Change sync variable to gloo backend

parent a670b6c9
...@@ -16,7 +16,7 @@ from megatron.data.samplers import DistributedBatchSampler ...@@ -16,7 +16,7 @@ from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever from megatron.model import REALMRetriever
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group, get_data_parallel_group from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group, get_data_parallel_group, get_gloo_comm_group
from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group, init_realm_groups from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group, init_realm_groups
from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard
from megatron.training import get_model from megatron.training import get_model
...@@ -176,10 +176,10 @@ class AsyncIndexBuilder(object): ...@@ -176,10 +176,10 @@ class AsyncIndexBuilder(object):
INDEX_READY = 1 - INDEX_READY INDEX_READY = 1 - INDEX_READY
print("Switched INDEX_READY", flush=True) print("Switched INDEX_READY", flush=True)
torch.cuda.synchronize() torch.cuda.synchronize()
send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, async_op=True) send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, group=get_gloo_comm_group(), async_op=True)
torch.distributed.barrier(get_data_parallel_group()) torch.distributed.barrier(get_data_parallel_group())
recv_handle = dist.broadcast(INDEX_READY, 0) dist.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
class BasicIndexBuilder(object): class BasicIndexBuilder(object):
...@@ -287,12 +287,14 @@ def get_ict_dataset(use_titles=True): ...@@ -287,12 +287,14 @@ def get_ict_dataset(use_titles=True):
return dataset return dataset
def get_one_epoch_dataloader(dataset): def get_one_epoch_dataloader(dataset, batch_size=None):
args = get_args() args = get_args()
world_size = mpu.get_data_parallel_world_size() world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank() rank = mpu.get_data_parallel_rank()
global_batch_size = args.batch_size * world_size if batch_size is None:
batch_size = args.batch_size
global_batch_size = batch_size * world_size
num_workers = args.num_workers num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset) sampler = torch.utils.data.SequentialSampler(dataset)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Megatron initialization.""" """Megatron initialization."""
import datetime
import random import random
import os import os
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Model and data parallel groups.""" """Model and data parallel groups."""
import datetime
import torch import torch
from .utils import ensure_divisibility from .utils import ensure_divisibility
...@@ -26,6 +27,7 @@ _MODEL_PARALLEL_GROUP = None ...@@ -26,6 +27,7 @@ _MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to. # Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None _DATA_PARALLEL_GROUP = None
_GLOO_COMM_GROUP = None
_TRAIN_GROUP = None _TRAIN_GROUP = None
_INDEX_GROUP = None _INDEX_GROUP = None
_INDEX_READY = None _INDEX_READY = None
...@@ -177,12 +179,22 @@ def destroy_model_parallel(): ...@@ -177,12 +179,22 @@ def destroy_model_parallel():
def init_realm_groups(max_training_rank, world_size): def init_realm_groups(max_training_rank, world_size):
global _GLOO_COMM_GROUP
_GLOO_COMM_GROUP = torch.distributed.new_group(list(range(world_size)),
backend="gloo",
timeout=datetime.timedelta(0, 7200))
global _TRAIN_GROUP global _TRAIN_GROUP
_TRAIN_GROUP = torch.distributed.new_group(list(range(max_training_rank))) _TRAIN_GROUP = torch.distributed.new_group(list(range(max_training_rank)))
global _INDEX_GROUP global _INDEX_GROUP
_INDEX_GROUP = torch.distributed.new_group(list(range(max_training_rank, world_size))) _INDEX_GROUP = torch.distributed.new_group(list(range(max_training_rank, world_size)))
global _INDEX_READY global _INDEX_READY
_INDEX_READY = torch.zeros(1).cuda() _INDEX_READY = torch.zeros(1)
def get_gloo_comm_group():
global _GLOO_COMM_GROUP
assert _GLOO_COMM_GROUP is not None
return _GLOO_COMM_GROUP
def get_train_group(): def get_train_group():
......
...@@ -36,7 +36,7 @@ from megatron.initialize import initialize_megatron ...@@ -36,7 +36,7 @@ from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from megatron.mpu.initialize import get_index_ready, get_train_group, get_data_parallel_group from megatron.mpu.initialize import get_index_ready, get_train_group, get_data_parallel_group, get_gloo_comm_group
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader from megatron.utils import make_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
...@@ -374,14 +374,20 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -374,14 +374,20 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start() timers('interval time').start()
report_memory_flag = True report_memory_flag = True
global INDEX_READY global INDEX_READY
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True) # start off by posting a receive call which will be answered.
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
while iteration < args.train_iters: while iteration < args.train_iters:
if args.max_training_rank is not None and INDEX_READY == 1:
# this only applies for realm right here
if args.max_training_rank is not None and recv_handle.is_completed():
# should add check that INDEX_READY == 1 but what else could be happening
true_model = model true_model = model
if hasattr(true_model, 'module'): if hasattr(true_model, 'module'):
true_model = true_model.module true_model = true_model.module
if hasattr(true_model, 'module'): if hasattr(true_model, 'module'):
true_model = true_model.module true_model = true_model.module
print(">>>>>>> starting to reload index", flush=True) print(">>>>>>> starting to reload index", flush=True)
true_model.retriever.reload_index() true_model.retriever.reload_index()
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
...@@ -390,10 +396,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -390,10 +396,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
INDEX_READY = 1 - INDEX_READY INDEX_READY = 1 - INDEX_READY
print(">>> Switched index ready", flush=True) print(">>> Switched index ready", flush=True)
torch.cuda.synchronize() torch.cuda.synchronize()
send_handle = torch.distributed.broadcast(INDEX_READY, 0) send_handle = torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
torch.distributed.barrier(get_data_parallel_group()) torch.distributed.barrier(get_data_parallel_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True) recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
else: elif iteration < 100:
print("moving right along", flush=True) print("moving right along", flush=True)
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter = train_step(forward_step_func,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment