Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6f56b909
Commit
6f56b909
authored
Mar 31, 2020
by
Neel Kant
Browse files
Remove debug statements and correct dataloader
parent
932c0970
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
23 additions
and
43 deletions
+23
-43
megatron/data_utils/datasets.py
megatron/data_utils/datasets.py
+23
-21
megatron/model/bert_model.py
megatron/model/bert_model.py
+0
-3
megatron/training.py
megatron/training.py
+0
-8
pretrain_bert_ict.py
pretrain_bert_ict.py
+0
-11
No files found.
megatron/data_utils/datasets.py
View file @
6f56b909
...
@@ -924,7 +924,6 @@ class InverseClozeDataset(data.Dataset):
...
@@ -924,7 +924,6 @@ class InverseClozeDataset(data.Dataset):
'context_types'
:
np
.
array
(
context_token_types
),
'context_types'
:
np
.
array
(
context_token_types
),
'context_pad_mask'
:
np
.
array
(
context_pad_mask
)
'context_pad_mask'
:
np
.
array
(
context_pad_mask
)
}
}
print
(
"got item"
)
return
sample
return
sample
...
@@ -958,7 +957,7 @@ class InverseClozeDataset(data.Dataset):
...
@@ -958,7 +957,7 @@ class InverseClozeDataset(data.Dataset):
doc
=
self
.
get_sentence_split_doc
(
doc_idx
)
doc
=
self
.
get_sentence_split_doc
(
doc_idx
)
if
not
doc
:
if
not
doc
:
doc
=
None
doc
=
None
print
(
"got doc sentences"
)
# set up and tokenize the entire selected document
# set up and tokenize the entire selected document
num_sentences
=
len
(
doc
)
num_sentences
=
len
(
doc
)
all_token_lists
=
[]
all_token_lists
=
[]
...
@@ -968,39 +967,42 @@ class InverseClozeDataset(data.Dataset):
...
@@ -968,39 +967,42 @@ class InverseClozeDataset(data.Dataset):
all_token_lists
.
append
(
tokens
)
all_token_lists
.
append
(
tokens
)
all_token_type_lists
.
append
(
token_types
)
all_token_type_lists
.
append
(
token_types
)
print
(
"got tokenized sentences"
)
sentence_token_lens
=
[
len
(
l
)
for
l
in
all_token_lists
]
sentence_token_lens
=
[
len
(
l
)
for
l
in
all_token_lists
]
inclusion_mask
=
[
Tru
e
]
*
num_sentences
inclusion_mask
=
[
Fals
e
]
*
num_sentences
# select a random sentence from the document as input
# select a random sentence from the document as input
input_sentence_idx
=
rng
.
randint
(
0
,
len
(
all_token_lists
)
-
1
)
input_sentence_idx
=
rng
.
randint
(
0
,
len
(
all_token_lists
)
-
1
)
input_tokens
=
all_token_lists
[
input_sentence_idx
].
copy
()
input_tokens
=
all_token_lists
[
input_sentence_idx
].
copy
()
[:
self
.
max_seq_len
-
2
]
input_token_types
=
all_token_type_lists
[
input_sentence_idx
].
copy
()
input_token_types
=
all_token_type_lists
[
input_sentence_idx
].
copy
()
[:
self
.
max_seq_len
-
2
]
# 10% of the time, the input sentence is left in the context.
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it.
# The other 90% of the time, remove it.
if
rng
.
random
()
>
0.1
:
if
rng
.
random
()
<
0.1
:
inclusion_mask
[
input_sentence_idx
]
=
Fals
e
inclusion_mask
[
input_sentence_idx
]
=
Tru
e
# parameters for examining sentences to remove from the context
# parameters for examining sentences to remove from the context
remove
_preceding
=
True
view
_preceding
=
True
view_radius
=
0
view_radius
=
1
while
sum
(
s
for
i
,
s
in
enumerate
(
sentence_token_lens
)
if
inclusion_mask
[
i
])
>
target
_seq_len
gth
:
while
sum
(
s
for
i
,
s
in
enumerate
(
sentence_token_lens
)
if
inclusion_mask
[
i
])
<
self
.
max
_seq_len
-
2
:
# keep removing sentences while the context is too large.
# keep removing sentences while the context is too large.
if
remove_preceding
:
if
view_preceding
:
if
view_radius
<
input_sentence_idx
:
examine_idx
=
input_sentence_idx
-
view_radius
inclusion_mask
[
view_radius
]
=
False
if
examine_idx
>=
0
:
inclusion_mask
[
examine_idx
]
=
True
else
:
examine_idx
=
input_sentence_idx
+
view_radius
if
examine_idx
<
num_sentences
:
inclusion_mask
[
examine_idx
]
=
True
view_radius
+=
1
view_radius
+=
1
elif
not
remove_preceding
and
num_sentences
-
view_radius
>
input_sentence_idx
:
view_preceding
=
not
view_preceding
inclusion_mask
[
num_sentences
-
view_radius
]
=
False
if
view_radius
>
num_sentences
:
remove_preceding
=
not
remove_preceding
break
print
(
"got inclusion mask"
)
# assemble the tokens and token types of the context
# assemble the tokens and token types of the context
context_tokens
=
list
(
itertools
.
chain
(
context_tokens
=
list
(
itertools
.
chain
(
*
[
l
for
i
,
l
in
enumerate
(
all_token_lists
)
if
inclusion_mask
[
i
]]))
*
[
l
for
i
,
l
in
enumerate
(
all_token_lists
)
if
inclusion_mask
[
i
]]))
[:
self
.
max_seq_len
-
2
]
context_token_types
=
list
(
itertools
.
chain
(
context_token_types
=
list
(
itertools
.
chain
(
*
[
l
for
i
,
l
in
enumerate
(
all_token_type_lists
)
if
inclusion_mask
[
i
]]))
*
[
l
for
i
,
l
in
enumerate
(
all_token_type_lists
)
if
inclusion_mask
[
i
]]))
[:
self
.
max_seq_len
-
2
]
# concatenate 'CLS' and 'SEP' tokens and add extra token types
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens
,
input_token_types
,
input_pad_mask
=
self
.
concat_and_pad_tokens
(
input_tokens
,
input_token_types
,
input_pad_mask
=
self
.
concat_and_pad_tokens
(
...
@@ -1008,7 +1010,6 @@ class InverseClozeDataset(data.Dataset):
...
@@ -1008,7 +1010,6 @@ class InverseClozeDataset(data.Dataset):
context_tokens
,
context_token_types
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
context_tokens
,
context_token_types
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
context_tokens
,
context_token_types
)
context_tokens
,
context_token_types
)
print
(
"got all tokens"
)
return
(
input_tokens
,
input_token_types
,
input_pad_mask
),
\
return
(
input_tokens
,
input_token_types
,
input_pad_mask
),
\
(
context_tokens
,
context_token_types
,
context_pad_mask
)
(
context_tokens
,
context_token_types
,
context_pad_mask
)
...
@@ -1018,6 +1019,7 @@ class InverseClozeDataset(data.Dataset):
...
@@ -1018,6 +1019,7 @@ class InverseClozeDataset(data.Dataset):
tokens
=
[
self
.
tokenizer
.
get_command
(
'ENC'
).
Id
]
+
tokens
+
[
self
.
tokenizer
.
get_command
(
'sep'
).
Id
]
tokens
=
[
self
.
tokenizer
.
get_command
(
'ENC'
).
Id
]
+
tokens
+
[
self
.
tokenizer
.
get_command
(
'sep'
).
Id
]
token_types
=
[
token_types
[
0
]]
+
token_types
+
[
token_types
[
0
]]
token_types
=
[
token_types
[
0
]]
+
token_types
+
[
token_types
[
0
]]
assert
len
(
tokens
)
<=
self
.
max_seq_len
num_pad
=
max
(
0
,
self
.
max_seq_len
-
len
(
tokens
))
num_pad
=
max
(
0
,
self
.
max_seq_len
-
len
(
tokens
))
pad_mask
=
[
0
]
*
len
(
tokens
)
+
[
1
]
*
num_pad
pad_mask
=
[
0
]
*
len
(
tokens
)
+
[
1
]
*
num_pad
tokens
+=
[
self
.
tokenizer
.
get_command
(
'pad'
).
Id
]
*
num_pad
tokens
+=
[
self
.
tokenizer
.
get_command
(
'pad'
).
Id
]
*
num_pad
...
...
megatron/model/bert_model.py
View file @
6f56b909
...
@@ -292,13 +292,10 @@ class ICTBertModel(MegatronModule):
...
@@ -292,13 +292,10 @@ class ICTBertModel(MegatronModule):
context_tokens
,
context_attention_mask
,
context_types
):
context_tokens
,
context_attention_mask
,
context_types
):
question_ict_logits
,
_
=
self
.
question_model
.
forward
(
input_tokens
,
input_attention_mask
,
input_types
)
question_ict_logits
,
_
=
self
.
question_model
.
forward
(
input_tokens
,
input_attention_mask
,
input_types
)
print
(
"(bert ict forward) got question logits"
)
context_ict_logits
,
_
=
self
.
context_model
.
forward
(
context_tokens
,
context_attention_mask
,
context_types
)
context_ict_logits
,
_
=
self
.
context_model
.
forward
(
context_tokens
,
context_attention_mask
,
context_types
)
print
(
"(bert ict forward) got context logits"
)
# [batch x h] * [h x batch]
# [batch x h] * [h x batch]
retrieval_scores
=
question_ict_logits
.
matmul
(
torch
.
transpose
(
context_ict_logits
,
0
,
1
))
retrieval_scores
=
question_ict_logits
.
matmul
(
torch
.
transpose
(
context_ict_logits
,
0
,
1
))
print
(
"(bert ict forward) got retrieval scores"
)
return
retrieval_scores
return
retrieval_scores
...
...
megatron/training.py
View file @
6f56b909
...
@@ -253,7 +253,6 @@ def setup_model_and_optimizer(model_provider_func, args):
...
@@ -253,7 +253,6 @@ def setup_model_and_optimizer(model_provider_func, args):
def
backward_step
(
optimizer
,
model
,
loss
,
args
,
timers
):
def
backward_step
(
optimizer
,
model
,
loss
,
args
,
timers
):
"""Backward step."""
"""Backward step."""
print
(
"back1"
)
# Backward pass.
# Backward pass.
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
if
args
.
fp16
:
if
args
.
fp16
:
...
@@ -261,7 +260,6 @@ def backward_step(optimizer, model, loss, args, timers):
...
@@ -261,7 +260,6 @@ def backward_step(optimizer, model, loss, args, timers):
else
:
else
:
loss
.
backward
()
loss
.
backward
()
print
(
"back2"
)
# All-reduce if needed.
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
timers
(
'allreduce'
).
start
()
timers
(
'allreduce'
).
start
()
...
@@ -269,12 +267,10 @@ def backward_step(optimizer, model, loss, args, timers):
...
@@ -269,12 +267,10 @@ def backward_step(optimizer, model, loss, args, timers):
fp32_allreduce
=
args
.
fp32_allreduce
)
fp32_allreduce
=
args
.
fp32_allreduce
)
timers
(
'allreduce'
).
stop
()
timers
(
'allreduce'
).
stop
()
print
(
"back3"
)
# Update master gradients.
# Update master gradients.
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
update_master_grads
()
optimizer
.
update_master_grads
()
print
(
"back4"
)
# Clipping gradients helps prevent the exploding gradient.
# Clipping gradients helps prevent the exploding gradient.
if
args
.
clip_grad
>
0
:
if
args
.
clip_grad
>
0
:
if
not
args
.
fp16
:
if
not
args
.
fp16
:
...
@@ -282,7 +278,6 @@ def backward_step(optimizer, model, loss, args, timers):
...
@@ -282,7 +278,6 @@ def backward_step(optimizer, model, loss, args, timers):
else
:
else
:
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
print
(
"back5"
)
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr_scheduler
,
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr_scheduler
,
args
,
timers
):
args
,
timers
):
...
@@ -293,21 +288,18 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
...
@@ -293,21 +288,18 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
,
args
,
timers
)
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
,
args
,
timers
)
timers
(
'forward'
).
stop
()
timers
(
'forward'
).
stop
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
"confirm forward"
)
# Calculate gradients, reduce across processes, and clip.
# Calculate gradients, reduce across processes, and clip.
timers
(
'backward'
).
start
()
timers
(
'backward'
).
start
()
backward_step
(
optimizer
,
model
,
loss
,
args
,
timers
)
backward_step
(
optimizer
,
model
,
loss
,
args
,
timers
)
timers
(
'backward'
).
stop
()
timers
(
'backward'
).
stop
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
"did backward step"
)
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
optimizer
.
step
()
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
timers
(
'optimizer'
).
stop
()
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
"did optim step"
)
# Update learning rate.
# Update learning rate.
skipped_iter
=
0
skipped_iter
=
0
...
...
pretrain_bert_ict.py
View file @
6f56b909
...
@@ -79,9 +79,6 @@ def get_batch(data_iterator, timers):
...
@@ -79,9 +79,6 @@ def get_batch(data_iterator, timers):
context_types
=
data_b
[
'context_types'
].
long
()
context_types
=
data_b
[
'context_types'
].
long
()
context_pad_mask
=
data_b
[
'context_pad_mask'
].
long
()
context_pad_mask
=
data_b
[
'context_pad_mask'
].
long
()
global
num_batches
print
(
"got batch {}"
.
format
(
num_batches
))
return
input_tokens
,
input_types
,
input_pad_mask
,
\
return
input_tokens
,
input_types
,
input_pad_mask
,
\
context_tokens
,
context_types
,
context_pad_mask
context_tokens
,
context_types
,
context_pad_mask
...
@@ -98,19 +95,11 @@ def forward_step(data_iterator, model, args, timers):
...
@@ -98,19 +95,11 @@ def forward_step(data_iterator, model, args, timers):
# Forward model.
# Forward model.
retrieval_scores
=
model
(
input_tokens
,
1
-
input_pad_mask
,
input_types
,
retrieval_scores
=
model
(
input_tokens
,
1
-
input_pad_mask
,
input_types
,
context_tokens
,
1
-
context_pad_mask
,
context_types
)
context_tokens
,
1
-
context_pad_mask
,
context_types
)
print
(
"ran model to get retrieval scores"
)
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
0
)
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
0
)
retrieval_loss
=
F
.
cross_entropy
(
softmaxed
,
torch
.
arange
(
softmaxed
.
shape
[
0
]).
cuda
())
retrieval_loss
=
F
.
cross_entropy
(
softmaxed
,
torch
.
arange
(
softmaxed
.
shape
[
0
]).
cuda
())
print
(
type
(
retrieval_loss
))
reduced_losses
=
reduce_losses
([
retrieval_loss
])
reduced_losses
=
reduce_losses
([
retrieval_loss
])
global
num_batches
print
(
"did forward step {}"
.
format
(
num_batches
))
num_batches
+=
1
print
(
retrieval_loss
,
{
'retrieval loss'
:
reduced_losses
[
0
]})
return
retrieval_loss
,
{
'retrieval loss'
:
reduced_losses
[
0
]}
return
retrieval_loss
,
{
'retrieval loss'
:
reduced_losses
[
0
]}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment