Commit dca47cfb authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

debugging DPR

parent f64977fd
import sys import sys
import time
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -102,7 +103,12 @@ class IndexBuilder(object): ...@@ -102,7 +103,12 @@ class IndexBuilder(object):
while not hasattr(unwrapped_model, 'embed_text'): while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module unwrapped_model = unwrapped_model.module
counter = 0
start_time = time.time()
cur_time = start_time
while True: while True:
#start_time = time.time()
t1 = time.time()
try: try:
# batch also has query_tokens and query_pad_data # batch also has query_tokens and query_pad_data
row_id, context_tokens, context_mask, context_types, \ row_id, context_tokens, context_mask, context_types, \
...@@ -111,6 +117,8 @@ class IndexBuilder(object): ...@@ -111,6 +117,8 @@ class IndexBuilder(object):
except (StopIteration, IndexError): except (StopIteration, IndexError):
break break
#print_rank_0("get batch time {}".format(cur_time - time.time()))
t2 = time.time()
# TODO: can we add with torch.no_grad() to reduce memory usage # TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData # detach, separate fields and add to BlockData
assert context_mask.dtype == torch.bool assert context_mask.dtype == torch.bool
...@@ -120,10 +128,18 @@ class IndexBuilder(object): ...@@ -120,10 +128,18 @@ class IndexBuilder(object):
context_logits = detach(context_logits) context_logits = detach(context_logits)
row_id = detach(row_id) row_id = detach(row_id)
#print_rank_0("embed text {}".format(cur_time - time.time()))
t3 = time.time()
self.evidence_embedder_obj.add_block_data(row_id, context_logits) self.evidence_embedder_obj.add_block_data(row_id, context_logits)
self.track_and_report_progress(batch_size=len(row_id)) self.track_and_report_progress(batch_size=len(row_id))
#print_rank_0("add block time {}".format(cur_time - time.time()))
t4 = time.time()
counter += 1
if counter % 1000 == 0:
print_rank_0("total time {} 1000 iter time {}".format(time.time() - start_time, time.time() - cur_time))
print_rank_0("breakdown batch {} model {} block {}".format(t2 - t1, t3 - t2, t4 -t3))
cur_time = time.time()
# This process signals to finalize its shard and then synchronize with # This process signals to finalize its shard and then synchronize with
# the other processes # the other processes
self.evidence_embedder_obj.save_shard() self.evidence_embedder_obj.save_shard()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment