Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
730266ca
Commit
730266ca
authored
May 05, 2020
by
Neel Kant
Browse files
Refactor and add more REALM arguments
parent
a2e64ad5
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
248 additions
and
213 deletions
+248
-213
hashed_index.py
hashed_index.py
+2
-2
megatron/arguments.py
megatron/arguments.py
+4
-0
megatron/checkpointing.py
megatron/checkpointing.py
+6
-2
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+4
-4
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+3
-3
megatron/model/__init__.py
megatron/model/__init__.py
+2
-1
megatron/model/bert_model.py
megatron/model/bert_model.py
+0
-200
megatron/model/realm_model.py
megatron/model/realm_model.py
+226
-0
pretrain_realm.py
pretrain_realm.py
+1
-1
No files found.
hashed_index.py
View file @
730266ca
...
@@ -5,7 +5,7 @@ from megatron import get_args
...
@@ -5,7 +5,7 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.realm_dataset
import
I
nverseCloze
Dataset
from
megatron.data.realm_dataset
import
I
CT
Dataset
from
megatron.data.realm_index
import
detach
,
BlockData
,
RandProjectionLSHIndex
from
megatron.data.realm_index
import
detach
,
BlockData
,
RandProjectionLSHIndex
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
...
@@ -150,7 +150,7 @@ def get_ict_dataset():
...
@@ -150,7 +150,7 @@ def get_ict_dataset():
short_seq_prob
=
0.0001
,
# doesn't matter
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
seed
=
1
)
)
dataset
=
I
nverseCloze
Dataset
(
**
kwargs
)
dataset
=
I
CT
Dataset
(
**
kwargs
)
return
dataset
return
dataset
...
...
megatron/arguments.py
View file @
730266ca
...
@@ -245,6 +245,8 @@ def _add_checkpointing_args(parser):
...
@@ -245,6 +245,8 @@ def _add_checkpointing_args(parser):
help
=
'Directory containing a model checkpoint.'
)
help
=
'Directory containing a model checkpoint.'
)
group
.
add_argument
(
'--ict-load'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--ict-load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing an ICTBertModel checkpoint'
)
help
=
'Directory containing an ICTBertModel checkpoint'
)
group
.
add_argument
(
'--bert-load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing an BertModel checkpoint (needed to start REALM)'
)
group
.
add_argument
(
'--no-load-optim'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-load-optim'
,
action
=
'store_true'
,
help
=
'Do not load optimizer when loading checkpoint.'
)
help
=
'Do not load optimizer when loading checkpoint.'
)
group
.
add_argument
(
'--no-load-rng'
,
action
=
'store_true'
,
group
.
add_argument
(
'--no-load-rng'
,
action
=
'store_true'
,
...
@@ -326,6 +328,8 @@ def _add_data_args(parser):
...
@@ -326,6 +328,8 @@ def _add_data_args(parser):
help
=
'Path to pickled BlockData data structure'
)
help
=
'Path to pickled BlockData data structure'
)
group
.
add_argument
(
'--block-index-path'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--block-index-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to pickled data structure for efficient block indexing'
)
help
=
'Path to pickled data structure for efficient block indexing'
)
group
.
add_argument
(
'--block-top-k'
,
type
=
int
,
default
=
5
,
help
=
'Number of blocks to use as top-k during retrieval'
)
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
help
=
'Comma-separated list of proportions for training,'
help
=
'Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
' validation, and test split. For example the split '
...
...
megatron/checkpointing.py
View file @
730266ca
...
@@ -131,11 +131,15 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -131,11 +131,15 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
):
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
):
"""Load a model checkpoint and return the iteration."""
"""Load a model checkpoint and return the iteration."""
args
=
get_args
()
args
=
get_args
()
load_dir
=
args
.
load
from
megatron.model.bert_model
import
BertModel
if
isinstance
(
model
,
BertModel
)
and
args
.
bert_load
is
not
None
:
load_dir
=
args
.
bert_load
if
isinstance
(
model
,
torchDDP
):
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
model
=
model
.
module
# Read the tracker file and set the iteration.
# Read the tracker file and set the iteration.
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
load
)
tracker_filename
=
get_checkpoint_tracker_filename
(
load
_dir
)
# If no tracker file, return iretation zero.
# If no tracker file, return iretation zero.
if
not
os
.
path
.
isfile
(
tracker_filename
):
if
not
os
.
path
.
isfile
(
tracker_filename
):
...
@@ -164,7 +168,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
...
@@ -164,7 +168,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
tracker_filename
)
tracker_filename
)
# Checkpoint.
# Checkpoint.
checkpoint_name
=
get_checkpoint_name
(
args
.
load
,
iteration
,
release
)
checkpoint_name
=
get_checkpoint_name
(
load
_dir
,
iteration
,
release
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
...
...
megatron/data/dataset_utils.py
View file @
730266ca
...
@@ -454,8 +454,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -454,8 +454,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
print_split_stats
(
'test'
,
2
)
print_split_stats
(
'test'
,
2
)
def
build_dataset
(
index
,
name
):
def
build_dataset
(
index
,
name
):
from
megatron.data.realm_dataset
import
I
nverseCloze
Dataset
from
megatron.data.realm_dataset
import
I
CT
Dataset
from
megatron.data.realm_dataset
import
R
ealm
Dataset
from
megatron.data.realm_dataset
import
R
EALM
Dataset
dataset
=
None
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
if
splits
[
index
+
1
]
>
splits
[
index
]:
# Get the pointer to the original doc-idx so we can set it later.
# Get the pointer to the original doc-idx so we can set it later.
...
@@ -478,13 +478,13 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -478,13 +478,13 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
)
)
if
dataset_type
==
'ict'
:
if
dataset_type
==
'ict'
:
dataset
=
I
nverseCloze
Dataset
(
dataset
=
I
CT
Dataset
(
block_dataset
=
indexed_dataset
,
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
title_dataset
=
title_dataset
,
**
kwargs
**
kwargs
)
)
else
:
else
:
dataset_cls
=
BertDataset
if
dataset_type
==
'standard_bert'
else
R
ealm
Dataset
dataset_cls
=
BertDataset
if
dataset_type
==
'standard_bert'
else
R
EALM
Dataset
dataset
=
dataset_cls
(
dataset
=
dataset_cls
(
indexed_dataset
=
indexed_dataset
,
indexed_dataset
=
indexed_dataset
,
masked_lm_prob
=
masked_lm_prob
,
masked_lm_prob
=
masked_lm_prob
,
...
...
megatron/data/realm_dataset.py
View file @
730266ca
...
@@ -15,7 +15,7 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
...
@@ -15,7 +15,7 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
#qa_nlp = spacy.load('en_core_web_lg')
#qa_nlp = spacy.load('en_core_web_lg')
class
R
ealm
Dataset
(
BertDataset
):
class
R
EALM
Dataset
(
BertDataset
):
"""Dataset containing simple masked sentences for masked language modeling.
"""Dataset containing simple masked sentences for masked language modeling.
The dataset should yield sentences just like the regular BertDataset
The dataset should yield sentences just like the regular BertDataset
...
@@ -28,7 +28,7 @@ class RealmDataset(BertDataset):
...
@@ -28,7 +28,7 @@ class RealmDataset(BertDataset):
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
max_seq_length
,
short_seq_prob
,
seed
):
super
(
R
ealm
Dataset
,
self
).
__init__
(
name
,
indexed_dataset
,
data_prefix
,
super
(
R
EALM
Dataset
,
self
).
__init__
(
name
,
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
)
max_seq_length
,
short_seq_prob
,
seed
)
self
.
build_sample_fn
=
build_simple_training_sample
self
.
build_sample_fn
=
build_simple_training_sample
...
@@ -81,7 +81,7 @@ def spacy_ner(block_text):
...
@@ -81,7 +81,7 @@ def spacy_ner(block_text):
candidates
[
'answers'
]
=
answers
candidates
[
'answers'
]
=
answers
class
I
nverseCloze
Dataset
(
Dataset
):
class
I
CT
Dataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
num_epochs
,
max_num_samples
,
max_seq_length
,
...
...
megatron/model/__init__.py
View file @
730266ca
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
from
.distributed
import
*
from
.distributed
import
*
from
.bert_model
import
BertModel
,
ICTBertModel
,
REALMBertModel
,
REALMRetriever
from
.bert_model
import
BertModel
from
megatron.model.realm_model
import
ICTBertModel
,
REALMRetriever
,
REALMBertModel
from
.gpt2_model
import
GPT2Model
from
.gpt2_model
import
GPT2Model
from
.utils
import
get_params_for_weight_decay_optimization
from
.utils
import
get_params_for_weight_decay_optimization
megatron/model/bert_model.py
View file @
730266ca
...
@@ -15,14 +15,9 @@
...
@@ -15,14 +15,9 @@
"""BERT model."""
"""BERT model."""
import
pickle
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.data.realm_index
import
detach
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
from
megatron.model.transformer
import
LayerNorm
from
megatron.model.transformer
import
LayerNorm
...
@@ -224,198 +219,3 @@ class BertModel(MegatronModule):
...
@@ -224,198 +219,3 @@ class BertModel(MegatronModule):
state_dict
[
self
.
_ict_head_key
],
strict
=
strict
)
state_dict
[
self
.
_ict_head_key
],
strict
=
strict
)
class
REALMBertModel
(
MegatronModule
):
def
__init__
(
self
,
retriever
):
super
(
REALMBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
1
,
add_binary_head
=
False
,
parallel_output
=
True
)
self
.
lm_model
=
BertModel
(
**
bert_args
)
self
.
_lm_key
=
'realm_lm'
self
.
retriever
=
retriever
self
.
_retriever_key
=
'retriever'
def
forward
(
self
,
tokens
,
attention_mask
):
# [batch_size x 5 x seq_length]
top5_block_tokens
,
top5_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
)
batch_size
=
tokens
.
shape
[
0
]
seq_length
=
top5_block_tokens
.
shape
[
2
]
top5_block_tokens
=
torch
.
cuda
.
LongTensor
(
top5_block_tokens
).
reshape
(
-
1
,
seq_length
)
top5_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
top5_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
# [batch_size x 5 x embed_size]
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
fresh_block_logits
=
true_model
.
embed_block
(
top5_block_tokens
,
top5_block_attention_mask
).
reshape
(
batch_size
,
5
,
-
1
)
# [batch_size x embed_size x 1]
query_logits
=
true_model
.
embed_query
(
tokens
,
attention_mask
).
unsqueeze
(
2
)
# [batch_size x 5]
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
# [batch_size * 5 x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
# [batch_size * 5 x 2 * seq_length]
all_tokens
=
torch
.
cat
((
tokens
,
top5_block_tokens
),
axis
=
1
)
all_attention_mask
=
torch
.
cat
((
attention_mask
,
top5_block_attention_mask
),
axis
=
1
)
all_token_types
=
torch
.
zeros
(
all_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
# [batch_size x 5 x 2 * seq_length x vocab_size]
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
lm_logits
=
lm_logits
.
reshape
(
batch_size
,
5
,
2
*
seq_length
,
-
1
)
return
lm_logits
,
block_probs
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_lm_key
]
=
self
.
lm_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
class
REALMRetriever
(
MegatronModule
):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def
__init__
(
self
,
ict_model
,
ict_dataset
,
block_data
,
hashed_index
,
top_k
=
5
):
super
(
REALMRetriever
,
self
).
__init__
()
self
.
ict_model
=
ict_model
self
.
ict_dataset
=
ict_dataset
self
.
block_data
=
block_data
self
.
hashed_index
=
hashed_index
self
.
top_k
=
top_k
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
"""Get the top k evidence blocks for query_text in text form"""
print
(
"-"
*
100
)
print
(
"Query: "
,
query_text
)
padless_max_len
=
self
.
ict_dataset
.
max_seq_length
-
2
query_tokens
=
self
.
ict_dataset
.
encode_text
(
query_text
)[:
padless_max_len
]
query_tokens
,
query_pad_mask
=
self
.
ict_dataset
.
concat_and_pad_tokens
(
query_tokens
)
query_tokens
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
top5_block_tokens
,
_
=
self
.
retrieve_evidence_blocks
(
query_tokens
,
query_pad_mask
)
for
i
,
block
in
enumerate
(
top5_block_tokens
[
0
]):
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
print
(
'
\n
> Block {}: {}'
.
format
(
i
,
block_text
))
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
):
"""Embed blocks to be used in a forward pass"""
with
torch
.
no_grad
():
true_model
=
self
.
ict_model
.
module
.
module
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
all_top5_tokens
,
all_top5_pad_masks
=
[],
[]
for
indices
in
block_indices
:
# [k x meta_dim]
top5_metas
=
np
.
array
([
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
])
top5_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
top5_metas
]
top5_tokens
,
top5_pad_masks
=
zip
(
*
top5_block_data
)
all_top5_tokens
.
append
(
np
.
array
(
top5_tokens
))
all_top5_pad_masks
.
append
(
np
.
array
(
top5_pad_masks
))
# [batch_size x k x seq_length]
return
np
.
array
(
all_top5_tokens
),
np
.
array
(
all_top5_pad_masks
)
class
ICTBertModel
(
MegatronModule
):
"""Bert-based module for Inverse Cloze task."""
def
__init__
(
self
,
ict_head_size
,
num_tokentypes
=
1
,
parallel_output
=
True
,
only_query_model
=
False
,
only_block_model
=
False
):
super
(
ICTBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
False
,
ict_head_size
=
ict_head_size
,
parallel_output
=
parallel_output
)
assert
not
(
only_block_model
and
only_query_model
)
self
.
use_block_model
=
not
only_query_model
self
.
use_query_model
=
not
only_block_model
if
self
.
use_query_model
:
# this model embeds (pseudo-)queries - Embed_input in the paper
self
.
query_model
=
BertModel
(
**
bert_args
)
self
.
_query_key
=
'question_model'
if
self
.
use_block_model
:
# this model embeds evidence blocks - Embed_doc in the paper
self
.
block_model
=
BertModel
(
**
bert_args
)
self
.
_block_key
=
'context_model'
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
block_tokens
,
block_attention_mask
,
only_query
=
False
,
only_block
=
False
):
"""Run a forward pass for each of the models and compute the similarity scores."""
if
only_query
:
return
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
if
only_block
:
return
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
query_logits
=
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
# [batch x embed] * [embed x batch]
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
block_logits
,
0
,
1
))
return
retrieval_scores
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
"""Embed a batch of tokens using the query model"""
if
self
.
use_query_model
:
query_types
=
torch
.
zeros
(
query_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
query_ict_logits
,
_
=
self
.
query_model
.
forward
(
query_tokens
,
query_attention_mask
,
query_types
)
return
query_ict_logits
else
:
raise
ValueError
(
"Cannot embed query without query model."
)
def
embed_block
(
self
,
block_tokens
,
block_attention_mask
):
"""Embed a batch of tokens using the block model"""
if
self
.
use_block_model
:
block_types
=
torch
.
zeros
(
block_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
block_ict_logits
,
_
=
self
.
block_model
.
forward
(
block_tokens
,
block_attention_mask
,
block_types
)
return
block_ict_logits
else
:
raise
ValueError
(
"Cannot embed block without block model."
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
if
self
.
use_query_model
:
state_dict_
[
self
.
_query_key
]
\
=
self
.
query_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
use_block_model
:
state_dict_
[
self
.
_block_key
]
\
=
self
.
block_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
if
self
.
use_query_model
:
print
(
"Loading ICT query model"
,
flush
=
True
)
self
.
query_model
.
load_state_dict
(
state_dict
[
self
.
_query_key
],
strict
=
strict
)
if
self
.
use_block_model
:
print
(
"Loading ICT block model"
,
flush
=
True
)
self
.
block_model
.
load_state_dict
(
state_dict
[
self
.
_block_key
],
strict
=
strict
)
megatron/model/realm_model.py
0 → 100644
View file @
730266ca
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
megatron.checkpointing
import
load_checkpoint
from
megatron.data.realm_index
import
detach
from
megatron.model
import
BertModel
from
megatron.module
import
MegatronModule
class
REALMBertModel
(
MegatronModule
):
def
__init__
(
self
,
retriever
):
super
(
REALMBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
1
,
add_binary_head
=
False
,
parallel_output
=
True
)
self
.
lm_model
=
BertModel
(
**
bert_args
)
load_checkpoint
(
self
.
lm_model
,
optimizer
=
None
,
lr_scheduler
=
None
)
self
.
_lm_key
=
'realm_lm'
self
.
retriever
=
retriever
self
.
_retriever_key
=
'retriever'
def
forward
(
self
,
tokens
,
attention_mask
):
# [batch_size x 5 x seq_length]
top5_block_tokens
,
top5_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
)
batch_size
=
tokens
.
shape
[
0
]
seq_length
=
top5_block_tokens
.
shape
[
2
]
top5_block_tokens
=
torch
.
cuda
.
LongTensor
(
top5_block_tokens
).
reshape
(
-
1
,
seq_length
)
top5_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
top5_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
# [batch_size x 5 x embed_size]
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
fresh_block_logits
=
true_model
.
embed_block
(
top5_block_tokens
,
top5_block_attention_mask
).
reshape
(
batch_size
,
5
,
-
1
)
# [batch_size x embed_size x 1]
query_logits
=
true_model
.
embed_query
(
tokens
,
attention_mask
).
unsqueeze
(
2
)
# [batch_size x 5]
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
# [batch_size * 5 x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
# [batch_size * 5 x 2 * seq_length]
all_tokens
=
torch
.
cat
((
tokens
,
top5_block_tokens
),
axis
=
1
)
all_attention_mask
=
torch
.
cat
((
attention_mask
,
top5_block_attention_mask
),
axis
=
1
)
all_token_types
=
torch
.
zeros
(
all_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
# [batch_size x 5 x 2 * seq_length x vocab_size]
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
lm_logits
=
lm_logits
.
reshape
(
batch_size
,
5
,
2
*
seq_length
,
-
1
)
return
lm_logits
,
block_probs
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_lm_key
]
=
self
.
lm_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_retriever_key
]
=
self
.
retriever
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
self
.
lm_model
.
load_state_dict
(
state_dict
[
self
.
_lm_key
],
strict
)
self
.
retriever
.
load_state_dict
(
state_dict
[
self
.
_retriever_key
],
strict
)
class
REALMRetriever
(
MegatronModule
):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def
__init__
(
self
,
ict_model
,
ict_dataset
,
block_data
,
hashed_index
,
top_k
=
5
):
super
(
REALMRetriever
,
self
).
__init__
()
self
.
ict_model
=
ict_model
self
.
ict_dataset
=
ict_dataset
self
.
block_data
=
block_data
self
.
hashed_index
=
hashed_index
self
.
top_k
=
top_k
self
.
_ict_key
=
'ict_model'
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
"""Get the top k evidence blocks for query_text in text form"""
print
(
"-"
*
100
)
print
(
"Query: "
,
query_text
)
padless_max_len
=
self
.
ict_dataset
.
max_seq_length
-
2
query_tokens
=
self
.
ict_dataset
.
encode_text
(
query_text
)[:
padless_max_len
]
query_tokens
,
query_pad_mask
=
self
.
ict_dataset
.
concat_and_pad_tokens
(
query_tokens
)
query_tokens
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
top5_block_tokens
,
_
=
self
.
retrieve_evidence_blocks
(
query_tokens
,
query_pad_mask
)
for
i
,
block
in
enumerate
(
top5_block_tokens
[
0
]):
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
print
(
'
\n
> Block {}: {}'
.
format
(
i
,
block_text
))
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
):
"""Embed blocks to be used in a forward pass"""
with
torch
.
no_grad
():
true_model
=
self
.
ict_model
.
module
.
module
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
all_top5_tokens
,
all_top5_pad_masks
=
[],
[]
for
indices
in
block_indices
:
# [k x meta_dim]
top5_metas
=
np
.
array
([
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
])
top5_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
top5_metas
]
top5_tokens
,
top5_pad_masks
=
zip
(
*
top5_block_data
)
all_top5_tokens
.
append
(
np
.
array
(
top5_tokens
))
all_top5_pad_masks
.
append
(
np
.
array
(
top5_pad_masks
))
# [batch_size x k x seq_length]
return
np
.
array
(
all_top5_tokens
),
np
.
array
(
all_top5_pad_masks
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_ict_key
]
=
self
.
ict_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
self
.
ict_model
.
load_state_dict
(
state_dict
[
self
.
_ict_key
],
strict
)
class
ICTBertModel
(
MegatronModule
):
"""Bert-based module for Inverse Cloze task."""
def
__init__
(
self
,
ict_head_size
,
num_tokentypes
=
1
,
parallel_output
=
True
,
only_query_model
=
False
,
only_block_model
=
False
):
super
(
ICTBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
False
,
ict_head_size
=
ict_head_size
,
parallel_output
=
parallel_output
)
assert
not
(
only_block_model
and
only_query_model
)
self
.
use_block_model
=
not
only_query_model
self
.
use_query_model
=
not
only_block_model
if
self
.
use_query_model
:
# this model embeds (pseudo-)queries - Embed_input in the paper
self
.
query_model
=
BertModel
(
**
bert_args
)
self
.
_query_key
=
'question_model'
if
self
.
use_block_model
:
# this model embeds evidence blocks - Embed_doc in the paper
self
.
block_model
=
BertModel
(
**
bert_args
)
self
.
_block_key
=
'context_model'
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
block_tokens
,
block_attention_mask
,
only_query
=
False
,
only_block
=
False
):
"""Run a forward pass for each of the models and compute the similarity scores."""
if
only_query
:
return
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
if
only_block
:
return
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
query_logits
=
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
# [batch x embed] * [embed x batch]
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
block_logits
,
0
,
1
))
return
retrieval_scores
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
"""Embed a batch of tokens using the query model"""
if
self
.
use_query_model
:
query_types
=
torch
.
zeros
(
query_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
query_ict_logits
,
_
=
self
.
query_model
.
forward
(
query_tokens
,
query_attention_mask
,
query_types
)
return
query_ict_logits
else
:
raise
ValueError
(
"Cannot embed query without query model."
)
def
embed_block
(
self
,
block_tokens
,
block_attention_mask
):
"""Embed a batch of tokens using the block model"""
if
self
.
use_block_model
:
block_types
=
torch
.
zeros
(
block_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
block_ict_logits
,
_
=
self
.
block_model
.
forward
(
block_tokens
,
block_attention_mask
,
block_types
)
return
block_ict_logits
else
:
raise
ValueError
(
"Cannot embed block without block model."
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
if
self
.
use_query_model
:
state_dict_
[
self
.
_query_key
]
\
=
self
.
query_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
use_block_model
:
state_dict_
[
self
.
_block_key
]
\
=
self
.
block_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
if
self
.
use_query_model
:
print
(
"Loading ICT query model"
,
flush
=
True
)
self
.
query_model
.
load_state_dict
(
state_dict
[
self
.
_query_key
],
strict
=
strict
)
if
self
.
use_block_model
:
print
(
"Loading ICT block model"
,
flush
=
True
)
self
.
block_model
.
load_state_dict
(
state_dict
[
self
.
_block_key
],
strict
=
strict
)
pretrain_realm.py
View file @
730266ca
...
@@ -44,7 +44,7 @@ def model_provider():
...
@@ -44,7 +44,7 @@ def model_provider():
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_l2'
,
embed_size
=
128
)
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_l2'
,
embed_size
=
128
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
,
args
.
block_top_k
)
# TODO: REALMBertModel should accept a path to a pretrained bert-base
# TODO: REALMBertModel should accept a path to a pretrained bert-base
model
=
REALMBertModel
(
retriever
)
model
=
REALMBertModel
(
retriever
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment