Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
787882a0
"tests/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "32e1992924929a9b79e880ed6f5bdc74089e8c73"
Commit
787882a0
authored
Mar 14, 2022
by
Lawrence McAfee
Browse files
cleaned training.py
parent
0f2a9f19
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
36 deletions
+7
-36
megatron/training.py
megatron/training.py
+7
-36
No files found.
megatron/training.py
View file @
787882a0
...
@@ -52,9 +52,6 @@ from megatron.schedules import get_forward_backward_func
...
@@ -52,9 +52,6 @@ from megatron.schedules import get_forward_backward_func
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
from
megatron.model.vision.knn_monitor
import
compute_feature_bank
from
megatron.model.vision.knn_monitor
import
compute_feature_bank
# >>>
from
lutil
import
pax
# <<<
def
print_datetime
(
string
):
def
print_datetime
(
string
):
"""Note that this call will sync across all ranks."""
"""Note that this call will sync across all ranks."""
...
@@ -364,16 +361,11 @@ def setup_model_and_optimizer(model_provider_func,
...
@@ -364,16 +361,11 @@ def setup_model_and_optimizer(model_provider_func,
args
=
get_args
()
args
=
get_args
()
model
=
get_model
(
model_provider_func
,
model_type
)
model
=
get_model
(
model_provider_func
,
model_type
)
unwrapped_model
=
unwrap_model
(
model
,
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
(
torchDDP
,
LocalDDP
,
Float16Module
))
# >>>
# optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
# scale_lr_cond, lr_mult)
optimizer
=
get_megatron_optimizer
(
model
,
no_wd_decay_cond
,
optimizer
=
get_megatron_optimizer
(
model
,
no_wd_decay_cond
,
scale_lr_cond
,
lr_mult
)
scale_lr_cond
,
lr_mult
)
# <<<
opt_param_scheduler
=
get_optimizer_param_scheduler
(
optimizer
)
opt_param_scheduler
=
get_optimizer_param_scheduler
(
optimizer
)
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
...
@@ -405,8 +397,7 @@ def setup_model_and_optimizer(model_provider_func,
...
@@ -405,8 +397,7 @@ def setup_model_and_optimizer(model_provider_func,
def
train_step
(
forward_step_func
,
data_iterator
,
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
opt_param_scheduler
,
model
,
optimizer
,
opt_param_scheduler
):
ITERATION
):
"""Single training step."""
"""Single training step."""
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -417,50 +408,35 @@ def train_step(forward_step_func, data_iterator,
...
@@ -417,50 +408,35 @@ def train_step(forward_step_func, data_iterator,
partition
.
zero_grad_buffer
()
partition
.
zero_grad_buffer
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
# >>>
# Forward pass.
# Forward pass.
# <<<
forward_backward_func
=
get_forward_backward_func
()
forward_backward_func
=
get_forward_backward_func
()
losses_reduced
=
forward_backward_func
(
losses_reduced
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
=
False
)
optimizer
,
timers
,
forward_only
=
False
)
# >>>
# Empty unused memory.
# Empty unused memory.
# <<<
if
args
.
empty_unused_memory_level
>=
1
:
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# >>>
# optimizer.debug_model(ITERATION, "before reduce grads.", 1)
# <<<
# >>>
# Reduce gradients.
# Reduce gradients.
optimizer
.
reduce_model_grads
(
args
,
timers
)
optimizer
.
reduce_model_grads
(
args
,
timers
)
# <<<
# Vision gradients.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
],
unwrapped_model
=
unwrap_model
(
model
[
0
],
(
torchDDP
,
LocalDDP
,
Float16Module
))
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
(
args
,
timers
,
ITERATION
)
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
(
args
,
timers
)
timers
(
'optimizer'
).
stop
()
timers
(
'optimizer'
).
stop
()
# >>>
# Gather params.
# Gather params.
if
update_successful
:
if
update_successful
:
optimizer
.
gather_model_params
(
args
,
timers
,
ITERATION
)
optimizer
.
gather_model_params
(
args
,
timers
)
# <<<
# >>>
# optimizer.debug_model(ITERATION, "after gather params.", 0)
# <<<
# Vision momentum.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
],
unwrapped_model
=
unwrap_model
(
model
[
0
],
(
torchDDP
,
LocalDDP
,
Float16Module
))
(
torchDDP
,
LocalDDP
,
Float16Module
))
...
@@ -476,9 +452,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -476,9 +452,7 @@ def train_step(forward_step_func, data_iterator,
else
:
else
:
skipped_iter
=
1
skipped_iter
=
1
# >>>
# Empty unused memory.
# Empty unused memory.
# <<<
if
args
.
empty_unused_memory_level
>=
2
:
if
args
.
empty_unused_memory_level
>=
2
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -706,10 +680,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
...
@@ -706,10 +680,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator
,
train_data_iterator
,
model
,
model
,
optimizer
,
optimizer
,
opt_param_scheduler
opt_param_scheduler
)
# >>>
,
ITERATION
=
iteration
)
# <<<
iteration
+=
1
iteration
+=
1
args
.
consumed_train_samples
+=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
consumed_train_samples
+=
mpu
.
get_data_parallel_world_size
()
*
\
args
.
micro_batch_size
*
\
args
.
micro_batch_size
*
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment