Commit 3b9dc880 authored by Jared Casper's avatar Jared Casper
Browse files

Update to work with latest main branch.

parent 67aa8619
...@@ -96,6 +96,16 @@ class MegatronModule(torch.nn.Module): ...@@ -96,6 +96,16 @@ class MegatronModule(torch.nn.Module):
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True self.word_embeddings.weight.shared = True
if not torch.distributed.is_initialized():
if not getattr(MegatronModule, "embedding_warning_printed", False):
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
MegatronModule.embedding_warning_printed = True
return
# Zero out initial weights for decoder embedding. # Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule. # NOTE: We don't currently support T5 with the interleaved schedule.
if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \ if not mpu.is_pipeline_first_stage(ignore_virtual=True) and \
...@@ -105,7 +115,6 @@ class MegatronModule(torch.nn.Module): ...@@ -105,7 +115,6 @@ class MegatronModule(torch.nn.Module):
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if torch.distributed.is_initialized():
if mpu.is_rank_in_embedding_group(): if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
...@@ -124,12 +133,6 @@ class MegatronModule(torch.nn.Module): ...@@ -124,12 +133,6 @@ class MegatronModule(torch.nn.Module):
position_embeddings = self.language_model.embedding.position_embeddings position_embeddings = self.language_model.embedding.position_embeddings
torch.distributed.all_reduce(position_embeddings.weight.data, torch.distributed.all_reduce(position_embeddings.weight.data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
def conversion_helper(val, conversion): def conversion_helper(val, conversion):
......
...@@ -23,34 +23,13 @@ def _load_checkpoint(queue, args): ...@@ -23,34 +23,13 @@ def _load_checkpoint(queue, args):
from megatron.arguments import parse_args, validate_args from megatron.arguments import parse_args, validate_args
from megatron.global_vars import set_args, set_global_variables, rebuild_tokenizer from megatron.global_vars import set_args, set_global_variables, rebuild_tokenizer
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron.model import ModelType
from megatron import mpu, fused_kernels from megatron import mpu, fused_kernels
except ModuleNotFoundError: except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
queue.put("exit") queue.put("exit")
exit(1) exit(1)
def get_models(count, dtype, pre_process, post_process):
if args.model_type == 'GPT':
from pretrain_gpt import model_provider
elif args.model_type == 'BERT':
from pretrain_bert import model_provider
else:
raise Exception(f'unrecognized model type: {args.model_type}')
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
# models = [f.result().bfloat16() for f in futures]
models = []
for rank in range(count):
mpu.initialize.set_tensor_model_parallel_rank(rank)
model_ = [model_provider(pre_process, post_process).to(dtype)]
margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0
load_checkpoint(model_, None, None)
assert(len(model_) == 1)
models.append(model_[0])
return models
# We want all arguments to come from us # We want all arguments to come from us
sys.argv = ['script.py', sys.argv = ['script.py',
'--no-masked-softmax-fusion', '--no-masked-softmax-fusion',
...@@ -95,6 +74,31 @@ def _load_checkpoint(queue, args): ...@@ -95,6 +74,31 @@ def _load_checkpoint(queue, args):
check_for_arg('params_dtype') check_for_arg('params_dtype')
# Determine how to make our models
if args.model_type == 'GPT':
from pretrain_gpt import model_provider
margs.model_type = ModelType.encoder_or_decoder
elif args.model_type == 'BERT':
from pretrain_bert import model_provider
margs.model_type = ModelType.encoder_or_decoder
else:
raise Exception(f'unrecognized model type: {args.model_type}')
def get_models(count, dtype, pre_process, post_process):
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
# models = [f.result().bfloat16() for f in futures]
models = []
for rank in range(count):
mpu.initialize.set_tensor_model_parallel_rank(rank)
model_ = [model_provider(pre_process, post_process).to(dtype)]
margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0
load_checkpoint(model_, None, None)
assert(len(model_) == 1)
models.append(model_[0])
return models
set_args(margs) set_args(margs)
if margs.num_layers_per_virtual_pipeline_stage is not None: if margs.num_layers_per_virtual_pipeline_stage is not None:
......
...@@ -30,6 +30,7 @@ def save_checkpoint(queue, args): ...@@ -30,6 +30,7 @@ def save_checkpoint(queue, args):
try: try:
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.global_vars import set_global_variables, get_args from megatron.global_vars import set_global_variables, get_args
from megatron.model import ModelType
from megatron import mpu, fused_kernels from megatron import mpu, fused_kernels
except ModuleNotFoundError: except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
...@@ -44,18 +45,6 @@ def save_checkpoint(queue, args): ...@@ -44,18 +45,6 @@ def save_checkpoint(queue, args):
md = queue_get() md = queue_get()
def get_models(count, dtype, pre_process, post_process):
if md.model_type == 'GPT':
from pretrain_gpt import model_provider
elif md.model_type == 'BERT':
from pretrain_bert import model_provider
else:
raise Exception(f'unrecognized model type: {md.model_type}')
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
# models = [f.result().bfloat16() for f in futures]
models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
return models
if args.target_tensor_parallel_size is None: if args.target_tensor_parallel_size is None:
if hasattr(md, 'previous_tensor_parallel_size'): if hasattr(md, 'previous_tensor_parallel_size'):
...@@ -114,6 +103,20 @@ def save_checkpoint(queue, args): ...@@ -114,6 +103,20 @@ def save_checkpoint(queue, args):
# margs = megatron args # margs = megatron args
margs = get_args() margs = get_args()
# Determine how to make our models
if md.model_type == 'GPT':
from pretrain_gpt import model_provider
margs.model_type = ModelType.encoder_or_decoder
elif md.model_type == 'BERT':
from pretrain_bert import model_provider
margs.model_type = ModelType.encoder_or_decoder
else:
raise Exception(f'unrecognized model type: {args.model_type}')
def get_models(count, dtype, pre_process, post_process):
models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)]
return models
# fake initializing distributed # fake initializing distributed
mpu.initialize.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size) mpu.initialize.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size) mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment