Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
dfaf674d
Commit
dfaf674d
authored
Apr 20, 2020
by
Neel Kant
Browse files
Merge branch 'hashing' of
https://gitlab-master.nvidia.com/ADLR/megatron-lm
into hashing
parents
43d5d84b
017a943f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
39 deletions
+29
-39
hashed_index.py
hashed_index.py
+29
-39
No files found.
hashed_index.py
View file @
dfaf674d
from
collections
import
defaultdict
from
collections
import
defaultdict
import
os
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
...
@@ -16,7 +17,7 @@ from megatron.training import get_model
...
@@ -16,7 +17,7 @@ from megatron.training import get_model
from
pretrain_bert_ict
import
get_batch
,
model_provider
from
pretrain_bert_ict
import
get_batch
,
model_provider
def
main
():
def
embed_docs
():
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
args
=
get_args
()
...
@@ -27,15 +28,9 @@ def main():
...
@@ -27,15 +28,9 @@ def main():
hash_data
=
defaultdict
(
list
)
hash_data
=
defaultdict
(
list
)
hash_matrix
=
torch
.
cuda
.
HalfTensor
(
np
.
random
.
rand
(
128
,
1024
))
hash_matrix
=
torch
.
cuda
.
HalfTensor
(
np
.
random
.
rand
(
128
,
1024
))
hash_data
[
'matrix'
]
=
hash_matrix
#all_input_tokens = []
#all_input_logits = []
#all_block_tokens = []
block_data
=
defaultdict
(
list
)
block_data
=
defaultdict
(
list
)
all_block_logits
=
[]
all_block_indices
=
[]
my_rank
=
args
.
rank
block_file
=
open
(
f
'block_data
{
my_rank
}
.pkl'
,
'wb'
)
i
=
0
i
=
0
while
True
:
while
True
:
try
:
try
:
...
@@ -52,47 +47,42 @@ def main():
...
@@ -52,47 +47,42 @@ def main():
block_hash_full
=
torch
.
cat
((
block_hash_pos
,
-
block_hash_pos
),
axis
=
1
)
block_hash_full
=
torch
.
cat
((
block_hash_pos
,
-
block_hash_pos
),
axis
=
1
)
block_hashes
=
torch
.
argmax
(
block_hash_full
,
axis
=
1
).
detach
().
cpu
().
numpy
()
block_hashes
=
torch
.
argmax
(
block_hash_full
,
axis
=
1
).
detach
().
cpu
().
numpy
()
for
hash
,
indices_array
in
zip
(
block_hashes
,
block_indices
):
for
hash
,
indices_array
in
zip
(
block_hashes
,
block_indices
):
hash_data
[
int
(
hash
)].
append
(
indices_array
)
hash_data
[
int
(
hash
)].
append
(
indices_array
.
detach
().
cpu
().
numpy
())
#all_input_tokens.append(input_tokens.detach().cpu().numpy())
#all_input_logits.append(input_logits.detach().cpu().numpy())
#all_block_tokens.append(block_tokens.detach().cpu().numpy())
#all_block_logits.append(block_logits.detach().cpu().numpy())
#all_block_indices.append(block_indices.detach().cpu().numpy()[:, 3])
block_logits
=
block_logits
.
detach
().
cpu
().
numpy
()
block_logits
=
block_logits
.
detach
().
cpu
().
numpy
()
block_indices
=
block_indices
.
detach
().
cpu
().
numpy
()[:,
3
]
block_indices
=
block_indices
.
detach
().
cpu
().
numpy
()[:,
3
]
for
logits
,
idx
in
zip
(
block_logits
,
block_indices
):
for
logits
,
idx
in
zip
(
block_logits
,
block_indices
):
pickle
.
dump
({
idx
:
logits
},
block_file
)
block_data
[
int
(
idx
)]
=
logits
if
i
==
100
:
print
(
i
)
if
i
%
100
==
0
:
print
(
i
,
flush
=
True
)
i
+=
1
i
+=
1
block_file
.
close
()
dir_name
=
'block_hash_data'
#all_input_tokens = np.array(all_input_tokens).reshape(-1, args.seq_length)
if
not
os
.
path
.
isdir
(
dir_name
):
#all_input_logits = np.array(all_input_logits).reshape(-1, 128)
os
.
mkdir
(
dir_name
)
#all_block_tokens = np.array(all_block_tokens).reshape(-1, args.seq_length)
#all_block_logits = np.array(all_block_logits).reshape(-1, 128)
#all_block_indices = np.array(all_block_indices).reshape(all_block_logits.shape[0])
#for logits, idx in zip(all_block_logits, all_block_indices):
# block_data[idx] = logits
#with as block_file:
with
open
(
'{}/{}.pkl'
.
format
(
dir_name
,
args
.
rank
),
'wb'
)
as
data_file
:
# pickle.dump(block_data, block_file)
all_data
=
{
'block_data'
:
block_data
,
'hash_data'
:
hash_data
}
pickle
.
dump
(
all_data
,
data_file
)
#np.save(f'input_tokens{my_rank}.npy', all_input_tokens)
torch
.
distributed
.
barrier
()
#np.save(f'input_logits{my_rank}.npy', all_input_logits)
#np.save(f'block_tokens{my_rank}.npy', all_block_tokens)
#np.save(f'block_logits{my_rank}.npy', all_block_logits)
for
hash
,
block_indices
in
hash_data
.
items
():
if
mpu
.
get_data_parallel_rank
()
==
0
:
hash_data
[
hash
]
=
np
.
array
(
block_indices
)
all_block_data
=
defaultdict
(
dict
)
dir_name
=
'block_hash_data'
fnames
=
os
.
listdir
(
dir_name
)
for
fname
in
fnames
:
with
open
(
fname
,
'rb'
)
as
f
:
data
=
pickle
.
load
(
f
)
all_block_data
[
'hash_data'
].
update
(
data
[
'hash_data'
])
all_block_data
[
'block_data'
].
update
(
data
[
'block_data'
])
hash_data
[
'matrix'
]
=
hash_matrix
with
open
(
'block_hash_data.pkl'
,
'wb'
)
as
final_file
:
with
open
(
f
'hash_data
{
my_rank
}
.pkl'
,
'wb'
)
as
hash_file
:
pickle
.
dump
(
all_block_data
,
final_file
)
pickle
.
dump
(
hash_data
,
hash_file
)
os
.
rmdir
(
dir_name
)
return
def
load_checkpoint
():
def
load_checkpoint
():
...
@@ -163,4 +153,4 @@ def get_dataloader(dataset):
...
@@ -163,4 +153,4 @@ def get_dataloader(dataset):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
embed_docs
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment