Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ca6b6687
Commit
ca6b6687
authored
Apr 03, 2020
by
Neel Kant
Browse files
Fix InverseClozeDataset behavior (with commented out test code)
parent
423c51b0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
14 deletions
+48
-14
megatron/data_utils/datasets.py
megatron/data_utils/datasets.py
+24
-1
megatron/model/bert_model.py
megatron/model/bert_model.py
+9
-7
pretrain_bert_ict.py
pretrain_bert_ict.py
+15
-6
No files found.
megatron/data_utils/datasets.py
View file @
ca6b6687
...
...
@@ -966,6 +966,7 @@ class InverseClozeDataset(data.Dataset):
padless_max_len
=
self
.
max_seq_len
-
2
# select a random sentence from the document as input
# TODO: consider adding multiple input sentences.
input_sentence_idx
=
rng
.
randint
(
0
,
num_sentences
-
1
)
tokens
,
token_types
=
self
.
sentence_tokenize
(
doc
[
input_sentence_idx
],
0
)
input_tokens
,
input_token_types
=
tokens
[:
target_seq_length
],
token_types
[:
target_seq_length
]
...
...
@@ -976,14 +977,17 @@ class InverseClozeDataset(data.Dataset):
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it.
if
rng
.
random
()
<
0.1
:
# if True:
context_tokens
=
input_tokens
.
copy
()
context_token_types
=
input_token_types
.
copy
()
# parameters for examining sentences to remove from the context
# TODO: test detokenized stuff, make sure it's the same doc in the same order.
# change preceding rng condition to always true
view_preceding
=
True
view_radius
=
1
while
len
(
context_tokens
)
<
padless_max_len
:
# keep
remov
ing sentences while the context
is too larg
e.
# keep
add
ing sentences while the context
can accommodate mor
e.
if
view_preceding
:
examine_idx
=
input_sentence_idx
-
view_radius
if
examine_idx
>=
0
:
...
...
@@ -1001,6 +1005,25 @@ class InverseClozeDataset(data.Dataset):
if
view_radius
>
num_sentences
:
break
# detokenized_input = self.tokenizer.DecodeIds(input_tokens)
# detokenized_context = self.tokenizer.DecodeIds(context_tokens)
# encoded_sentences = [self.tokenizer.EncodeAsIds(s).tokenization for s in doc]
# full_document_encoded = list(itertools.chain(*encoded_sentences))
# detokenized_doc = self.tokenizer.DecodeIds(full_document_encoded)
# b1 = detokenized_input in detokenized_doc
# b2 = detokenized_context in detokenized_doc
# print("-" * 100)
# print('> input idx: {}'.format(input_sentence_idx))
# print('> input in doc: {}'.format(b1))
# print('> context in doc: {}'.format(b2))
# print('> input: {}'.format(detokenized_input))
# print('> context: {}'.format(detokenized_context))
# print('\n> doc: {}'.format(detokenized_doc))
# if not (b1 and b2):
# raise ValueError("you dun goofed")
# assemble the tokens and token types of the context
context_tokens
=
context_tokens
[:
padless_max_len
]
context_token_types
=
context_token_types
[:
padless_max_len
]
...
...
megatron/model/bert_model.py
View file @
ca6b6687
...
...
@@ -215,6 +215,7 @@ class BertModel(MegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
not
self
.
add_ict_head
:
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
...
...
@@ -232,6 +233,7 @@ class BertModel(MegatronModule):
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
not
self
.
add_ict_head
:
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
if
self
.
add_binary_head
:
...
...
@@ -291,8 +293,8 @@ class ICTBertModel(MegatronModule):
def
forward
(
self
,
input_tokens
,
input_attention_mask
,
input_types
,
context_tokens
,
context_attention_mask
,
context_types
):
question_ict_logits
,
_
=
self
.
question_model
.
forward
(
input_tokens
,
input_attention_mask
,
input_types
)
context_ict_logits
,
_
=
self
.
context_model
.
forward
(
context_tokens
,
context_attention_mask
,
context_types
)
question_ict_logits
,
_
=
self
.
question_model
.
forward
(
input_tokens
,
1
-
input_attention_mask
,
input_types
)
context_ict_logits
,
_
=
self
.
context_model
.
forward
(
context_tokens
,
1
-
context_attention_mask
,
context_types
)
# [batch x h] * [h x batch]
retrieval_scores
=
question_ict_logits
.
matmul
(
torch
.
transpose
(
context_ict_logits
,
0
,
1
))
...
...
pretrain_bert_ict.py
View file @
ca6b6687
...
...
@@ -93,14 +93,23 @@ def forward_step(data_iterator, model, args, timers):
timers
(
'batch generator'
).
stop
()
# Forward model.
retrieval_scores
=
model
(
input_tokens
,
1
-
input_pad_mask
,
input_types
,
context_tokens
,
1
-
context_pad_mask
,
context_types
)
# TODO: important to make sure that everything, including padding mask is as expected here.
retrieval_scores
=
model
(
input_tokens
,
input_pad_mask
,
input_types
,
context_tokens
,
context_pad_mask
,
context_types
).
float
()
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
0
)
retrieval_loss
=
F
.
cross_entr
op
y
(
softmaxed
,
torch
.
arange
(
softmaxed
.
shape
[
0
]).
cuda
()
)
reduced_losses
=
reduce_losses
([
retrieval_loss
])
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
top5_vals
,
top5_indices
=
torch
.
t
op
k
(
softmaxed
,
k
=
5
,
sorted
=
True
)
batch_size
=
softmaxed
.
shape
[
0
]
return
retrieval_loss
,
{
'retrieval loss'
:
reduced_losses
[
0
]}
top1_acc
=
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
top5_indices
[
i
,
0
]
==
i
)
for
i
in
range
(
batch_size
)])
/
batch_size
])
top5_acc
=
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
i
in
top5_indices
[
i
])
for
i
in
range
(
batch_size
)])
/
batch_size
])
retrieval_loss
=
F
.
cross_entropy
(
softmaxed
,
torch
.
arange
(
batch_size
).
cuda
())
reduced_losses
=
reduce_losses
([
retrieval_loss
,
top1_acc
,
top5_acc
])
return
retrieval_loss
,
{
'retrieval loss'
:
reduced_losses
[
0
],
'top1_acc'
:
reduced_losses
[
1
],
'top5_acc'
:
reduced_losses
[
2
]}
def
get_train_val_test_data
(
args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment