"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "0a7c354dc1d2e2a6642ef8fbb28a6ad1da89283e"
Commit cb6f96b6 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

wip; switching to grad-buffer-centric design

parent a3f3c3ad
...@@ -130,9 +130,11 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -130,9 +130,11 @@ def parse_args(extra_args_provider=None, defaults={},
args.global_batch_size), flush=True) args.global_batch_size), flush=True)
assert args.global_batch_size > 0 assert args.global_batch_size > 0
if args.num_layers_per_virtual_pipeline_stage is not None: if args.num_layers_per_virtual_pipeline_stage is not None:
assert args.pipeline_model_parallel_size > 2, \ # >>> [ temporarily turning off ]
'pipeline-model-parallel size should be greater than 2 with ' \ # assert args.pipeline_model_parallel_size > 2, \
'interleaved schedule' # 'pipeline-model-parallel size should be greater than 2 with ' \
# 'interleaved schedule'
# <<<
assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \ assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
'number of layers is not divisible by number of layers per virtual ' \ 'number of layers is not divisible by number of layers per virtual ' \
'pipeline stage' 'pipeline stage'
......
...@@ -97,11 +97,11 @@ def get_megatron_optimizer(model, ...@@ -97,11 +97,11 @@ def get_megatron_optimizer(model,
# from lutil import pax # from lutil import pax
# pax(0, { # pax(0, {
# "model" : model, # "model" : model,
# "param_groups" : param_groups, # # "param_groups" : param_groups,
# "param_groups / 0" : param_groups[0], # # "param_groups / 0" : param_groups[0],
# "param_groups / 0 / params" : param_groups[0]["params"], # # "param_groups / 0 / params" : param_groups[0]["params"],
# "param_groups / 1" : param_groups[1], # # "param_groups / 1" : param_groups[1],
# "param_groups / 1 / params" : param_groups[1]["params"], # # "param_groups / 1 / params" : param_groups[1]["params"],
# }) # })
# <<< # <<<
...@@ -164,7 +164,8 @@ def get_megatron_optimizer(model, ...@@ -164,7 +164,8 @@ def get_megatron_optimizer(model,
params_have_main_grad, params_have_main_grad,
args.use_contiguous_buffers_in_local_ddp, args.use_contiguous_buffers_in_local_ddp,
args.bf16, args.bf16,
grad_scaler) grad_scaler,
model)
# <<< # <<<
# FP32. # FP32.
......
This diff is collapsed.
...@@ -365,8 +365,12 @@ def setup_model_and_optimizer(model_provider_func, ...@@ -365,8 +365,12 @@ def setup_model_and_optimizer(model_provider_func,
unwrapped_model = unwrap_model(model, unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, Float16Module)) (torchDDP, LocalDDP, Float16Module))
optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond, # >>>
# optimizer = get_megatron_optimizer(unwrapped_model, no_wd_decay_cond,
# scale_lr_cond, lr_mult)
optimizer = get_megatron_optimizer(model, no_wd_decay_cond,
scale_lr_cond, lr_mult) scale_lr_cond, lr_mult)
# <<<
opt_param_scheduler = get_optimizer_param_scheduler(optimizer) opt_param_scheduler = get_optimizer_param_scheduler(optimizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment