Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e919dd8e
"tests/vscode:/vscode.git/clone" did not exist on "12a53b43833b7bea279a205e313f2bd3f0cdfd99"
Commit
e919dd8e
authored
Feb 01, 2021
by
Mostofa Patwary
Browse files
cleared the commented codes
parent
bfc20ecf
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
14 additions
and
108 deletions
+14
-108
megatron/arguments.py
megatron/arguments.py
+2
-2
megatron/learning_rates.py
megatron/learning_rates.py
+1
-7
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+7
-16
megatron/training.py
megatron/training.py
+2
-23
megatron/utils.py
megatron/utils.py
+1
-37
pretrain_ict.py
pretrain_ict.py
+1
-23
No files found.
megatron/arguments.py
View file @
e919dd8e
...
...
@@ -606,8 +606,8 @@ def _add_biencoder_args(parser):
# faiss index
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
,
help
=
'Whether create the FaissMIPSIndex on GPU'
)
#
group.add_argument('--block-data-path', type=str, default=None,
#
help='Where to save/load BlockData to/from')
group
.
add_argument
(
'--block-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Where to save/load BlockData to/from'
)
# indexer
group
.
add_argument
(
'--indexer-batch-size'
,
type
=
int
,
default
=
128
,
...
...
megatron/learning_rates.py
View file @
e919dd8e
...
...
@@ -59,12 +59,6 @@ class AnnealingLR(object):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
#print_rank_0(self.warmup_steps)
#print_rank_0(self.num_steps)
#print_rank_0(self.warmup_steps)
#print_rank_0(self.max_lr)
#print_rank_0(self.max_lr * float(self.num_steps) / float(self.warmup_steps))
# Use linear warmup for the initial part.
if
self
.
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
warmup_steps
:
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
...
...
@@ -103,7 +97,7 @@ class AnnealingLR(object):
new_lr
=
self
.
get_lr
()
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
#print_rank_0(new_lr)
def
state_dict
(
self
):
state_dict
=
{
...
...
megatron/model/biencoder_model.py
View file @
e919dd8e
...
...
@@ -27,7 +27,7 @@ def biencoder_model_provider(only_query_model=False,
print_rank_0
(
'building BiEncoderModel...'
)
# simpler to just keep using 2 tokentypes since
# simpler to just keep using 2 tokentypes since
# the LM we initialize with has 2 tokentypes
model
=
BiEncoderModel
(
num_tokentypes
=
2
,
...
...
@@ -78,7 +78,7 @@ class BiEncoderModel(MegatronModule):
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
query_types
,
context_tokens
,
context_attention_mask
,
context_types
):
"""Run a forward pass for each of the models and
"""Run a forward pass for each of the models and
return the respective embeddings."""
if
self
.
use_query_model
:
...
...
@@ -145,7 +145,7 @@ class BiEncoderModel(MegatronModule):
state_dict
[
self
.
_context_key
],
strict
=
strict
)
def
init_state_dict_from_bert
(
self
):
"""Initialize the state from a pretrained BERT model
"""Initialize the state from a pretrained BERT model
on iteration zero of ICT pretraining"""
args
=
get_args
()
...
...
@@ -160,11 +160,6 @@ class BiEncoderModel(MegatronModule):
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
#for param in self.query_model.language_model.parameters():
# print(param.data)
#break
#sys.exit()
checkpoint_name
=
get_checkpoint_name
(
args
.
bert_load
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading BERT checkpoint {}'
.
format
(
...
...
@@ -193,17 +188,13 @@ class BiEncoderModel(MegatronModule):
if
self
.
query_model
is
not
None
and
self
.
projection_dim
>
0
:
self
.
context_model
.
projection_enc
.
load_state_dict
\
(
query_proj_state_dict
)
#for param in self.query_model.language_model.parameters():
# print(param.data)
# #sys.exit()
class
PretrainedBertModel
(
MegatronModule
):
"""BERT-based encoder for queries or contexts used for
"""BERT-based encoder for queries or contexts used for
learned information retrieval."""
def
__init__
(
self
,
num_tokentypes
=
2
,
def
__init__
(
self
,
num_tokentypes
=
2
,
parallel_output
=
True
):
super
(
PretrainedBertModel
,
self
).
__init__
()
...
...
@@ -242,7 +233,7 @@ class PretrainedBertModel(MegatronModule):
tokentype_ids
=
tokentype_ids
)
# This mask will be used in average-pooling and max-pooling
pool_mask
=
(
input_ids
==
self
.
pad_id
).
unsqueeze
(
2
)
# Taking the representation of the [CLS] token of BERT
if
self
.
pool_type
==
"cls-token"
:
pooled_output
=
lm_output
[:,
0
,
:]
...
...
@@ -256,7 +247,7 @@ class PretrainedBertModel(MegatronModule):
# Converting to float16 dtype
pooled_output
=
pooled_output
.
to
(
lm_output
.
dtype
)
# Output.
if
self
.
projection_dim
:
pooled_output
=
self
.
projection_enc
(
pooled_output
)
...
...
megatron/training.py
View file @
e919dd8e
...
...
@@ -48,7 +48,7 @@ from megatron.model import get_params_for_weight_decay_optimization
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.data.data_loaders
import
build_pretraining_data_loader
from
megatron.utils
import
report_memory
,
params_grad_norm
,
params_global_norm
,
print_model
,
print_grads
from
megatron.utils
import
report_memory
def
print_datetime
(
string
):
...
...
@@ -648,7 +648,6 @@ def train_step(forward_step_func, data_iterator,
if
args
.
fp16
:
optimizer
.
update_master_grads
()
timers
(
'backward-master-grad'
).
stop
()
grad_norm_local
=
None
# Clipping gradients helps prevent the exploding gradient.
timers
(
'backward-clip-grad'
).
start
()
...
...
@@ -663,30 +662,14 @@ def train_step(forward_step_func, data_iterator,
mpu
.
clip_grad_norm
(
parameters
,
args
.
clip_grad
,
parameter_names
=
parameter_names
)
else
:
grad_norm_local
=
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
timers
(
'backward-clip-grad'
).
stop
()
#print_rank_0("print-grad_norm_local {}".format(grad_norm_local))
#print_rank_0("after backward")
#print_grads(model)
#print_model(model)
#print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
# Update parameters.
timers
(
'optimizer'
).
start
()
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
#print_rank_0("after optimizer")
#print_model(model)
#print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
#sys.exit()
#print_rank_0("print-optimizer.overflow {}".format(optimizer.overflow))
# Update learning rate.
skipped_iter
=
0
if
not
(
args
.
fp16
and
optimizer
.
overflow
):
...
...
@@ -861,10 +844,6 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
iteration
=
args
.
iteration
#print_rank_0("Check betas before iterations")
#for group in optimizer.optimizer.param_groups:
# print_rank_0("betas {} lr {} weight_decay {} eps {}".format(group['betas'], group['lr'], group['weight_decay'], group['eps']))
timers
(
'interval time'
).
start
()
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
...
...
megatron/utils.py
View file @
e919dd8e
...
...
@@ -150,40 +150,4 @@ def get_ltor_masks_and_position_ids(data,
return
attention_mask
,
loss_mask
,
position_ids
def
params_grad_norm
(
model
):
print_rank_0
(
"params_grad_norm"
)
norm2
=
torch
.
cuda
.
FloatTensor
([
0.0
])
for
param
in
model
.
parameters
():
if
param
.
grad
is
None
:
continue
norm
=
torch
.
norm
(
param
.
grad
.
data
.
float
(),
2
)
norm2
+=
norm
*
norm
torch
.
distributed
.
all_reduce
(
norm2
)
norm
=
norm2
**
0.5
return
norm
.
item
()
def
params_global_norm
(
model
):
print_rank_0
(
"params_global_norm"
)
norm2
=
torch
.
cuda
.
FloatTensor
([
0.0
])
for
param
in
model
.
parameters
():
norm
=
torch
.
norm
(
param
.
data
.
float
(),
2
)
norm2
+=
norm
*
norm
torch
.
distributed
.
all_reduce
(
norm2
)
norm
=
norm2
**
0.5
return
norm
.
item
()
def
print_model
(
model
):
print_rank_0
(
"print-model"
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
#print("{} {}".format(name, param.data), flush=True)
print_rank_0
(
"{} {}"
.
format
(
name
,
param
.
data
))
return
def
print_grads
(
model
):
print_rank_0
(
"print-grads"
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
grad
is
None
:
continue
print_rank_0
(
"{} {}"
.
format
(
name
,
param
.
grad
))
pretrain_ict.py
View file @
e919dd8e
...
...
@@ -91,31 +91,16 @@ def forward_step(data_iterator, model, input_tensor):
query_types
=
torch
.
cuda
.
LongTensor
(
*
query_tokens
.
shape
).
fill_
(
0
)
context_types
=
torch
.
cuda
.
LongTensor
(
*
context_tokens
.
shape
).
fill_
(
0
)
#print_rank_0(query_tokens)
#print_rank_0(context_tokens)
#print_rank_0(torch.sum(query_types))
#print_rank_0(torch.sum(query_mask))
#print_rank_0(torch.sum(context_types))
#print_rank_0(torch.sum(context_mask))
#print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
# Forward model.
query_logits
,
context_logits
=
model
(
query_tokens
,
query_mask
,
query_types
,
context_tokens
,
context_mask
,
context_types
)
#print_rank_0(query_logits)
#print_rank_0(context_logits)
micro_batch_size
=
query_logits
.
shape
[
0
]
# recall we assert that tensor_model_parallel_size == 1
global_batch_size
=
dist
.
get_world_size
()
*
micro_batch_size
all_query_logits
=
AllgatherFromDataParallelRegion
.
apply
(
query_logits
)
all_context_logits
=
AllgatherFromDataParallelRegion
.
apply
(
context_logits
)
#global_batch_size = micro_batch_size
#all_query_logits = query_logits
#all_context_logits = context_logits
all_context_logits
=
AllgatherFromDataParallelRegion
.
apply
(
context_logits
)
# scores are inner products between query and context embeddings
retrieval_scores
=
torch
.
matmul
(
all_query_logits
,
...
...
@@ -141,17 +126,10 @@ def forward_step(data_iterator, model, input_tensor):
# Scale the retrieval loss
loss
=
loss
*
mpu
.
get_data_parallel_world_size
()
#retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
#retrieval_loss = retrieval_loss.float()
#averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict
=
{
'top{}_acc'
.
format
(
k
):
v
*
100
for
k
,
v
in
\
zip
(
args
.
report_topk_accuracies
,
reduced_losses
[
1
:])}
stats_dict
=
dict
(
loss
=
reduced_losses
[
0
],
**
topk_acc_dict
)
#print_rank_0(loss)
#print_rank_0(stats_dict)
#sys.exit()
return
loss
,
stats_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment