Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1b44a4c4
"vscode:/vscode.git/clone" did not exist on "f50b18eec7d646bf98aef576dbb0f47ff512beaa"
Commit
1b44a4c4
authored
Apr 21, 2020
by
Neel Kant
Browse files
add test_retriever
parent
88637044
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
1 deletion
+13
-1
hashed_index.py
hashed_index.py
+12
-0
megatron/model/__init__.py
megatron/model/__init__.py
+1
-1
No files found.
hashed_index.py
View file @
1b44a4c4
...
@@ -14,6 +14,7 @@ from megatron.data.bert_dataset import get_indexed_dataset_
...
@@ -14,6 +14,7 @@ from megatron.data.bert_dataset import get_indexed_dataset_
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
REALMRetriever
from
megatron.training
import
get_model
from
megatron.training
import
get_model
from
pretrain_bert_ict
import
get_batch
,
model_provider
from
pretrain_bert_ict
import
get_batch
,
model_provider
...
@@ -101,6 +102,17 @@ class HashedIndex(object):
...
@@ -101,6 +102,17 @@ class HashedIndex(object):
self
.
hash_data
=
defaultdict
(
list
)
self
.
hash_data
=
defaultdict
(
list
)
def
test_retriever
():
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
model
=
load_checkpoint
()
model
.
eval
()
dataset
=
get_dataset
()
hashed_index
=
HashedIndex
(
embed_size
=
128
,
num_buckets
=
2048
)
retriever
=
REALMRetriever
(
model
,
dataset
,
hashed_index
)
retriever
.
retrieve_evidence_blocks_text
(
"The last monarch from the house of windsor"
)
def
main
():
def
main
():
# TODO
# TODO
...
...
megatron/model/__init__.py
View file @
1b44a4c4
...
@@ -14,6 +14,6 @@
...
@@ -14,6 +14,6 @@
# limitations under the License.
# limitations under the License.
from
.distributed
import
*
from
.distributed
import
*
from
.bert_model
import
BertModel
,
ICTBertModel
,
REALMBertModel
from
.bert_model
import
BertModel
,
ICTBertModel
,
REALMBertModel
,
REALMRetriever
from
.gpt2_model
import
GPT2Model
from
.gpt2_model
import
GPT2Model
from
.utils
import
get_params_for_weight_decay_optimization
from
.utils
import
get_params_for_weight_decay_optimization
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment