Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
efcee158
Commit
efcee158
authored
May 05, 2020
by
Neel Kant
Browse files
Add null block and exclude trivial block
parent
730266ca
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
184 additions
and
56 deletions
+184
-56
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+8
-2
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+137
-20
megatron/model/realm_model.py
megatron/model/realm_model.py
+32
-26
pretrain_realm.py
pretrain_realm.py
+7
-8
No files found.
megatron/data/dataset_utils.py
View file @
efcee158
...
...
@@ -483,9 +483,15 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
title_dataset
=
title_dataset
,
**
kwargs
)
elif
dataset_type
==
'realm'
:
dataset
=
REALMDataset
(
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
masked_lm_prob
=
masked_lm_prob
,
**
kwargs
)
else
:
dataset_cls
=
BertDataset
if
dataset_type
==
'standard_bert'
else
REALMDataset
dataset
=
dataset_cls
(
dataset
=
BertDataset
(
indexed_dataset
=
indexed_dataset
,
masked_lm_prob
=
masked_lm_prob
,
**
kwargs
...
...
megatron/data/realm_dataset.py
View file @
efcee158
...
...
@@ -15,30 +15,10 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
#qa_nlp = spacy.load('en_core_web_lg')
class
REALMDataset
(
BertDataset
):
"""Dataset containing simple masked sentences for masked language modeling.
The dataset should yield sentences just like the regular BertDataset
However, this dataset also needs to be able to return a set of blocks
given their start and end indices.
Presumably
"""
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
super
(
REALMDataset
,
self
).
__init__
(
name
,
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
)
self
.
build_sample_fn
=
build_simple_training_sample
def
build_simple_training_sample
(
sample
,
target_seq_length
,
max_seq_length
,
vocab_id_list
,
vocab_id_to_token_dict
,
cls_id
,
sep_id
,
mask_id
,
pad_id
,
masked_lm_prob
,
np_rng
):
tokens
=
list
(
itertools
.
chain
(
*
sample
))[:
max_seq_length
-
2
]
tokens
,
tokentypes
=
create_single_tokens_and_tokentypes
(
tokens
,
cls_id
,
sep_id
)
...
...
@@ -60,6 +40,137 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
return
train_sample
class
REALMDataset
(
Dataset
):
"""Dataset containing simple masked sentences for masked language modeling.
The dataset should yield sentences just like the regular BertDataset
However, this dataset also needs to be able to return a set of blocks
given their start and end indices.
Presumably
"""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
masked_lm_prob
=
masked_lm_prob
self
.
block_dataset
=
block_dataset
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
samples_mapping
=
self
.
get_samples_mapping
(
data_prefix
,
num_epochs
,
max_num_samples
)
self
.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
self
.
cls_id
=
self
.
tokenizer
.
cls
self
.
sep_id
=
self
.
tokenizer
.
sep
self
.
mask_id
=
self
.
tokenizer
.
mask
self
.
pad_id
=
self
.
tokenizer
.
pad
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
seq_length
=
self
.
max_seq_length
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
np_rng
=
np
.
random
.
RandomState
(
seed
=
(
self
.
seed
+
idx
))
sample
=
build_simple_training_sample
(
block
,
seq_length
,
self
.
max_seq_length
,
self
.
vocab_id_list
,
self
.
vocab_id_to_token_list
,
self
.
cls_id
,
self
.
sep_id
,
self
.
mask_id
,
self
.
pad_id
,
self
.
masked_lm_prob
,
np_rng
)
sample
.
update
({
'query_block_indices'
:
np
.
array
([
block_idx
])})
return
sample
def
get_samples_mapping
(
self
,
data_prefix
,
num_epochs
,
max_num_samples
):
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
self
.
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
self
.
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
self
.
seed
)
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
self
.
block_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
self
.
block_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
self
.
name
))
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
from
megatron.data
import
helpers
samples_mapping
=
helpers
.
build_blocks_mapping
(
self
.
block_dataset
.
doc_idx
,
self
.
block_dataset
.
sizes
,
self
.
title_dataset
.
sizes
,
num_epochs
,
max_num_samples
,
self
.
max_seq_length
-
3
,
# account for added tokens
self
.
seed
,
verbose
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
return
samples_mapping
def
create_single_tokens_and_tokentypes
(
_tokens
,
cls_id
,
sep_id
):
tokens
=
[]
tokens
.
append
(
cls_id
)
...
...
@@ -160,6 +271,12 @@ class ICTDataset(Dataset):
return
(
block_tokens
,
block_pad_mask
)
def
get_null_block
(
self
):
block
,
title
=
[],
[]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
(
block_tokens
,
block_pad_mask
)
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
...
...
megatron/model/realm_model.py
View file @
efcee158
...
...
@@ -21,40 +21,43 @@ class REALMBertModel(MegatronModule):
self
.
_lm_key
=
'realm_lm'
self
.
retriever
=
retriever
self
.
top_k
=
self
.
retriever
.
top_k
self
.
_retriever_key
=
'retriever'
def
forward
(
self
,
tokens
,
attention_mask
):
# [batch_size x 5 x seq_length]
top5_block_tokens
,
top5_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
)
def
forward
(
self
,
tokens
,
attention_mask
,
query_block_indices
):
# [batch_size x k x seq_length]
topk_block_tokens
,
topk_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
,
query_block_indices
=
query_block_indices
,
include_null_doc
=
True
)
batch_size
=
tokens
.
shape
[
0
]
seq_length
=
top
5
_block_tokens
.
shape
[
2
]
top
5
_block_tokens
=
torch
.
cuda
.
LongTensor
(
top
5
_block_tokens
).
reshape
(
-
1
,
seq_length
)
top
5
_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
top
5
_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
seq_length
=
top
k
_block_tokens
.
shape
[
2
]
top
k
_block_tokens
=
torch
.
cuda
.
LongTensor
(
top
k
_block_tokens
).
reshape
(
-
1
,
seq_length
)
top
k
_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
top
k
_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
# [batch_size x
5
x embed_size]
# [batch_size x
k
x embed_size]
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
fresh_block_logits
=
true_model
.
embed_block
(
top5_block_tokens
,
top5_block_attention_mask
).
reshape
(
batch_size
,
5
,
-
1
)
fresh_block_logits
=
true_model
.
embed_block
(
topk_block_tokens
,
topk_block_attention_mask
)
fresh_block_logits
=
fresh_block_logits
.
reshape
(
batch_size
,
self
.
top_k
,
-
1
)
# [batch_size x embed_size x 1]
query_logits
=
true_model
.
embed_query
(
tokens
,
attention_mask
).
unsqueeze
(
2
)
# [batch_size x
5
]
# [batch_size x
k
]
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
# [batch_size *
5
x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
# [batch_size *
k
x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
self
.
top_k
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
# [batch_size *
5
x 2 * seq_length]
all_tokens
=
torch
.
cat
((
tokens
,
top
5
_block_tokens
),
axis
=
1
)
all_attention_mask
=
torch
.
cat
((
attention_mask
,
top
5
_block_attention_mask
),
axis
=
1
)
# [batch_size *
k
x 2 * seq_length]
all_tokens
=
torch
.
cat
((
tokens
,
top
k
_block_tokens
),
axis
=
1
)
all_attention_mask
=
torch
.
cat
((
attention_mask
,
top
k
_block_attention_mask
),
axis
=
1
)
all_token_types
=
torch
.
zeros
(
all_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
# [batch_size x
5
x 2 * seq_length x vocab_size]
# [batch_size x
k
x 2 * seq_length x vocab_size]
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
lm_logits
=
lm_logits
.
reshape
(
batch_size
,
5
,
2
*
seq_length
,
-
1
)
lm_logits
=
lm_logits
.
reshape
(
batch_size
,
self
.
top_k
,
2
*
seq_length
,
-
1
)
return
lm_logits
,
block_probs
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
...
...
@@ -101,24 +104,27 @@ class REALMRetriever(MegatronModule):
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
print
(
'
\n
> Block {}: {}'
.
format
(
i
,
block_text
))
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
):
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
,
query_block_indices
=
None
,
include_null_doc
=
False
):
"""Embed blocks to be used in a forward pass"""
with
torch
.
no_grad
():
true_model
=
self
.
ict_model
.
module
.
module
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
all_top
5
_tokens
,
all_top
5
_pad_masks
=
[],
[]
for
indices
in
block_indices
:
all_top
k
_tokens
,
all_top
k
_pad_masks
=
[],
[]
for
query_idx
,
indices
in
enumerate
(
block_indices
)
:
# [k x meta_dim]
top5_metas
=
np
.
array
([
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
])
top5_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
top5_metas
]
top5_tokens
,
top5_pad_masks
=
zip
(
*
top5_block_data
)
# exclude trivial candidate if it appears, else just trim the weakest in the top-k
topk_metas
=
[
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
if
idx
!=
query_block_indices
[
query_idx
]]
topk_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
topk_metas
[:
self
.
top_k
-
1
]]
if
include_null_doc
:
topk_block_data
.
append
(
self
.
ict_dataset
.
get_null_block
())
topk_tokens
,
topk_pad_masks
=
zip
(
*
topk_block_data
)
all_top
5
_tokens
.
append
(
np
.
array
(
top
5
_tokens
))
all_top
5
_pad_masks
.
append
(
np
.
array
(
top
5
_pad_masks
))
all_top
k
_tokens
.
append
(
np
.
array
(
top
k
_tokens
))
all_top
k
_pad_masks
.
append
(
np
.
array
(
top
k
_pad_masks
))
# [batch_size x k x seq_length]
return
np
.
array
(
all_top
5
_tokens
),
np
.
array
(
all_top
5
_pad_masks
)
return
np
.
array
(
all_top
k
_tokens
),
np
.
array
(
all_top
k
_pad_masks
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
...
...
pretrain_realm.py
View file @
efcee158
...
...
@@ -44,8 +44,8 @@ def model_provider():
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_l2'
,
embed_size
=
128
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
,
args
.
block_top_k
)
# TODO: REALMBertModel should accept a path to a pretrained bert-base
# top_k + 1 because we may need to exclude trivial candidate
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
,
args
.
block_top_k
+
1
)
model
=
REALMBertModel
(
retriever
)
return
model
...
...
@@ -53,7 +53,7 @@ def model_provider():
def
get_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'tokens'
,
'labels'
,
'loss_mask'
,
'pad_mask'
]
keys
=
[
'tokens'
,
'labels'
,
'loss_mask'
,
'pad_mask'
,
'query_block_indices'
]
datatype
=
torch
.
int64
# Broadcast data.
...
...
@@ -68,8 +68,9 @@ def get_batch(data_iterator):
labels
=
data_b
[
'labels'
].
long
()
loss_mask
=
data_b
[
'loss_mask'
].
long
()
pad_mask
=
data_b
[
'pad_mask'
].
long
()
query_block_indices
=
data_b
[
'query_block_indices'
].
long
()
return
tokens
,
labels
,
loss_mask
,
pad_mask
return
tokens
,
labels
,
loss_mask
,
pad_mask
,
query_block_indices
def
forward_step
(
data_iterator
,
model
):
...
...
@@ -78,16 +79,15 @@ def forward_step(data_iterator, model):
# Get the batch.
timers
(
'batch generator'
).
start
()
tokens
,
labels
,
loss_mask
,
pad_mask
=
get_batch
(
data_iterator
)
tokens
,
labels
,
loss_mask
,
pad_mask
,
query_block_indices
=
get_batch
(
data_iterator
)
timers
(
'batch generator'
).
stop
()
# Forward model.
# TODO: MAKE SURE PAD IS NOT 1 - PAD
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
)
lm_logits
,
block_probs
=
model
(
tokens
,
pad_mask
,
query_block_indices
)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs
=
block_probs
.
unsqueeze
(
2
).
unsqueeze
(
3
).
expand_as
(
lm_logits
)
#block_probs.register_hook(lambda x: print("block_probs: ", x.shape, flush=True))
lm_logits
=
torch
.
sum
(
lm_logits
*
block_probs
,
dim
=
1
)[:,
:
labels
.
shape
[
1
]]
lm_loss_
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
contiguous
().
float
(),
...
...
@@ -95,7 +95,6 @@ def forward_step(data_iterator, model):
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
reduced_loss
=
reduce_losses
([
lm_loss
])
torch
.
cuda
.
synchronize
()
print
(
reduced_loss
,
flush
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment