"vscode:/vscode.git/clone" did not exist on "53ca15529a89bce5b776d36e929da9e98737555b"
Commit 5942af97 authored by Jared Casper's avatar Jared Casper
Browse files

Alias core.parallel_state as mpu and use it throughout code. RIP mpu.

parent c2ea914f
......@@ -30,7 +30,8 @@ def _load_checkpoint(queue, args):
from megatron.global_vars import set_args, set_global_variables
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint
from megatron.model import ModelType, module
from megatron import mpu, fused_kernels
from megatron.core import mpu
from megatron import fused_kernels
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
queue.put("exit")
......@@ -99,7 +100,7 @@ def _load_checkpoint(queue, args):
nonlocal consumed_valid_samples
models = []
for rank in range(count):
mpu.initialize.set_tensor_model_parallel_rank(rank)
mpu.parallel_state.set_tensor_model_parallel_rank(rank)
model_ = [model_provider(pre_process, post_process).to(dtype)]
margs.consumed_train_samples = 0
margs.consumed_valid_samples = 0
......@@ -123,8 +124,8 @@ def _load_checkpoint(queue, args):
exit(1)
set_global_variables(margs)
mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
mpu.parallel_state.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
mpu.parallel_state.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
fused_kernels.load(margs)
# Get true (non-padded) vocab size
......@@ -162,7 +163,7 @@ def _load_checkpoint(queue, args):
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by
# Get first pipe stage
mpu.initialize.set_pipeline_model_parallel_rank(0)
mpu.parallel_state.set_pipeline_model_parallel_rank(0)
post_process = pp_size == 1
models = get_models(tp_size, md.params_dtype, True, post_process)
......@@ -188,7 +189,7 @@ def _load_checkpoint(queue, args):
total_layer_num = 0
for pp_rank in range(pp_size):
if pp_rank > 0:
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
mpu.parallel_state.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == pp_size - 1
models = get_models(tp_size, md.params_dtype, False, post_process)
for layer_num in range(len(models[0].language_model.encoder.layers)):
......
......@@ -34,7 +34,8 @@ def save_checkpoint(queue, args):
from megatron.global_vars import set_global_variables, get_args
from megatron.model import ModelType
from megatron.tokenizer.tokenizer import _vocab_size_with_padding
from megatron import mpu, fused_kernels
from megatron import fused_kernels
from megatron.core import mpu
except ModuleNotFoundError:
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.")
exit(1)
......@@ -152,10 +153,10 @@ def save_checkpoint(queue, args):
return models
# fake initializing distributed
mpu.initialize.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
mpu.initialize.set_tensor_model_parallel_rank(0)
mpu.initialize.set_pipeline_model_parallel_rank(0)
mpu.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size)
mpu.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
mpu.set_tensor_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(0)
fused_kernels.load(margs)
# Embeddings
......@@ -197,7 +198,7 @@ def save_checkpoint(queue, args):
out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0)
# Make models for first pipeline stage and fill in embeddings
mpu.initialize.set_pipeline_model_parallel_rank(0)
mpu.set_pipeline_model_parallel_rank(0)
post_process = args.target_pipeline_parallel_size == 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process)
for tp_rank, model in enumerate(models):
......@@ -211,7 +212,7 @@ def save_checkpoint(queue, args):
for pp_rank in range(args.target_pipeline_parallel_size):
# For later pipeline parallel ranks, make the new models
if pp_rank > 0:
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank)
mpu.set_pipeline_model_parallel_rank(pp_rank)
post_process = pp_rank == args.target_pipeline_parallel_size - 1
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process)
......@@ -317,6 +318,6 @@ def save_checkpoint(queue, args):
print("ERROR: got some more data but was expecting to be done")
for tp_rank in range(args.target_tensor_parallel_size):
mpu.initialize.set_tensor_model_parallel_rank(tp_rank)
mpu.set_tensor_model_parallel_rank(tp_rank)
save_checkpoint(md.iteration, [models[tp_rank]], None, None)
print("Done!")
......@@ -8,7 +8,7 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
import socket
from megatron import get_args
from megatron import print_rank_0
from megatron import mpu
from megatron.core import mpu
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.model import GPTModel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment