"src/vscode:/vscode.git/clone" did not exist on "98b6bee1a13d225167ef1b8cbab910f8b4f25d3c"
Commit 787882a0 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

cleaned training.py

parent 0f2a9f19
......@@ -52,9 +52,6 @@ from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank
# >>>
from lutil import pax
# <<<
def print_datetime(string):
"""Note that this call will sync across all ranks."""
......@@ -364,16 +361,11 @@ def setup_model_and_optimizer(model_provider_func,
args = get_args()
model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module))
# >>>
# optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
# scale_lr_cond, lr_mult)
optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
scale_lr_cond, lr_mult)
# <<<
opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None:
......@@ -405,8 +397,7 @@ def setup_model_and_optimizer(model_provider_func,
def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler,
ITERATION):
model, optimizer, opt_param_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
......@@ -417,50 +408,35 @@ def train_step(forward_step_func, data_iterator,
partition.zero_grad_buffer()
optimizer.zero_grad()
# >>>
# Forward pass.
# <<<
forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False)
# >>>
# Empty unused memory.
# <<<
if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache()
# >>>
# optimizer.debug_model(ITERATION, "before reduce grads.", 1)
# <<<
# >>>
# Reduce gradients.
optimizer.reduce_model_grads(args, timers)
# <<<
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters.
timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers, ITERATION)
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop()
# >>>
# Gather params.
if update_successful:
optimizer.gather_model_params(args, timers, ITERATION)
# <<<
# >>>
# optimizer.debug_model(ITERATION, "after gather params.", 0)
# <<<
optimizer.gather_model_params(args, timers)
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module))
......@@ -476,9 +452,7 @@ def train_step(forward_step_func, data_iterator,
else:
skipped_iter = 1
# >>>
# Empty unused memory.
# <<<
if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache()
......@@ -706,10 +680,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator,
model,
optimizer,
opt_param_scheduler
# >>>
,ITERATION = iteration)
# <<<
opt_param_scheduler)
iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment