Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0e8f4331
Commit
0e8f4331
authored
May 27, 2020
by
Neel Kant
Browse files
Correct CrossEntropyLoss
parent
8e22824e
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
87 additions
and
16 deletions
+87
-16
indexer.py
indexer.py
+2
-1
megatron/arguments.py
megatron/arguments.py
+1
-1
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+2
-1
megatron/model/realm_model.py
megatron/model/realm_model.py
+31
-2
megatron/training.py
megatron/training.py
+14
-4
pretrain_bert_ict.py
pretrain_bert_ict.py
+37
-7
No files found.
indexer.py
View file @
0e8f4331
...
...
@@ -95,6 +95,7 @@ def setup_realm_groups_and_vars():
class
IndexBuilder
(
object
):
def
__init__
(
self
):
args
=
get_args
()
self
.
debug
=
args
.
debug
self
.
rank
=
args
.
rank
self
.
model
=
None
self
.
dataloader
=
None
...
...
@@ -287,6 +288,6 @@ def get_one_epoch_dataloader(dataset, batch_size=None):
if
__name__
==
"__main__"
:
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
index_builder
=
Basic
IndexBuilder
()
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
megatron/arguments.py
View file @
0e8f4331
...
...
@@ -265,7 +265,7 @@ def _add_checkpointing_args(parser):
group
.
add_argument
(
'--ict-load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing an ICTBertModel checkpoint'
)
group
.
add_argument
(
'--bert-load'
,
type
=
str
,
default
=
None
,
help
=
'Directory containing an BertModel checkpoint (needed to start REALM)'
)
help
=
'Directory containing an BertModel checkpoint (needed to start
ICT and
REALM)'
)
group
.
add_argument
(
'--no-load-optim'
,
action
=
'store_true'
,
help
=
'Do not load optimizer when loading checkpoint.'
)
group
.
add_argument
(
'--no-load-rng'
,
action
=
'store_true'
,
...
...
megatron/data/ict_dataset.py
View file @
0e8f4331
...
...
@@ -97,7 +97,8 @@ class InverseClozeDataset(Dataset):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
if
title
is
not
None
:
tokens
+=
title
+
[
self
.
sep_id
]
# tokens += title + [self.sep_id]
tokens
=
t
assert
len
(
tokens
)
<=
self
.
max_seq_length
,
len
(
tokens
)
num_pad
=
self
.
max_seq_length
-
len
(
tokens
)
...
...
megatron/model/realm_model.py
View file @
0e8f4331
...
...
@@ -294,10 +294,11 @@ class ICTBertModel(MegatronModule):
query_logits
=
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
return
query_logits
,
block_logits
# [batch x embed] * [embed x batch]
retrieval_scores
=
query_logits
.
matmul
(
torch
.
transpose
(
block_logits
,
0
,
1
))
return
retrieval_scores
#
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
#
return retrieval_scores
def
embed_query
(
self
,
query_tokens
,
query_attention_mask
):
"""Embed a batch of tokens using the query model"""
...
...
@@ -343,3 +344,31 @@ class ICTBertModel(MegatronModule):
print
(
"Loading ICT block model"
,
flush
=
True
)
self
.
block_model
.
load_state_dict
(
state_dict
[
self
.
_block_key
],
strict
=
strict
)
def
init_state_dict_from_bert
(
self
):
args
=
get_args
()
import
os
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
bert_load
)
if
not
os
.
path
.
isfile
(
tracker_filename
):
raise
FileNotFoundError
(
"Could not find BERT load for ICT"
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
checkpoint_name
=
get_checkpoint_name
(
args
.
bert_load
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
try
:
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
except
BaseException
:
raise
ValueError
(
"Could not load checkpoint"
)
model_dict
=
state_dict
[
'model'
][
'language_model'
]
self
.
query_model
.
language_model
.
load_state_dict
(
model_dict
)
self
.
block_model
.
language_model
.
load_state_dict
(
model_dict
)
query_ict_head_state_dict
=
self
.
state_dict_for_save_checkpoint
()[
self
.
_query_key
][
'ict_head'
]
self
.
block_model
.
ict_head
.
load_state_dict
(
query_ict_head_state_dict
)
megatron/training.py
View file @
0e8f4331
...
...
@@ -37,6 +37,7 @@ from megatron.learning_rates import AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.mpu.initialize
import
get_index_ready
,
get_train_group
,
get_data_parallel_group
,
get_gloo_comm_group
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
report_memory
...
...
@@ -229,6 +230,12 @@ def setup_model_and_optimizer(model_provider_func):
else
:
args
.
iteration
=
0
if
args
.
iteration
==
0
and
isinstance
(
model
.
module
.
module
,
ICTBertModel
):
print
(
"Yes, located ICT model"
,
flush
=
True
)
model
.
module
.
module
.
init_state_dict_from_bert
()
elif
args
.
iteration
==
0
:
print
(
"Ooops"
,
flush
=
True
)
return
model
,
optimizer
,
lr_scheduler
...
...
@@ -239,10 +246,12 @@ def backward_step(optimizer, model, loss):
# torch.cuda.synchronize()
# Backward pass.
optimizer
.
zero_grad
(
set_grads_to_None
=
True
)
#
optimizer.zero_grad(set_grads_to_None=True)
if
args
.
fp16
:
optimizer
.
zero_grad
(
set_grads_to_None
=
True
)
optimizer
.
backward
(
loss
,
update_master_grads
=
False
)
else
:
optimizer
.
zero_grad
()
loss
.
backward
()
# All-reduce if needed.
...
...
@@ -377,9 +386,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
print
(
'>>> Starting train()'
,
flush
=
True
)
# start off by posting a receive call which will be answered.
# synchronize for start
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
if
args
.
max_training_rank
is
not
None
:
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
group
=
get_gloo_comm_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
group
=
get_gloo_comm_group
(),
async_op
=
True
)
last_reload_iteration
=
iteration
while
iteration
<
args
.
train_iters
:
# this only applies for realm right here
if
args
.
max_training_rank
is
not
None
and
recv_handle
.
is_completed
()
and
iteration
>=
last_reload_iteration
+
500
:
...
...
pretrain_bert_ict.py
View file @
0e8f4331
...
...
@@ -16,6 +16,7 @@
"""Pretrain BERT for Inverse Cloze Task"""
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
megatron
import
get_args
...
...
@@ -71,6 +72,7 @@ def get_batch(data_iterator):
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
...
...
@@ -80,21 +82,49 @@ def forward_step(data_iterator, model):
timers
(
'batch generator'
).
stop
()
# Forward model.
retrieval_scores
=
model
(
query_tokens
,
query_pad_mask
,
block_tokens
,
block_pad_mask
).
float
()
# retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float()
query_logits
,
block_logits
=
model
(
query_tokens
,
query_pad_mask
,
block_tokens
,
block_pad_mask
)
data_parallel_size
=
dist
.
get_world_size
()
/
args
.
model_parallel_size
batch_size
=
query_logits
.
shape
[
0
]
global_batch_size
=
int
(
batch_size
*
data_parallel_size
)
all_logits_shape
=
(
int
(
global_batch_size
),
int
(
query_logits
.
shape
[
1
]))
all_query_logits
=
torch
.
zeros
(
all_logits_shape
).
type
(
query_logits
.
dtype
).
cuda
()
all_block_logits
=
all_query_logits
.
clone
().
cuda
()
all_query_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
query_logits
all_block_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
block_logits
# print(all_query_logits[:, :5], flush=True)
# print(all_block_logits[:, :5], flush=True)
dist
.
all_reduce
(
all_query_logits
)
dist
.
all_reduce
(
all_block_logits
)
# print(all_query_logits[:, :5], flush=True)
# print(all_block_logits[:, :5], flush=True)
retrieval_scores
=
all_query_logits
.
float
().
matmul
(
torch
.
transpose
(
all_block_logits
,
0
,
1
).
float
())
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
sorted_vals
,
sorted_indices
=
torch
.
topk
(
softmaxed
,
k
=
softmaxed
.
shape
[
1
],
sorted
=
True
)
batch_size
=
softmaxed
.
shape
[
0
]
top1_acc
=
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
sorted_indices
[
i
,
0
]
==
i
)
for
i
in
range
(
batch_size
)])
/
batch_size
])
top5_acc
=
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
i
in
sorted_indices
[
i
,
:
5
])
for
i
in
range
(
batch_size
)])
/
batch_size
])
def
topk_acc
(
k
):
return
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
i
in
sorted_indices
[
i
,
:
k
])
for
i
in
range
(
global_batch_size
)])
/
global_batch_size
])
top_accs
=
[
topk_acc
(
k
)
for
k
in
[
1
,
8
,
20
,
100
]]
retrieval_loss
=
torch
.
nn
.
CrossEntropyLoss
()(
retrieval_scores
,
torch
.
arange
(
global_batch_size
).
long
().
cuda
())
# correct_probs = torch.gather(softmaxed, 1, torch.arange(global_batch_size).long().cuda().reshape(-1, 1))
# assert correct_probs[3] == softmaxed[3, 3]
# retrieval_loss = -torch.sum(torch.log(correct_probs)) / global_batch_size
retrieval_loss
=
F
.
cross_entropy
(
softmaxed
,
torch
.
arange
(
batch_size
).
cuda
())
reduced_losses
=
reduce_losses
([
retrieval_loss
,
top1_acc
,
top5_acc
])
reduced_losses
=
reduce_losses
([
retrieval_loss
,
*
top_accs
])
stats_dict
=
{
'retrieval loss'
:
reduced_losses
[
0
],
'top1_acc'
:
reduced_losses
[
1
],
'top5_acc'
:
reduced_losses
[
2
]
'top8_acc'
:
reduced_losses
[
2
],
'top20_acc'
:
reduced_losses
[
3
],
'top100_acc'
:
reduced_losses
[
4
],
}
return
retrieval_loss
,
stats_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment