Commit dca47cfb authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

debugging DPR

parent f64977fd
import sys
import time
import torch
import torch.distributed as dist
......@@ -102,7 +103,12 @@ class IndexBuilder(object):
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
counter = 0
start_time = time.time()
cur_time = start_time
while True:
#start_time = time.time()
t1 = time.time()
try:
# batch also has query_tokens and query_pad_data
row_id, context_tokens, context_mask, context_types, \
......@@ -111,6 +117,8 @@ class IndexBuilder(object):
except (StopIteration, IndexError):
break
#print_rank_0("get batch time {}".format(cur_time - time.time()))
t2 = time.time()
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
assert context_mask.dtype == torch.bool
......@@ -120,10 +128,18 @@ class IndexBuilder(object):
context_logits = detach(context_logits)
row_id = detach(row_id)
#print_rank_0("embed text {}".format(cur_time - time.time()))
t3 = time.time()
self.evidence_embedder_obj.add_block_data(row_id, context_logits)
self.track_and_report_progress(batch_size=len(row_id))
#print_rank_0("add block time {}".format(cur_time - time.time()))
t4 = time.time()
counter += 1
if counter % 1000 == 0:
print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
cur_time = time.time()
# This process signals to finalize its shard and then synchronize with
# the other processes
self.evidence_embedder_obj.save_shard()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment