Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
59031aa7
Commit
59031aa7
authored
May 03, 2020
by
Neel Kant
Browse files
more for pretrain_realm
parent
002cb170
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
6 deletions
+8
-6
megatron/model/bert_model.py
megatron/model/bert_model.py
+4
-3
pretrain_realm.py
pretrain_realm.py
+4
-3
No files found.
megatron/model/bert_model.py
View file @
59031aa7
...
...
@@ -284,10 +284,11 @@ class REALMBertModel(MegatronModule):
class
REALMRetriever
(
MegatronModule
):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def
__init__
(
self
,
ict_model
,
ict_dataset
,
hashed_index
,
top_k
=
5
):
def
__init__
(
self
,
ict_model
,
ict_dataset
,
block_data
,
hashed_index
,
top_k
=
5
):
super
(
REALMRetriever
,
self
).
__init__
()
self
.
ict_model
=
ict_model
self
.
ict_dataset
=
ict_dataset
self
.
block_data
=
block_data
self
.
hashed_index
=
hashed_index
self
.
top_k
=
top_k
...
...
@@ -320,8 +321,8 @@ class REALMRetriever(MegatronModule):
block_buckets
[
j
]
=
block_buckets
[
i
].
copy
()
# [batch_size x max_bucket_population x embed_size]
block_embeds
=
[
torch
.
cuda
.
FloatTensor
(
np
.
array
([
self
.
hashed_index
.
get_block_embed
(
arr
[
3
])
for
arr
in
bucket
]))
for
bucket
in
block_buckets
]
block_embeds
=
[
torch
.
cuda
.
FloatTensor
(
np
.
array
([
self
.
block_data
.
embed_data
[
idx
]
for
idx
in
bucket
]))
for
bucket
in
block_buckets
]
all_top5_tokens
,
all_top5_pad_masks
=
[],
[]
for
query_embed
,
embed_tensor
,
bucket
in
zip
(
query_embeds
,
block_embeds
,
block_buckets
):
...
...
pretrain_realm.py
View file @
59031aa7
...
...
@@ -19,7 +19,7 @@ import torch
import
torch.nn.functional
as
F
from
hashed_index
import
load_ict_checkpoint
,
get_ict_dataset
from
megatron.data.realm_
dataset
import
Hashed
Index
from
megatron.data.realm_
index
import
BlockData
,
RandProjectionLSH
Index
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
...
...
@@ -39,9 +39,10 @@ def model_provider():
ict_model
=
load_ict_checkpoint
()
ict_dataset
=
get_ict_dataset
()
hashed_index
=
HashedIndex
.
load_from_file
(
args
.
hash_data_path
)
all_block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
hashed_index
=
RandProjectionLSHIndex
.
load_from_file
(
args
.
block_index_path
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
hashed_index
)
retriever
=
REALMRetriever
(
ict_model
,
ict_dataset
,
all_block_data
,
hashed_index
)
# TODO: REALMBertModel should accept a path to a pretrained bert-base
model
=
REALMBertModel
(
retriever
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment