Commit 378db1cd authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'checkpoint_util_fix' into 'main'

Some quick fixes to checkpoint_util.

See merge request ADLR/megatron-lm!568
parents 8dbd0757 be1a575e
......@@ -1012,8 +1012,9 @@ class ParallelTransformer(MegatronModule):
import transformer_engine
self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid
self.fp8_recipe = None
self.fp8_group = mpu.get_data_parallel_group()
self.fp8_group = None
if self.use_fp8:
self.fp8_group = mpu.get_data_parallel_group()
if args.fp8_e4m3:
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif args.fp8_hybrid:
......
......@@ -43,6 +43,7 @@ def _load_checkpoint(queue, args):
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--no-async-tensor-model-parallel-allreduce',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
......@@ -101,7 +102,7 @@ def _load_checkpoint(queue, args):
nonlocal consumed_valid_samples
models = []
for rank in range(count):
mpu.parallel_state.set_tensor_model_parallel_rank(rank)
mpu.set_tensor_model_parallel_rank(rank)
model_ = [model_provider(pre_process, post_process).to(dtype)]
margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0
......@@ -125,8 +126,8 @@ def _load_checkpoint(queue, args):
exit(1)
set_global_variables(margs)
mpu.parallel_state.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.parallel_state.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
fused_kernels.load(margs)
# Get true (non-padded) vocab size
......@@ -164,7 +165,7 @@ def _load_checkpoint(queue, args):
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
# Get first pipe stage
mpu.parallel_state.set_pipeline_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(0)
post_process = pp_size == 1
models = get_models(tp_size, md.params_dtype, True, post_process)
......@@ -190,7 +191,7 @@ def _load_checkpoint(queue, args):
total_layer_num = 0
for pp_rank in range(pp_size):
if pp_rank > 0:
mpu.parallel_state.set_pipeline_model_parallel_rank(pp_rank)
mpu.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == pp_size - 1
models = get_models(tp_size, md.params_dtype, False, post_process)
for layer_num in range(len(models[0].language_model.encoder.layers)):
......@@ -242,7 +243,6 @@ def _load_checkpoint(queue, args):
# Send BERT lm head and binary head if it exists
if md.model_type == 'BERT':
print("Sending LM Pooler")
message = {
"weight": models[0].language_model.pooler.dense.weight.data,
"bias": models[0].language_model.pooler.dense.bias.data
......@@ -258,8 +258,6 @@ def _load_checkpoint(queue, args):
queue_put("lm head", message)
if md.bert_binary_head:
print("Sending BERT Binary head")
queue.put("binary head")
message = {
"weight": models[0].binary_head.weight.data,
"bias": models[0].binary_head.bias.data
......
......@@ -102,6 +102,7 @@ def save_checkpoint(queue, args):
'--no-masked-softmax-fusion',
'--no-bias-gelu-fusion',
'--no-bias-dropout-fusion',
'--no-async-tensor-model-parallel-allreduce',
'--use-cpu-initialization',
'--micro-batch-size', '1',
'--no-load-optim',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment