Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2f7d666c
Commit
2f7d666c
authored
May 14, 2020
by
Neel Kant
Browse files
Add retrieval utility and autoresume for indexer
parent
9b9b8e01
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
53 additions
and
14 deletions
+53
-14
indexer.py
indexer.py
+18
-11
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+4
-0
megatron/model/realm_model.py
megatron/model/realm_model.py
+2
-0
pretrain_realm.py
pretrain_realm.py
+27
-3
No files found.
indexer.py
View file @
2f7d666c
import
os
import
sys
import
time
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
get_args
,
get_adlr_autoresume
,
print_rank_0
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.realm_dataset
import
ICTDataset
from
megatron.data.realm_index
import
detach
,
BlockData
,
RandProjectionLSH
Index
from
megatron.data.realm_index
import
detach
,
BlockData
,
FaissMIPS
Index
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
REALMRetriever
from
megatron.training
import
get_model
from
megatron.utils
import
check_adlr_autoresume_termination
from
pretrain_bert_ict
import
get_batch
,
model_provider
from
indexer_utils
import
set_index_com_file_ready
,
set_model_com_file_not_ready
,
check_model_com_file_ready
...
...
@@ -40,14 +42,14 @@ def test_retriever():
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
model
=
load_ict_checkpoint
(
only_block_model
=
True
)
model
=
load_ict_checkpoint
()
model
.
eval
()
dataset
=
get_ict_dataset
()
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
mips_index
=
FaissMIPSIndex
(
'flat_ip'
,
128
)
mips_index
.
add_block_embed_data
(
block_data
)
retriever
=
REALMRetriever
(
model
,
dataset
,
mips_index
,
top_k
=
5
)
retriever
=
REALMRetriever
(
model
,
dataset
,
block_data
,
mips_index
,
top_k
=
5
)
strs
=
[
"The last monarch from the house of windsor"
,
...
...
@@ -71,7 +73,6 @@ def main():
dataset
=
get_ict_dataset
()
data_iter
=
iter
(
get_one_epoch_dataloader
(
dataset
))
all_block_data
=
BlockData
()
hashed_index
=
RandProjectionLSHIndex
(
embed_size
=
128
,
num_buckets
=
32
,
whiten
=
True
)
i
=
1
total
=
0
...
...
@@ -103,18 +104,24 @@ def main():
if
args
.
rank
==
0
:
all_block_data
.
consolidate_shards_and_save
()
hashed_index
.
hash_whitened_block_embeds
(
all_block_data
)
hashed_index
.
save_to_file
()
else
:
all_block_data
.
clear
()
ran_once
=
True
set_index_com_file_ready
()
torch
.
distributed
.
barrier
()
while
not
check_model_com_file_ready
():
time
.
sleep
(
5
)
set_model_com_file_not_ready
()
if
args
.
async_indexer
:
while
not
check_model_com_file_ready
():
time
.
sleep
(
5
)
autoresume
=
get_adlr_autoresume
()
if
autoresume
.
termination_requested
():
print_rank_0
(
">>> autoresume termination request found!"
)
if
torch
.
distributed
.
get_rank
()
==
0
:
autoresume
.
request_resume
()
print_rank_0
(
">>> training terminated. Returning"
)
sys
.
exit
(
0
)
set_model_com_file_not_ready
()
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
,
from_realm_chkpt
=
False
):
...
...
megatron/arguments.py
View file @
2f7d666c
...
...
@@ -348,6 +348,8 @@ def _add_data_args(parser):
help
=
'Path to pickled data structure for efficient block indexing'
)
group
.
add_argument
(
'--block-top-k'
,
type
=
int
,
default
=
5
,
help
=
'Number of blocks to use as top-k during retrieval'
)
group
.
add_argument
(
'--async-indexer'
,
action
=
'store_true'
,
help
=
'Whether the indexer job is running asynchronously with a trainer job'
)
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
help
=
'Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
...
...
megatron/data/realm_dataset_utils.py
View file @
2f7d666c
...
...
@@ -93,6 +93,8 @@ def salient_span_mask(tokens, mask_id):
Note: Tokens here are vocab ids and not text tokens."""
tokenizer
=
get_tokenizer
()
tokens_str
=
join_str_list
(
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
tokens
))
print
(
"-"
*
100
)
print
(
"TOKEN STR
\n
"
,
tokens_str
)
# need to get all named entities
entities
=
SPACY_NER
(
tokens_str
).
ents
...
...
@@ -101,6 +103,7 @@ def salient_span_mask(tokens, mask_id):
return
None
entity_idx
=
np
.
random
.
randint
(
0
,
len
(
entities
))
selected_entity
=
entities
[
entity_idx
]
print
(
"SELECTED ENTITY
\n
"
,
selected_entity
.
text
)
token_pos_map
=
id_to_str_pos_map
(
tokens
,
tokenizer
)
mask_start
=
mask_end
=
0
...
...
@@ -118,6 +121,7 @@ def salient_span_mask(tokens, mask_id):
for
id_idx
in
masked_positions
:
labels
.
append
(
tokens
[
id_idx
])
output_tokens
[
id_idx
]
=
mask_id
print
(
"OUTPUT
\n
"
,
join_str_list
(
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
output_tokens
)))
return
output_tokens
,
masked_positions
,
labels
...
...
megatron/model/realm_model.py
View file @
2f7d666c
...
...
@@ -192,6 +192,8 @@ class REALMRetriever(MegatronModule):
with
torch
.
no_grad
():
if
hasattr
(
self
.
ict_model
,
'module'
):
true_model
=
self
.
ict_model
.
module
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
else
:
true_model
=
self
.
ict_model
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
...
...
pretrain_realm.py
View file @
2f7d666c
...
...
@@ -87,8 +87,9 @@ def forward_step(data_iterator, model):
timers
(
'batch generator'
).
stop
()
# Forward model.
# TODO: MAKE SURE PAD IS NOT 1 - PAD
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
,
query_block_indices
)
with
torch
.
no_grad
():
retrieval_utility
=
get_retrieval_utility
(
lm_logits
,
labels
,
loss_mask
)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
...
...
@@ -99,9 +100,32 @@ def forward_step(data_iterator, model):
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
reduced_loss
=
reduce_losses
([
lm_loss
])
reduced_loss
=
reduce_losses
([
lm_loss
,
retrieval_utility
])
torch
.
cuda
.
synchronize
()
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
]}
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
],
'retrieval_utility'
:
reduced_loss
[
1
]}
def
get_retrieval_utility
(
lm_logits
,
labels
,
loss_mask
):
"""log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
null_block_lm_logits
=
lm_logits
[:,
-
1
,
:,
:]
null_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
null_block_lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
null_block_loss
=
torch
.
sum
(
null_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_losses
=
[]
for
block_num
in
range
(
lm_logits
.
shape
[
1
]
-
1
):
retrieved_block_lm_logits
=
lm_logits
[:,
block_num
,
:,
:]
retrieved_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
retrieved_block_lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
retrieved_block_loss
=
torch
.
sum
(
retrieved_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_losses
.
append
(
retrieved_block_loss
)
avg_retrieved_block_loss
=
torch
.
sum
(
retrieved_block_losses
)
/
(
lm_logits
.
shape
[
1
]
-
1
)
retrieval_utility
=
null_block_loss
-
avg_retrieved_block_loss
return
retrieval_utility
def
qa_forward_step
(
data_iterator
,
model
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment