Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
17d897e0
Commit
17d897e0
authored
Jan 29, 2021
by
Mostofa Patwary
Browse files
WIP: main_retriver_merge
parent
b69bc7ef
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
185 additions
and
42 deletions
+185
-42
megatron/arguments.py
megatron/arguments.py
+17
-6
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+19
-3
megatron/learning_rates.py
megatron/learning_rates.py
+7
-1
megatron/model/language_model.py
megatron/model/language_model.py
+13
-0
megatron/model/transformer.py
megatron/model/transformer.py
+4
-1
megatron/training.py
megatron/training.py
+18
-4
megatron/utils.py
megatron/utils.py
+37
-1
pretrain_ict.py
pretrain_ict.py
+70
-26
No files found.
megatron/arguments.py
View file @
17d897e0
...
...
@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser
=
_add_validation_args
(
parser
)
parser
=
_add_data_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_
realm
_args
(
parser
)
parser
=
_add_
biencoder
_args
(
parser
)
# Custom arguments.
if
extra_args_provider
is
not
None
:
...
...
@@ -310,6 +310,8 @@ def _add_training_args(parser):
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--override-checkpoint-version'
,
type
=
float
,
default
=
None
,
help
=
'Override checkpoint version'
)
group
.
add_argument
(
'--distribute-checkpointed-activations'
,
action
=
'store_true'
,
help
=
'If set, distribute checkpointed activations '
...
...
@@ -567,12 +569,19 @@ def _add_autoresume_args(parser):
return
parser
def
_add_
realm
_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'
realm
'
)
def
_add_
biencoder
_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'
biencoder
'
)
# network size
group
.
add_argument
(
'--ict-head-size'
,
type
=
int
,
default
=
None
,
help
=
'Size of block embeddings to be used in ICT and REALM (paper default: 128)'
)
group
.
add_argument
(
'--projection-dim'
,
type
=
int
,
default
=
0
,
help
=
'Size of projection head used in biencoder (paper default: 128)'
)
group
.
add_argument
(
'--shared-query-context-model'
,
action
=
'store_true'
,
help
=
'Whether to share the parameters of the query and context models or not'
)
group
.
add_argument
(
'--pool-type'
,
type
=
str
,
default
=
'cls-token'
,
choices
=
[
'avg'
,
'cls-token'
,
'max'
],
help
=
'different options are: avg | cls-token | max, default=cls-token'
)
# checkpointing
group
.
add_argument
(
'--ict-load'
,
type
=
str
,
default
=
None
,
...
...
@@ -589,14 +598,16 @@ def _add_realm_args(parser):
help
=
'Whether to use one sentence documents in ICT'
)
# training
group
.
add_argument
(
'--report-topk-accuracies'
,
nargs
=
'+'
,
default
=
[],
group
.
add_argument
(
'--report-topk-accuracies'
,
nargs
=
'+'
,
type
=
int
,
default
=
[],
help
=
"Which top-k accuracies to report (e.g. '1 5 20')"
)
group
.
add_argument
(
'--retriever-score-scaling'
,
action
=
'store_true'
,
help
=
"Whether to scale retriever scores by inverse square root of hidden size"
)
# faiss index
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
,
help
=
'Whether create the FaissMIPSIndex on GPU'
)
group
.
add_argument
(
'--block-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Where to save/load BlockData to/from'
)
#
group.add_argument('--block-data-path', type=str, default=None,
#
help='Where to save/load BlockData to/from')
# indexer
group
.
add_argument
(
'--indexer-batch-size'
,
type
=
int
,
default
=
128
,
...
...
megatron/data/ict_dataset.py
View file @
17d897e0
...
...
@@ -9,6 +9,16 @@ from megatron import get_args
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.realm_dataset_utils
import
get_block_samples_mapping
def
make_attention_mask
(
source_block
,
target_block
):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask
=
(
target_block
[
None
,
:]
>=
1
)
*
(
source_block
[:,
None
]
>=
1
)
mask
=
mask
.
astype
(
np
.
int64
)
# (source_length, target_length)
return
mask
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
...
...
@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
title_pad_offset
]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
context_tokens
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
query_mask
=
make_attention_mask
(
query_tokens
,
query_tokens
)
context_mask
=
make_attention_mask
(
context_tokens
,
context_tokens
)
block_data
=
sample_data
.
as_array
()
sample
=
{
'query_tokens'
:
query_tokens
,
'query_mask'
:
query_mask
,
'query_pad_mask'
:
query_pad_mask
,
'block_tokens'
:
block_tokens
,
'block_pad_mask'
:
block_pad_mask
,
'context_tokens'
:
context_tokens
,
'context_mask'
:
context_mask
,
'context_pad_mask'
:
context_pad_mask
,
'block_data'
:
block_data
,
}
...
...
megatron/learning_rates.py
View file @
17d897e0
...
...
@@ -59,6 +59,12 @@ class AnnealingLR(object):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
#print_rank_0(self.warmup_steps)
#print_rank_0(self.num_steps)
#print_rank_0(self.warmup_steps)
#print_rank_0(self.max_lr)
#print_rank_0(self.max_lr * float(self.num_steps) / float(self.warmup_steps))
# Use linear warmup for the initial part.
if
self
.
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
warmup_steps
:
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
...
...
@@ -97,7 +103,7 @@ class AnnealingLR(object):
new_lr
=
self
.
get_lr
()
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
#print_rank_0(new_lr)
def
state_dict
(
self
):
state_dict
=
{
...
...
megatron/model/language_model.py
View file @
17d897e0
...
...
@@ -374,6 +374,19 @@ class TransformerLanguageModelBase(MegatronModule):
# Transformer.
if
self
.
_transformer_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_transformer_key
]
# for compatiability with t5 architecture
# this is temporary unless t5_main is merged
elif
'encoder'
in
state_dict
:
state_dict_
=
state_dict
[
'encoder'
]
# for forward compatibility for t5 architecture
state_dict_attention
=
{}
for
key
in
state_dict_
.
keys
():
if
'.self_attention.'
in
key
:
state_dict_attention
[
key
.
replace
(
".self_attention."
,
".attention."
)]
=
state_dict_
[
key
]
else
:
state_dict_attention
[
key
]
=
state_dict_
[
key
]
state_dict_
=
state_dict_attention
else
:
# for backward compatibility.
state_dict_
=
{}
...
...
megatron/model/transformer.py
View file @
17d897e0
...
...
@@ -214,6 +214,9 @@ class ParallelSelfAttention(MegatronModule):
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
checkpoint_version
=
get_checkpoint_version
()
if
get_args
().
override_checkpoint_version
is
not
None
:
checkpoint_version
=
get_args
().
override_checkpoint_version
if
checkpoint_version
is
not
None
:
if
checkpoint_version
==
0
:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
...
...
megatron/training.py
View file @
17d897e0
...
...
@@ -48,7 +48,7 @@ from megatron.model import get_params_for_weight_decay_optimization
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.data.data_loaders
import
build_pretraining_data_loader
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
,
params_grad_norm
,
params_global_norm
,
print_model
,
print_grads
def
print_datetime
(
string
):
...
...
@@ -663,11 +663,25 @@ def train_step(forward_step_func, data_iterator,
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
timers
(
'backward-clip-grad'
).
stop
()
#print_rank_0("after backward")
#print_grads(model)
print_model
(
model
)
print_rank_0
(
params_global_norm
(
model
))
print_rank_0
(
params_grad_norm
(
model
))
# Update parameters.
timers
(
'optimizer'
).
start
()
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
#print_rank_0("after optimizer")
#print_model(model)
print_rank_0
(
params_global_norm
(
model
))
#print_rank_0(params_grad_norm(model))
#sys.exit()
# Update learning rate.
skipped_iter
=
0
if
not
(
args
.
fp16
and
optimizer
.
overflow
):
...
...
@@ -905,9 +919,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Exiting based on iterations
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
#
if not saved_checkpoint:
#
save_checkpoint_and_time(iteration, model, optimizer,
#
lr_scheduler)
torch
.
distributed
.
barrier
()
print_datetime
(
'exiting program at iteration {}'
.
format
(
iteration
))
sys
.
exit
()
...
...
megatron/utils.py
View file @
17d897e0
...
...
@@ -150,4 +150,40 @@ def get_ltor_masks_and_position_ids(data,
return
attention_mask
,
loss_mask
,
position_ids
def
params_grad_norm
(
model
):
print_rank_0
(
"params_grad_norm"
)
norm2
=
torch
.
cuda
.
FloatTensor
([
0.0
])
for
param
in
model
.
parameters
():
if
param
.
grad
is
None
:
continue
norm
=
torch
.
norm
(
param
.
grad
.
data
.
float
(),
2
)
norm2
+=
norm
*
norm
torch
.
distributed
.
all_reduce
(
norm2
)
norm
=
norm2
**
0.5
return
norm
.
item
()
def
params_global_norm
(
model
):
print_rank_0
(
"params_global_norm"
)
norm2
=
torch
.
cuda
.
FloatTensor
([
0.0
])
for
param
in
model
.
parameters
():
norm
=
torch
.
norm
(
param
.
data
.
float
(),
2
)
norm2
+=
norm
*
norm
torch
.
distributed
.
all_reduce
(
norm2
)
norm
=
norm2
**
0.5
return
norm
.
item
()
def
print_model
(
model
):
print_rank_0
(
"print-model"
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
#print("{} {}".format(name, param.data), flush=True)
print_rank_0
(
"{} {}"
.
format
(
name
,
param
.
data
))
return
def
print_grads
(
model
):
print_rank_0
(
"print-grads"
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
grad
is
None
:
continue
print_rank_0
(
"{} {}"
.
format
(
name
,
param
.
grad
))
pretrain_ict.py
View file @
17d897e0
...
...
@@ -14,6 +14,8 @@
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import
sys
import
math
import
torch
import
torch.distributed
as
dist
...
...
@@ -26,14 +28,16 @@ from megatron import mpu
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.model.
realm
_model
import
general_ict
_model_provider
from
megatron.data.
realm
_dataset_utils
import
get_ict_batch
from
megatron.model.
biencoder
_model
import
biencoder
_model_provider
from
megatron.data.
biencoder
_dataset_utils
import
get_ict_batch
def
pretrain_ict_model_provider
():
args
=
get_args
()
return
general_ict_model_provider
(
False
,
False
)
model
=
biencoder_model_provider
(
only_context_model
=
False
,
only_query_model
=
False
,
shared_query_context_model
=
args
.
shared_query_context_model
)
return
model
def
get_group_world_size_rank
():
...
...
@@ -72,7 +76,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
output
=
output_list
[
rank
].
contiguous
()
return
output
def
forward_step
(
data_iterator
,
model
,
input_tensor
):
"""Forward step."""
args
=
get_args
()
...
...
@@ -80,37 +83,76 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch.
timers
(
'batch-generator'
).
start
()
query_tokens
,
query_
pad_
mask
,
\
block
_tokens
,
block_pad_mask
,
block
_indices
=
get_ict_batch
(
data_iterator
)
query_tokens
,
query_mask
,
\
context
_tokens
,
context_mask
,
context
_indices
=
get_ict_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
# Query and Context Types
query_types
=
torch
.
cuda
.
LongTensor
(
*
query_tokens
.
shape
).
fill_
(
0
)
context_types
=
torch
.
cuda
.
LongTensor
(
*
context_tokens
.
shape
).
fill_
(
0
)
# Forward model.
query_logits
,
block_logits
=
model
(
query_tokens
,
query_pad_mask
,
block_tokens
,
block_pad_mask
)
micro_batch_size
=
query_logits
.
shape
[
0
]
global_batch_size
=
dist
.
get_world_size
()
*
micro_batch_size
# recall we assert that tensor_model_parallel_size == 1
#print_rank_0(query_tokens)
#print_rank_0(context_tokens)
#print_rank_0(torch.sum(query_types))
#print_rank_0(torch.sum(query_mask))
#print_rank_0(torch.sum(context_types))
#print_rank_0(torch.sum(context_mask))
all_query_logits
=
AllgatherFromDataParallelRegion
.
apply
(
query_logits
)
all_block_logits
=
AllgatherFromDataParallelRegion
.
apply
(
block_logits
)
#print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
# Forward model.
query_logits
,
context_logits
=
model
(
query_tokens
,
query_mask
,
query_types
,
context_tokens
,
context_mask
,
context_types
)
#print_rank_0(query_logits)
#print_rank_0(context_logits)
# scores are inner products between query and block embeddings
retrieval_scores
=
all_query_logits
.
float
().
matmul
(
torch
.
transpose
(
all_block_logits
,
0
,
1
).
float
())
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
sorted_vals
,
sorted_indices
=
torch
.
topk
(
softmaxed
,
k
=
softmaxed
.
shape
[
1
],
sorted
=
True
)
micro_batch_size
=
query_logits
.
shape
[
0
]
# recall we assert that tensor_model_parallel_size == 1
#global_batch_size = dist.get_world_size() * micro_batch_size
#all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
#all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
global_batch_size
=
micro_batch_size
all_query_logits
=
query_logits
all_context_logits
=
context_logits
# scores are inner products between query and context embeddings
retrieval_scores
=
torch
.
matmul
(
all_query_logits
,
torch
.
transpose
(
all_context_logits
,
0
,
1
))
# scaling the retriever scores
if
args
.
retriever_score_scaling
:
retrieval_scores
=
retrieval_scores
/
math
.
sqrt
(
args
.
hidden_size
)
softmax_scores
=
F
.
log_softmax
(
retrieval_scores
,
dim
=
1
)
sorted_vals
,
sorted_indices
=
torch
.
topk
(
softmax_scores
,
k
=
softmax_scores
.
shape
[
1
],
sorted
=
True
)
def
topk_accuracy
(
k
):
return
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
i
in
sorted_indices
[
i
,
:
k
])
for
i
in
range
(
global_batch_size
)])
/
global_batch_size
])
return
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
i
in
sorted_indices
[
i
,
:
k
])
\
for
i
in
range
(
global_batch_size
)])
/
global_batch_size
])
topk_accs
=
[
topk_accuracy
(
int
(
k
))
for
k
in
args
.
report_topk_accuracies
]
retrieval_loss
=
torch
.
nn
.
CrossEntropyLoss
()(
retrieval_scores
,
torch
.
arange
(
global_batch_size
).
long
().
cuda
())
retrieval_loss
=
retrieval_loss
.
float
()
averaged_losses
=
average_losses_across_data_parallel_group
([
retrieval_loss
,
*
topk_accs
])
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict
=
{
'top{}_acc'
.
format
(
k
):
v
for
k
,
v
in
zip
(
args
.
report_topk_accuracies
,
averaged_losses
[
1
:])}
stats_dict
=
dict
(
retrieval_loss
=
averaged_losses
[
0
],
**
topk_acc_dict
)
labels
=
torch
.
arange
(
global_batch_size
).
long
().
cuda
()
loss
=
F
.
nll_loss
(
softmax_scores
,
labels
,
reduction
=
'mean'
)
reduced_losses
=
average_losses_across_data_parallel_group
([
loss
,
*
topk_accs
])
# Scale the retrieval loss
loss
=
loss
*
mpu
.
get_data_parallel_world_size
()
return
retrieval_loss
,
stats_dict
#retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
#retrieval_loss = retrieval_loss.float()
#averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict
=
{
'top{}_acc'
.
format
(
k
):
v
*
100
for
k
,
v
in
\
zip
(
args
.
report_topk_accuracies
,
reduced_losses
[
1
:])}
stats_dict
=
dict
(
loss
=
reduced_losses
[
0
],
**
topk_acc_dict
)
#print_rank_0(loss)
#print_rank_0(stats_dict)
#sys.exit()
return
loss
,
stats_dict
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
...
...
@@ -136,5 +178,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
pretrain_ict_model_provider
,
forward_step
,
pretrain
(
train_valid_test_datasets_provider
,
pretrain_ict_model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment