"tests/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "32e1992924929a9b79e880ed6f5bdc74089e8c73"
Commit 787882a0 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

cleaned training.py

parent 0f2a9f19
...@@ -52,9 +52,6 @@ from megatron.schedules import get_forward_backward_func ...@@ -52,9 +52,6 @@ from megatron.schedules import get_forward_backward_func
from megatron.utils import report_memory from megatron.utils import report_memory
from megatron.model.vision.knn_monitor import compute_feature_bank from megatron.model.vision.knn_monitor import compute_feature_bank
# >>>
from lutil import pax
# <<<
def print_datetime(string): def print_datetime(string):
"""Note that this call will sync across all ranks.""" """Note that this call will sync across all ranks."""
...@@ -364,16 +361,11 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -364,16 +361,11 @@ def setup_model_and_optimizer(model_provider_func,
args = get_args() args = get_args()
model = get_model(model_provider_func, model_type) model = get_model(model_provider_func, model_type)
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
# >>>
# optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
# scale_lr_cond, lr_mult)
optimizer = get_megatron_optimizer(model, no_wd_decay_cond, optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
scale_lr_cond, lr_mult) scale_lr_cond, lr_mult)
# <<<
opt_param_scheduler = get_optimizer_param_scheduler(optimizer) opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
if args.load is not None: if args.load is not None:
...@@ -405,8 +397,7 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -405,8 +397,7 @@ def setup_model_and_optimizer(model_provider_func,
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, opt_param_scheduler, model, optimizer, opt_param_scheduler):
ITERATION):
"""Single training step.""" """Single training step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -417,50 +408,35 @@ def train_step(forward_step_func, data_iterator, ...@@ -417,50 +408,35 @@ def train_step(forward_step_func, data_iterator,
partition.zero_grad_buffer() partition.zero_grad_buffer()
optimizer.zero_grad() optimizer.zero_grad()
# >>>
# Forward pass. # Forward pass.
# <<<
forward_backward_func = get_forward_backward_func() forward_backward_func = get_forward_backward_func()
losses_reduced = forward_backward_func( losses_reduced = forward_backward_func(
forward_step_func, data_iterator, model, forward_step_func, data_iterator, model,
optimizer, timers, forward_only=False) optimizer, timers, forward_only=False)
# >>>
# Empty unused memory. # Empty unused memory.
# <<<
if args.empty_unused_memory_level >= 1: if args.empty_unused_memory_level >= 1:
torch.cuda.empty_cache() torch.cuda.empty_cache()
# >>>
# optimizer.debug_model(ITERATION, "before reduce grads.", 1)
# <<<
# >>>
# Reduce gradients. # Reduce gradients.
optimizer.reduce_model_grads(args, timers) optimizer.reduce_model_grads(args, timers)
# <<<
# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0], unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration) unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers, ITERATION) update_successful, grad_norm, num_zeros_in_grad = optimizer.step(args, timers)
timers('optimizer').stop() timers('optimizer').stop()
# >>>
# Gather params. # Gather params.
if update_successful: if update_successful:
optimizer.gather_model_params(args, timers, ITERATION) optimizer.gather_model_params(args, timers)
# <<<
# >>>
# optimizer.debug_model(ITERATION, "after gather params.", 0)
# <<<
# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino": if args.vision_pretraining and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0], unwrapped_model = unwrap_model(model[0],
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
...@@ -476,9 +452,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -476,9 +452,7 @@ def train_step(forward_step_func, data_iterator,
else: else:
skipped_iter = 1 skipped_iter = 1
# >>>
# Empty unused memory. # Empty unused memory.
# <<<
if args.empty_unused_memory_level >= 2: if args.empty_unused_memory_level >= 2:
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -706,10 +680,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -706,10 +680,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
opt_param_scheduler opt_param_scheduler)
# >>>
,ITERATION = iteration)
# <<<
iteration += 1 iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \ args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \ args.micro_batch_size * \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment