Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f2094783
Commit
f2094783
authored
May 06, 2020
by
Neel Kant
Browse files
Add REALMAnswerSpanModel and MLM features
parent
c17d880c
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
289 additions
and
21 deletions
+289
-21
hashed_index.py
hashed_index.py
+3
-2
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+178
-14
megatron/model/realm_model.py
megatron/model/realm_model.py
+79
-1
pretrain_bert_ict.py
pretrain_bert_ict.py
+3
-3
pretrain_realm.py
pretrain_realm.py
+26
-1
No files found.
hashed_index.py
View file @
f2094783
...
...
@@ -134,7 +134,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
return
model
def
get_ict_dataset
():
def
get_ict_dataset
(
use_titles
=
True
):
args
=
get_args
()
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
'mmap'
,
True
)
...
...
@@ -148,7 +148,8 @@ def get_ict_dataset():
max_num_samples
=
None
,
max_seq_length
=
288
,
# doesn't matter
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
seed
=
1
,
use_titles
=
use_titles
)
dataset
=
ICTDataset
(
**
kwargs
)
return
dataset
...
...
megatron/data/realm_dataset.py
View file @
f2094783
...
...
@@ -10,9 +10,7 @@ from torch.utils.data import Dataset
from
megatron
import
get_tokenizer
,
print_rank_0
,
mpu
from
megatron.data.bert_dataset
import
BertDataset
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
#qa_nlp = spacy.load('en_core_web_lg')
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
,
is_start_piece
def
build_simple_training_sample
(
sample
,
target_seq_length
,
max_seq_length
,
...
...
@@ -40,6 +38,169 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
return
train_sample
qa_nlp
=
spacy
.
load
(
'en_core_web_lg'
)
def
salient_span_mask
(
tokens
,
vocab_id_list
,
vocab_id_to_token_dict
,
cls_id
,
sep_id
,
mask_id
,
np_rng
,
do_permutation
=
False
):
"""Creates the predictions for the masked LM objective.
Note: Tokens here are vocab ids and not text tokens."""
cand_indexes
=
[]
# Note(mingdachen): We create a list for recording if the piece is
# the starting piece of current token, where 1 means true, so that
# on-the-fly whole word masking is possible.
token_boundary
=
[
0
]
*
len
(
tokens
)
for
(
i
,
token
)
in
enumerate
(
tokens
):
if
token
==
cls_id
or
token
==
sep_id
:
token_boundary
[
i
]
=
1
continue
# Whole Word Masking means that if we mask all of the wordpieces
# corresponding to an original word.
#
# Note that Whole Word Masking does *not* change the training code
# at all -- we still predict each WordPiece independently, softmaxed
# over the entire vocabulary.
if
len
(
cand_indexes
)
>=
1
and
not
is_start_piece
(
vocab_id_to_token_dict
[
token
]):
cand_indexes
[
-
1
].
append
(
i
)
else
:
cand_indexes
.
append
([
i
])
if
is_start_piece
(
vocab_id_to_token_dict
[
token
]):
token_boundary
[
i
]
=
1
output_tokens
=
list
(
tokens
)
masked_lm_positions
=
[]
masked_lm_labels
=
[]
ngram_indexes
=
[]
for
idx
in
range
(
len
(
cand_indexes
)):
ngram_index
=
[]
for
n
in
ngrams
:
ngram_index
.
append
(
cand_indexes
[
idx
:
idx
+
n
])
ngram_indexes
.
append
(
ngram_index
)
np_rng
.
shuffle
(
ngram_indexes
)
masked_lms
=
[]
covered_indexes
=
set
()
for
cand_index_set
in
ngram_indexes
:
if
len
(
masked_lms
)
>=
num_to_predict
:
break
if
not
cand_index_set
:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for
index_set
in
cand_index_set
[
0
]:
for
index
in
index_set
:
if
index
in
covered_indexes
:
continue
n
=
np_rng
.
choice
(
ngrams
[:
len
(
cand_index_set
)],
p
=
pvals
[:
len
(
cand_index_set
)]
/
pvals
[:
len
(
cand_index_set
)].
sum
(
keepdims
=
True
))
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# Note(mingdachen):
# Repeatedly looking for a candidate that does not exceed the
# maximum number of predictions by trying shorter ngrams.
while
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
if
n
==
0
:
break
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
masked_lms
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
covered_indexes
.
add
(
index
)
masked_token
=
None
# 80% of the time, replace with [MASK]
if
np_rng
.
random
()
<
0.8
:
masked_token
=
mask_id
else
:
# 10% of the time, keep original
if
np_rng
.
random
()
<
0.5
:
masked_token
=
tokens
[
index
]
# 10% of the time, replace with random word
else
:
masked_token
=
vocab_id_list
[
np_rng
.
randint
(
0
,
len
(
vocab_id_list
))]
output_tokens
[
index
]
=
masked_token
masked_lms
.
append
(
MaskedLmInstance
(
index
=
index
,
label
=
tokens
[
index
]))
assert
len
(
masked_lms
)
<=
num_to_predict
np_rng
.
shuffle
(
ngram_indexes
)
select_indexes
=
set
()
if
do_permutation
:
for
cand_index_set
in
ngram_indexes
:
if
len
(
select_indexes
)
>=
num_to_predict
:
break
if
not
cand_index_set
:
continue
# Note(mingdachen):
# Skip current piece if they are covered in lm masking or previous ngrams.
for
index_set
in
cand_index_set
[
0
]:
for
index
in
index_set
:
if
index
in
covered_indexes
or
index
in
select_indexes
:
continue
n
=
np
.
random
.
choice
(
ngrams
[:
len
(
cand_index_set
)],
p
=
pvals
[:
len
(
cand_index_set
)]
/
pvals
[:
len
(
cand_index_set
)].
sum
(
keepdims
=
True
))
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
while
len
(
select_indexes
)
+
len
(
index_set
)
>
num_to_predict
:
if
n
==
0
:
break
index_set
=
sum
(
cand_index_set
[
n
-
1
],
[])
n
-=
1
# If adding a whole-word mask would exceed the maximum number of
# predictions, then just skip this candidate.
if
len
(
select_indexes
)
+
len
(
index_set
)
>
num_to_predict
:
continue
is_any_index_covered
=
False
for
index
in
index_set
:
if
index
in
covered_indexes
or
index
in
select_indexes
:
is_any_index_covered
=
True
break
if
is_any_index_covered
:
continue
for
index
in
index_set
:
select_indexes
.
add
(
index
)
assert
len
(
select_indexes
)
<=
num_to_predict
select_indexes
=
sorted
(
select_indexes
)
permute_indexes
=
list
(
select_indexes
)
np_rng
.
shuffle
(
permute_indexes
)
orig_token
=
list
(
output_tokens
)
for
src_i
,
tgt_i
in
zip
(
select_indexes
,
permute_indexes
):
output_tokens
[
src_i
]
=
orig_token
[
tgt_i
]
masked_lms
.
append
(
MaskedLmInstance
(
index
=
src_i
,
label
=
orig_token
[
src_i
]))
masked_lms
=
sorted
(
masked_lms
,
key
=
lambda
x
:
x
.
index
)
for
p
in
masked_lms
:
masked_lm_positions
.
append
(
p
.
index
)
masked_lm_labels
.
append
(
p
.
label
)
return
(
output_tokens
,
masked_lm_positions
,
masked_lm_labels
,
token_boundary
)
class
REALMDataset
(
Dataset
):
"""Dataset containing simple masked sentences for masked language modeling.
...
...
@@ -196,7 +357,7 @@ class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
):
short_seq_prob
,
seed
,
use_titles
=
True
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
...
...
@@ -204,6 +365,7 @@ class ICTDataset(Dataset):
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
use_titles
=
use_titles
self
.
samples_mapping
=
self
.
get_samples_mapping
(
data_prefix
,
num_epochs
,
max_num_samples
)
...
...
@@ -220,15 +382,16 @@ class ICTDataset(Dataset):
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
if
self
.
use_titles
:
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
title_pad_offset
=
3
+
len
(
title
)
else
:
title
=
None
title_pad_offset
=
2
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
# avoid selecting the first or last sentence to be the query.
if
len
(
block
)
==
2
:
rand_sent_idx
=
int
(
self
.
rng
.
random
()
>
0.5
)
else
:
rand_sent_idx
=
self
.
rng
.
randint
(
1
,
len
(
block
)
-
2
)
rand_sent_idx
=
self
.
rng
.
randint
(
0
,
len
(
block
)
-
1
)
# keep the query in the context 10% of the time.
if
self
.
rng
.
random
()
<
1
:
...
...
@@ -239,7 +402,7 @@ class ICTDataset(Dataset):
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query
=
query
[:
self
.
max_seq_length
-
2
]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))
]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
title_pad_offset
]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
...
...
@@ -279,9 +442,10 @@ class ICTDataset(Dataset):
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
if
title
is
not
None
:
tokens
+=
title
+
[
self
.
sep_id
]
if
title
is
None
:
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
else
:
tokens
=
[
self
.
cls_id
]
+
title
+
[
self
.
sep_id
]
+
tokens
+
[
self
.
sep_id
]
assert
len
(
tokens
)
<=
self
.
max_seq_length
,
len
(
tokens
)
num_pad
=
self
.
max_seq_length
-
len
(
tokens
)
...
...
megatron/model/realm_model.py
View file @
f2094783
...
...
@@ -2,12 +2,79 @@ import numpy as np
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.checkpointing
import
load_checkpoint
from
megatron.data.realm_index
import
detach
from
megatron.model
import
BertModel
from
megatron.model.utils
import
get_linear_layer
,
init_method_normal
from
megatron.module
import
MegatronModule
class
REALMAnswerSpanModel
(
MegatronModule
):
def
__init__
(
self
,
realm_model
,
mlp_hidden_size
=
64
):
super
(
REALMAnswerSpanModel
,
self
).
__init__
()
self
.
realm_model
=
realm_model
self
.
mlp_hidden_size
=
mlp_hidden_size
args
=
get_args
()
init_method
=
init_method_normal
(
args
.
init_method_std
)
self
.
fc1
=
get_linear_layer
(
2
*
args
.
hidden_size
,
self
.
mlp_hidden_size
,
init_method
)
self
.
_fc1_key
=
'fc1'
self
.
fc2
=
get_linear_layer
(
self
.
mlp_hidden_size
,
1
,
init_method
)
self
.
_fc2_key
=
'fc2'
max_length
=
10
self
.
start_ends
=
[]
for
length
in
range
(
max_length
):
self
.
start_ends
.
extend
([(
i
,
i
+
length
)
for
i
in
range
(
288
-
length
)])
def
forward
(
self
,
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
):
lm_logits
,
block_probs
,
topk_block_tokens
=
self
.
realm_model
(
question_tokens
,
question_attention_mask
,
query_block_indices
=
None
,
return_topk_block_tokens
=
True
)
batch_span_reps
,
batch_loss_masks
=
[],
[]
# go through batch one-by-one
for
i
in
range
(
len
(
answer_token_lengths
)):
answer_length
=
answer_token_lengths
[
i
]
answer_span_tokens
=
answer_tokens
[
i
][:
answer_length
]
span_reps
,
loss_masks
=
[],
[]
# go through the top k for the batch item
for
logits
,
block_tokens
in
zip
(
lm_logits
[
i
],
topk_block_tokens
[
i
]):
block_logits
=
logits
[
len
(
logits
)
/
2
:]
span_starts
=
range
(
len
(
block_tokens
)
-
(
answer_length
-
1
))
# record the start, end indices of spans which match the answer
matching_indices
=
set
([
(
idx
,
idx
+
answer_length
-
1
)
for
idx
in
span_starts
if
np
.
array_equal
(
block_tokens
[
idx
:
idx
+
answer_length
],
answer_span_tokens
)
])
# create a mask for computing the loss on P(y | z, x)
# [num_spans]
loss_masks
.
append
(
torch
.
LongTensor
([
int
(
idx_pair
in
matching_indices
)
for
idx_pair
in
self
.
start_ends
]))
# get all of the candidate spans that need to be fed to MLP
# [num_spans x 2 * embed_size]
span_reps
.
append
([
torch
.
cat
((
block_logits
[
s
],
block_logits
[
e
]))
for
(
s
,
e
)
in
self
.
start_ends
])
# data for all k blocks for a single batch item
# [k x num_spans]
batch_loss_masks
.
append
(
torch
.
stack
(
loss_masks
))
# [k x num_spans x 2 * embed_size]
batch_span_reps
.
append
(
torch
.
stack
(
span_reps
))
# data for all batch items
# [batch_size x k x num_spans]
batch_loss_masks
=
torch
.
stack
(
batch_loss_masks
)
batch_span_reps
=
torch
.
stack
(
batch_span_reps
)
# [batch_size x k x num_spans]
batch_span_logits
=
self
.
fc2
(
self
.
fc1
(
batch_span_reps
)).
squeeze
()
return
batch_span_logits
,
batch_loss_masks
,
block_probs
# block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
# lm_logits = torch.sum(lm_logits * block_probs, dim=1)
class
REALMBertModel
(
MegatronModule
):
def
__init__
(
self
,
retriever
):
super
(
REALMBertModel
,
self
).
__init__
()
...
...
@@ -24,11 +91,13 @@ class REALMBertModel(MegatronModule):
self
.
top_k
=
self
.
retriever
.
top_k
self
.
_retriever_key
=
'retriever'
def
forward
(
self
,
tokens
,
attention_mask
,
query_block_indices
):
def
forward
(
self
,
tokens
,
attention_mask
,
query_block_indices
,
return_topk_block_tokens
=
False
):
# [batch_size x k x seq_length]
topk_block_tokens
,
topk_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
,
query_block_indices
=
query_block_indices
,
include_null_doc
=
True
)
batch_size
=
tokens
.
shape
[
0
]
# create a copy in case it needs to be returned
ret_topk_block_tokens
=
np
.
array
(
topk_block_tokens
)
seq_length
=
topk_block_tokens
.
shape
[
2
]
topk_block_tokens
=
torch
.
cuda
.
LongTensor
(
topk_block_tokens
).
reshape
(
-
1
,
seq_length
)
...
...
@@ -58,6 +127,10 @@ class REALMBertModel(MegatronModule):
# [batch_size x k x 2 * seq_length x vocab_size]
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
lm_logits
=
lm_logits
.
reshape
(
batch_size
,
self
.
top_k
,
2
*
seq_length
,
-
1
)
if
return_topk_block_tokens
:
return
lm_logits
,
block_probs
,
ret_topk_block_tokens
return
lm_logits
,
block_probs
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
...
...
@@ -111,6 +184,11 @@ class REALMRetriever(MegatronModule):
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
all_topk_tokens
,
all_topk_pad_masks
=
[],
[]
# this will result in no candidate exclusion
if
query_block_indices
is
None
:
query_block_indices
=
[
-
1
]
*
len
(
block_indices
)
for
query_idx
,
indices
in
enumerate
(
block_indices
):
# [k x meta_dim]
# exclude trivial candidate if it appears, else just trim the weakest in the top-k
...
...
pretrain_bert_ict.py
View file @
f2094783
...
...
@@ -83,11 +83,11 @@ def forward_step(data_iterator, model):
retrieval_scores
=
model
(
query_tokens
,
query_pad_mask
,
block_tokens
,
block_pad_mask
).
float
()
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
top5
_vals
,
top5
_indices
=
torch
.
topk
(
softmaxed
,
k
=
5
,
sorted
=
True
)
sorted
_vals
,
sorted
_indices
=
torch
.
topk
(
softmaxed
,
k
=
softmaxed
.
shape
[
1
]
,
sorted
=
True
)
batch_size
=
softmaxed
.
shape
[
0
]
top1_acc
=
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
top5
_indices
[
i
,
0
]
==
i
)
for
i
in
range
(
batch_size
)])
/
batch_size
])
top5_acc
=
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
i
in
top5
_indices
[
i
])
for
i
in
range
(
batch_size
)])
/
batch_size
])
top1_acc
=
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
sorted
_indices
[
i
,
0
]
==
i
)
for
i
in
range
(
batch_size
)])
/
batch_size
])
top5_acc
=
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
i
in
sorted
_indices
[
i
,
:
5
])
for
i
in
range
(
batch_size
)])
/
batch_size
])
retrieval_loss
=
F
.
cross_entropy
(
softmaxed
,
torch
.
arange
(
batch_size
).
cuda
())
reduced_losses
=
reduce_losses
([
retrieval_loss
,
top1_acc
,
top5_acc
])
...
...
pretrain_realm.py
View file @
f2094783
...
...
@@ -38,7 +38,7 @@ def model_provider():
print_rank_0
(
'building REALM models ...'
)
ict_model
=
load_ict_checkpoint
()
ict_dataset
=
get_ict_dataset
()
ict_dataset
=
get_ict_dataset
(
use_titles
=
False
)
all_block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_l2'
,
embed_size
=
128
)
...
...
@@ -73,6 +73,11 @@ def get_batch(data_iterator):
return
tokens
,
labels
,
loss_mask
,
pad_mask
,
query_block_indices
def
get_qa_batch
(
data_iterator
):
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
=
next
(
data_iterator
)
return
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
...
...
@@ -101,6 +106,26 @@ def forward_step(data_iterator, model):
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
]}
def
qa_forward_step
(
data_iterator
,
model
):
timers
=
get_timers
()
# this dataset interface needs to be implemented
timers
(
'batch generator'
).
start
()
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
=
get_qa_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
batch_span_logits
,
batch_loss_masks
,
block_probs
=
model
(
question_tokens
,
question_attention_mask
,
answer_tokens
,
answer_token_lengths
)
# [batch_size x k x num_spans]
block_probs
=
block_probs
.
unsqueeze
(
2
).
expand_as
(
batch_span_logits
)
batch_span_probs
=
F
.
softmax
(
batch_span_logits
,
dim
=
2
)
reduced_block_span_probs
=
torch
.
sum
(
batch_span_probs
*
block_probs
,
dim
=
1
)
qa_span_loss_
=
-
torch
.
log
(
reduced_block_span_probs
)
qa_span_loss
=
torch
.
sum
(
qa_span_loss_
.
view
(
-
1
)
*
batch_loss_masks
)
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
"""Build train, valid and test datasets."""
args
=
get_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment