Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
730266ca
Commit
730266ca
authored
May 05, 2020
by
Neel Kant
Browse files
Refactor and add more REALM arguments
parent
a2e64ad5
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
248 additions
and
213 deletions
+248
-213
hashed_index.py
hashed_index.py
+2
-2
megatron/arguments.py
megatron/arguments.py
+4
-0
megatron/checkpointing.py
megatron/checkpointing.py
+6
-2
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+4
-4
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+3
-3
megatron/model/__init__.py
megatron/model/__init__.py
+2
-1
megatron/model/bert_model.py
megatron/model/bert_model.py
+0
-200
megatron/model/realm_model.py
megatron/model/realm_model.py
+226
-0
pretrain_realm.py
pretrain_realm.py
+1
-1
No files found.
hashed_index.py
View file @
730266ca
...
...
@@ -5,7 +5,7 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.realm_dataset
import
I
nverseCloze
Dataset
from
megatron.data.realm_dataset
import
I
CT
Dataset
from
megatron.data.realm_index
import
detach
,
BlockData
,
RandProjectionLSHIndex
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
...
...
@@ -150,7 +150,7 @@ def get_ict_dataset():
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
)
dataset
=
I
nverseCloze
Dataset
(
**
kwargs
)
dataset
=
I
CT
Dataset
(
**
kwargs
)
return
dataset
...
...
megatron/arguments.py
View file @
730266ca
...
...
@@ -245,6 +245,8 @@ def _add_checkpointing_args(parser):
help
=
'Directory containing a model checkpoint.'
)
group
.
add_argument
(
'--ict-load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing an ICTBertModel checkpoint'
)
group
.
add_argument
(
'--bert-load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing an BertModel checkpoint (needed to start REALM)'
)
group
.
add_argument
(
'--no-load-optim'
,
action
=
'store_true'
,
help
=
'Do not load optimizer when loading checkpoint.'
)
group
.
add_argument
(
'--no-load-rng'
,
action
=
'store_true'
,
...
...
@@ -326,6 +328,8 @@ def _add_data_args(parser):
help
=
'Path to pickled BlockData data structure'
)
group
.
add_argument
(
'--block-index-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to pickled data structure for efficient block indexing'
)
group
.
add_argument
(
'--block-top-k'
,
type
=
int
,
default
=
5
,
help
=
'Number of blocks to use as top-k during retrieval'
)
group
.
add_argument
(
'--split'
,
type
=
str
,
default
=
'969, 30, 1'
,
help
=
'Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
...
...
megatron/checkpointing.py
View file @
730266ca
...
...
@@ -131,11 +131,15 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
):
"""Load a model checkpoint and return the iteration."""
args
=
get_args
()
load_dir
=
args
.
load
from
megatron.model.bert_model
import
BertModel
if
isinstance
(
model
,
BertModel
)
and
args
.
bert_load
is
not
None
:
load_dir
=
args
.
bert_load
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
# Read the tracker file and set the iteration.
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
load
)
tracker_filename
=
get_checkpoint_tracker_filename
(
load
_dir
)
# If no tracker file, return iretation zero.
if
not
os
.
path
.
isfile
(
tracker_filename
):
...
...
@@ -164,7 +168,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
tracker_filename
)
# Checkpoint.
checkpoint_name
=
get_checkpoint_name
(
args
.
load
,
iteration
,
release
)
checkpoint_name
=
get_checkpoint_name
(
load
_dir
,
iteration
,
release
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
...
...
megatron/data/dataset_utils.py
View file @
730266ca
...
...
@@ -454,8 +454,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
print_split_stats
(
'test'
,
2
)
def
build_dataset
(
index
,
name
):
from
megatron.data.realm_dataset
import
I
nverseCloze
Dataset
from
megatron.data.realm_dataset
import
R
ealm
Dataset
from
megatron.data.realm_dataset
import
I
CT
Dataset
from
megatron.data.realm_dataset
import
R
EALM
Dataset
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
# Get the pointer to the original doc-idx so we can set it later.
...
...
@@ -478,13 +478,13 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
)
if
dataset_type
==
'ict'
:
dataset
=
I
nverseCloze
Dataset
(
dataset
=
I
CT
Dataset
(
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
**
kwargs
)
else
:
dataset_cls
=
BertDataset
if
dataset_type
==
'standard_bert'
else
R
ealm
Dataset
dataset_cls
=
BertDataset
if
dataset_type
==
'standard_bert'
else
R
EALM
Dataset
dataset
=
dataset_cls
(
indexed_dataset
=
indexed_dataset
,
masked_lm_prob
=
masked_lm_prob
,
...
...
megatron/data/realm_dataset.py
View file @
730266ca
...
...
@@ -15,7 +15,7 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
#qa_nlp = spacy.load('en_core_web_lg')
class
R
ealm
Dataset
(
BertDataset
):
class
R
EALM
Dataset
(
BertDataset
):
"""Dataset containing simple masked sentences for masked language modeling.
The dataset should yield sentences just like the regular BertDataset
...
...
@@ -28,7 +28,7 @@ class RealmDataset(BertDataset):
def
__init__
(
self
,
name
,
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
):
super
(
R
ealm
Dataset
,
self
).
__init__
(
name
,
indexed_dataset
,
data_prefix
,
super
(
R
EALM
Dataset
,
self
).
__init__
(
name
,
indexed_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
masked_lm_prob
,
max_seq_length
,
short_seq_prob
,
seed
)
self
.
build_sample_fn
=
build_simple_training_sample
...
...
@@ -81,7 +81,7 @@ def spacy_ner(block_text):
candidates
[
'answers'
]
=
answers
class
I
nverseCloze
Dataset
(
Dataset
):
class
I
CT
Dataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
...
...
megatron/model/__init__.py
View file @
730266ca
...
...
@@ -14,6 +14,7 @@
# limitations under the License.
from
.distributed
import
*
from
.bert_model
import
BertModel
,
ICTBertModel
,
REALMBertModel
,
REALMRetriever
from
.bert_model
import
BertModel
from
megatron.model.realm_model
import
ICTBertModel
,
REALMRetriever
,
REALMBertModel
from
.gpt2_model
import
GPT2Model
from
.utils
import
get_params_for_weight_decay_optimization
megatron/model/bert_model.py
View file @
730266ca
...
...
@@ -15,14 +15,9 @@
"""BERT model."""
import
pickle
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.data.realm_index
import
detach
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model.transformer
import
LayerNorm
...
...
@@ -224,198 +219,3 @@ class BertModel(MegatronModule):
state_dict
[
self
.
_ict_head_key
],
strict
=
strict
)
class
REALMBertModel
(
MegatronModule
):
def
__init__
(
self
,
retriever
):
super
(
REALMBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
1
,
add_binary_head
=
False
,
parallel_output
=
True
)
self
.
lm_model
=
BertModel
(
**
bert_args
)
self
.
_lm_key
=
'realm_lm'
self
.
retriever
=
retriever
self
.
_retriever_key
=
'retriever'
def
forward
(
self
,
tokens
,
attention_mask
):
# [batch_size x 5 x seq_length]
top5_block_tokens
,
top5_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
)
batch_size
=
tokens
.
shape
[
0
]
seq_length
=
top5_block_tokens
.
shape
[
2
]
top5_block_tokens
=
torch
.
cuda
.
LongTensor
(
top5_block_tokens
).
reshape
(
-
1
,
seq_length
)
top5_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
top5_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
# [batch_size x 5 x embed_size]
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
fresh_block_logits
=
true_model
.
embed_block
(
top5_block_tokens
,
top5_block_attention_mask
).
reshape
(
batch_size
,
5
,
-
1
)
# [batch_size x embed_size x 1]
query_logits
=
true_model
.
embed_query
(
tokens
,
attention_mask
).
unsqueeze
(
2
)
# [batch_size x 5]
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
# [batch_size * 5 x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
# [batch_size * 5 x 2 * seq_length]
all_tokens
=
torch
.
cat
((
tokens
,
top5_block_tokens
),
axis
=
1
)
all_attention_mask
=
torch
.
cat
((
attention_mask
,
top5_block_attention_mask
),
axis
=
1
)
all_token_types
=
torch
.
zeros
(
all_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
# [batch_size x 5 x 2 * seq_length x vocab_size]
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
lm_logits
=
lm_logits
.
reshape
(
batch_size
,
5
,
2
*
seq_length
,
-
1
)
return
lm_logits
,
block_probs
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_lm_key
]
=
self
.
lm_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
class
REALMRetriever
(
MegatronModule
):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def
__init__
(
self
,
ict_model
,
ict_dataset
,
block_data
,
hashed_index
,
top_k
=
5
):
super
(
REALMRetriever
,
self
).
__init__
()
self
.
ict_model
=
ict_model
self
.
ict_dataset
=
ict_dataset
self
.
block_data
=
block_data
self
.
hashed_index
=
hashed_index
self
.
top_k
=
top_k
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
"""Get the top k evidence blocks for query_text in text form"""
print
(
"-"
*
100
)
print
(
"Query: "
,
query_text
)
padless_max_len
=
self
.
ict_dataset
.
max_seq_length
-
2
query_tokens
=
self
.
ict_dataset
.
encode_text
(
query_text
)[:
padless_max_len
]
query_tokens
,
query_pad_mask
=
self
.
ict_dataset
.
concat_and_pad_tokens
(
query_tokens
)
query_tokens
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
top5_block_tokens
,
_
=
self
.
retrieve_evidence_blocks
(
query_tokens
,
query_pad_mask
)
for
i
,
block
in
enumerate
(
top5_block_tokens
[
0
]):
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
print
(
'
\n
> Block {}: {}'
.
format
(
i
,
block_text
))
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
):
"""Embed blocks to be used in a forward pass"""
with
torch
.
no_grad
():
true_model
=
self
.
ict_model
.
module
.
module
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
all_top5_tokens
,
all_top5_pad_masks
=
[],
[]
for
indices
in
block_indices
:
# [k x meta_dim]
top5_metas
=
np
.
array
([
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
])
top5_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
top5_metas
]
top5_tokens
,
top5_pad_masks
=
zip
(
*
top5_block_data
)
all_top5_tokens
.
append
(
np
.
array
(
top5_tokens
))
all_top5_pad_masks
.
append
(
np
.
array
(
top5_pad_masks
))
# [batch_size x k x seq_length]
return
np
.
array
(
all_top5_tokens
),
np
.
array
(
all_top5_pad_masks
)
class
ICTBertModel
(
MegatronModule
):
"""Bert-based module for Inverse Cloze task."""
def
__init__
(
self
,
ict_head_size
,
num_tokentypes
=
1
,
parallel_output
=
True
,
only_query_model
=
False
,
only_block_model
=
False
):
super
(
ICTBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
False
,
ict_head_size
=
ict_head_size
,
parallel_output
=
parallel_output
)
assert
not
(
only_block_model
and
only_query_model
)
self
.
use_block_model
=
not
only_query_model
self
.
use_query_model
=
not
only_block_model
if
self
.
use_query_model
:
# this model embeds (pseudo-)queries - Embed_input in the paper
self
.
query_model
=
BertModel
(
**
bert_args
)
self
.
_query_key
=
'question_model'
if
self
.
use_block_model
:
# this model embeds evidence blocks - Embed_doc in the paper
self
.
block_model
=
BertModel
(
**
bert_args
)
self
.
_block_key
=
'context_model'
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
block_tokens
,
block_attention_mask
,
only_query
=
False
,
only_block
=
False
):
"""Run a forward pass for each of the models and compute the similarity scores."""
if
only_query
:
return
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
if
only_block
:
return
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
query_logits
=
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
# [batch x embed] * [embed x batch]
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
block_logits
,
0
,
1
))
return
retrieval_scores
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
"""Embed a batch of tokens using the query model"""
if
self
.
use_query_model
:
query_types
=
torch
.
zeros
(
query_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
query_ict_logits
,
_
=
self
.
query_model
.
forward
(
query_tokens
,
query_attention_mask
,
query_types
)
return
query_ict_logits
else
:
raise
ValueError
(
"Cannot embed query without query model."
)
def
embed_block
(
self
,
block_tokens
,
block_attention_mask
):
"""Embed a batch of tokens using the block model"""
if
self
.
use_block_model
:
block_types
=
torch
.
zeros
(
block_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
block_ict_logits
,
_
=
self
.
block_model
.
forward
(
block_tokens
,
block_attention_mask
,
block_types
)
return
block_ict_logits
else
:
raise
ValueError
(
"Cannot embed block without block model."
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
if
self
.
use_query_model
:
state_dict_
[
self
.
_query_key
]
\
=
self
.
query_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
use_block_model
:
state_dict_
[
self
.
_block_key
]
\
=
self
.
block_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
if
self
.
use_query_model
:
print
(
"Loading ICT query model"
,
flush
=
True
)
self
.
query_model
.
load_state_dict
(
state_dict
[
self
.
_query_key
],
strict
=
strict
)
if
self
.
use_block_model
:
print
(
"Loading ICT block model"
,
flush
=
True
)
self
.
block_model
.
load_state_dict
(
state_dict
[
self
.
_block_key
],
strict
=
strict
)
megatron/model/realm_model.py
0 → 100644
View file @
730266ca
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
megatron.checkpointing
import
load_checkpoint
from
megatron.data.realm_index
import
detach
from
megatron.model
import
BertModel
from
megatron.module
import
MegatronModule
class
REALMBertModel
(
MegatronModule
):
def
__init__
(
self
,
retriever
):
super
(
REALMBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
1
,
add_binary_head
=
False
,
parallel_output
=
True
)
self
.
lm_model
=
BertModel
(
**
bert_args
)
load_checkpoint
(
self
.
lm_model
,
optimizer
=
None
,
lr_scheduler
=
None
)
self
.
_lm_key
=
'realm_lm'
self
.
retriever
=
retriever
self
.
_retriever_key
=
'retriever'
def
forward
(
self
,
tokens
,
attention_mask
):
# [batch_size x 5 x seq_length]
top5_block_tokens
,
top5_block_attention_mask
=
self
.
retriever
.
retrieve_evidence_blocks
(
tokens
,
attention_mask
)
batch_size
=
tokens
.
shape
[
0
]
seq_length
=
top5_block_tokens
.
shape
[
2
]
top5_block_tokens
=
torch
.
cuda
.
LongTensor
(
top5_block_tokens
).
reshape
(
-
1
,
seq_length
)
top5_block_attention_mask
=
torch
.
cuda
.
LongTensor
(
top5_block_attention_mask
).
reshape
(
-
1
,
seq_length
)
# [batch_size x 5 x embed_size]
true_model
=
self
.
retriever
.
ict_model
.
module
.
module
fresh_block_logits
=
true_model
.
embed_block
(
top5_block_tokens
,
top5_block_attention_mask
).
reshape
(
batch_size
,
5
,
-
1
)
# [batch_size x embed_size x 1]
query_logits
=
true_model
.
embed_query
(
tokens
,
attention_mask
).
unsqueeze
(
2
)
# [batch_size x 5]
fresh_block_scores
=
torch
.
matmul
(
fresh_block_logits
,
query_logits
).
squeeze
()
block_probs
=
F
.
softmax
(
fresh_block_scores
,
dim
=
1
)
# [batch_size * 5 x seq_length]
tokens
=
torch
.
stack
([
tokens
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
attention_mask
=
torch
.
stack
([
attention_mask
.
unsqueeze
(
1
)]
*
5
,
dim
=
1
).
reshape
(
-
1
,
seq_length
)
# [batch_size * 5 x 2 * seq_length]
all_tokens
=
torch
.
cat
((
tokens
,
top5_block_tokens
),
axis
=
1
)
all_attention_mask
=
torch
.
cat
((
attention_mask
,
top5_block_attention_mask
),
axis
=
1
)
all_token_types
=
torch
.
zeros
(
all_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
# [batch_size x 5 x 2 * seq_length x vocab_size]
lm_logits
,
_
=
self
.
lm_model
.
forward
(
all_tokens
,
all_attention_mask
,
all_token_types
)
lm_logits
=
lm_logits
.
reshape
(
batch_size
,
5
,
2
*
seq_length
,
-
1
)
return
lm_logits
,
block_probs
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_lm_key
]
=
self
.
lm_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
state_dict_
[
self
.
_retriever_key
]
=
self
.
retriever
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
self
.
lm_model
.
load_state_dict
(
state_dict
[
self
.
_lm_key
],
strict
)
self
.
retriever
.
load_state_dict
(
state_dict
[
self
.
_retriever_key
],
strict
)
class
REALMRetriever
(
MegatronModule
):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def
__init__
(
self
,
ict_model
,
ict_dataset
,
block_data
,
hashed_index
,
top_k
=
5
):
super
(
REALMRetriever
,
self
).
__init__
()
self
.
ict_model
=
ict_model
self
.
ict_dataset
=
ict_dataset
self
.
block_data
=
block_data
self
.
hashed_index
=
hashed_index
self
.
top_k
=
top_k
self
.
_ict_key
=
'ict_model'
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
"""Get the top k evidence blocks for query_text in text form"""
print
(
"-"
*
100
)
print
(
"Query: "
,
query_text
)
padless_max_len
=
self
.
ict_dataset
.
max_seq_length
-
2
query_tokens
=
self
.
ict_dataset
.
encode_text
(
query_text
)[:
padless_max_len
]
query_tokens
,
query_pad_mask
=
self
.
ict_dataset
.
concat_and_pad_tokens
(
query_tokens
)
query_tokens
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
LongTensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
top5_block_tokens
,
_
=
self
.
retrieve_evidence_blocks
(
query_tokens
,
query_pad_mask
)
for
i
,
block
in
enumerate
(
top5_block_tokens
[
0
]):
block_text
=
self
.
ict_dataset
.
decode_tokens
(
block
)
print
(
'
\n
> Block {}: {}'
.
format
(
i
,
block_text
))
def
retrieve_evidence_blocks
(
self
,
query_tokens
,
query_pad_mask
):
"""Embed blocks to be used in a forward pass"""
with
torch
.
no_grad
():
true_model
=
self
.
ict_model
.
module
.
module
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
all_top5_tokens
,
all_top5_pad_masks
=
[],
[]
for
indices
in
block_indices
:
# [k x meta_dim]
top5_metas
=
np
.
array
([
self
.
block_data
.
meta_data
[
idx
]
for
idx
in
indices
])
top5_block_data
=
[
self
.
ict_dataset
.
get_block
(
*
block_meta
)
for
block_meta
in
top5_metas
]
top5_tokens
,
top5_pad_masks
=
zip
(
*
top5_block_data
)
all_top5_tokens
.
append
(
np
.
array
(
top5_tokens
))
all_top5_pad_masks
.
append
(
np
.
array
(
top5_pad_masks
))
# [batch_size x k x seq_length]
return
np
.
array
(
all_top5_tokens
),
np
.
array
(
all_top5_pad_masks
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_ict_key
]
=
self
.
ict_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
self
.
ict_model
.
load_state_dict
(
state_dict
[
self
.
_ict_key
],
strict
)
class
ICTBertModel
(
MegatronModule
):
"""Bert-based module for Inverse Cloze task."""
def
__init__
(
self
,
ict_head_size
,
num_tokentypes
=
1
,
parallel_output
=
True
,
only_query_model
=
False
,
only_block_model
=
False
):
super
(
ICTBertModel
,
self
).
__init__
()
bert_args
=
dict
(
num_tokentypes
=
num_tokentypes
,
add_binary_head
=
False
,
ict_head_size
=
ict_head_size
,
parallel_output
=
parallel_output
)
assert
not
(
only_block_model
and
only_query_model
)
self
.
use_block_model
=
not
only_query_model
self
.
use_query_model
=
not
only_block_model
if
self
.
use_query_model
:
# this model embeds (pseudo-)queries - Embed_input in the paper
self
.
query_model
=
BertModel
(
**
bert_args
)
self
.
_query_key
=
'question_model'
if
self
.
use_block_model
:
# this model embeds evidence blocks - Embed_doc in the paper
self
.
block_model
=
BertModel
(
**
bert_args
)
self
.
_block_key
=
'context_model'
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
block_tokens
,
block_attention_mask
,
only_query
=
False
,
only_block
=
False
):
"""Run a forward pass for each of the models and compute the similarity scores."""
if
only_query
:
return
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
if
only_block
:
return
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
query_logits
=
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
# [batch x embed] * [embed x batch]
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
block_logits
,
0
,
1
))
return
retrieval_scores
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
"""Embed a batch of tokens using the query model"""
if
self
.
use_query_model
:
query_types
=
torch
.
zeros
(
query_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
query_ict_logits
,
_
=
self
.
query_model
.
forward
(
query_tokens
,
query_attention_mask
,
query_types
)
return
query_ict_logits
else
:
raise
ValueError
(
"Cannot embed query without query model."
)
def
embed_block
(
self
,
block_tokens
,
block_attention_mask
):
"""Embed a batch of tokens using the block model"""
if
self
.
use_block_model
:
block_types
=
torch
.
zeros
(
block_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
block_ict_logits
,
_
=
self
.
block_model
.
forward
(
block_tokens
,
block_attention_mask
,
block_types
)
return
block_ict_logits
else
:
raise
ValueError
(
"Cannot embed block without block model."
)
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""Save dict with state dicts of each of the models."""
state_dict_
=
{}
if
self
.
use_query_model
:
state_dict_
[
self
.
_query_key
]
\
=
self
.
query_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
self
.
use_block_model
:
state_dict_
[
self
.
_block_key
]
\
=
self
.
block_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Load the state dicts of each of the models"""
if
self
.
use_query_model
:
print
(
"Loading ICT query model"
,
flush
=
True
)
self
.
query_model
.
load_state_dict
(
state_dict
[
self
.
_query_key
],
strict
=
strict
)
if
self
.
use_block_model
:
print
(
"Loading ICT block model"
,
flush
=
True
)
self
.
block_model
.
load_state_dict
(
state_dict
[
self
.
_block_key
],
strict
=
strict
)
pretrain_realm.py
View file @
730266ca
...
...
@@ -44,7 +44,7 @@ def model_provider():
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_l2'
,
embed_size
=
128
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
,
args
.
block_top_k
)
# TODO: REALMBertModel should accept a path to a pretrained bert-base
model
=
REALMBertModel
(
retriever
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment