Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c150c48b
Commit
c150c48b
authored
May 12, 2020
by
Neel Kant
Browse files
Debug and run salient span masking
parent
451d1325
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
29 additions
and
60 deletions
+29
-60
indexer.py
indexer.py
+2
-39
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+2
-1
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+17
-9
megatron/data/realm_index.py
megatron/data/realm_index.py
+7
-9
megatron/training.py
megatron/training.py
+1
-1
pretrain_realm.py
pretrain_realm.py
+0
-1
No files found.
indexer.py
View file @
c150c48b
...
...
@@ -15,6 +15,7 @@ from megatron.initialize import initialize_megatron
from
megatron.model
import
REALMRetriever
from
megatron.training
import
get_model
from
pretrain_bert_ict
import
get_batch
,
model_provider
from
indexer_utils
import
set_index_com_file_ready
,
set_model_com_file_not_ready
,
check_model_com_file_ready
# TODO re: main()
...
...
@@ -115,45 +116,6 @@ def main():
set_model_com_file_not_ready
()
INDEX_COM_FILE
=
'ready.index'
MODEL_COM_FILE
=
'ready.model'
def
set_index_com_file_not_ready
():
with
open
(
INDEX_COM_FILE
,
'w'
)
as
com_file
:
com_file
.
write
(
'0'
)
def
set_index_com_file_ready
():
with
open
(
INDEX_COM_FILE
,
'w'
)
as
com_file
:
com_file
.
write
(
'1'
)
def
check_index_com_file_ready
():
if
not
os
.
path
.
exists
(
INDEX_COM_FILE
):
set_index_com_file_not_ready
()
with
open
(
INDEX_COM_FILE
,
'r'
)
as
com_file
:
return
bool
(
com_file
.
readline
())
def
set_model_com_file_not_ready
():
with
open
(
MODEL_COM_FILE
,
'w'
)
as
com_file
:
com_file
.
write
(
'0'
)
def
set_model_com_file_ready
():
with
open
(
MODEL_COM_FILE
,
'w'
)
as
com_file
:
com_file
.
write
(
'1'
)
def
check_model_com_file_ready
():
if
not
os
.
path
.
exists
(
MODEL_COM_FILE
):
set_index_com_file_not_ready
()
with
open
(
MODEL_COM_FILE
,
'r'
)
as
com_file
:
return
bool
(
com_file
.
readline
())
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
,
from_realm_chkpt
=
False
):
args
=
get_args
()
...
...
@@ -210,6 +172,7 @@ def get_ict_dataset(use_titles=True):
max_seq_length
=
288
,
# doesn't matter
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
,
query_in_block_prob
=
1
,
use_titles
=
use_titles
)
dataset
=
ICTDataset
(
**
kwargs
)
...
...
megatron/data/dataset_utils.py
View file @
c150c48b
...
...
@@ -375,6 +375,7 @@ def create_masked_lm_predictions(tokens,
for
p
in
masked_lms
:
masked_lm_positions
.
append
(
p
.
index
)
masked_lm_labels
.
append
(
p
.
label
)
return
(
output_tokens
,
masked_lm_positions
,
masked_lm_labels
,
token_boundary
)
...
...
@@ -387,7 +388,7 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
padding_length
=
max_seq_length
-
num_tokens
assert
padding_length
>=
0
assert
len
(
tokentypes
)
==
num_tokens
assert
len
(
masked_positions
)
==
len
(
masked_labels
)
assert
len
(
masked_positions
)
==
len
(
masked_labels
)
,
(
len
(
masked_positions
),
len
(
masked_labels
))
# Tokens and token types.
filler
=
[
pad_id
]
*
padding_length
...
...
megatron/data/realm_dataset_utils.py
View file @
c150c48b
...
...
@@ -25,14 +25,13 @@ def build_realm_training_sample(sample, max_seq_length,
except
TypeError
:
# this means the above returned None, and None isn't iterable.
# TODO: consider coding style.
print
(
"No salient span found."
,
flush
=
True
)
max_predictions_per_seq
=
masked_lm_prob
*
max_seq_length
masked_tokens
,
masked_positions
,
masked_labels
,
_
=
create_masked_lm_predictions
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
masked_lm_prob
,
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
)
tokens_np
,
tokentypes_np
,
labels_np
,
padding_mask_np
,
loss_mask_np
\
=
pad_and_convert_to_numpy
(
tokens
,
tokentypes
,
masked_positions
,
=
pad_and_convert_to_numpy
(
masked_
tokens
,
tokentypes
,
masked_positions
,
masked_labels
,
pad_id
,
max_seq_length
)
train_sample
=
{
...
...
@@ -84,7 +83,7 @@ def id_to_str_pos_map(token_ids, tokenizer):
# make sure total size is correct
offset
=
-
2
if
token_strs
[
-
1
].
startswith
(
"##"
)
else
0
total_len
=
pos_map
[
-
1
]
+
len
(
token_strs
[
-
1
])
+
offset
assert
total_len
==
len
(
join_str_list
(
token_strs
))
assert
total_len
==
len
(
join_str_list
(
token_strs
))
-
1
,
(
total_len
,
len
(
join_str_list
(
token_strs
)))
return
pos_map
...
...
@@ -93,25 +92,34 @@ def salient_span_mask(tokens, mask_id):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
tokenizer
=
get_tokenizer
()
tokens_str
=
join_str_list
(
tokenizer
.
tokenize
(
tokens
))
tokens_str
=
join_str_list
(
tokenizer
.
tokenize
r
.
convert_ids_to_tokens
(
tokens
))
# need to get all named entities
entities
=
SPACY_NER
(
tokens_str
).
ents
entities
=
[
e
for
e
in
entities
if
e
.
text
!=
"CLS"
]
if
len
(
entities
)
==
0
:
return
None
entity_idx
=
np
.
random
.
randint
(
0
,
len
(
entities
))
selected_entity
=
entities
[
entity_idx
]
selected_entity
=
np
.
random
.
choice
(
entities
)
token_pos_map
=
id_to_str_pos_map
(
tokens
,
tokenizer
)
mask_start
=
mask_end
=
token_pos_map
.
index
(
selected_entity
.
start_char
)
mask_start
=
mask_end
=
0
set_mask_start
=
False
while
mask_end
<
len
(
token_pos_map
)
and
token_pos_map
[
mask_end
]
<
selected_entity
.
end_char
:
if
token_pos_map
[
mask_start
]
>
selected_entity
.
start_char
:
set_mask_start
=
True
if
not
set_mask_start
:
mask_start
+=
1
mask_end
+=
1
masked_positions
=
list
(
range
(
mask_start
,
mask_end
+
1
))
labels
=
tokens
.
copy
()
labels
=
[]
output_tokens
=
tokens
.
copy
()
for
id_idx
in
range
(
mask_start
,
mask_end
):
for
id_idx
in
masked_positions
:
labels
.
append
(
tokens
[
id_idx
])
output_tokens
[
id_idx
]
=
mask_id
return
output_tokens
,
list
(
range
(
mask_start
,
mask_end
))
,
labels
return
output_tokens
,
masked_positions
,
labels
def
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
...
...
megatron/data/realm_index.py
View file @
c150c48b
...
...
@@ -108,12 +108,8 @@ class FaissMIPSIndex(object):
if
self
.
index_type
not
in
INDEX_TYPES
:
raise
ValueError
(
"Invalid index type specified"
)
if
self
.
index_type
==
'flat_l2'
:
index
=
faiss
.
IndexFlatL2
(
self
.
embed_size
+
2
*
self
.
m
)
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
index
)
elif
self
.
index_type
==
'flat_ip'
:
index
=
faiss
.
IndexFlatIP
(
self
.
embed_size
)
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
index
)
index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
index
)
def
reset_index
(
self
):
self
.
_set_block_index
()
...
...
@@ -126,7 +122,7 @@ class FaissMIPSIndex(object):
if
self
.
index_type
==
'flat_l2'
:
block_embeds
=
self
.
alsh_block_preprocess_fn
(
block_embeds
)
self
.
block_mips_index
.
add_with_ids
(
np
.
array
(
block_embeds
),
np
.
array
(
block_indices
))
self
.
block_mips_index
.
add_with_ids
(
np
.
float32
(
np
.
array
(
block_embeds
)
)
,
np
.
array
(
block_indices
))
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
"""Get the top-k blocks by the index distance metric.
...
...
@@ -138,10 +134,10 @@ class FaissMIPSIndex(object):
query_embeds
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
if
reconstruct
:
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
.
astype
(
'float32'
)
,
top_k
)
return
top_k_block_embeds
else
:
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
.
astype
(
'float32'
)
,
top_k
)
return
distances
,
block_indices
def
get_norm_powers_and_halves_array
(
self
,
embeds
):
...
...
@@ -176,6 +172,8 @@ class FaissMIPSIndex(object):
return
np
.
float32
(
np
.
concatenate
((
query_embeds
,
halves_array
,
norm_powers
),
axis
=
1
))
# This was the original hashing scheme, not used anymore
class
RandProjectionLSHIndex
(
object
):
"""Class for holding hashed data"""
def
__init__
(
self
,
embed_size
,
num_buckets
,
whiten
=
True
,
seed
=
0
):
...
...
megatron/training.py
View file @
c150c48b
...
...
@@ -39,7 +39,7 @@ from megatron.model import get_params_for_weight_decay_optimization
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
report_memory
from
indexer
import
check_index_com_file_ready
,
set_index_com_file_not_ready
,
set_model_com_file_ready
from
indexer
_utils
import
check_index_com_file_ready
,
set_index_com_file_not_ready
,
set_model_com_file_ready
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
...
...
pretrain_realm.py
View file @
c150c48b
...
...
@@ -14,7 +14,6 @@
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import
torch
import
torch.nn.functional
as
F
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment