Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
256eb6ed
Commit
256eb6ed
authored
Apr 20, 2020
by
Neel Kant
Browse files
Enhance hashed_index and more improvements elsewhere
parent
9f9b2cf8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
21 deletions
+62
-21
hashed_index.py
hashed_index.py
+13
-6
megatron/model/bert_model.py
megatron/model/bert_model.py
+21
-11
pretrain_realm.py
pretrain_realm.py
+28
-4
No files found.
hashed_index.py
View file @
256eb6ed
...
@@ -17,6 +17,10 @@ from megatron.training import get_model
...
@@ -17,6 +17,10 @@ from megatron.training import get_model
from
pretrain_bert_ict
import
get_batch
,
model_provider
from
pretrain_bert_ict
import
get_batch
,
model_provider
def
detach
(
tensor
):
return
tensor
.
detach
().
cpu
().
numpy
()
def
embed_docs
():
def
embed_docs
():
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
...
@@ -45,12 +49,13 @@ def embed_docs():
...
@@ -45,12 +49,13 @@ def embed_docs():
block_hash_pos
=
torch
.
matmul
(
block_logits
,
hash_matrix
)
block_hash_pos
=
torch
.
matmul
(
block_logits
,
hash_matrix
)
block_hash_full
=
torch
.
cat
((
block_hash_pos
,
-
block_hash_pos
),
axis
=
1
)
block_hash_full
=
torch
.
cat
((
block_hash_pos
,
-
block_hash_pos
),
axis
=
1
)
block_hashes
=
torch
.
argmax
(
block_hash_full
,
axis
=
1
)
.
detach
().
cpu
().
numpy
(
)
block_hashes
=
detach
(
torch
.
argmax
(
block_hash_full
,
axis
=
1
))
for
hash
,
indices_array
in
zip
(
block_hashes
,
block_indices
):
for
hash
,
indices_array
in
zip
(
block_hashes
,
block_indices
):
hash_data
[
int
(
hash
)].
append
(
indices_array
.
detach
().
cpu
().
numpy
(
))
hash_data
[
int
(
hash
)].
append
(
detach
(
indices_array
))
block_logits
=
block_logits
.
detach
().
cpu
().
numpy
()
block_logits
=
detach
(
block_logits
)
block_indices
=
block_indices
.
detach
().
cpu
().
numpy
()[:,
3
]
# originally this has [start_idx, end_idx, doc_idx, block_idx]
block_indices
=
detach
(
block_indices
)[:,
3
]
for
logits
,
idx
in
zip
(
block_logits
,
block_indices
):
for
logits
,
idx
in
zip
(
block_logits
,
block_indices
):
block_data
[
int
(
idx
)]
=
logits
block_data
[
int
(
idx
)]
=
logits
...
@@ -68,6 +73,10 @@ def embed_docs():
...
@@ -68,6 +73,10 @@ def embed_docs():
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
all_data
.
clear
()
del
all_data
del
model
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
all_block_data
=
defaultdict
(
dict
)
all_block_data
=
defaultdict
(
dict
)
dir_name
=
'block_hash_data'
dir_name
=
'block_hash_data'
...
@@ -80,9 +89,7 @@ def embed_docs():
...
@@ -80,9 +89,7 @@ def embed_docs():
with
open
(
'block_hash_data.pkl'
,
'wb'
)
as
final_file
:
with
open
(
'block_hash_data.pkl'
,
'wb'
)
as
final_file
:
pickle
.
dump
(
all_block_data
,
final_file
)
pickle
.
dump
(
all_block_data
,
final_file
)
os
.
rmdir
(
dir_name
)
os
.
rmdir
(
dir_name
)
return
def
load_checkpoint
():
def
load_checkpoint
():
...
...
megatron/model/bert_model.py
View file @
256eb6ed
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
"""BERT model."""
"""BERT model."""
import
pickle
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -215,7 +217,7 @@ class BertModel(MegatronModule):
...
@@ -215,7 +217,7 @@ class BertModel(MegatronModule):
class
REALMBertModel
(
MegatronModule
):
class
REALMBertModel
(
MegatronModule
):
def
__init__
(
self
,
ict_model
_path
,
block_hash_data_path
):
def
__init__
(
self
,
ict_model
,
block_hash_data_path
):
super
(
REALMBertModel
,
self
).
__init__
()
super
(
REALMBertModel
,
self
).
__init__
()
bert_args
=
dict
(
bert_args
=
dict
(
num_tokentypes
=
2
,
num_tokentypes
=
2
,
...
@@ -226,17 +228,21 @@ class REALMBertModel(MegatronModule):
...
@@ -226,17 +228,21 @@ class REALMBertModel(MegatronModule):
self
.
_lm_key
=
'realm_lm'
self
.
_lm_key
=
'realm_lm'
self
.
ict_model
=
ict_model
self
.
ict_model
=
ict_model
self
.
ict_dataset
=
ict_dataset
with
open
(
block_hash_data_path
,
'rb'
)
as
data_file
:
data
=
pickle
.
load
(
data_file
)
self
.
block_hash_data
=
block_hash_data
# {block_idx: block_embed} - the main index
self
.
block_data
=
data
[
'block_data'
]
# {hash_num: [start, end, doc, block]} - the hash table
self
.
hash_data
=
data
[
'hash_data'
]
# [embed_size x num_buckets / 2] - the projection matrix used for hashing
self
.
hash_matrix
=
self
.
hash_data
[
'matrix'
]
def
forward
(
self
,
tokens
,
attention_mask
,
token_types
):
def
forward
(
self
,
tokens
,
attention_mask
,
token_types
):
# [batch_size x embed_size]
# [batch_size x embed_size]
query_logits
=
self
.
ict_model
.
embed_query
(
tokens
,
attention_mask
,
token_types
)
query_logits
=
self
.
ict_model
.
embed_query
(
tokens
,
attention_mask
,
token_types
)
hash_matrix_pos
=
self
.
hash_data
[
'matrix'
]
# [batch_size
,
num_buckets / 2]
# [batch_size
x
num_buckets / 2]
query_hash_pos
=
torch
.
matmul
(
query_logits
,
hash_matrix
_pos
)
query_hash_pos
=
torch
.
matmul
(
query_logits
,
self
.
hash_matrix
)
query_hash_full
=
torch
.
cat
((
query_hash_pos
,
-
query_hash_pos
),
axis
=
1
)
query_hash_full
=
torch
.
cat
((
query_hash_pos
,
-
query_hash_pos
),
axis
=
1
)
# [batch_size]
# [batch_size]
...
@@ -247,15 +253,19 @@ class REALMBertModel(MegatronModule):
...
@@ -247,15 +253,19 @@ class REALMBertModel(MegatronModule):
# TODO: this should be made into a single np.array in preprocessing
# TODO: this should be made into a single np.array in preprocessing
bucket_blocks
=
self
.
hash_data
[
hash
]
bucket_blocks
=
self
.
hash_data
[
hash
]
block_indices
=
bucket_blocks
[:,
3
]
block_indices
=
bucket_blocks
[:,
3
]
# [bucket_pop
,
embed_size]
# [bucket_pop
x
embed_size]
block_embeds
=
[
self
.
block_data
[
idx
]
for
idx
in
block_indices
]
block_embeds
=
[
self
.
block_data
[
idx
]
for
idx
in
block_indices
]
# will become [batch_size
,
bucket_pop
,
embed_size]
# will become [batch_size
x
bucket_pop
x
embed_size]
# will require padding to do tensor multiplication
# will require padding to do tensor multiplication
batch_block_embeds
.
append
(
block_embeds
)
batch_block_embeds
.
append
(
block_embeds
)
# [batch_size x max bucket_pop x embed_size]
batch_block_embeds
=
np
.
array
(
batch_block_embeds
)
batch_block_embeds
=
np
.
array
(
batch_block_embeds
)
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
batch_block_embeds
,
0
,
1
))
# [batch_size x 1 x max bucket_pop]
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
batch_block_embeds
,
1
,
2
))
# [batch_size x max bucket_pop]
retrieval_scores
=
retrieval_scores
.
squeeze
()
top5_vals
,
top5_indices
=
torch
.
topk
(
retrieval_scores
,
k
=
5
)
...
...
pretrain_realm.py
View file @
256eb6ed
...
@@ -17,24 +17,49 @@
...
@@ -17,24 +17,49 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.data.bert_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
ICTBertModel
,
REALMBertModel
from
megatron.model
import
ICTBertModel
,
REALMBertModel
from
megatron.training
import
pretrain
from
megatron.training
import
get_model
,
pretrain
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
from
pretrain_bert_ict
import
model_provider
as
ict_model_provider
num_batches
=
0
num_batches
=
0
def
model_provider
():
def
model_provider
():
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'building BERT models ...'
)
print_rank_0
(
'building BERT models ...'
)
realm_model
=
REALMBertModel
(
args
.
ict_model_path
,
ict_model
=
get_model
(
ict_model_provider
)
if
isinstance
(
ict_model
,
torchDDP
):
model
=
ict_model
.
module
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
load
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
checkpoint_name
=
get_checkpoint_name
(
args
.
load
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_model
.
load_state_dict
(
state_dict
[
'model'
])
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
realm_model
=
REALMBertModel
(
ict_model
,
args
.
block_hash_data_path
)
args
.
block_hash_data_path
)
return
ict_model
return
ict_model
...
@@ -107,8 +132,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -107,8 +132,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
masked_lm_prob
=
args
.
mask_prob
,
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
),
skip_warmup
=
(
not
args
.
mmap_warmup
))
ict_dataset
=
True
)
print_rank_0
(
"> finished creating BERT ICT datasets ..."
)
print_rank_0
(
"> finished creating BERT ICT datasets ..."
)
return
train_ds
,
valid_ds
,
test_ds
return
train_ds
,
valid_ds
,
test_ds
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment