Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2f7d666c
Commit
2f7d666c
authored
May 14, 2020
by
Neel Kant
Browse files
Add retrieval utility and autoresume for indexer
parent
9b9b8e01
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
53 additions
and
14 deletions
+53
-14
indexer.py
indexer.py
+18
-11
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+4
-0
megatron/model/realm_model.py
megatron/model/realm_model.py
+2
-0
pretrain_realm.py
pretrain_realm.py
+27
-3
No files found.
indexer.py
View file @
2f7d666c
import
os
import
os
import
sys
import
time
import
time
import
torch
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
get_args
,
get_adlr_autoresume
,
print_rank_0
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.realm_dataset
import
ICTDataset
from
megatron.data.realm_dataset
import
ICTDataset
from
megatron.data.realm_index
import
detach
,
BlockData
,
RandProjectionLSH
Index
from
megatron.data.realm_index
import
detach
,
BlockData
,
FaissMIPS
Index
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
REALMRetriever
from
megatron.model
import
REALMRetriever
from
megatron.training
import
get_model
from
megatron.training
import
get_model
from
megatron.utils
import
check_adlr_autoresume_termination
from
pretrain_bert_ict
import
get_batch
,
model_provider
from
pretrain_bert_ict
import
get_batch
,
model_provider
from
indexer_utils
import
set_index_com_file_ready
,
set_model_com_file_not_ready
,
check_model_com_file_ready
from
indexer_utils
import
set_index_com_file_ready
,
set_model_com_file_not_ready
,
check_model_com_file_ready
...
@@ -40,14 +42,14 @@ def test_retriever():
...
@@ -40,14 +42,14 @@ def test_retriever():
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
args
=
get_args
()
model
=
load_ict_checkpoint
(
only_block_model
=
True
)
model
=
load_ict_checkpoint
()
model
.
eval
()
model
.
eval
()
dataset
=
get_ict_dataset
()
dataset
=
get_ict_dataset
()
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
mips_index
=
FaissMIPSIndex
(
'flat_ip'
,
128
)
mips_index
=
FaissMIPSIndex
(
'flat_ip'
,
128
)
mips_index
.
add_block_embed_data
(
block_data
)
mips_index
.
add_block_embed_data
(
block_data
)
retriever
=
REALMRetriever
(
model
,
dataset
,
mips_index
,
top_k
=
5
)
retriever
=
REALMRetriever
(
model
,
dataset
,
block_data
,
mips_index
,
top_k
=
5
)
strs
=
[
strs
=
[
"The last monarch from the house of windsor"
,
"The last monarch from the house of windsor"
,
...
@@ -71,7 +73,6 @@ def main():
...
@@ -71,7 +73,6 @@ def main():
dataset
=
get_ict_dataset
()
dataset
=
get_ict_dataset
()
data_iter
=
iter
(
get_one_epoch_dataloader
(
dataset
))
data_iter
=
iter
(
get_one_epoch_dataloader
(
dataset
))
all_block_data
=
BlockData
()
all_block_data
=
BlockData
()
hashed_index
=
RandProjectionLSHIndex
(
embed_size
=
128
,
num_buckets
=
32
,
whiten
=
True
)
i
=
1
i
=
1
total
=
0
total
=
0
...
@@ -103,18 +104,24 @@ def main():
...
@@ -103,18 +104,24 @@ def main():
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
all_block_data
.
consolidate_shards_and_save
()
all_block_data
.
consolidate_shards_and_save
()
hashed_index
.
hash_whitened_block_embeds
(
all_block_data
)
hashed_index
.
save_to_file
()
else
:
else
:
all_block_data
.
clear
()
all_block_data
.
clear
()
ran_once
=
True
ran_once
=
True
set_index_com_file_ready
()
set_index_com_file_ready
()
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
while
not
check_model_com_file_ready
():
if
args
.
async_indexer
:
time
.
sleep
(
5
)
while
not
check_model_com_file_ready
():
time
.
sleep
(
5
)
set_model_com_file_not_ready
()
autoresume
=
get_adlr_autoresume
()
if
autoresume
.
termination_requested
():
print_rank_0
(
">>> autoresume termination request found!"
)
if
torch
.
distributed
.
get_rank
()
==
0
:
autoresume
.
request_resume
()
print_rank_0
(
">>> training terminated. Returning"
)
sys
.
exit
(
0
)
set_model_com_file_not_ready
()
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
,
from_realm_chkpt
=
False
):
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
,
from_realm_chkpt
=
False
):
...
...
megatron/arguments.py
View file @
2f7d666c
...
@@ -348,6 +348,8 @@ def _add_data_args(parser):
...
@@ -348,6 +348,8 @@ def _add_data_args(parser):
help
=
'Path to pickled data structure for efficient block indexing'
)
help
=
'Path to pickled data structure for efficient block indexing'
)
group
.
add_argument
(
'--block-top-k'
,
type
=
int
,
default
=
5
,
group
.
add_argument
(
'--block-top-k'
,
type
=
int
,
default
=
5
,
help
=
'Number of blocks to use as top-k during retrieval'
)
help
=
'Number of blocks to use as top-k during retrieval'
)
group
.
add_argument
(
'--async-indexer'
,
action
=
'store_true'
,
help
=
'Whether the indexer job is running asynchronously with a trainer job'
)
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
help
=
'Comma-separated list of proportions for training,'
help
=
'Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
' validation, and test split. For example the split '
...
...
megatron/data/realm_dataset_utils.py
View file @
2f7d666c
...
@@ -93,6 +93,8 @@ def salient_span_mask(tokens, mask_id):
...
@@ -93,6 +93,8 @@ def salient_span_mask(tokens, mask_id):
Note: Tokens here are vocab ids and not text tokens."""
Note: Tokens here are vocab ids and not text tokens."""
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
tokens_str
=
join_str_list
(
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
tokens
))
tokens_str
=
join_str_list
(
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
tokens
))
print
(
"-"
*
100
)
print
(
"TOKEN STR
\n
"
,
tokens_str
)
# need to get all named entities
# need to get all named entities
entities
=
SPACY_NER
(
tokens_str
).
ents
entities
=
SPACY_NER
(
tokens_str
).
ents
...
@@ -101,6 +103,7 @@ def salient_span_mask(tokens, mask_id):
...
@@ -101,6 +103,7 @@ def salient_span_mask(tokens, mask_id):
return
None
return
None
entity_idx
=
np
.
random
.
randint
(
0
,
len
(
entities
))
entity_idx
=
np
.
random
.
randint
(
0
,
len
(
entities
))
selected_entity
=
entities
[
entity_idx
]
selected_entity
=
entities
[
entity_idx
]
print
(
"SELECTED ENTITY
\n
"
,
selected_entity
.
text
)
token_pos_map
=
id_to_str_pos_map
(
tokens
,
tokenizer
)
token_pos_map
=
id_to_str_pos_map
(
tokens
,
tokenizer
)
mask_start
=
mask_end
=
0
mask_start
=
mask_end
=
0
...
@@ -118,6 +121,7 @@ def salient_span_mask(tokens, mask_id):
...
@@ -118,6 +121,7 @@ def salient_span_mask(tokens, mask_id):
for
id_idx
in
masked_positions
:
for
id_idx
in
masked_positions
:
labels
.
append
(
tokens
[
id_idx
])
labels
.
append
(
tokens
[
id_idx
])
output_tokens
[
id_idx
]
=
mask_id
output_tokens
[
id_idx
]
=
mask_id
print
(
"OUTPUT
\n
"
,
join_str_list
(
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
output_tokens
)))
return
output_tokens
,
masked_positions
,
labels
return
output_tokens
,
masked_positions
,
labels
...
...
megatron/model/realm_model.py
View file @
2f7d666c
...
@@ -192,6 +192,8 @@ class REALMRetriever(MegatronModule):
...
@@ -192,6 +192,8 @@ class REALMRetriever(MegatronModule):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
hasattr
(
self
.
ict_model
,
'module'
):
if
hasattr
(
self
.
ict_model
,
'module'
):
true_model
=
self
.
ict_model
.
module
true_model
=
self
.
ict_model
.
module
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
else
:
else
:
true_model
=
self
.
ict_model
true_model
=
self
.
ict_model
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
...
...
pretrain_realm.py
View file @
2f7d666c
...
@@ -87,8 +87,9 @@ def forward_step(data_iterator, model):
...
@@ -87,8 +87,9 @@ def forward_step(data_iterator, model):
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward model.
# Forward model.
# TODO: MAKE SURE PAD IS NOT 1 - PAD
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
,
query_block_indices
)
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
,
query_block_indices
)
with
torch
.
no_grad
():
retrieval_utility
=
get_retrieval_utility
(
lm_logits
,
labels
,
loss_mask
)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
...
@@ -99,9 +100,32 @@ def forward_step(data_iterator, model):
...
@@ -99,9 +100,32 @@ def forward_step(data_iterator, model):
lm_loss
=
torch
.
sum
(
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
reduced_loss
=
reduce_losses
([
lm_loss
])
reduced_loss
=
reduce_losses
([
lm_loss
,
retrieval_utility
])
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
]}
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
],
'retrieval_utility'
:
reduced_loss
[
1
]}
def
get_retrieval_utility
(
lm_logits
,
labels
,
loss_mask
):
"""log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
null_block_lm_logits
=
lm_logits
[:,
-
1
,
:,
:]
null_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
null_block_lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
null_block_loss
=
torch
.
sum
(
null_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_losses
=
[]
for
block_num
in
range
(
lm_logits
.
shape
[
1
]
-
1
):
retrieved_block_lm_logits
=
lm_logits
[:,
block_num
,
:,
:]
retrieved_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
retrieved_block_lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
retrieved_block_loss
=
torch
.
sum
(
retrieved_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_losses
.
append
(
retrieved_block_loss
)
avg_retrieved_block_loss
=
torch
.
sum
(
retrieved_block_losses
)
/
(
lm_logits
.
shape
[
1
]
-
1
)
retrieval_utility
=
null_block_loss
-
avg_retrieved_block_loss
return
retrieval_utility
def
qa_forward_step
(
data_iterator
,
model
):
def
qa_forward_step
(
data_iterator
,
model
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment