Commit dfd428cb authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_retriver_merge' into 'main'

ICT Retriever

See merge request ADLR/megatron-lm!235
parents 68233932 3f3ba5e7
...@@ -317,14 +317,12 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -317,14 +317,12 @@ def setup_model_and_optimizer(model_provider_func):
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers # get model without FP16 and/or TorchDDP wrappers
model = unwrap_model(model) if args.iteration == 0 and len(unwrapped_model) == 1 \
for module in model: and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
if args.iteration == 0 and hasattr(module, print_rank_0("Initializing ICT from pretrained BERT model")
'init_state_dict_from_bert'): unwrapped_model[0].init_state_dict_from_bert()
print("Initializing ICT from pretrained BERT model", flush=True) if args.fp16:
module.init_state_dict_from_bert() optimizer.reload_model_params()
if args.fp16:
optimizer.reload_model_params()
return model, optimizer, lr_scheduler return model, optimizer, lr_scheduler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment