Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
360885ee
"scripts/deprecated/test_httpserver_decode_stream.py" did not exist on "30db99b3d98cbc4886dc3e35dce0f1658a44939c"
Commit
360885ee
authored
Apr 09, 2020
by
Neel Kant
Browse files
Qualitative test prep works
parent
7504ef44
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
128 additions
and
6 deletions
+128
-6
ict_qualitative_test.py
ict_qualitative_test.py
+119
-0
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+5
-1
megatron/model/bert_model.py
megatron/model/bert_model.py
+4
-5
No files found.
megatron/
ict_qualitative_test.py
→
ict_qualitative_test.py
View file @
360885ee
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.training
import
get_model
from
megatron.training
import
get_model
from
pretrain_bert_ict
import
model_provider
from
pretrain_bert_ict
import
get_batch
,
model_provider
def
main
():
def
main
():
...
@@ -17,22 +18,33 @@ def main():
...
@@ -17,22 +18,33 @@ def main():
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
args
=
get_args
()
model
=
load_checkpoint
()
model
=
load_checkpoint
()
model
.
eval
()
dataset
=
get_dataset
()
dataset
=
get_dataset
()
data_iter
=
iter
(
get_dataloader
(
dataset
))
num_docs
=
100
all_input_tokens
=
[]
all_doc_logits
=
np
.
zeros
(
num_docs
,
128
)
all_input_logits
=
[]
for
i
in
range
(
num_docs
):
all_doc_tokens
=
[]
doc_tokens
=
[]
all_doc_logits
=
[]
doc_token_lists
=
dataset
.
get_sentence_split_doc
(
i
)
ptr
=
0
while
len
(
doc_tokens
)
<
args
.
seq_length
and
ptr
<
len
(
doc_token_lists
):
doc_tokens
.
extend
(
doc_token_lists
[
ptr
])
doc_tokens
,
doc_token_types
,
doc_pad_mask
=
dataset
.
concat_and_pad_tokens
(
doc_tokens
)
for
i
in
range
(
100
):
doc_logits
=
model
.
embed_doc
(
np
.
array
(
doc_tokens
),
np
.
array
(
doc_pad_mask
),
np
.
array
(
doc_token_types
))
input_tokens
,
input_types
,
input_pad_mask
,
doc_tokens
,
doc_token_types
,
doc_pad_mask
=
get_batch
(
data_iter
)
all_doc_logits
[
i
]
=
doc_logits
input_logits
,
doc_logits
,
_
=
model
.
module
.
module
.
forward
(
input_tokens
,
input_types
,
input_pad_mask
,
doc_tokens
,
doc_pad_mask
,
doc_token_types
,
return_logits
=
True
)
print
(
all_doc_logits
,
flush
=
True
)
all_input_tokens
.
append
(
input_tokens
.
detach
().
cpu
().
numpy
())
all_input_logits
.
append
(
input_logits
.
detach
().
cpu
().
numpy
())
all_doc_tokens
.
append
(
doc_tokens
.
detach
().
cpu
().
numpy
())
all_doc_logits
.
append
(
doc_logits
.
detach
().
cpu
().
numpy
())
all_inputs_tokens
=
np
.
array
(
all_input_tokens
).
reshape
(
-
1
,
args
.
seq_length
)
all_inputs_logits
=
np
.
array
(
all_input_logits
).
reshape
(
-
1
,
128
)
all_doc_tokens
=
np
.
array
(
all_doc_tokens
).
reshape
(
-
1
,
args
.
seq_length
)
all_doc_logits
=
np
.
array
(
all_doc_logits
).
reshape
(
-
1
,
128
)
np
.
save
(
'input_tokens.npy'
,
all_input_tokens
)
np
.
save
(
'input_logits.npy'
,
all_input_logits
)
np
.
save
(
'doc_tokens.npy'
,
all_doc_tokens
)
np
.
save
(
'doc_logits.npy'
,
all_doc_logits
)
def
load_checkpoint
():
def
load_checkpoint
():
...
@@ -61,10 +73,6 @@ def load_checkpoint():
...
@@ -61,10 +73,6 @@ def load_checkpoint():
return
model
return
model
def
load_doc_embeds
(
path
):
pass
def
get_dataset
():
def
get_dataset
():
args
=
get_args
()
args
=
get_args
()
indexed_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
indexed_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
...
@@ -79,12 +87,33 @@ def get_dataset():
...
@@ -79,12 +87,33 @@ def get_dataset():
num_epochs
=
None
,
num_epochs
=
None
,
max_num_samples
=
total_num_documents
,
max_num_samples
=
total_num_documents
,
max_seq_length
=
288
,
# doesn't matter
max_seq_length
=
288
,
# doesn't matter
short_seq_prob
=
0.01
,
# doesn't matter
short_seq_prob
=
0.
00
01
,
# doesn't matter
seed
=
1
seed
=
1
)
)
dataset
=
InverseClozeDataset
(
**
kwargs
)
dataset
=
InverseClozeDataset
(
**
kwargs
)
return
dataset
return
dataset
def
get_dataloader
(
dataset
):
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
global_batch_size
=
args
.
batch_size
*
world_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
True
,
rank
=
rank
,
world_size
=
world_size
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
megatron/data/ict_dataset.py
View file @
360885ee
...
@@ -29,6 +29,7 @@ class InverseClozeDataset(Dataset):
...
@@ -29,6 +29,7 @@ class InverseClozeDataset(Dataset):
self
.
sep_id
=
tokenizer
.
sep
self
.
sep_id
=
tokenizer
.
sep
self
.
mask_id
=
tokenizer
.
mask
self
.
mask_id
=
tokenizer
.
mask
self
.
pad_id
=
tokenizer
.
pad
self
.
pad_id
=
tokenizer
.
pad
self
.
offset
=
0
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
indexed_dataset
.
doc_idx
.
shape
[
0
]
return
self
.
indexed_dataset
.
doc_idx
.
shape
[
0
]
...
@@ -85,9 +86,10 @@ class InverseClozeDataset(Dataset):
...
@@ -85,9 +86,10 @@ class InverseClozeDataset(Dataset):
num_tries
+=
1
num_tries
+=
1
doc
=
None
doc
=
None
while
doc
is
None
:
while
doc
is
None
:
doc
=
self
.
get_sentence_split_doc
(
idx
)
doc
=
self
.
get_sentence_split_doc
(
idx
+
self
.
offset
)
if
not
doc
:
if
not
doc
:
doc
=
None
doc
=
None
self
.
offset
+=
1
num_sentences
=
len
(
doc
)
num_sentences
=
len
(
doc
)
padless_max_len
=
self
.
max_seq_length
-
2
padless_max_len
=
self
.
max_seq_length
-
2
...
@@ -97,6 +99,7 @@ class InverseClozeDataset(Dataset):
...
@@ -97,6 +99,7 @@ class InverseClozeDataset(Dataset):
input_sentence_idx
=
rng
.
randint
(
0
,
num_sentences
-
1
)
input_sentence_idx
=
rng
.
randint
(
0
,
num_sentences
-
1
)
input_tokens
=
doc
[
input_sentence_idx
][:
target_seq_length
]
input_tokens
=
doc
[
input_sentence_idx
][:
target_seq_length
]
if
not
len
(
input_tokens
)
>
0
:
if
not
len
(
input_tokens
)
>
0
:
self
.
offset
+=
1
continue
continue
context_tokens
=
[]
context_tokens
=
[]
...
@@ -127,6 +130,7 @@ class InverseClozeDataset(Dataset):
...
@@ -127,6 +130,7 @@ class InverseClozeDataset(Dataset):
# assemble the tokens and token types of the context
# assemble the tokens and token types of the context
context_tokens
=
context_tokens
[:
padless_max_len
]
context_tokens
=
context_tokens
[:
padless_max_len
]
if
not
len
(
context_tokens
)
>
0
:
if
not
len
(
context_tokens
)
>
0
:
self
.
offset
+=
1
continue
continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types
# concatenate 'CLS' and 'SEP' tokens and add extra token types
...
...
megatron/model/bert_model.py
View file @
360885ee
...
@@ -233,7 +233,7 @@ class ICTBertModel(MegatronModule):
...
@@ -233,7 +233,7 @@ class ICTBertModel(MegatronModule):
self
.
_context_key
=
'context_model'
self
.
_context_key
=
'context_model'
def
forward
(
self
,
input_tokens
,
input_attention_mask
,
input_types
,
def
forward
(
self
,
input_tokens
,
input_attention_mask
,
input_types
,
context_tokens
,
context_attention_mask
,
context_types
):
context_tokens
,
context_attention_mask
,
context_types
,
return_logits
=
False
):
question_ict_logits
,
_
=
self
.
question_model
.
forward
(
input_tokens
,
1
-
input_attention_mask
,
input_types
)
question_ict_logits
,
_
=
self
.
question_model
.
forward
(
input_tokens
,
1
-
input_attention_mask
,
input_types
)
context_ict_logits
,
_
=
self
.
context_model
.
forward
(
context_tokens
,
1
-
context_attention_mask
,
context_types
)
context_ict_logits
,
_
=
self
.
context_model
.
forward
(
context_tokens
,
1
-
context_attention_mask
,
context_types
)
...
@@ -241,12 +241,11 @@ class ICTBertModel(MegatronModule):
...
@@ -241,12 +241,11 @@ class ICTBertModel(MegatronModule):
# [batch x h] * [h x batch]
# [batch x h] * [h x batch]
retrieval_scores
=
question_ict_logits
.
matmul
(
torch
.
transpose
(
context_ict_logits
,
0
,
1
))
retrieval_scores
=
question_ict_logits
.
matmul
(
torch
.
transpose
(
context_ict_logits
,
0
,
1
))
return
retrieval_scores
if
return_logits
:
return
question_ict_logits
,
context_ict_logits
,
retrieval_scores
def
embed_doc
(
self
,
doc_tokens
,
doc_attention_mask
,
doc_types
):
return
retrieval_scores
doc_logits
,
_
=
self
.
context_model
.
forward
(
doc_tokens
,
1
-
doc_attention_mask
,
doc_types
)
return
doc_logits
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
keep_vars
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment