Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wuxk1
Megatron-LM
Commits
cc691cbf
Commit
cc691cbf
authored
Feb 10, 2021
by
Deepak Narayanan
Browse files
Move unwrap to megatron/utils.py and clean up imports in megatron/schedules.py
parent
e3e5ea89
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
44 deletions
+50
-44
megatron/checkpointing.py
megatron/checkpointing.py
+5
-16
megatron/schedules.py
megatron/schedules.py
+22
-22
megatron/training.py
megatron/training.py
+5
-5
megatron/utils.py
megatron/utils.py
+18
-1
No files found.
megatron/checkpointing.py
View file @
cc691cbf
...
@@ -21,12 +21,12 @@ import sys
...
@@ -21,12 +21,12 @@ import sys
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
torchDDP
from
megatron
import
(
get_args
,
from
megatron
import
(
get_args
,
mpu
,
mpu
,
print_rank_0
,
print_rank_0
,
update_num_microbatches
)
update_num_microbatches
,
utils
)
_CHECKPOINT_VERSION
=
None
_CHECKPOINT_VERSION
=
None
...
@@ -111,12 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
...
@@ -111,12 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args
=
get_args
()
args
=
get_args
()
# Only rank zero of the data parallel writes to the disk.
# Only rank zero of the data parallel writes to the disk.
unwrapped_model
=
[]
model
=
utils
.
unwrap_model
(
model
)
for
model_module
in
model
:
if
isinstance
(
model_module
,
torchDDP
):
model_module
=
model_module
.
module
unwrapped_model
.
append
(
model_module
)
model
=
unwrapped_model
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
print_rank_0
(
'saving checkpoint at iteration {:7d} to {}'
.
format
(
iteration
,
args
.
save
))
iteration
,
args
.
save
))
...
@@ -220,12 +215,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
...
@@ -220,12 +215,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
args
=
get_args
()
args
=
get_args
()
load_dir
=
getattr
(
args
,
load_arg
)
load_dir
=
getattr
(
args
,
load_arg
)
unwrapped_model
=
[]
model
=
utils
.
unwrap_model
(
model
)
for
model_module
in
model
:
if
isinstance
(
model_module
,
torchDDP
):
model_module
=
model_module
.
module
unwrapped_model
.
append
(
model_module
)
model
=
unwrapped_model
# Read the tracker file and set the iteration.
# Read the tracker file and set the iteration.
tracker_filename
=
get_checkpoint_tracker_filename
(
load_dir
)
tracker_filename
=
get_checkpoint_tracker_filename
(
load_dir
)
...
@@ -389,8 +379,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f
...
@@ -389,8 +379,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f
args
=
get_args
()
args
=
get_args
()
if
isinstance
(
model
,
torchDDP
):
model
=
utils
.
unwrap_model
(
model
)
model
=
model
.
module
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
...
...
megatron/schedules.py
View file @
cc691cbf
...
@@ -16,14 +16,10 @@
...
@@ -16,14 +16,10 @@
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_num_microbatches
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron
import
get_num_microbatches
from
megatron
import
p2p_communication
from
megatron.p2p_communication
import
recv_forward
,
recv_backward
from
megatron.p2p_communication
import
send_forward
,
send_backward
from
megatron.p2p_communication
import
send_forward_recv_backward
,
send_backward_recv_forward
from
megatron.p2p_communication
import
send_forward_recv_forward
,
send_backward_recv_backward
from
megatron.p2p_communication
import
send_forward_backward_recv_forward_backward
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
def
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
):
...
@@ -154,7 +150,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -154,7 +150,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes.
# Run warmup forward passes.
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
mpu
.
set_virtual_pipeline_model_parallel_rank
(
0
)
input_tensors
[
0
].
append
(
recv_forward
(
timers
,
use_ring_exchange
=
True
))
input_tensors
[
0
].
append
(
p2p_communication
.
recv_forward
(
timers
,
use_ring_exchange
=
True
))
for
k
in
range
(
num_warmup_microbatches
):
for
k
in
range
(
num_warmup_microbatches
):
output_tensor
=
forward_step_helper
(
k
)
output_tensor
=
forward_step_helper
(
k
)
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
next_forward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
True
)
...
@@ -173,13 +169,14 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -173,13 +169,14 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
recv_next
=
False
recv_next
=
False
input_tensor
,
output_tensor_grad
=
\
input_tensor
,
output_tensor_grad
=
\
send_forward_backward_recv_forward_backward
(
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
timers
=
timers
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grad
)
else
:
else
:
input_tensor
=
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
)
input_tensor
=
\
p2p_communication
.
send_forward_recv_forward
(
output_tensor
,
recv_prev
,
timers
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
input_tensors
[
next_forward_model_chunk_id
].
append
(
input_tensor
)
# Run 1F1B in steady state.
# Run 1F1B in steady state.
...
@@ -238,7 +235,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -238,7 +235,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Communicate tensors.
# Communicate tensors.
input_tensor
,
output_tensor_grad
=
\
input_tensor
,
output_tensor_grad
=
\
send_forward_backward_recv_forward_backward
(
p2p_communication
.
send_forward_backward_recv_forward_backward
(
output_tensor
,
input_tensor_grad
,
output_tensor
,
input_tensor_grad
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
recv_prev
=
recv_prev
,
recv_next
=
recv_next
,
timers
=
timers
)
timers
=
timers
)
...
@@ -253,7 +250,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -253,7 +250,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if
not
forward_only
:
if
not
forward_only
:
if
all_warmup_microbatches
:
if
all_warmup_microbatches
:
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
output_tensor_grads
[
num_model_chunks
-
1
].
append
(
recv_backward
(
timers
,
use_ring_exchange
=
True
))
p2p_communication
.
recv_backward
(
timers
,
use_ring_exchange
=
True
))
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
for
k
in
range
(
num_microbatches_remaining
,
num_microbatches
):
input_tensor_grad
=
backward_step_helper
(
k
)
input_tensor_grad
=
backward_step_helper
(
k
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
next_backward_model_chunk_id
=
get_model_chunk_id
(
k
+
1
,
forward
=
False
)
...
@@ -264,7 +261,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
...
@@ -264,7 +261,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if
k
==
(
num_microbatches
-
1
):
if
k
==
(
num_microbatches
-
1
):
recv_next
=
False
recv_next
=
False
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
output_tensor_grads
[
next_backward_model_chunk_id
].
append
(
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
))
p2p_communication
.
send_backward_recv_backward
(
input_tensor_grad
,
recv_next
,
timers
))
return
losses_reduced
return
losses_reduced
...
@@ -294,7 +292,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
...
@@ -294,7 +292,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
# Run warmup forward passes.
# Run warmup forward passes.
for
i
in
range
(
num_warmup_microbatches
):
for
i
in
range
(
num_warmup_microbatches
):
input_tensor
=
recv_forward
(
timers
)
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
losses_reduced
)
# Barrier before first receive to measure forward stall.
# Barrier before first receive to measure forward stall.
...
@@ -302,7 +300,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
...
@@ -302,7 +300,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
timers
(
'forward-pipeline-stall'
).
start
()
timers
(
'forward-pipeline-stall'
).
start
()
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_pipeline_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_pipeline_model_parallel_group
())
timers
(
'forward-pipeline-stall'
).
stop
()
timers
(
'forward-pipeline-stall'
).
stop
()
send_forward
(
output_tensor
,
timers
)
p2p_communication
.
send_forward
(
output_tensor
,
timers
)
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
output_tensors
.
append
(
output_tensor
)
...
@@ -317,7 +315,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
...
@@ -317,7 +315,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
# If all microbatches are run in warmup / cooldown phase, then no need to
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
# receive this tensor here.
if
num_microbatches_remaining
>
0
:
if
num_microbatches_remaining
>
0
:
input_tensor
=
recv_forward
(
timers
)
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
# Run 1F1B in steady state.
# Run 1F1B in steady state.
for
i
in
range
(
num_microbatches_remaining
):
for
i
in
range
(
num_microbatches_remaining
):
...
@@ -326,9 +324,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
...
@@ -326,9 +324,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
output_tensor
=
forward_step
(
forward_step_func
,
data_iterator
,
model
,
input_tensor
,
losses_reduced
)
input_tensor
,
losses_reduced
)
if
forward_only
:
if
forward_only
:
send_forward
(
output_tensor
,
timers
)
p2p_communication
.
send_forward
(
output_tensor
,
timers
)
else
:
else
:
output_tensor_grad
=
send_forward_recv_backward
(
output_tensor
,
timers
)
output_tensor_grad
=
\
p2p_communication
.
send_forward_recv_backward
(
output_tensor
,
timers
)
# Add input_tensor and output_tensor to end of list, then pop from the
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
# start of the list for backward pass.
...
@@ -337,7 +336,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
...
@@ -337,7 +336,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
if
forward_only
:
if
forward_only
:
if
not
last_iteration
:
if
not
last_iteration
:
input_tensor
=
recv_forward
(
timers
)
input_tensor
=
p2p_communication
.
recv_forward
(
timers
)
else
:
else
:
input_tensor
,
output_tensor
=
input_tensors
.
pop
(
0
),
output_tensors
.
pop
(
0
)
input_tensor
,
output_tensor
=
input_tensors
.
pop
(
0
),
output_tensors
.
pop
(
0
)
...
@@ -347,9 +346,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
...
@@ -347,9 +346,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
if
last_iteration
:
if
last_iteration
:
input_tensor
=
None
input_tensor
=
None
send_backward
(
input_tensor_grad
,
timers
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
)
else
:
else
:
input_tensor
=
send_backward_recv_forward
(
input_tensor_grad
,
timers
)
input_tensor
=
\
p2p_communication
.
send_backward_recv_forward
(
input_tensor_grad
,
timers
)
# Run cooldown backward passes.
# Run cooldown backward passes.
if
not
forward_only
:
if
not
forward_only
:
...
@@ -357,12 +357,12 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
...
@@ -357,12 +357,12 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
input_tensor
=
input_tensors
.
pop
(
0
)
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor_grad
=
recv_backward
(
timers
)
output_tensor_grad
=
p2p_communication
.
recv_backward
(
timers
)
input_tensor_grad
=
\
input_tensor_grad
=
\
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
backward_step
(
optimizer
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
output_tensor_grad
)
send_backward
(
input_tensor_grad
,
timers
)
p2p_communication
.
send_backward
(
input_tensor_grad
,
timers
)
return
losses_reduced
return
losses_reduced
megatron/training.py
View file @
cc691cbf
...
@@ -46,6 +46,7 @@ from megatron.learning_rates import AnnealingLR
...
@@ -46,6 +46,7 @@ from megatron.learning_rates import AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
unwrap_model
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.utils
import
calc_params_l2_norm
from
megatron.utils
import
calc_params_l2_norm
from
megatron.schedules
import
forward_backward_no_pipelining
from
megatron.schedules
import
forward_backward_no_pipelining
...
@@ -288,9 +289,8 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -288,9 +289,8 @@ def setup_model_and_optimizer(model_provider_func):
model
=
get_model
(
model_provider_func
)
model
=
get_model
(
model_provider_func
)
unwrapped_model
=
model
unwrapped_model
=
unwrap_model
(
model
,
while
isinstance
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16Module
)):
(
torchDDP
,
LocalDDP
,
FP16Module
))
unwrapped_model
=
unwrapped_model
.
module
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
...
@@ -370,8 +370,8 @@ def train_step(forward_step_func, data_iterator,
...
@@ -370,8 +370,8 @@ def train_step(forward_step_func, data_iterator,
unwrapped_model
=
model
[
0
]
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
unwrapped_model
=
model
[
-
1
]
while
isinstance
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16Module
)):
unwrapped_model
=
unwrap_model
(
unwrapped_model
=
unwrapped_model
.
m
odule
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16M
odule
))
if
unwrapped_model
.
share_word_embeddings
:
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
...
...
megatron/utils.py
View file @
cc691cbf
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
import
sys
import
sys
import
torch
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
torchDDP
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
import
amp_C
...
@@ -26,11 +27,25 @@ from megatron import get_args
...
@@ -26,11 +27,25 @@ from megatron import get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
save_checkpoint
from
megatron.model.module
import
param_is_not_shared
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
def
unwrap_model
(
model
,
module_instances
=
(
torchDDP
)):
return_list
=
True
if
not
isinstance
(
model
,
list
):
model
=
[
model
]
return_list
=
False
unwrapped_model
=
[]
for
model_module
in
model
:
while
isinstance
(
model_module
,
module_instances
):
model_module
=
model_module
.
module
unwrapped_model
.
append
(
model_module
)
if
not
return_list
:
return
unwrapped_model
[
0
]
return
unwrapped_model
def
calc_params_l2_norm
(
model
):
def
calc_params_l2_norm
(
model
):
"""Calculate l2 norm of parameters """
"""Calculate l2 norm of parameters """
# Remove duplicate params.
# Remove duplicate params.
...
@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration):
...
@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration):
def
check_adlr_autoresume_termination
(
iteration
,
model
,
def
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
optimizer
,
lr_scheduler
):
"""Check for autoresume signal and exit if it is received."""
"""Check for autoresume signal and exit if it is received."""
from
megatron.checkpointing
import
save_checkpoint
args
=
get_args
()
args
=
get_args
()
autoresume
=
get_adlr_autoresume
()
autoresume
=
get_adlr_autoresume
()
# Add barrier to ensure consistnecy.
# Add barrier to ensure consistnecy.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment