Commit a670b6c9 authored by Neel Kant's avatar Neel Kant
Browse files

Async works for total 8 GPU, indexer debug mode

parent 5684f904
......@@ -16,7 +16,7 @@ from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever
from megatron.global_vars import set_global_variables
from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group
from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group, get_data_parallel_group
from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group, init_realm_groups
from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard
from megatron.training import get_model
......@@ -67,12 +67,12 @@ def initialize_and_run_async_megatron(extra_args_provider=None, args_defaults={}
torch.distributed.barrier()
if args.rank < args.max_training_rank:
torch.distributed.barrier(get_train_group())
torch.distributed.barrier(get_data_parallel_group())
pprint("All trainers ready.")
return
else:
runner = AsyncIndexBuilder(args.rank)
torch.distributed.barrier(get_index_group())
torch.distributed.barrier(get_data_parallel_group())
pprint("All indexers ready.")
runner.run_async()
......@@ -123,6 +123,7 @@ class AsyncIndexBuilder(object):
try:
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True)
except:
print(">>>>> No realm chkpt available", flush=True)
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
......@@ -148,7 +149,7 @@ class AsyncIndexBuilder(object):
total += block_indices.size
i += 1
if i % 10 == 0:
if i % 500 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if self.debug:
break
......@@ -162,7 +163,7 @@ class AsyncIndexBuilder(object):
sys.exit(0)
self.block_data.save_shard(self.rank)
torch.distributed.barrier()
torch.distributed.barrier(get_data_parallel_group())
del self.model
if self.is_main_builder:
......@@ -174,12 +175,11 @@ class AsyncIndexBuilder(object):
if self.is_main_builder:
INDEX_READY = 1 - INDEX_READY
print("Switched INDEX_READY", flush=True)
import time
print(time.ctime(time.time()), flush=True)
torch.cuda.synchronize()
send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, async_op=True)
torch.distributed.barrier(get_index_group())
recv_handle = dist.broadcast(INDEX_READY, 0, async_op=True)
torch.distributed.barrier(get_data_parallel_group())
recv_handle = dist.broadcast(INDEX_READY, 0)
class BasicIndexBuilder(object):
......@@ -236,7 +236,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
# assert iteration > 0
checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
......@@ -245,6 +245,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model']
if from_realm_chkpt:
print(">>>> Attempting to get ict state dict from realm", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model:
......@@ -256,7 +257,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
model.load_state_dict(ict_state_dict)
else:
model.load_state_dict(ict_state_dict)
torch.distributed.barrier()
torch.distributed.barrier(get_data_parallel_group())
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
......@@ -290,9 +291,7 @@ def get_one_epoch_dataloader(dataset):
args = get_args()
world_size = mpu.get_data_parallel_world_size()
print(world_size, flush=True)
rank = mpu.get_data_parallel_rank()
print(rank, flush=True)
global_batch_size = args.batch_size * world_size
num_workers = args.num_workers
......
......@@ -24,7 +24,7 @@ import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu
from megatron.mpu.initialize import get_train_group
from megatron.mpu.initialize import get_train_group, get_data_parallel_group
from megatron import get_args
from megatron import print_rank_0
......@@ -119,14 +119,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print(' successfully saved {}'.format(checkpoint_name))
# Wait so everyone is done (necessary)
torch.distributed.barrier(get_train_group())
torch.distributed.barrier(get_data_parallel_group())
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier(get_train_group())
torch.distributed.barrier(get_data_parallel_group())
def load_checkpoint(model, optimizer, lr_scheduler):
......@@ -243,7 +243,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
'exiting ...'.format(checkpoint_name))
sys.exit()
torch.distributed.barrier()
# torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
......
......@@ -164,14 +164,14 @@ class _Timer:
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
# torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
# torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
......
......@@ -56,7 +56,7 @@ class DistributedDataParallel(MegatronModule):
if not no_scale and not reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
dist.all_reduce(coalesced, group=self.data_parallel_group)
torch.cuda.synchronize()
# torch.cuda.synchronize()
if not no_scale and reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
......
......@@ -199,7 +199,9 @@ class REALMRetriever(MegatronModule):
true_model = true_model.module
else:
true_model = self.ict_model
query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask))
# print("true model: ", true_model, flush=True)
query_embeds = detach(self.ict_model(query_tokens, query_pad_mask, None, None, only_query=True))
_, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
all_topk_tokens, all_topk_pad_masks = [], []
......@@ -268,7 +270,6 @@ class ICTBertModel(MegatronModule):
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask, only_query=False, only_block=False):
"""Run a forward pass for each of the models and compute the similarity scores."""
if only_query:
return self.embed_query(query_tokens, query_attention_mask)
......
......@@ -109,10 +109,8 @@ def set_model_parallel_group(group):
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
#print(">>> yeah this function works.")
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
#print(_DATA_PARALLEL_GROUP)
return _DATA_PARALLEL_GROUP
......
......@@ -36,7 +36,7 @@ from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.mpu.initialize import get_index_ready, get_train_group
from megatron.mpu.initialize import get_index_ready, get_train_group, get_data_parallel_group
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader
from megatron.utils import report_memory
......@@ -236,7 +236,7 @@ def backward_step(optimizer, model, loss):
"""Backward step."""
args = get_args()
timers = get_timers()
torch.cuda.synchronize()
# torch.cuda.synchronize()
# Backward pass.
optimizer.zero_grad(set_grads_to_None=True)
......@@ -373,19 +373,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start()
report_memory_flag = True
import time
print(">>> going to sleep", flush=True)
time.sleep(10)
print(">>> woke from sleep", flush=True)
print(time.ctime(time.time()), flush=True)
global INDEX_READY
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
print(">>>>>>>> Created recv handle", flush=True)
while iteration < args.train_iters:
print("INDEX READY: ", INDEX_READY)
if args.max_training_rank is not None and INDEX_READY == 1:
print(">>>>>>> entering the good stuff", flush=True)
true_model = model
if hasattr(true_model, 'module'):
true_model = true_model.module
......@@ -393,24 +384,24 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
true_model = true_model.module
print(">>>>>>> starting to reload index", flush=True)
true_model.retriever.reload_index()
print(">>>>>>> starting to save checkpoint", flush=True)
save_checkpoint(iteration, model, optimizer, lr_scheduler)
print(">>>>>>> saved checkpoint", flush=True)
if args.rank == 0:
INDEX_READY = 1 - INDEX_READY
print("Switched index ready", flush=True)
send_handle = torch.distributed.broadcast(INDEX_READY, 0, async_op=True)
torch.distributed.barrier(get_train_group())
print(">>> Switched index ready", flush=True)
torch.cuda.synchronize()
send_handle = torch.distributed.broadcast(INDEX_READY, 0)
torch.distributed.barrier(get_data_parallel_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
else:
print(">>>>>>> moving right along", flush=True)
print("moving right along", flush=True)
loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
lr_scheduler)
skipped_iters += skipped_iter
iteration += 1
......@@ -443,7 +434,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration, False)
if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier()
torch.distributed.barrier(get_data_parallel_group())
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank()
print_rank_0('rank: {} | time: {} | exiting the program at '
......
......@@ -25,6 +25,7 @@ from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint
from megatron.data.samplers import DistributedBatchSampler
from megatron.mpu.initialize import get_data_parallel_group
from megatron.fp16 import FP16_Optimizer
......@@ -32,7 +33,7 @@ def reduce_losses(losses):
"""Reduce a tensor of losses across all GPUs."""
reduced_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses)
torch.distributed.all_reduce(reduced_losses, group=get_data_parallel_group())
reduced_losses = reduced_losses / torch.distributed.get_world_size()
return reduced_losses
......@@ -78,7 +79,7 @@ def check_adlr_autoresume_termination(iteration, model,
args = get_args()
autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy.
torch.distributed.barrier()
torch.distributed.barrier(get_data_parallel_group())
if autoresume.termination_requested():
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
......
......@@ -102,7 +102,7 @@ def forward_step(data_iterator, model):
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss = reduce_losses([lm_loss, retrieval_utility])
torch.cuda.synchronize()
# torch.cuda.synchronize()
return lm_loss, {'lm_loss': reduced_loss[0], 'retrieval_utility': reduced_loss[1]}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment