Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0f0f60aa
Commit
0f0f60aa
authored
May 15, 2020
by
Neel Kant
Browse files
Able to run REALM with terrible index sync
parent
2f7d666c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
12 deletions
+19
-12
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+5
-5
megatron/data/realm_index.py
megatron/data/realm_index.py
+6
-3
pretrain_realm.py
pretrain_realm.py
+8
-4
No files found.
megatron/data/realm_dataset_utils.py
View file @
0f0f60aa
...
...
@@ -93,8 +93,6 @@ def salient_span_mask(tokens, mask_id):
Note: Tokens here are vocab ids and not text tokens."""
tokenizer
=
get_tokenizer
()
tokens_str
=
join_str_list
(
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
tokens
))
print
(
"-"
*
100
)
print
(
"TOKEN STR
\n
"
,
tokens_str
)
# need to get all named entities
entities
=
SPACY_NER
(
tokens_str
).
ents
...
...
@@ -103,7 +101,6 @@ def salient_span_mask(tokens, mask_id):
return
None
entity_idx
=
np
.
random
.
randint
(
0
,
len
(
entities
))
selected_entity
=
entities
[
entity_idx
]
print
(
"SELECTED ENTITY
\n
"
,
selected_entity
.
text
)
token_pos_map
=
id_to_str_pos_map
(
tokens
,
tokenizer
)
mask_start
=
mask_end
=
0
...
...
@@ -114,14 +111,17 @@ def salient_span_mask(tokens, mask_id):
if
not
set_mask_start
:
mask_start
+=
1
mask_end
+=
1
masked_positions
=
list
(
range
(
mask_start
,
mask_end
+
1
))
masked_positions
=
list
(
range
(
mask_start
-
1
,
mask_end
))
labels
=
[]
output_tokens
=
tokens
.
copy
()
for
id_idx
in
masked_positions
:
labels
.
append
(
tokens
[
id_idx
])
output_tokens
[
id_idx
]
=
mask_id
print
(
"OUTPUT
\n
"
,
join_str_list
(
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
output_tokens
)))
#print("-" * 100 + '\n',
# "TOKEN STR\n", tokens_str + '\n',
# "SELECTED ENTITY\n", selected_entity.text + '\n',
# "OUTPUT\n", join_str_list(tokenizer.tokenizer.convert_ids_to_tokens(output_tokens)), flush=True)
return
output_tokens
,
masked_positions
,
labels
...
...
megatron/data/realm_index.py
View file @
0f0f60aa
...
...
@@ -91,7 +91,7 @@ class FaissMIPSIndex(object):
self
.
_set_block_index
()
def
_set_block_index
(
self
):
INDEX_TYPES
=
[
'flat_l2'
,
'flat_ip'
]
INDEX_TYPES
=
[
'flat_ip'
]
if
self
.
index_type
not
in
INDEX_TYPES
:
raise
ValueError
(
"Invalid index type specified"
)
...
...
@@ -123,14 +123,17 @@ class FaissMIPSIndex(object):
"""
if
self
.
index_type
==
'flat_l2'
:
query_embeds
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
query_embeds
=
np
.
float32
(
query_embeds
)
if
reconstruct
:
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
.
astype
(
'float32'
)
,
top_k
)
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
return
top_k_block_embeds
else
:
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
.
astype
(
'float32'
)
,
top_k
)
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
return
distances
,
block_indices
# functions below are for ALSH, which currently isn't being used
def
get_norm_powers_and_halves_array
(
self
,
embeds
):
norm
=
np
.
linalg
.
norm
(
embeds
,
axis
=
1
)
norm_powers
=
[
np
.
multiply
(
norm
,
norm
)]
# squared L2 norms of all
...
...
pretrain_realm.py
View file @
0f0f60aa
...
...
@@ -89,7 +89,7 @@ def forward_step(data_iterator, model):
# Forward model.
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
,
query_block_indices
)
with
torch
.
no_grad
():
retrieval_utility
=
get_retrieval_utility
(
lm_logits
,
labels
,
loss_mask
)
retrieval_utility
=
get_retrieval_utility
(
lm_logits
,
block_probs
,
labels
,
loss_mask
)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
...
...
@@ -105,9 +105,13 @@ def forward_step(data_iterator, model):
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
],
'retrieval_utility'
:
reduced_loss
[
1
]}
def
get_retrieval_utility
(
lm_logits
,
labels
,
loss_mask
):
def
get_retrieval_utility
(
lm_logits
,
block_probs
,
labels
,
loss_mask
):
"""log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
lm_logits
=
lm_logits
[:,
:,
:
labels
.
shape
[
1
],
:]
#non_null_block_probs = block_probs[:, :-1]
#non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
# non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
null_block_lm_logits
=
lm_logits
[:,
-
1
,
:,
:]
null_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
null_block_lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
...
...
@@ -119,10 +123,11 @@ def get_retrieval_utility(lm_logits, labels, loss_mask):
retrieved_block_lm_logits
=
lm_logits
[:,
block_num
,
:,
:]
retrieved_block_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
retrieved_block_lm_logits
.
contiguous
().
float
(),
labels
.
contiguous
())
#retrieved_block_loss_ *= non_null_block_probs[:, block_num].reshape(-1, 1)
retrieved_block_loss
=
torch
.
sum
(
retrieved_block_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
retrieved_block_losses
.
append
(
retrieved_block_loss
)
avg_retrieved_block_loss
=
torch
.
sum
(
retrieved_block_losses
)
/
(
lm_logits
.
shape
[
1
]
-
1
)
avg_retrieved_block_loss
=
torch
.
sum
(
torch
.
cuda
.
FloatTensor
(
retrieved_block_losses
)
)
/
(
lm_logits
.
shape
[
1
]
-
1
)
retrieval_utility
=
null_block_loss
-
avg_retrieved_block_loss
return
retrieval_utility
...
...
@@ -171,6 +176,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment