Commit 378db1cd authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'checkpoint_util_fix' into 'main'

Some quick fixes to checkpoint_util.

See merge request ADLR/megatron-lm!568
parents 8dbd0757 be1a575e
...@@ -1012,8 +1012,9 @@ class ParallelTransformer(MegatronModule): ...@@ -1012,8 +1012,9 @@ class ParallelTransformer(MegatronModule):
import transformer_engine import transformer_engine
self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid
self.fp8_recipe = None self.fp8_recipe = None
self.fp8_group = mpu.get_data_parallel_group() self.fp8_group = None
if self.use_fp8: if self.use_fp8:
self.fp8_group = mpu.get_data_parallel_group()
if args.fp8_e4m3: if args.fp8_e4m3:
fp8_format = transformer_engine.common.recipe.Format.E4M3 fp8_format = transformer_engine.common.recipe.Format.E4M3
elif args.fp8_hybrid: elif args.fp8_hybrid:
......
...@@ -43,6 +43,7 @@ def _load_checkpoint(queue, args): ...@@ -43,6 +43,7 @@ def _load_checkpoint(queue, args):
'--no-masked-softmax-fusion', '--no-masked-softmax-fusion',
'--no-bias-gelu-fusion', '--no-bias-gelu-fusion',
'--no-bias-dropout-fusion', '--no-bias-dropout-fusion',
'--no-async-tensor-model-parallel-allreduce',
'--use-cpu-initialization', '--use-cpu-initialization',
'--micro-batch-size', '1', '--micro-batch-size', '1',
'--no-load-optim', '--no-load-optim',
...@@ -101,7 +102,7 @@ def _load_checkpoint(queue, args): ...@@ -101,7 +102,7 @@ def _load_checkpoint(queue, args):
nonlocal consumed_valid_samples nonlocal consumed_valid_samples
models = [] models = []
for rank in range(count): for rank in range(count):
mpu.parallel_state.set_tensor_model_parallel_rank(rank) mpu.set_tensor_model_parallel_rank(rank)
model_ = [model_provider(pre_process, post_process).to(dtype)] model_ = [model_provider(pre_process, post_process).to(dtype)]
margs.consumed_train_samples = 0 margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0 margs.consumed_valid_samples = 0
...@@ -125,8 +126,8 @@ def _load_checkpoint(queue, args): ...@@ -125,8 +126,8 @@ def _load_checkpoint(queue, args):
exit(1) exit(1)
set_global_variables(margs) set_global_variables(margs)
mpu.parallel_state.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.parallel_state.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
fused_kernels.load(margs) fused_kernels.load(margs)
# Get true (non-padded) vocab size # Get true (non-padded) vocab size
...@@ -164,7 +165,7 @@ def _load_checkpoint(queue, args): ...@@ -164,7 +165,7 @@ def _load_checkpoint(queue, args):
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
# Get first pipe stage # Get first pipe stage
mpu.parallel_state.set_pipeline_model_parallel_rank(0) mpu.set_pipeline_model_parallel_rank(0)
post_process = pp_size == 1 post_process = pp_size == 1
models = get_models(tp_size, md.params_dtype, True, post_process) models = get_models(tp_size, md.params_dtype, True, post_process)
...@@ -190,7 +191,7 @@ def _load_checkpoint(queue, args): ...@@ -190,7 +191,7 @@ def _load_checkpoint(queue, args):
total_layer_num = 0 total_layer_num = 0
for pp_rank in range(pp_size): for pp_rank in range(pp_size):
if pp_rank > 0: if pp_rank > 0:
mpu.parallel_state.set_pipeline_model_parallel_rank(pp_rank) mpu.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == pp_size - 1 post_process = pp_rank == pp_size - 1
models = get_models(tp_size, md.params_dtype, False, post_process) models = get_models(tp_size, md.params_dtype, False, post_process)
for layer_num in range(len(models[0].language_model.encoder.layers)): for layer_num in range(len(models[0].language_model.encoder.layers)):
...@@ -242,7 +243,6 @@ def _load_checkpoint(queue, args): ...@@ -242,7 +243,6 @@ def _load_checkpoint(queue, args):
# Send BERT lm head and binary head if it exists # Send BERT lm head and binary head if it exists
if md.model_type == 'BERT': if md.model_type == 'BERT':
print("Sending LM Pooler")
message = { message = {
"weight": models[0].language_model.pooler.dense.weight.data, "weight": models[0].language_model.pooler.dense.weight.data,
"bias": models[0].language_model.pooler.dense.bias.data "bias": models[0].language_model.pooler.dense.bias.data
...@@ -258,8 +258,6 @@ def _load_checkpoint(queue, args): ...@@ -258,8 +258,6 @@ def _load_checkpoint(queue, args):
queue_put("lm head", message) queue_put("lm head", message)
if md.bert_binary_head: if md.bert_binary_head:
print("Sending BERT Binary head")
queue.put("binary head")
message = { message = {
"weight": models[0].binary_head.weight.data, "weight": models[0].binary_head.weight.data,
"bias": models[0].binary_head.bias.data "bias": models[0].binary_head.bias.data
......
...@@ -102,6 +102,7 @@ def save_checkpoint(queue, args): ...@@ -102,6 +102,7 @@ def save_checkpoint(queue, args):
'--no-masked-softmax-fusion', '--no-masked-softmax-fusion',
'--no-bias-gelu-fusion', '--no-bias-gelu-fusion',
'--no-bias-dropout-fusion', '--no-bias-dropout-fusion',
'--no-async-tensor-model-parallel-allreduce',
'--use-cpu-initialization', '--use-cpu-initialization',
'--micro-batch-size', '1', '--micro-batch-size', '1',
'--no-load-optim', '--no-load-optim',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment