"vscode:/vscode.git/clone" did not exist on "3d4a2eadd3c1481447b8e885018ed24341ea91a5"
Commit a670b6c9 authored by Neel Kant's avatar Neel Kant
Browse files

Async works for total 8 GPU, indexer debug mode

parent 5684f904
...@@ -16,7 +16,7 @@ from megatron.data.samplers import DistributedBatchSampler ...@@ -16,7 +16,7 @@ from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
from megatron.model import REALMRetriever from megatron.model import REALMRetriever
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group from megatron.mpu.initialize import get_index_ready, get_index_group, get_train_group, get_data_parallel_group
from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group, init_realm_groups from megatron.mpu.initialize import set_data_parallel_group, set_model_parallel_group, init_realm_groups
from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard from megatron.initialize import init_distributed, _init_autoresume, _set_random_seed, _write_args_to_tensorboard
from megatron.training import get_model from megatron.training import get_model
...@@ -67,12 +67,12 @@ def initialize_and_run_async_megatron(extra_args_provider=None, args_defaults={} ...@@ -67,12 +67,12 @@ def initialize_and_run_async_megatron(extra_args_provider=None, args_defaults={}
torch.distributed.barrier() torch.distributed.barrier()
if args.rank < args.max_training_rank: if args.rank < args.max_training_rank:
torch.distributed.barrier(get_train_group()) torch.distributed.barrier(get_data_parallel_group())
pprint("All trainers ready.") pprint("All trainers ready.")
return return
else: else:
runner = AsyncIndexBuilder(args.rank) runner = AsyncIndexBuilder(args.rank)
torch.distributed.barrier(get_index_group()) torch.distributed.barrier(get_data_parallel_group())
pprint("All indexers ready.") pprint("All indexers ready.")
runner.run_async() runner.run_async()
...@@ -123,6 +123,7 @@ class AsyncIndexBuilder(object): ...@@ -123,6 +123,7 @@ class AsyncIndexBuilder(object):
try: try:
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True) self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=True)
except: except:
print(">>>>> No realm chkpt available", flush=True)
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False) self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
self.model.eval() self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset())) self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
...@@ -148,7 +149,7 @@ class AsyncIndexBuilder(object): ...@@ -148,7 +149,7 @@ class AsyncIndexBuilder(object):
total += block_indices.size total += block_indices.size
i += 1 i += 1
if i % 10 == 0: if i % 500 == 0:
print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True) print('Batch {:10d} | Total {:10d}'.format(i, total), flush=True)
if self.debug: if self.debug:
break break
...@@ -162,7 +163,7 @@ class AsyncIndexBuilder(object): ...@@ -162,7 +163,7 @@ class AsyncIndexBuilder(object):
sys.exit(0) sys.exit(0)
self.block_data.save_shard(self.rank) self.block_data.save_shard(self.rank)
torch.distributed.barrier() torch.distributed.barrier(get_data_parallel_group())
del self.model del self.model
if self.is_main_builder: if self.is_main_builder:
...@@ -174,12 +175,11 @@ class AsyncIndexBuilder(object): ...@@ -174,12 +175,11 @@ class AsyncIndexBuilder(object):
if self.is_main_builder: if self.is_main_builder:
INDEX_READY = 1 - INDEX_READY INDEX_READY = 1 - INDEX_READY
print("Switched INDEX_READY", flush=True) print("Switched INDEX_READY", flush=True)
import time torch.cuda.synchronize()
print(time.ctime(time.time()), flush=True)
send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, async_op=True) send_handle = dist.broadcast(INDEX_READY, self.main_builder_idx, async_op=True)
torch.distributed.barrier(get_index_group()) torch.distributed.barrier(get_data_parallel_group())
recv_handle = dist.broadcast(INDEX_READY, 0, async_op=True) recv_handle = dist.broadcast(INDEX_READY, 0)
class BasicIndexBuilder(object): class BasicIndexBuilder(object):
...@@ -236,7 +236,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad= ...@@ -236,7 +236,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
with open(tracker_filename, 'r') as f: with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip()) iteration = int(f.read().strip())
assert iteration > 0 # assert iteration > 0
checkpoint_name = get_checkpoint_name(load_path, iteration, False) checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print('global rank {} is loading checkpoint {}'.format(
...@@ -245,6 +245,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad= ...@@ -245,6 +245,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model'] ict_state_dict = state_dict['model']
if from_realm_chkpt: if from_realm_chkpt:
print(">>>> Attempting to get ict state dict from realm", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model'] ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model: if only_query_model:
...@@ -256,7 +257,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad= ...@@ -256,7 +257,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
model.load_state_dict(ict_state_dict) model.load_state_dict(ict_state_dict)
else: else:
model.load_state_dict(ict_state_dict) model.load_state_dict(ict_state_dict)
torch.distributed.barrier() torch.distributed.barrier(get_data_parallel_group())
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name)) print(' successfully loaded {}'.format(checkpoint_name))
...@@ -290,9 +291,7 @@ def get_one_epoch_dataloader(dataset): ...@@ -290,9 +291,7 @@ def get_one_epoch_dataloader(dataset):
args = get_args() args = get_args()
world_size = mpu.get_data_parallel_world_size() world_size = mpu.get_data_parallel_world_size()
print(world_size, flush=True)
rank = mpu.get_data_parallel_rank() rank = mpu.get_data_parallel_rank()
print(rank, flush=True)
global_batch_size = args.batch_size * world_size global_batch_size = args.batch_size * world_size
num_workers = args.num_workers num_workers = args.num_workers
......
...@@ -24,7 +24,7 @@ import torch ...@@ -24,7 +24,7 @@ import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu from megatron import mpu
from megatron.mpu.initialize import get_train_group from megatron.mpu.initialize import get_train_group, get_data_parallel_group
from megatron import get_args from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
...@@ -119,14 +119,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -119,14 +119,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print(' successfully saved {}'.format(checkpoint_name)) print(' successfully saved {}'.format(checkpoint_name))
# Wait so everyone is done (necessary) # Wait so everyone is done (necessary)
torch.distributed.barrier(get_train_group()) torch.distributed.barrier(get_data_parallel_group())
# And update the latest iteration # And update the latest iteration
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save) tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f: with open(tracker_filename, 'w') as f:
f.write(str(iteration)) f.write(str(iteration))
# Wait so everyone is done (not necessary) # Wait so everyone is done (not necessary)
torch.distributed.barrier(get_train_group()) torch.distributed.barrier(get_data_parallel_group())
def load_checkpoint(model, optimizer, lr_scheduler): def load_checkpoint(model, optimizer, lr_scheduler):
...@@ -243,7 +243,7 @@ def load_checkpoint(model, optimizer, lr_scheduler): ...@@ -243,7 +243,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
'exiting ...'.format(checkpoint_name)) 'exiting ...'.format(checkpoint_name))
sys.exit() sys.exit()
torch.distributed.barrier() # torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name)) print(' successfully loaded {}'.format(checkpoint_name))
......
...@@ -164,14 +164,14 @@ class _Timer: ...@@ -164,14 +164,14 @@ class _Timer:
def start(self): def start(self):
"""Start the timer.""" """Start the timer."""
assert not self.started_, 'timer has already been started' assert not self.started_, 'timer has already been started'
torch.cuda.synchronize() # torch.cuda.synchronize()
self.start_time = time.time() self.start_time = time.time()
self.started_ = True self.started_ = True
def stop(self): def stop(self):
"""Stop the timer.""" """Stop the timer."""
assert self.started_, 'timer is not started' assert self.started_, 'timer is not started'
torch.cuda.synchronize() # torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time) self.elapsed_ += (time.time() - self.start_time)
self.started_ = False self.started_ = False
......
...@@ -56,7 +56,7 @@ class DistributedDataParallel(MegatronModule): ...@@ -56,7 +56,7 @@ class DistributedDataParallel(MegatronModule):
if not no_scale and not reduce_after: if not no_scale and not reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group) coalesced /= dist.get_world_size(group=self.data_parallel_group)
dist.all_reduce(coalesced, group=self.data_parallel_group) dist.all_reduce(coalesced, group=self.data_parallel_group)
torch.cuda.synchronize() # torch.cuda.synchronize()
if not no_scale and reduce_after: if not no_scale and reduce_after:
coalesced /= dist.get_world_size(group=self.data_parallel_group) coalesced /= dist.get_world_size(group=self.data_parallel_group)
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
......
...@@ -199,7 +199,9 @@ class REALMRetriever(MegatronModule): ...@@ -199,7 +199,9 @@ class REALMRetriever(MegatronModule):
true_model = true_model.module true_model = true_model.module
else: else:
true_model = self.ict_model true_model = self.ict_model
query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask)) # print("true model: ", true_model, flush=True)
query_embeds = detach(self.ict_model(query_tokens, query_pad_mask, None, None, only_query=True))
_, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False) _, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
all_topk_tokens, all_topk_pad_masks = [], [] all_topk_tokens, all_topk_pad_masks = [], []
...@@ -268,7 +270,6 @@ class ICTBertModel(MegatronModule): ...@@ -268,7 +270,6 @@ class ICTBertModel(MegatronModule):
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask, only_query=False, only_block=False): def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask, only_query=False, only_block=False):
"""Run a forward pass for each of the models and compute the similarity scores.""" """Run a forward pass for each of the models and compute the similarity scores."""
if only_query: if only_query:
return self.embed_query(query_tokens, query_attention_mask) return self.embed_query(query_tokens, query_attention_mask)
......
...@@ -109,10 +109,8 @@ def set_model_parallel_group(group): ...@@ -109,10 +109,8 @@ def set_model_parallel_group(group):
def get_data_parallel_group(): def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to.""" """Get the data parallel group the caller rank belongs to."""
#print(">>> yeah this function works.")
assert _DATA_PARALLEL_GROUP is not None, \ assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized' 'data parallel group is not initialized'
#print(_DATA_PARALLEL_GROUP)
return _DATA_PARALLEL_GROUP return _DATA_PARALLEL_GROUP
......
...@@ -36,7 +36,7 @@ from megatron.initialize import initialize_megatron ...@@ -36,7 +36,7 @@ from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from megatron.mpu.initialize import get_index_ready, get_train_group from megatron.mpu.initialize import get_index_ready, get_train_group, get_data_parallel_group
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader from megatron.utils import make_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
...@@ -236,7 +236,7 @@ def backward_step(optimizer, model, loss): ...@@ -236,7 +236,7 @@ def backward_step(optimizer, model, loss):
"""Backward step.""" """Backward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
torch.cuda.synchronize() # torch.cuda.synchronize()
# Backward pass. # Backward pass.
optimizer.zero_grad(set_grads_to_None=True) optimizer.zero_grad(set_grads_to_None=True)
...@@ -373,19 +373,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -373,19 +373,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers('interval time').start() timers('interval time').start()
report_memory_flag = True report_memory_flag = True
import time
print(">>> going to sleep", flush=True)
time.sleep(10)
print(">>> woke from sleep", flush=True)
print(time.ctime(time.time()), flush=True)
global INDEX_READY global INDEX_READY
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True) recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
print(">>>>>>>> Created recv handle", flush=True)
while iteration < args.train_iters: while iteration < args.train_iters:
print("INDEX READY: ", INDEX_READY)
if args.max_training_rank is not None and INDEX_READY == 1: if args.max_training_rank is not None and INDEX_READY == 1:
print(">>>>>>> entering the good stuff", flush=True)
true_model = model true_model = model
if hasattr(true_model, 'module'): if hasattr(true_model, 'module'):
true_model = true_model.module true_model = true_model.module
...@@ -393,24 +384,24 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -393,24 +384,24 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
true_model = true_model.module true_model = true_model.module
print(">>>>>>> starting to reload index", flush=True) print(">>>>>>> starting to reload index", flush=True)
true_model.retriever.reload_index() true_model.retriever.reload_index()
print(">>>>>>> starting to save checkpoint", flush=True)
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
print(">>>>>>> saved checkpoint", flush=True)
if args.rank == 0: if args.rank == 0:
INDEX_READY = 1 - INDEX_READY INDEX_READY = 1 - INDEX_READY
print("Switched index ready", flush=True) print(">>> Switched index ready", flush=True)
send_handle = torch.distributed.broadcast(INDEX_READY, 0, async_op=True) torch.cuda.synchronize()
torch.distributed.barrier(get_train_group()) send_handle = torch.distributed.broadcast(INDEX_READY, 0)
torch.distributed.barrier(get_data_parallel_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True) recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, async_op=True)
else: else:
print(">>>>>>> moving right along", flush=True) print("moving right along", flush=True)
loss_dict, skipped_iter = train_step(forward_step_func, loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
lr_scheduler) lr_scheduler)
skipped_iters += skipped_iter skipped_iters += skipped_iter
iteration += 1 iteration += 1
...@@ -443,7 +434,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -443,7 +434,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration, False) iteration, False)
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier() torch.distributed.barrier(get_data_parallel_group())
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S') time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank() rank = torch.distributed.get_rank()
print_rank_0('rank: {} | time: {} | exiting the program at ' print_rank_0('rank: {} | time: {} | exiting the program at '
......
...@@ -25,6 +25,7 @@ from megatron import mpu ...@@ -25,6 +25,7 @@ from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.data.samplers import DistributedBatchSampler from megatron.data.samplers import DistributedBatchSampler
from megatron.mpu.initialize import get_data_parallel_group
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
...@@ -32,7 +33,7 @@ def reduce_losses(losses): ...@@ -32,7 +33,7 @@ def reduce_losses(losses):
"""Reduce a tensor of losses across all GPUs.""" """Reduce a tensor of losses across all GPUs."""
reduced_losses = torch.cat( reduced_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses]) [loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(reduced_losses) torch.distributed.all_reduce(reduced_losses, group=get_data_parallel_group())
reduced_losses = reduced_losses / torch.distributed.get_world_size() reduced_losses = reduced_losses / torch.distributed.get_world_size()
return reduced_losses return reduced_losses
...@@ -78,7 +79,7 @@ def check_adlr_autoresume_termination(iteration, model, ...@@ -78,7 +79,7 @@ def check_adlr_autoresume_termination(iteration, model,
args = get_args() args = get_args()
autoresume = get_adlr_autoresume() autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy. # Add barrier to ensure consistnecy.
torch.distributed.barrier() torch.distributed.barrier(get_data_parallel_group())
if autoresume.termination_requested(): if autoresume.termination_requested():
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
......
...@@ -102,7 +102,7 @@ def forward_step(data_iterator, model): ...@@ -102,7 +102,7 @@ def forward_step(data_iterator, model):
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss = reduce_losses([lm_loss, retrieval_utility]) reduced_loss = reduce_losses([lm_loss, retrieval_utility])
torch.cuda.synchronize() # torch.cuda.synchronize()
return lm_loss, {'lm_loss': reduced_loss[0], 'retrieval_utility': reduced_loss[1]} return lm_loss, {'lm_loss': reduced_loss[0], 'retrieval_utility': reduced_loss[1]}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment