Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
aae93362
Commit
aae93362
authored
Apr 21, 2020
by
Neel Kant
Browse files
Create HashedIndex class
parent
9b599754
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
106 additions
and
47 deletions
+106
-47
hashed_index.py
hashed_index.py
+106
-47
No files found.
hashed_index.py
View file @
aae93362
...
@@ -22,7 +22,99 @@ def detach(tensor):
...
@@ -22,7 +22,99 @@ def detach(tensor):
return
tensor
.
detach
().
cpu
().
numpy
()
return
tensor
.
detach
().
cpu
().
numpy
()
def
embed_docs
():
class
HashedIndex
(
object
):
"""Class for holding hashed data"""
def
__init__
(
self
,
embed_size
,
num_buckets
,
seed
=
0
):
np
.
random
.
seed
(
seed
)
self
.
block_data
=
defaultdict
(
list
)
self
.
hash_data
=
defaultdict
(
list
)
self
.
hash_matrix
=
np
.
random
.
rand
(
embed_size
,
num_buckets
/
2
)
def
state
(
self
):
state
=
{
'block_data'
:
self
.
block_data
,
'hash_data'
:
self
.
hash_data
,
'hash_matrix'
:
self
.
hash_matrix
}
return
state
def
get_block_bucket
(
self
,
hash
):
return
self
.
hash_data
[
hash
]
def
get_block_embed
(
self
,
block_idx
):
return
self
.
block_data
[
block_idx
]
def
hash_embeds
(
self
,
embeds
,
block_data
=
None
):
"""Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos
=
torch
.
matmul
(
embeds
,
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
))
embed_scores
=
torch
.
cat
((
embed_scores_pos
,
-
embed_scores_pos
),
axis
=
1
)
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
if
block_data
is
not
None
:
for
hash
,
indices
in
zip
(
embed_hashes
,
block_data
):
self
.
hash_data
[
hash
].
append
(
indices
)
return
embed_hashes
def
assign_block_embeds
(
self
,
block_indices
,
block_embeds
,
allow_overwrite
=
False
):
"""Assign the embeddings for each block index into a hash map"""
for
idx
,
embed
in
zip
(
block_indices
,
block_embeds
):
if
not
allow_overwrite
and
int
(
idx
)
in
self
.
block_data
:
raise
ValueError
(
"Attempted to overwrite a read-only HashedIndex"
)
self
.
block_data
[
int
(
idx
)]
=
embed
def
save_shard
(
self
,
rank
):
dir_name
=
'block_hash_data'
if
not
os
.
path
.
isdir
(
dir_name
):
os
.
mkdir
(
dir_name
)
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
dir_name
,
rank
),
'wb'
)
as
data_file
:
pickle
.
dump
(
self
.
state
(),
data_file
)
def
consolidate_shards_and_save
(
self
):
"""Combine all the shards made using self.save_shard()"""
dir_name
=
'block_hash_data'
fnames
=
os
.
listdir
(
dir_name
)
for
fname
in
fnames
:
with
open
(
'{}/{}'
.
format
(
dir_name
,
fname
),
'rb'
)
as
f
:
data
=
pickle
.
load
(
f
)
assert
data
[
'hash_matrix'
]
==
self
.
hash_matrix
old_size
=
len
(
self
.
block_data
)
shard_size
=
len
(
data
[
'block_data'
])
self
.
block_data
.
update
(
data
[
'block_data'
])
assert
len
(
self
.
block_data
)
==
old_size
+
shard_size
for
bucket
,
items
in
data
[
'hash_data'
].
items
():
self
.
hash_data
[
bucket
].
extend
(
items
)
with
open
(
'block_hash_data.pkl'
,
'wb'
)
as
final_file
:
pickle
.
dump
(
self
.
state
(),
final_file
)
shutil
.
rmtree
(
dir_name
,
ignore_errors
=
True
)
def
clear
(
self
):
"""Clear the data structures to save memory"""
self
.
block_data
=
defaultdict
(
list
)
self
.
hash_data
=
defaultdict
(
list
)
def
main
():
# TODO
# consider broadcasting/all-reducing all in memory rather than using the filesystem
# create a different process group in the same nccl world - don't have to use chkpts on disc or transfer things on disc
# torch distributed new group, constains a list of rank, gives back a group which I can hand to the collective operations
# create a training process group, indexing process group
# pass the training group to the distributed DDP, instead of the large world process group
# use indexing process group for the shard-combining
# communication group between process "8" and process "0" which tells training group that there's a new index
# also, process 0 sends process 8 the new model
# if i want to launch a separate process for indexing, may have to work with environment variables to
# allocate the resources well. Have to subsequently assign the correct gpus to the indexing job
# consider initializing everything in a single group and break off processes based on the ranks
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
args
=
get_args
()
...
@@ -30,68 +122,35 @@ def embed_docs():
...
@@ -30,68 +122,35 @@ def embed_docs():
model
.
eval
()
model
.
eval
()
dataset
=
get_dataset
()
dataset
=
get_dataset
()
data_iter
=
iter
(
get_dataloader
(
dataset
))
data_iter
=
iter
(
get_dataloader
(
dataset
))
hashed_index
=
HashedIndex
(
embed_size
=
128
,
num_buckets
=
2048
)
hash_data
=
defaultdict
(
list
)
hash_matrix
=
torch
.
cuda
.
HalfTensor
(
np
.
random
.
rand
(
128
,
1024
))
hash_data
[
'matrix'
]
=
hash_matrix
block_data
=
defaultdict
(
list
)
i
=
0
i
=
0
while
True
:
while
True
:
try
:
try
:
input
_tokens
,
input_types
,
input
_pad_mask
,
\
query
_tokens
,
query
_pad_mask
,
\
block_tokens
,
block_token_types
,
block_pad_mask
,
block_indices
=
get_batch
(
data_iter
)
block_tokens
,
block_pad_mask
,
block_indices
=
get_batch
(
data_iter
)
except
:
except
:
break
break
input_logits
,
block_logits
=
model
.
module
.
module
.
forward
(
actual_model
=
model
.
module
.
module
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_pad_mask
,
block_token_typ
es
)
block_indices
=
detach
(
block_indic
es
)
block_hash_pos
=
torch
.
matmul
(
block_logits
,
hash_matrix
)
block_logits
=
actual_model
.
embed_block
(
block_tokens
,
block_pad_mask
)
block_hash_full
=
torch
.
cat
((
block_hash_pos
,
-
block_hash_pos
),
axis
=
1
)
hashed_index
.
hash_embeds
(
block_logits
,
block_indices
)
block_hashes
=
detach
(
torch
.
argmax
(
block_hash_full
,
axis
=
1
))
hashed_index
.
assign_block_embeds
(
block_indices
,
detach
(
block_logits
))
for
hash
,
indices_array
in
zip
(
block_hashes
,
block_indices
):
hash_data
[
int
(
hash
)].
append
(
detach
(
indices_array
))
block_logits
=
detach
(
block_logits
)
# originally this has [start_idx, end_idx, doc_idx, block_idx]
block_indices
=
detach
(
block_indices
)[:,
3
]
for
logits
,
idx
in
zip
(
block_logits
,
block_indices
):
block_data
[
int
(
idx
)]
=
logits
if
i
%
100
==
0
:
if
i
%
100
==
0
:
print
(
i
,
flush
=
True
)
print
(
i
,
flush
=
True
)
i
+=
1
i
+=
1
dir_name
=
'block_hash_data'
hashed_index
.
save_shard
(
args
.
rank
)
if
not
os
.
path
.
isdir
(
dir_name
):
os
.
mkdir
(
dir_name
)
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
dir_name
,
args
.
rank
),
'wb'
)
as
data_file
:
all_data
=
{
'block_data'
:
block_data
,
'hash_data'
:
hash_data
}
pickle
.
dump
(
all_data
,
data_file
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
all_data
.
clear
()
del
all_data
del
model
del
model
# rank 0 process consolidates shards and saves into final file
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
all_block_data
=
defaultdict
(
dict
)
hashed_index
.
consolidate_shards_and_save
()
dir_name
=
'block_hash_data'
else
:
fnames
=
os
.
listdir
(
dir_name
)
hashed_index
.
clear
()
for
fname
in
fnames
:
with
open
(
'{}/{}'
.
format
(
dir_name
,
fname
),
'rb'
)
as
f
:
data
=
pickle
.
load
(
f
)
all_block_data
[
'hash_data'
].
update
(
data
[
'hash_data'
])
all_block_data
[
'block_data'
].
update
(
data
[
'block_data'
])
with
open
(
'block_hash_data.pkl'
,
'wb'
)
as
final_file
:
pickle
.
dump
(
all_block_data
,
final_file
)
shutil
.
rmtree
(
dir_name
,
ignore_errors
=
True
)
def
load_checkpoint
():
def
load_checkpoint
():
...
@@ -162,4 +221,4 @@ def get_dataloader(dataset):
...
@@ -162,4 +221,4 @@ def get_dataloader(dataset):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
embed_docs
()
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment