Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e59496bf
Commit
e59496bf
authored
May 24, 2020
by
Neel Kant
Browse files
Restructure Indexer classes
parent
e0a1caba
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
67 deletions
+49
-67
indexer.py
indexer.py
+49
-67
No files found.
indexer.py
View file @
e59496bf
...
...
@@ -92,39 +92,18 @@ def setup_realm_groups_and_vars():
set_data_parallel_group
(
get_index_group
())
class
AsyncIndexBuilder
(
object
):
def
__init__
(
self
,
rank
):
self
.
rank
=
rank
class
IndexBuilder
(
object
):
def
__init__
(
self
):
args
=
get_args
()
self
.
is_main_builder
=
self
.
rank
==
args
.
max_training_rank
self
.
main_builder_idx
=
args
.
max_training_rank
self
.
debug
=
args
.
debug
self
.
rank
=
args
.
rank
self
.
model
=
None
self
.
dataloader
=
None
self
.
block_data
=
None
self
.
load_attributes
()
global
INDEX_READY
INDEX_READY
=
get_index_ready
()
def
run_async
(
self
):
while
True
:
print
(
"Starting (again!)"
,
flush
=
True
)
self
.
build_and_save_index
()
self
.
send_index_ready_signal
()
while
INDEX_READY
==
1
:
print
(
"Waiting for new model checkpoint."
,
flush
=
True
)
time
.
sleep
(
5
)
self
.
load_attributes
()
self
.
is_main_builder
=
args
.
rank
==
0
def
load_attributes
(
self
):
try
:
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
True
)
except
:
print
(
">>>>> No realm chkpt available"
,
flush
=
True
)
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
False
)
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
False
)
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
block_data
=
BlockData
()
...
...
@@ -149,7 +128,7 @@ class AsyncIndexBuilder(object):
total
+=
block_indices
.
size
i
+=
1
if
i
%
5
00
==
0
:
if
i
%
10
00
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
self
.
debug
:
break
...
...
@@ -162,57 +141,60 @@ class AsyncIndexBuilder(object):
self
.
block_data
.
consolidate_shards_and_save
(
ignore_shard
=
self
.
rank
)
self
.
block_data
.
clear
()
def
send_index_ready_signal
(
self
):
class
AsyncIndexBuilder
(
IndexBuilder
):
def
__init__
(
self
,
rank
):
self
.
rank
=
rank
args
=
get_args
()
self
.
is_main_builder
=
self
.
rank
==
args
.
max_training_rank
self
.
main_builder_idx
=
args
.
max_training_rank
self
.
debug
=
args
.
debug
self
.
model
=
None
self
.
dataloader
=
None
self
.
block_data
=
None
self
.
load_attributes
()
global
INDEX_READY
if
self
.
is_main_builder
:
INDEX_READY
=
1
-
INDEX_READY
print
(
"Switched INDEX_READY"
,
flush
=
True
)
torch
.
cuda
.
synchronize
()
send_handle
=
dist
.
broadcast
(
INDEX_READY
,
self
.
main_builder_idx
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
INDEX_READY
=
get_index_ready
()
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
def
run_async
(
self
):
global
INDEX_READY
# synchronize for start
dist
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
while
True
:
print
(
"Starting (again!)"
,
flush
=
True
)
self
.
build_and_save_index
()
self
.
send_index_ready_signal
()
while
INDEX_READY
==
1
:
print
(
"Waiting for new model checkpoint."
,
flush
=
True
)
time
.
sleep
(
5
)
self
.
load_attributes
()
class
BasicIndexBuilder
(
object
):
def
__init__
(
self
):
args
=
get_args
()
self
.
rank
=
args
.
rank
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
False
)
def
load_attributes
(
self
):
try
:
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
True
)
except
:
print
(
">>>>> No realm chkpt available"
,
flush
=
True
)
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
False
)
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
block_data
=
BlockData
()
def
build_and_save_index
(
self
):
i
=
1
total
=
0
while
True
:
with
torch
.
no_grad
():
try
:
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_index_data
=
get_batch
(
self
.
dataloader
)
except
:
break
block_index_data
=
detach
(
block_index_data
)
block_indices
=
block_index_data
[:,
3
]
block_meta
=
block_index_data
[:,
:
3
]
block_logits
=
detach
(
self
.
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
))
self
.
block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_meta
)
total
+=
block_indices
.
size
i
+=
1
if
i
%
2000
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
def
send_index_ready_signal
(
self
):
global
INDEX_READY
if
self
.
is_main_builder
:
INDEX_READY
=
1
-
INDEX_READY
print
(
"Switched INDEX_READY"
,
flush
=
True
)
torch
.
cuda
.
synchronize
()
self
.
block_data
.
save_shard
(
self
.
rank
)
torch
.
distributed
.
barrier
()
del
self
.
model
# send handle
dist
.
broadcast
(
INDEX_READY
,
self
.
main_builder_idx
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
if
self
.
rank
==
0
:
self
.
block_data
.
consolidate_shards_and_save
(
ignore_shard
=
self
.
rank
)
self
.
block_data
.
clear
()
# recv handle
dist
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
()
)
torch
.
distributed
.
barrier
(
get_data_parallel_group
()
)
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
,
from_realm_chkpt
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment