Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9d225b44
Commit
9d225b44
authored
Apr 29, 2020
by
Neel Kant
Browse files
Whitening code
parent
5e56e563
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
90 additions
and
49 deletions
+90
-49
hashed_index.py
hashed_index.py
+86
-17
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+0
-3
pretrain_bert_ict.py
pretrain_bert_ict.py
+2
-1
test_retriever.sh
test_retriever.sh
+0
-28
No files found.
hashed_index.py
View file @
9d225b44
...
@@ -25,18 +25,23 @@ def detach(tensor):
...
@@ -25,18 +25,23 @@ def detach(tensor):
class
HashedIndex
(
object
):
class
HashedIndex
(
object
):
"""Class for holding hashed data"""
"""Class for holding hashed data"""
def
__init__
(
self
,
embed_size
,
num_buckets
,
seed
=
0
):
def
__init__
(
self
,
embed_size
,
num_buckets
,
whiten
=
False
,
seed
=
0
):
np
.
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
self
.
block_data
=
defaultdict
(
list
)
self
.
block_data
=
defaultdict
(
list
)
self
.
hash_data
=
defaultdict
(
list
)
self
.
hash_data
=
defaultdict
(
list
)
hash_matrix
=
np
.
random
.
rand
(
embed_size
,
int
(
num_buckets
/
2
))
hash_matrix
=
np
.
random
.
rand
(
embed_size
,
int
(
num_buckets
/
2
))
self
.
hash_matrix
=
hash_matrix
/
np
.
linalg
.
norm
(
hash_matrix
,
axis
=
0
).
reshape
(
1
,
-
1
)
self
.
hash_matrix
=
hash_matrix
/
np
.
linalg
.
norm
(
hash_matrix
,
axis
=
0
).
reshape
(
1
,
-
1
)
self
.
embed_mean
=
None
self
.
embed_whitener
=
None
self
.
whiten
=
whiten
def
state
(
self
):
def
state
(
self
):
state
=
{
state
=
{
'block_data'
:
self
.
block_data
,
'block_data'
:
self
.
block_data
,
'hash_data'
:
self
.
hash_data
,
'hash_data'
:
self
.
hash_data
,
'hash_matrix'
:
self
.
hash_matrix
'hash_matrix'
:
self
.
hash_matrix
,
'embed_mean'
:
self
.
embed_mean
,
'embed_whitener'
:
self
.
embed_whitener
,
}
}
return
state
return
state
...
@@ -79,8 +84,6 @@ class HashedIndex(object):
...
@@ -79,8 +84,6 @@ class HashedIndex(object):
dir_name
=
'block_hash_data'
dir_name
=
'block_hash_data'
fnames
=
os
.
listdir
(
dir_name
)
fnames
=
os
.
listdir
(
dir_name
)
for
fname
in
fnames
:
for
fname
in
fnames
:
if
str
(
ignore_shard
)
in
fname
:
continue
with
open
(
'{}/{}'
.
format
(
dir_name
,
fname
),
'rb'
)
as
f
:
with
open
(
'{}/{}'
.
format
(
dir_name
,
fname
),
'rb'
)
as
f
:
data
=
pickle
.
load
(
f
)
data
=
pickle
.
load
(
f
)
assert
np
.
array_equal
(
data
[
'hash_matrix'
],
self
.
hash_matrix
)
assert
np
.
array_equal
(
data
[
'hash_matrix'
],
self
.
hash_matrix
)
...
@@ -88,10 +91,14 @@ class HashedIndex(object):
...
@@ -88,10 +91,14 @@ class HashedIndex(object):
old_size
=
len
(
self
.
block_data
)
old_size
=
len
(
self
.
block_data
)
shard_size
=
len
(
data
[
'block_data'
])
shard_size
=
len
(
data
[
'block_data'
])
self
.
block_data
.
update
(
data
[
'block_data'
])
self
.
block_data
.
update
(
data
[
'block_data'
])
assert
len
(
self
.
block_data
)
==
old_size
+
shard_size
,
(
old_size
,
shard_size
,
len
(
self
.
block_data
)
)
assert
(
len
(
self
.
block_data
)
==
old_size
+
shard_size
)
or
(
str
(
ignore_shard
)
in
fname
)
for
bucket
,
items
in
data
[
'hash_data'
].
items
():
if
not
self
.
whiten
:
self
.
hash_data
[
bucket
].
extend
(
items
)
for
bucket
,
items
in
data
[
'hash_data'
].
items
():
self
.
hash_data
[
bucket
].
extend
(
items
)
if
self
.
whiten
:
self
.
whiten_block_embeds
()
args
=
get_args
()
args
=
get_args
()
with
open
(
args
.
hash_data_path
,
'wb'
)
as
final_file
:
with
open
(
args
.
hash_data_path
,
'wb'
)
as
final_file
:
...
@@ -100,8 +107,43 @@ class HashedIndex(object):
...
@@ -100,8 +107,43 @@ class HashedIndex(object):
def
clear
(
self
):
def
clear
(
self
):
"""Clear the data structures to save memory"""
"""Clear the data structures to save memory"""
self
.
block_data
=
defaultdict
(
list
)
self
.
block_data
=
dict
()
self
.
hash_data
=
defaultdict
(
list
)
def
whiten_block_embeds
(
self
):
"""Transform all block embeds to have zero mean and unit covariance
when treated as samples from a distribution"""
block_idx
,
all_embeds
=
zip
(
*
self
.
block_data
.
items
())
arr_embeds
=
np
.
transpose
(
np
.
array
(
all_embeds
))
mean
=
np
.
mean
(
arr_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
centered
=
arr_embeds
-
mean
inv_cov
=
np
.
linalg
.
inv
(
np
.
cov
(
arr_embeds
))
whitener
=
np
.
transpose
(
np
.
linalg
.
cholesky
(
inv_cov
))
whitened
=
np
.
transpose
(
whitener
.
dot
(
centered
))
self
.
embed_mean
=
mean
.
reshape
(
-
1
)
self
.
embed_whitener
=
whitener
self
.
block_data
=
dict
(
zip
(
block_idx
,
list
(
whitened
)))
self
.
hash_data
=
defaultdict
(
list
)
self
.
hash_data
=
defaultdict
(
list
)
batch_size
=
16384
i
=
0
with
torch
.
no_grad
():
hashing_tensor
=
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
)
while
True
:
batch_slice
=
slice
(
i
*
batch_size
,
(
i
+
1
)
*
batch_size
)
batch_embed
=
torch
.
cuda
.
HalfTensor
(
whitened
[
batch_slice
])
batch_block_idx
=
block_idx
[
batch_slice
]
if
batch_embed
.
size
==
0
:
break
hash_scores_pos
=
torch
.
matmul
(
batch_embed
,
hashing_tensor
)
embed_scores
=
torch
.
cat
((
hash_scores_pos
,
-
hash_scores_pos
),
axis
=
1
)
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
for
hash
,
embed
in
zip
(
list
(
embed_hashes
),
list
(
detach
(
batch_embed
))):
# [int] instead of [array<int>] since this is just for analysis rn
self
.
hash_data
[
hash
].
append
(
batch_block_idx
)
@
classmethod
@
classmethod
def
load_from_file
(
cls
,
fname
):
def
load_from_file
(
cls
,
fname
):
...
@@ -114,8 +156,26 @@ class HashedIndex(object):
...
@@ -114,8 +156,26 @@ class HashedIndex(object):
new_index
.
block_data
=
state_dict
[
'block_data'
]
new_index
.
block_data
=
state_dict
[
'block_data'
]
new_index
.
hash_data
=
state_dict
[
'hash_data'
]
new_index
.
hash_data
=
state_dict
[
'hash_data'
]
new_index
.
hash_matrix
=
hash_matrix
new_index
.
hash_matrix
=
hash_matrix
return
new_index
return
new_index
@
classmethod
def
whiten_and_rehash
(
cls
,
fname
):
"""Load up a HashedIndex, whiten it and rehash"""
index
=
cls
.
load_from_file
(
fname
)
all_vectors
=
[]
for
block_embed
in
index
.
block_data
.
values
():
all_vectors
.
append
(
block_embed
)
arr_vectors
=
np
.
transpose
(
np
.
array
(
all_vectors
))
mean
=
np
.
mean
(
arr_vectors
,
axis
=
1
)
cov
=
np
.
cov
(
arr_vectors
)
inv_cov
=
np
.
linalg
.
inv
(
cov
)
def
test_retriever
():
def
test_retriever
():
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
...
@@ -163,10 +223,12 @@ def main():
...
@@ -163,10 +223,12 @@ def main():
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
)
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
)
model
.
eval
()
model
.
eval
()
dataset
=
get_ict_dataset
()
dataset
=
get_ict_dataset
()
data_iter
=
iter
(
get_dataloader
(
dataset
))
data_iter
=
iter
(
get_
one_epoch_
dataloader
(
dataset
))
hashed_index
=
HashedIndex
(
embed_size
=
128
,
num_buckets
=
4096
)
hashed_index
=
HashedIndex
(
embed_size
=
128
,
num_buckets
=
4096
,
whiten
=
True
)
i
=
0
i
=
1
total
=
0
whiten
=
False
while
True
:
while
True
:
try
:
try
:
query_tokens
,
query_pad_mask
,
\
query_tokens
,
query_pad_mask
,
\
...
@@ -176,18 +238,25 @@ def main():
...
@@ -176,18 +238,25 @@ def main():
block_indices
=
detach
(
block_indices
)
block_indices
=
detach
(
block_indices
)
block_logits
=
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
)
block_logits
=
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
)
hashed_index
.
hash_embeds
(
block_logits
,
block_indices
)
hashed_index
.
assign_block_embeds
(
block_indices
[:,
3
],
detach
(
block_logits
))
if
i
%
100
==
0
:
# If whiten, then hashing needs to be done after whitening the block embeds
print
(
i
,
flush
=
True
)
# which is done in consolidate_shards_and_save()
if
not
whiten
:
hashed_index
.
hash_embeds
(
block_logits
,
block_indices
)
hashed_index
.
assign_block_embeds
(
block_indices
[:,
3
],
detach
(
block_logits
))
total
+=
block_indices
.
size
i
+=
1
i
+=
1
if
i
%
20
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
args
.
debug
:
break
hashed_index
.
save_shard
(
args
.
rank
)
hashed_index
.
save_shard
(
args
.
rank
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
del
model
del
model
if
mpu
.
get_data_parallel_
rank
()
==
0
:
if
args
.
rank
==
0
:
hashed_index
.
consolidate_shards_and_save
()
hashed_index
.
consolidate_shards_and_save
()
else
:
else
:
hashed_index
.
clear
()
hashed_index
.
clear
()
...
@@ -247,7 +316,7 @@ def get_ict_dataset():
...
@@ -247,7 +316,7 @@ def get_ict_dataset():
return
dataset
return
dataset
def
get_dataloader
(
dataset
):
def
get_
one_epoch_
dataloader
(
dataset
):
args
=
get_args
()
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
world_size
=
mpu
.
get_data_parallel_world_size
()
...
...
megatron/arguments.py
View file @
9d225b44
...
@@ -184,6 +184,8 @@ def _add_training_args(parser):
...
@@ -184,6 +184,8 @@ def _add_training_args(parser):
def
_add_initialization_args
(
parser
):
def
_add_initialization_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'initialization'
)
group
=
parser
.
add_argument_group
(
title
=
'initialization'
)
group
.
add_argument
(
'--debug'
,
action
=
'store_true'
,
help
=
'Run things in debug mode'
)
group
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1234
,
group
.
add_argument
(
'--seed'
,
type
=
int
,
default
=
1234
,
help
=
'Random seed used for python, numpy, '
help
=
'Random seed used for python, numpy, '
'pytorch, and cuda.'
)
'pytorch, and cuda.'
)
...
...
megatron/data/realm_dataset.py
View file @
9d225b44
...
@@ -46,9 +46,6 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
...
@@ -46,9 +46,6 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
=
pad_and_convert_to_numpy
(
tokens
,
tokentypes
,
masked_positions
,
=
pad_and_convert_to_numpy
(
tokens
,
tokentypes
,
masked_positions
,
masked_labels
,
pad_id
,
max_seq_length
)
masked_labels
,
pad_id
,
max_seq_length
)
# REALM true sequence length is twice as long but none of that is to be predicted with LM
# loss_mask_np = np.concatenate((loss_mask_np, np.ones(loss_mask_np.shape)), -1).astype(np.int64)
train_sample
=
{
train_sample
=
{
'tokens'
:
tokens_np
,
'tokens'
:
tokens_np
,
'labels'
:
labels_np
,
'labels'
:
labels_np
,
...
...
pretrain_bert_ict.py
View file @
9d225b44
...
@@ -29,6 +29,7 @@ from megatron.utils import reduce_losses
...
@@ -29,6 +29,7 @@ from megatron.utils import reduce_losses
num_batches
=
0
num_batches
=
0
def
model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
def
model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
...
@@ -103,7 +104,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -103,7 +104,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets."""
"""Build train, valid and test datasets."""
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'> building train, validation, and test datasets '
print_rank_0
(
'> building train, validation, and test datasets '
'for BERT ...'
)
'for BERT
ICT
...'
)
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_prefix
=
args
.
data_path
,
...
...
test_retriever.sh
deleted
100755 → 0
View file @
5e56e563
COMMAND
=
"/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch python hashed_index.py
\
--num-layers 12
\
--hidden-size 768
\
--num-attention-heads 12
\
--batch-size 8
\
--checkpoint-activations
\
--seq-length 288
\
--max-position-embeddings 288
\
--train-iters 100000
\
--load /home/dcg-adlr-nkant-output.cosmos1203/chkpts/realm_debug
\
--ict-load /home/dcg-adlr-nkant-output.cosmos1203/chkpts/ict_best
\
--save /home/dcg-adlr-nkant-output.cosmos1203/chkpts/realm_debug
\
--data-path /home/universal-lm-data.cosmos549/datasets/wiki-indexed/wikipedia_lines
\
--titles-data-path /home/universal-lm-data.cosmos549/datasets/wiki-indexed/wikipedia_lines-titles
\
--hash-data-path /home/dcg-adlr-nkant-data.cosmos1202/hash_data/ict_best.pkl
\
--vocab-file /home/universal-lm-data.cosmos549/scratch/mshoeybi/data/albert/vocab.txt
\
--split 58,1,1
\
--distributed-backend nccl
\
--lr 0.0001
\
--num-workers 2
\
--lr-decay-style linear
\
--warmup .01
\
--save-interval 3000
\
--fp16
\
--adlr-autoresume
\
--adlr-autoresume-interval 100"
submit_job
--image
'http://gitlab-master.nvidia.com/adlr/megatron-lm/megatron:20.03'
--mounts
/home/universal-lm-data.cosmos549,/home/dcg-adlr-nkant-source.cosmos1204,/home/dcg-adlr-nkant-data.cosmos1202,/home/dcg-adlr-nkant-output.cosmos1203,/home/nkant
--name
test_retriever
--partition
interactive
--gpu
1
--nodes
1
--autoresume_timer
300
-c
"
${
COMMAND
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment