Commit c2ea914f authored by Jared Casper's avatar Jared Casper
Browse files

Move layers from mpu to core.tensor_parallel.

parent 209f91c9
...@@ -20,6 +20,7 @@ from megatron import get_num_microbatches ...@@ -20,6 +20,7 @@ from megatron import get_num_microbatches
from megatron import is_last_rank from megatron import is_last_rank
from megatron import update_num_microbatches from megatron import update_num_microbatches
from megatron import mpu from megatron import mpu
from megatron import core
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import print_rank_last from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
...@@ -257,7 +258,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -257,7 +258,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
# are set for all params so the optimizer can use them. # are set for all params so the optimizer can use them.
for model_module in model: for model_module in model:
for param in model_module.parameters(): for param in model_module.parameters():
mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param) core.tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters. # Print number of parameters.
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
......
...@@ -14,7 +14,7 @@ from megatron import get_args ...@@ -14,7 +14,7 @@ from megatron import get_args
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
def unwrap_model(model, module_instances=(torchDDP)): def unwrap_model(model, module_instances=(torchDDP)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment