Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
7504ef44
Commit
7504ef44
authored
Apr 08, 2020
by
Neel Kant
Browse files
Add ICT qualitative test script
parent
064bdc46
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
99 additions
and
8 deletions
+99
-8
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+3
-8
megatron/ict_qualitative_test.py
megatron/ict_qualitative_test.py
+90
-0
megatron/model/bert_model.py
megatron/model/bert_model.py
+6
-0
No files found.
megatron/data/ict_dataset.py
View file @
7504ef44
...
...
@@ -36,14 +36,13 @@ class InverseClozeDataset(Dataset):
def
__getitem__
(
self
,
idx
):
# get rng state corresponding to index (allows deterministic random pair)
rng
=
random
.
Random
(
idx
+
20000
+
self
.
seed
)
np_rng
=
np
.
random
.
RandomState
(
seed
=
[
rng
.
randint
(
0
,
2
**
32
-
1
)
for
_
in
range
(
16
)])
# get seq length. Save 2 tokens for beginning and end
target_seq_length
=
self
.
max_seq_length
-
2
if
rng
.
random
()
<
self
.
short_seq_prob
:
target_seq_length
=
rng
.
randint
(
5
,
target_seq_length
)
input_data
,
context_data
=
self
.
get_input_and_context
(
target_seq_length
,
rng
,
np_
rng
)
input_data
,
context_data
=
self
.
get_input_and_context
(
idx
,
target_seq_length
,
rng
)
input_tokens
,
input_token_types
,
input_pad_mask
=
input_data
context_tokens
,
context_token_types
,
context_pad_mask
=
context_data
...
...
@@ -79,16 +78,14 @@ class InverseClozeDataset(Dataset):
token_types
=
[
0
]
*
self
.
max_seq_length
return
tokens
,
token_types
,
pad_mask
def
get_input_and_context
(
self
,
target_seq_length
,
rng
,
np_
rng
):
def
get_input_and_context
(
self
,
idx
,
target_seq_length
,
rng
):
"""fetches a sentence and its surrounding context"""
num_tries
=
0
while
num_tries
<
20
:
num_tries
+=
1
doc
=
None
while
doc
is
None
:
doc_idx
=
np_rng
.
randint
(
len
(
self
)
-
1
)
# doc is a list of sentences
doc
=
self
.
get_sentence_split_doc
(
doc_idx
)
doc
=
self
.
get_sentence_split_doc
(
idx
)
if
not
doc
:
doc
=
None
...
...
@@ -140,5 +137,3 @@ class InverseClozeDataset(Dataset):
(
context_tokens
,
context_token_types
,
context_pad_mask
)
else
:
raise
RuntimeError
(
"Could not get a valid data point from InverseClozeDataset"
)
megatron/ict_qualitative_test.py
0 → 100644
View file @
7504ef44
import
numpy
as
np
import
torch
import
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron.initialize
import
initialize_megatron
from
megatron.training
import
get_model
from
pretrain_bert_ict
import
model_provider
def
main
():
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
model
=
load_checkpoint
()
dataset
=
get_dataset
()
num_docs
=
100
all_doc_logits
=
np
.
zeros
(
num_docs
,
128
)
for
i
in
range
(
num_docs
):
doc_tokens
=
[]
doc_token_lists
=
dataset
.
get_sentence_split_doc
(
i
)
ptr
=
0
while
len
(
doc_tokens
)
<
args
.
seq_length
and
ptr
<
len
(
doc_token_lists
):
doc_tokens
.
extend
(
doc_token_lists
[
ptr
])
doc_tokens
,
doc_token_types
,
doc_pad_mask
=
dataset
.
concat_and_pad_tokens
(
doc_tokens
)
doc_logits
=
model
.
embed_doc
(
np
.
array
(
doc_tokens
),
np
.
array
(
doc_pad_mask
),
np
.
array
(
doc_token_types
))
all_doc_logits
[
i
]
=
doc_logits
print
(
all_doc_logits
,
flush
=
True
)
def
load_checkpoint
():
args
=
get_args
()
model
=
get_model
(
model_provider
)
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
load
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
checkpoint_name
=
get_checkpoint_name
(
args
.
load
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
model
.
load_state_dict
(
state_dict
[
'model'
])
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
def
load_doc_embeds
(
path
):
pass
def
get_dataset
():
args
=
get_args
()
indexed_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
doc_idx_ptr
=
indexed_dataset
.
get_doc_idx
()
total_num_documents
=
indexed_dataset
.
doc_idx
.
shape
[
0
]
-
1
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
[
0
:
total_num_documents
])
kwargs
=
dict
(
name
=
'full'
,
indexed_dataset
=
indexed_dataset
,
data_prefix
=
args
.
data_path
,
num_epochs
=
None
,
max_num_samples
=
total_num_documents
,
max_seq_length
=
288
,
# doesn't matter
short_seq_prob
=
0.01
,
# doesn't matter
seed
=
1
)
dataset
=
InverseClozeDataset
(
**
kwargs
)
return
dataset
if
__name__
==
"__main__"
:
main
()
megatron/model/bert_model.py
View file @
7504ef44
...
...
@@ -15,6 +15,7 @@
"""BERT model."""
import
numpy
as
np
import
torch
from
megatron
import
get_args
...
...
@@ -242,6 +243,11 @@ class ICTBertModel(MegatronModule):
return
retrieval_scores
def
embed_doc
(
self
,
doc_tokens
,
doc_attention_mask
,
doc_types
):
doc_logits
,
_
=
self
.
context_model
.
forward
(
doc_tokens
,
1
-
doc_attention_mask
,
doc_types
)
return
doc_logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
state_dict_
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment