Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9b599754
Commit
9b599754
authored
Apr 20, 2020
by
Neel Kant
Browse files
Debug and run hashing code
parent
256eb6ed
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
7 deletions
+9
-7
hashed_index.py
hashed_index.py
+9
-7
No files found.
hashed_index.py
View file @
9b599754
from
collections
import
defaultdict
import
os
import
pickle
import
shutil
import
numpy
as
np
import
torch
...
...
@@ -43,9 +44,8 @@ def embed_docs():
except
:
break
# TODO: make sure input is still in block
input_logits
,
block_logits
,
_
=
model
.
module
.
module
.
forward
(
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_pad_mask
,
block_token_types
,
return_logits
=
True
)
input_logits
,
block_logits
=
model
.
module
.
module
.
forward
(
input_tokens
,
input_types
,
input_pad_mask
,
block_tokens
,
block_pad_mask
,
block_token_types
)
block_hash_pos
=
torch
.
matmul
(
block_logits
,
hash_matrix
)
block_hash_full
=
torch
.
cat
((
block_hash_pos
,
-
block_hash_pos
),
axis
=
1
)
...
...
@@ -67,6 +67,7 @@ def embed_docs():
if
not
os
.
path
.
isdir
(
dir_name
):
os
.
mkdir
(
dir_name
)
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
dir_name
,
args
.
rank
),
'wb'
)
as
data_file
:
all_data
=
{
'block_data'
:
block_data
,
'hash_data'
:
hash_data
}
pickle
.
dump
(
all_data
,
data_file
)
...
...
@@ -77,19 +78,20 @@ def embed_docs():
del
all_data
del
model
# rank 0 process consolidates shards and saves into final file
if
mpu
.
get_data_parallel_rank
()
==
0
:
all_block_data
=
defaultdict
(
dict
)
dir_name
=
'block_hash_data'
fnames
=
os
.
listdir
(
dir_name
)
for
fname
in
fnames
:
with
open
(
fname
,
'rb'
)
as
f
:
with
open
(
'{}/{}'
.
format
(
dir_name
,
fname
)
,
'rb'
)
as
f
:
data
=
pickle
.
load
(
f
)
all_block_data
[
'hash_data'
].
update
(
data
[
'hash_data'
])
all_block_data
[
'block_data'
].
update
(
data
[
'block_data'
])
with
open
(
'block_hash_data.pkl'
,
'wb'
)
as
final_file
:
pickle
.
dump
(
all_block_data
,
final_file
)
os
.
rmdir
(
dir_nam
e
)
shutil
.
rmtree
(
dir_name
,
ignore_errors
=
Tru
e
)
def
load_checkpoint
():
...
...
@@ -125,8 +127,8 @@ def get_dataset():
kwargs
=
dict
(
name
=
'full'
,
context
_dataset
=
block_dataset
,
title
s
_dataset
=
titles_dataset
,
block
_dataset
=
block_dataset
,
title_dataset
=
titles_dataset
,
data_prefix
=
args
.
data_path
,
num_epochs
=
1
,
max_num_samples
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment