Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
bfc20ecf
Commit
bfc20ecf
authored
Feb 01, 2021
by
Mostofa Patwary
Browse files
fixed isse from Initializing ICT from pretrained BERT model
parent
0295bb89
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
12 deletions
+21
-12
megatron/training.py
megatron/training.py
+15
-6
pretrain_ict.py
pretrain_ict.py
+6
-6
No files found.
megatron/training.py
View file @
bfc20ecf
...
@@ -320,6 +320,8 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -320,6 +320,8 @@ def setup_model_and_optimizer(model_provider_func):
'init_state_dict_from_bert'
):
'init_state_dict_from_bert'
):
print
(
"Initializing ICT from pretrained BERT model"
,
flush
=
True
)
print
(
"Initializing ICT from pretrained BERT model"
,
flush
=
True
)
unwrapped_model
.
init_state_dict_from_bert
()
unwrapped_model
.
init_state_dict_from_bert
()
if
args
.
fp16
:
optimizer
.
_model_params_to_master_params
()
return
model
,
optimizer
,
lr_scheduler
return
model
,
optimizer
,
lr_scheduler
...
@@ -646,6 +648,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -646,6 +648,7 @@ def train_step(forward_step_func, data_iterator,
if
args
.
fp16
:
if
args
.
fp16
:
optimizer
.
update_master_grads
()
optimizer
.
update_master_grads
()
timers
(
'backward-master-grad'
).
stop
()
timers
(
'backward-master-grad'
).
stop
()
grad_norm_local
=
None
# Clipping gradients helps prevent the exploding gradient.
# Clipping gradients helps prevent the exploding gradient.
timers
(
'backward-clip-grad'
).
start
()
timers
(
'backward-clip-grad'
).
start
()
...
@@ -660,16 +663,16 @@ def train_step(forward_step_func, data_iterator,
...
@@ -660,16 +663,16 @@ def train_step(forward_step_func, data_iterator,
mpu
.
clip_grad_norm
(
parameters
,
args
.
clip_grad
,
mpu
.
clip_grad_norm
(
parameters
,
args
.
clip_grad
,
parameter_names
=
parameter_names
)
parameter_names
=
parameter_names
)
else
:
else
:
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
grad_norm_local
=
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
timers
(
'backward-clip-grad'
).
stop
()
timers
(
'backward-clip-grad'
).
stop
()
#print_rank_0("print-grad_norm_local {}".format(grad_norm_local))
#print_rank_0("after backward")
#print_rank_0("after backward")
#print_grads(model)
#print_grads(model)
print_model
(
model
)
#print_model(model)
print_rank_0
(
params_global_norm
(
model
))
#print_rank_0(params_global_norm(model))
print_rank_0
(
params_grad_norm
(
model
))
#print_rank_0(params_grad_norm(model))
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
...
@@ -678,9 +681,11 @@ def train_step(forward_step_func, data_iterator,
...
@@ -678,9 +681,11 @@ def train_step(forward_step_func, data_iterator,
#print_rank_0("after optimizer")
#print_rank_0("after optimizer")
#print_model(model)
#print_model(model)
print_rank_0
(
params_global_norm
(
model
))
#
print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
#print_rank_0(params_grad_norm(model))
#sys.exit()
#sys.exit()
#print_rank_0("print-optimizer.overflow {}".format(optimizer.overflow))
# Update learning rate.
# Update learning rate.
skipped_iter
=
0
skipped_iter
=
0
...
@@ -856,6 +861,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -856,6 +861,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Iterations.
# Iterations.
iteration
=
args
.
iteration
iteration
=
args
.
iteration
#print_rank_0("Check betas before iterations")
#for group in optimizer.optimizer.param_groups:
# print_rank_0("betas {} lr {} weight_decay {} eps {}".format(group['betas'], group['lr'], group['weight_decay'], group['eps']))
timers
(
'interval time'
).
start
()
timers
(
'interval time'
).
start
()
print_datetime
(
'before the start of training step'
)
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
report_memory_flag
=
True
...
...
pretrain_ict.py
View file @
bfc20ecf
...
@@ -109,13 +109,13 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -109,13 +109,13 @@ def forward_step(data_iterator, model, input_tensor):
micro_batch_size
=
query_logits
.
shape
[
0
]
micro_batch_size
=
query_logits
.
shape
[
0
]
# recall we assert that tensor_model_parallel_size == 1
# recall we assert that tensor_model_parallel_size == 1
#
global_batch_size = dist.get_world_size() * micro_batch_size
global_batch_size
=
dist
.
get_world_size
()
*
micro_batch_size
#
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_query_logits
=
AllgatherFromDataParallelRegion
.
apply
(
query_logits
)
#
all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
all_context_logits
=
AllgatherFromDataParallelRegion
.
apply
(
context_logits
)
global_batch_size
=
micro_batch_size
#
global_batch_size = micro_batch_size
all_query_logits
=
query_logits
#
all_query_logits = query_logits
all_context_logits
=
context_logits
#
all_context_logits = context_logits
# scores are inner products between query and context embeddings
# scores are inner products between query and context embeddings
retrieval_scores
=
torch
.
matmul
(
all_query_logits
,
retrieval_scores
=
torch
.
matmul
(
all_query_logits
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment