Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
dca47cfb
Commit
dca47cfb
authored
May 14, 2021
by
Mostofa Patwary
Browse files
debugging DPR
parent
f64977fd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
2 deletions
+18
-2
megatron/indexer.py
megatron/indexer.py
+18
-2
No files found.
megatron/indexer.py
View file @
dca47cfb
import
sys
import
sys
import
time
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -102,7 +103,12 @@ class IndexBuilder(object):
...
@@ -102,7 +103,12 @@ class IndexBuilder(object):
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
unwrapped_model
=
unwrapped_model
.
module
unwrapped_model
=
unwrapped_model
.
module
counter
=
0
start_time
=
time
.
time
()
cur_time
=
start_time
while
True
:
while
True
:
#start_time = time.time()
t1
=
time
.
time
()
try
:
try
:
# batch also has query_tokens and query_pad_data
# batch also has query_tokens and query_pad_data
row_id
,
context_tokens
,
context_mask
,
context_types
,
\
row_id
,
context_tokens
,
context_mask
,
context_types
,
\
...
@@ -111,6 +117,8 @@ class IndexBuilder(object):
...
@@ -111,6 +117,8 @@ class IndexBuilder(object):
except
(
StopIteration
,
IndexError
):
except
(
StopIteration
,
IndexError
):
break
break
#print_rank_0("get batch time {}".format(cur_time - time.time()))
t2
=
time
.
time
()
# TODO: can we add with torch.no_grad() to reduce memory usage
# TODO: can we add with torch.no_grad() to reduce memory usage
# detach, separate fields and add to BlockData
# detach, separate fields and add to BlockData
assert
context_mask
.
dtype
==
torch
.
bool
assert
context_mask
.
dtype
==
torch
.
bool
...
@@ -120,10 +128,18 @@ class IndexBuilder(object):
...
@@ -120,10 +128,18 @@ class IndexBuilder(object):
context_logits
=
detach
(
context_logits
)
context_logits
=
detach
(
context_logits
)
row_id
=
detach
(
row_id
)
row_id
=
detach
(
row_id
)
#print_rank_0("embed text {}".format(cur_time - time.time()))
t3
=
time
.
time
()
self
.
evidence_embedder_obj
.
add_block_data
(
row_id
,
context_logits
)
self
.
evidence_embedder_obj
.
add_block_data
(
row_id
,
context_logits
)
self
.
track_and_report_progress
(
batch_size
=
len
(
row_id
))
self
.
track_and_report_progress
(
batch_size
=
len
(
row_id
))
#print_rank_0("add block time {}".format(cur_time - time.time()))
t4
=
time
.
time
()
counter
+=
1
if
counter
%
1000
==
0
:
print_rank_0
(
"total time {} 1000 iter time {}"
.
format
(
time
.
time
()
-
start_time
,
time
.
time
()
-
cur_time
))
print_rank_0
(
"breakdown batch {} model {} block {}"
.
format
(
t2
-
t1
,
t3
-
t2
,
t4
-
t3
))
cur_time
=
time
.
time
()
# This process signals to finalize its shard and then synchronize with
# This process signals to finalize its shard and then synchronize with
# the other processes
# the other processes
self
.
evidence_embedder_obj
.
save_shard
()
self
.
evidence_embedder_obj
.
save_shard
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment