Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
56e81e99
Commit
56e81e99
authored
May 03, 2020
by
Neel Kant
Browse files
Complete refactor of RandProjectLSHIndex
parent
642802e0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
33 additions
and
33 deletions
+33
-33
hashed_index.py
hashed_index.py
+20
-23
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+1
-2
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+4
-3
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+1
-1
megatron/data/realm_index.py
megatron/data/realm_index.py
+7
-4
No files found.
hashed_index.py
View file @
56e81e99
...
@@ -6,7 +6,7 @@ from megatron import mpu
...
@@ -6,7 +6,7 @@ from megatron import mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.realm_dataset
import
InverseClozeDataset
from
megatron.data.realm_dataset
import
InverseClozeDataset
from
megatron.data.realm_index
import
BlockData
,
RandProjectionLSHIndex
from
megatron.data.realm_index
import
detach
,
BlockData
,
RandProjectionLSHIndex
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
REALMRetriever
from
megatron.model
import
REALMRetriever
...
@@ -14,10 +14,6 @@ from megatron.training import get_model
...
@@ -14,10 +14,6 @@ from megatron.training import get_model
from
pretrain_bert_ict
import
get_batch
,
model_provider
from
pretrain_bert_ict
import
get_batch
,
model_provider
def
detach
(
tensor
):
return
tensor
.
detach
().
cpu
().
numpy
()
def
test_retriever
():
def
test_retriever
():
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
...
@@ -71,26 +67,27 @@ def main():
...
@@ -71,26 +67,27 @@ def main():
i
=
1
i
=
1
total
=
0
total
=
0
while
True
:
while
True
:
try
:
with
torch
.
no_grad
():
query_tokens
,
query_pad_mask
,
\
try
:
block_tokens
,
block_pad_mask
,
block_index_data
=
get_batch
(
data_iter
)
query_tokens
,
query_pad_mask
,
\
except
:
block_tokens
,
block_pad_mask
,
block_index_data
=
get_batch
(
data_iter
)
break
except
:
block_index_data
=
detach
(
block_index_data
)
block_indices
=
block_index_data
[:,
3
]
block_meta
=
block_index_data
[:,
:
3
]
block_logits
=
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
)
all_block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_meta
)
total
+=
block_indices
.
size
i
+=
1
if
i
%
20
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
args
.
debug
:
break
break
block_index_data
=
detach
(
block_index_data
)
block_indices
=
block_index_data
[:,
3
]
block_meta
=
block_index_data
[:,
:
3
]
block_logits
=
detach
(
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
))
all_block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_meta
)
total
+=
block_indices
.
size
i
+=
1
if
i
%
20
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
args
.
debug
:
break
all_block_data
.
save_shard
(
args
.
rank
)
all_block_data
.
save_shard
(
args
.
rank
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
del
model
del
model
...
...
megatron/data/bert_dataset.py
View file @
56e81e99
...
@@ -24,11 +24,9 @@ from torch.utils.data import Dataset
...
@@ -24,11 +24,9 @@ from torch.utils.data import Dataset
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
,
'realm'
]
class
BertDataset
(
Dataset
):
class
BertDataset
(
Dataset
):
...
@@ -64,6 +62,7 @@ class BertDataset(Dataset):
...
@@ -64,6 +62,7 @@ class BertDataset(Dataset):
self
.
sep_id
=
tokenizer
.
sep
self
.
sep_id
=
tokenizer
.
sep
self
.
mask_id
=
tokenizer
.
mask
self
.
mask_id
=
tokenizer
.
mask
self
.
pad_id
=
tokenizer
.
pad
self
.
pad_id
=
tokenizer
.
pad
from
megatron.data.dataset_utils
import
build_training_sample
self
.
build_sample_fn
=
build_training_sample
self
.
build_sample_fn
=
build_training_sample
def
__len__
(
self
):
def
__len__
(
self
):
...
...
megatron/data/dataset_utils.py
View file @
56e81e99
...
@@ -23,9 +23,9 @@ import itertools
...
@@ -23,9 +23,9 @@ import itertools
import
numpy
as
np
import
numpy
as
np
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.data.bert_dataset
import
DATASET_TYPES
,
get_indexed_dataset_
,
get_train_valid_test_split_
,
BertDataset
from
megatron.data.bert_dataset
import
get_indexed_dataset_
,
get_train_valid_test_split_
,
BertDataset
from
megatron.data.realm_dataset
import
InverseClozeDataset
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
,
'realm'
]
def
compile_helper
():
def
compile_helper
():
"""Compile helper function ar runtime. Make sure this
"""Compile helper function ar runtime. Make sure this
...
@@ -454,6 +454,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -454,6 +454,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
print_split_stats
(
'test'
,
2
)
print_split_stats
(
'test'
,
2
)
def
build_dataset
(
index
,
name
):
def
build_dataset
(
index
,
name
):
from
megatron.data.realm_dataset
import
InverseClozeDataset
from
megatron.data.realm_dataset
import
RealmDataset
from
megatron.data.realm_dataset
import
RealmDataset
dataset
=
None
dataset
=
None
if
splits
[
index
+
1
]
>
splits
[
index
]:
if
splits
[
index
+
1
]
>
splits
[
index
]:
...
@@ -502,4 +503,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -502,4 +503,4 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
valid_dataset
=
build_dataset
(
1
,
'valid'
)
valid_dataset
=
build_dataset
(
1
,
'valid'
)
test_dataset
=
build_dataset
(
2
,
'test'
)
test_dataset
=
build_dataset
(
2
,
'test'
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
\ No newline at end of file
megatron/data/realm_dataset.py
View file @
56e81e99
...
@@ -12,7 +12,7 @@ from megatron import get_tokenizer, print_rank_0, mpu
...
@@ -12,7 +12,7 @@ from megatron import get_tokenizer, print_rank_0, mpu
from
megatron.data.bert_dataset
import
BertDataset
from
megatron.data.bert_dataset
import
BertDataset
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
qa_nlp
=
spacy
.
load
(
'en_core_web_lg'
)
#
qa_nlp = spacy.load('en_core_web_lg')
class
RealmDataset
(
BertDataset
):
class
RealmDataset
(
BertDataset
):
...
...
megatron/data/realm_index.py
View file @
56e81e99
...
@@ -3,13 +3,16 @@ import os
...
@@ -3,13 +3,16 @@ import os
import
pickle
import
pickle
import
shutil
import
shutil
from
hashed_index
import
detach
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
def
detach
(
tensor
):
return
tensor
.
detach
().
cpu
().
numpy
()
class
BlockData
(
object
):
class
BlockData
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
embed_data
=
dict
()
self
.
embed_data
=
dict
()
...
@@ -43,7 +46,7 @@ class BlockData(object):
...
@@ -43,7 +46,7 @@ class BlockData(object):
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
self
.
embed_data
[
idx
]
=
embed
self
.
embed_data
[
idx
]
=
np
.
float16
(
embed
)
self
.
meta_data
[
idx
]
=
meta
self
.
meta_data
[
idx
]
=
meta
def
save_shard
(
self
,
rank
):
def
save_shard
(
self
,
rank
):
...
@@ -213,7 +216,7 @@ class RandProjectionLSHIndex(object):
...
@@ -213,7 +216,7 @@ class RandProjectionLSHIndex(object):
def
hash_embeds
(
self
,
embeds
,
write_block_data
=
None
):
def
hash_embeds
(
self
,
embeds
,
write_block_data
=
None
):
"""Hash a tensor of embeddings using a random projection matrix"""
"""Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos
=
torch
.
matmul
(
embeds
,
torch
.
cuda
.
Float
Tensor
(
self
.
hash_matrix
))
embed_scores_pos
=
torch
.
matmul
(
embeds
,
torch
.
cuda
.
Half
Tensor
(
self
.
hash_matrix
))
embed_scores
=
torch
.
cat
((
embed_scores_pos
,
-
embed_scores_pos
),
axis
=
1
)
embed_scores
=
torch
.
cat
((
embed_scores_pos
,
-
embed_scores_pos
),
axis
=
1
)
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
...
@@ -226,7 +229,7 @@ class RandProjectionLSHIndex(object):
...
@@ -226,7 +229,7 @@ class RandProjectionLSHIndex(object):
def
hash_whitened_block_embeds
(
self
,
block_data
):
def
hash_whitened_block_embeds
(
self
,
block_data
):
"""Transform all block embeds to have zero mean and unit covariance
"""Transform all block embeds to have zero mean and unit covariance
when treated as samples from a distribution"""
when treated as samples from a distribution"""
block_idx
,
all_embeds
=
zip
(
block_data
.
embed_data
.
items
())
block_idx
,
all_embeds
=
zip
(
*
block_data
.
embed_data
.
items
())
arr_embeds
=
np
.
transpose
(
np
.
array
(
all_embeds
))
arr_embeds
=
np
.
transpose
(
np
.
array
(
all_embeds
))
mean
=
np
.
mean
(
arr_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
mean
=
np
.
mean
(
arr_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment