Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ac79d374
Commit
ac79d374
authored
Apr 23, 2020
by
Neel Kant
Browse files
Debug test_retriever
parent
3fb02b8e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
10 deletions
+21
-10
hashed_index.py
hashed_index.py
+13
-2
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+3
-3
megatron/model/bert_model.py
megatron/model/bert_model.py
+5
-5
No files found.
hashed_index.py
View file @
ac79d374
...
...
@@ -103,7 +103,9 @@ class HashedIndex(object):
@
classmethod
def
load_from_file
(
cls
,
fname
):
print
(
" > Unpickling block hash data"
)
state_dict
=
pickle
.
load
(
open
(
fname
,
'rb'
))
print
(
" > Finished unpickling"
)
hash_matrix
=
state_dict
[
'hash_matrix'
]
new_index
=
HashedIndex
(
hash_matrix
.
shape
[
0
],
hash_matrix
.
shape
[
1
]
*
2
)
...
...
@@ -121,7 +123,16 @@ def test_retriever():
dataset
=
get_dataset
()
hashed_index
=
HashedIndex
.
load_from_file
(
'block_hash_data.pkl'
)
retriever
=
REALMRetriever
(
model
,
dataset
,
hashed_index
)
retriever
.
retrieve_evidence_blocks_text
(
"The last monarch from the house of windsor"
)
strs
=
[
"The last monarch from the house of windsor"
,
"married to Elvis Presley"
,
"tallest building in the world today"
,
"who makes graphics cards"
]
for
s
in
strs
:
retriever
.
retrieve_evidence_blocks_text
(
s
)
def
main
():
...
...
@@ -246,4 +257,4 @@ def get_dataloader(dataset):
if
__name__
==
"__main__"
:
main
()
test_retriever
()
megatron/data/ict_dataset.py
View file @
ac79d374
...
...
@@ -84,10 +84,10 @@ class InverseClozeDataset(Dataset):
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block
=
[
self
.
context
_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
list
(
self
.
title
s
_dataset
[
int
(
doc_idx
)])
block
=
[
list
(
self
.
block
_dataset
[
i
]
)
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
block
=
list
(
itertools
.
chain
(
*
block
))[
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block
=
list
(
itertools
.
chain
(
*
block
))[
:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
block_tokens
,
block_pad_mask
...
...
megatron/model/bert_model.py
View file @
ac79d374
...
...
@@ -293,20 +293,20 @@ class REALMRetriever(MegatronModule):
query_tokens
=
self
.
ict_dataset
.
encode_text
(
query_text
)[:
padless_max_len
]
query_tokens
,
query_pad_mask
=
self
.
ict_dataset
.
concat_and_pad_tokens
(
query_tokens
)
query_tokens
=
torch
.
cuda
.
Int
Tensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
Int
Tensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
query_tokens
=
torch
.
cuda
.
Long
Tensor
(
np
.
array
(
query_tokens
).
reshape
(
1
,
-
1
))
query_pad_mask
=
torch
.
cuda
.
Long
Tensor
(
np
.
array
(
query_pad_mask
).
reshape
(
1
,
-
1
))
query_embed
=
self
.
ict_model
.
embed_query
(
query_tokens
,
query_pad_mask
)
query_embed
=
self
.
ict_model
.
module
.
module
.
embed_query
(
query_tokens
,
query_pad_mask
)
query_hash
=
self
.
hashed_index
.
hash_embeds
(
query_embed
)
assert
query_hash
.
size
==
1
block_bucket
=
self
.
hashed_index
.
get_block_bucket
(
query_hash
[
0
])
block_embeds
=
[
self
.
hashed_index
.
get_block_embed
[
idx
]
for
idx
in
block_bucket
[:,
3
]
]
block_embeds
=
[
self
.
hashed_index
.
get_block_embed
(
arr
[
3
])
for
arr
in
block_bucket
]
block_embed_tensor
=
torch
.
cuda
.
HalfTensor
(
np
.
array
(
block_embeds
))
retrieval_scores
=
query_embed
.
matmul
(
torch
.
transpose
(
block_embed_tensor
,
0
,
1
))
top5_vals
,
top5_indices
=
torch
.
topk
(
retrieval_scores
,
k
=
5
,
sorted
=
True
)
top5_start_end_doc
=
[
block_bucket
[
idx
][:
3
]
for
idx
in
top5_indices
]
top5_start_end_doc
=
[
block_bucket
[
idx
][:
3
]
for
idx
in
top5_indices
.
squeeze
()
]
top5_blocks
=
[(
self
.
ict_dataset
.
get_block
(
*
indices
))
for
indices
in
top5_start_end_doc
]
for
i
,
(
block
,
_
)
in
enumerate
(
top5_blocks
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment