Commit 58cca6e9 authored by Jared Casper's avatar Jared Casper
Browse files

Transfer consumed train and valid samples to converted checkpoint.

parent 942c402d
......@@ -92,10 +92,11 @@ def _load_checkpoint(queue, args):
# supress warning about torch.distributed not being initialized
module.MegatronModule.embedding_warning_printed = True
consumed_train_samples = None
consumed_valid_samples = None
def get_models(count, dtype, pre_process, post_process):
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
# models = [f.result().bfloat16() for f in futures]
nonlocal consumed_train_samples
nonlocal consumed_valid_samples
models = []
for rank in range(count):
mpu.initialize.set_tensor_model_parallel_rank(rank)
......@@ -104,7 +105,16 @@ def _load_checkpoint(queue, args):
margs.consumed_valid_samples = 0
load_checkpoint(model_, None, None)
assert(len(model_) == 1)
models.append(model_[0])
model_ = model_[0]
if consumed_train_samples is not None:
assert(margs.consumed_train_samples == consumed_train_samples)
else:
consumed_train_samples = margs.consumed_train_samples
if consumed_valid_samples is not None:
assert(margs.consumed_valid_samples == consumed_valid_samples)
else:
consumed_valid_samples = margs.consumed_valid_samples
models.append(model_)
return models
if margs.num_layers_per_virtual_pipeline_stage is not None:
......@@ -150,13 +160,16 @@ def _load_checkpoint(queue, args):
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size
md.true_vocab_size = true_vocab_size
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
queue.put(md)
# Get first pipe stage
mpu.initialize.set_pipeline_model_parallel_rank(0)
post_process = pp_size == 1
models = get_models(tp_size, md.params_dtype, True, post_process)
md.consumed_train_samples = consumed_train_samples
md.consumed_valid_samples = consumed_valid_samples
queue.put(md)
# Send embeddings
word_embed = []
......
......@@ -110,6 +110,14 @@ def save_checkpoint(queue, args):
# margs = megatron args
margs = get_args()
if hasattr(md, 'consumed_train_samples'):
margs.consumed_train_samples = md.consumed_train_samples
margs.consumed_valid_samples = md.consumed_valid_samples
print(f"Setting consumed_train_samples to {margs.consumed_train_samples}"
f" and consumed_valid_samples to {margs.consumed_valid_samples}")
else:
print("consumed_train_samples not provided.")
# Determine how to make our models
if md.model_type == 'GPT':
from pretrain_gpt import model_provider
......
......@@ -32,6 +32,10 @@ import os
# bert_binary_head - Used only if model_type is BERT
# previous_tensor_parallel_size - Optional
# previous_pipeline_parallel_size - Optional
# true_vocab_size
# make_vocab_size_divisble_by
# consumed_train_samples
# consumed_valid_samples
# - Position embeddings
# - Word embeddings
# - For each transformer layer:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment