Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
88637044
Commit
88637044
authored
Apr 21, 2020
by
Neel Kant
Browse files
Debug hashed_index.main
parent
9a617f6c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
8 deletions
+10
-8
hashed_index.py
hashed_index.py
+7
-5
megatron/model/bert_model.py
megatron/model/bert_model.py
+3
-3
No files found.
hashed_index.py
View file @
88637044
...
...
@@ -28,7 +28,7 @@ class HashedIndex(object):
np
.
random
.
seed
(
seed
)
self
.
block_data
=
defaultdict
(
list
)
self
.
hash_data
=
defaultdict
(
list
)
self
.
hash_matrix
=
np
.
random
.
rand
(
embed_size
,
num_buckets
/
2
)
self
.
hash_matrix
=
np
.
random
.
rand
(
embed_size
,
int
(
num_buckets
/
2
)
)
def
state
(
self
):
state
=
{
...
...
@@ -72,19 +72,21 @@ class HashedIndex(object):
with
open
(
'{}/{}.pkl'
.
format
(
dir_name
,
rank
),
'wb'
)
as
data_file
:
pickle
.
dump
(
self
.
state
(),
data_file
)
def
consolidate_shards_and_save
(
self
):
def
consolidate_shards_and_save
(
self
,
ignore_shard
=
0
):
"""Combine all the shards made using self.save_shard()"""
dir_name
=
'block_hash_data'
fnames
=
os
.
listdir
(
dir_name
)
for
fname
in
fnames
:
if
str
(
ignore_shard
)
in
fname
:
continue
with
open
(
'{}/{}'
.
format
(
dir_name
,
fname
),
'rb'
)
as
f
:
data
=
pickle
.
load
(
f
)
assert
data
[
'hash_matrix'
]
==
self
.
hash_matrix
assert
np
.
array_equal
(
data
[
'hash_matrix'
]
,
self
.
hash_matrix
)
old_size
=
len
(
self
.
block_data
)
shard_size
=
len
(
data
[
'block_data'
])
self
.
block_data
.
update
(
data
[
'block_data'
])
assert
len
(
self
.
block_data
)
==
old_size
+
shard_size
assert
len
(
self
.
block_data
)
==
old_size
+
shard_size
,
(
old_size
,
shard_size
,
len
(
self
.
block_data
))
for
bucket
,
items
in
data
[
'hash_data'
].
items
():
self
.
hash_data
[
bucket
].
extend
(
items
)
...
...
@@ -137,7 +139,7 @@ def main():
block_logits
=
actual_model
.
embed_block
(
block_tokens
,
block_pad_mask
)
hashed_index
.
hash_embeds
(
block_logits
,
block_indices
)
hashed_index
.
assign_block_embeds
(
block_indices
,
detach
(
block_logits
))
hashed_index
.
assign_block_embeds
(
block_indices
[:,
3
]
,
detach
(
block_logits
))
if
i
%
100
==
0
:
print
(
i
,
flush
=
True
)
...
...
megatron/model/bert_model.py
View file @
88637044
...
...
@@ -329,7 +329,7 @@ class ICTBertModel(MegatronModule):
ict_head_size
=
ict_head_size
,
parallel_output
=
parallel_output
)
assert
not
only_block_model
and
only_query_model
assert
not
(
only_block_model
and
only_query_model
)
self
.
use_block_model
=
not
only_query_model
self
.
use_query_model
=
not
only_block_model
...
...
@@ -355,7 +355,7 @@ class ICTBertModel(MegatronModule):
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
"""Embed a batch of tokens using the query model"""
if
self
.
use_query_model
:
query_types
=
torch
.
zeros
(
query_tokens
.
shape
).
type
(
torch
.
float16
).
cuda
()
query_types
=
torch
.
zeros
(
query_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
query_ict_logits
,
_
=
self
.
query_model
.
forward
(
query_tokens
,
query_attention_mask
,
query_types
)
return
query_ict_logits
else
:
...
...
@@ -364,7 +364,7 @@ class ICTBertModel(MegatronModule):
def
embed_block
(
self
,
block_tokens
,
block_attention_mask
):
"""Embed a batch of tokens using the block model"""
if
self
.
use_block_model
:
block_types
=
torch
.
zeros
(
block_tokens
.
shape
).
type
(
torch
.
float16
).
cuda
()
block_types
=
torch
.
zeros
(
block_tokens
.
shape
).
type
(
torch
.
int64
).
cuda
()
block_ict_logits
,
_
=
self
.
block_model
.
forward
(
block_tokens
,
block_attention_mask
,
block_types
)
return
block_ict_logits
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment