Commit 5942af97 authored by Jared Casper's avatar Jared Casper
Browse files

Alias core.parallel_state as mpu and use it throughout code. RIP mpu.

parent c2ea914f
......@@ -8,7 +8,7 @@ import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_args
from megatron import mpu
from megatron.core import mpu
from .module import MegatronModule
......
......@@ -5,8 +5,7 @@
import torch
from megatron import get_args
from megatron import mpu
from megatron import core
from megatron.core import tensor_parallel
from .module import MegatronModule
from .enums import AttnMaskType
......@@ -34,9 +33,9 @@ def post_language_model_processing(lm_output, labels, logit_weights,
labels = labels.transpose(0,1).contiguous()
if fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = core.tensor_parallel.vocab_parallel_cross_entropy(output, labels)
loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else:
loss = core.tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
# [s b] => [b, s]
loss = loss.transpose(0,1).contiguous()
......
......@@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import core
from megatron.core import mpu, tensor_parallel
from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer
......@@ -22,15 +22,15 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if args.async_tensor_model_parallel_allreduce or\
args.sequence_parallel:
input_parallel = input_
model_parallel = core.get_tensor_model_parallel_world_size() > 1
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.sequence_parallel
else:
input_parallel = core.tensor_parallel.copy_to_tensor_model_parallel_region(input_)
input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
# Matrix multiply.
logits_parallel = core.tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
input=input_parallel,
weight=word_embeddings_weight,
bias=bias,
......@@ -42,7 +42,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if parallel_output:
return logits_parallel
return core.tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler,
......@@ -106,7 +106,7 @@ class Pooler(MegatronModule):
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = core.tensor_parallel.gather_from_sequence_parallel_region(
hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
hidden_states,
tensor_parallel_output_grad=False)
......@@ -146,7 +146,7 @@ class Embedding(MegatronModule):
args = get_args()
# Word embeddings (parallel).
self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding(
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size,
init_method=self.init_method,
params_dtype=args.params_dtype,
......@@ -229,8 +229,8 @@ class Embedding(MegatronModule):
# Dropout.
if self.sequence_parallel:
embeddings = core.tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
with core.tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
......
......@@ -7,8 +7,7 @@ from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron import get_args
from megatron import mpu
from megatron import core
from megatron.core import mpu, tensor_parallel
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
......@@ -77,7 +76,7 @@ class MegatronModule(torch.nn.Module):
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding(
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std),
params_dtype=args.params_dtype,
......
......@@ -5,7 +5,6 @@
import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
......
......@@ -5,7 +5,7 @@ from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel
from .module import MegatronModule
from megatron import mpu
from megatron.core import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
......
......@@ -4,10 +4,8 @@
import torch
from megatron import (
get_args,
mpu
)
from megatron import get_args
from megatron.core import tensor_parallel
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits, get_language_model
from megatron.model.transformer import LayerNorm
......@@ -151,10 +149,10 @@ class T5Model(MegatronModule):
lm_labels = lm_labels.transpose(0,1).contiguous()
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
# [s b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss
......
......@@ -6,10 +6,9 @@ from contextlib import nullcontext
import torch
import torch.nn.functional as F
from megatron import get_timers, get_args
from megatron.core import get_global_memory_buffer
from megatron import core
from megatron import get_timers, get_args, core
from .module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
......@@ -79,7 +78,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
self.dense_h_to_4h = core.tensor_parallel.ColumnParallelLinear(
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.ffn_hidden_size,
gather_output=False,
......@@ -96,7 +95,7 @@ class ParallelMLP(MegatronModule):
self.activation_func = erf_gelu
# Project back to h.
self.dense_4h_to_h = core.tensor_parallel.RowParallelLinear(
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
args.ffn_hidden_size,
args.hidden_size,
input_is_parallel=True,
......@@ -189,7 +188,7 @@ class CoreAttention(MegatronModule):
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = core.get_tensor_model_parallel_world_size()
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = core.utils.divide(
......@@ -237,7 +236,7 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = get_global_memory_buffer().get_tensor(
matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]),
query_layer.dtype, "mpu")
......@@ -263,7 +262,7 @@ class CoreAttention(MegatronModule):
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel:
with core.tensor_parallel.get_cuda_rng_tracker().fork():
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
......@@ -327,7 +326,7 @@ class ParallelAttention(MegatronModule):
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = core.get_tensor_model_parallel_world_size()
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide(
......@@ -335,7 +334,7 @@ class ParallelAttention(MegatronModule):
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = core.tensor_parallel.ColumnParallelLinear(
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
gather_output=False,
......@@ -344,7 +343,7 @@ class ParallelAttention(MegatronModule):
**_args_to_kwargs())
else:
assert attention_type == AttnType.cross_attn
self.query = core.tensor_parallel.ColumnParallelLinear(
self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
......@@ -353,7 +352,7 @@ class ParallelAttention(MegatronModule):
**_args_to_kwargs())
self.key_value = core.tensor_parallel.ColumnParallelLinear(
self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
......@@ -366,7 +365,7 @@ class ParallelAttention(MegatronModule):
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
# Output.
self.dense = core.tensor_parallel.RowParallelLinear(
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
input_is_parallel=True,
......@@ -386,7 +385,7 @@ class ParallelAttention(MegatronModule):
value_layer, attention_mask)
return output_
hidden_states = core.tensor_parallel.checkpoint(
hidden_states = tensor_parallel.checkpoint(
custom_forward,
False, query_layer, key_layer, value_layer, attention_mask)
......@@ -439,7 +438,7 @@ class ParallelAttention(MegatronModule):
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
......@@ -452,7 +451,7 @@ class ParallelAttention(MegatronModule):
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
......@@ -769,7 +768,7 @@ class ParallelTransformer(MegatronModule):
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = core.get_num_layers(
self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
......@@ -799,21 +798,21 @@ class ParallelTransformer(MegatronModule):
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = core.get_virtual_pipeline_model_parallel_rank() * (
offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(core.get_pipeline_model_parallel_rank() * self.num_layers)
(mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \
core.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = core.get_pipeline_model_parallel_rank()
mpu.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = core.get_pipeline_model_parallel_rank() * self.num_layers
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
......@@ -862,7 +861,7 @@ class ParallelTransformer(MegatronModule):
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = core.tensor_parallel.checkpoint(
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
......@@ -874,7 +873,7 @@ class ParallelTransformer(MegatronModule):
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.recompute_num_layers:
hidden_states = core.tensor_parallel.checkpoint(
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
......@@ -932,7 +931,7 @@ class ParallelTransformer(MegatronModule):
)
if self.sequence_parallel:
rng_context = core.tensor_parallel.get_cuda_rng_tracker().fork()
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
......
import torch.nn.functional as F
import torch
from megatron import print_rank_0, get_args, mpu
from megatron import print_rank_0, get_args
from megatron.core import mpu
from megatron.data.vit_dataset import ClassificationTransform
from megatron.data.image_folder import ImageFolder
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model parallel utility interface."""
from .initialize import is_unitialized
from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank
from .initialize import get_data_parallel_world_size
from .initialize import get_embedding_group
from .initialize import get_position_embedding_group
from .initialize import get_model_parallel_group
from .initialize import get_tensor_model_parallel_group
from .initialize import get_pipeline_model_parallel_group
from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import is_rank_in_embedding_group
from .initialize import is_rank_in_position_embedding_group
from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split
from .initialize import is_pipeline_stage_at_split
from .initialize import get_num_layers
from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_data_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank
from .initialize import get_pipeline_model_parallel_next_rank
from .initialize import get_pipeline_model_parallel_prev_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
from .utils import divide
from .utils import split_tensor_along_last_dim
......@@ -8,10 +8,9 @@ import torch
from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.model.module import param_is_not_shared
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
......@@ -290,9 +289,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_param = model_param.detach().view(-1) \
[param_range.start:param_range.end]
shard_main_param = shard_model_param.clone().float()
mpu.copy_tensor_model_parallel_attributes(
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param)
mpu.copy_tensor_model_parallel_attributes(
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
......@@ -309,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[param_range.start:param_range.end]
model_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param)
mpu.copy_tensor_model_parallel_attributes(
tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param)
if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared
......
......@@ -11,13 +11,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_timers
from megatron import mpu
from megatron import core
from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
from megatron.model.module import param_is_not_shared
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
......@@ -103,7 +101,7 @@ class MegatronOptimizer(ABC):
grad = param.grad
grad_not_none = grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
......@@ -528,8 +526,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Create a copy
main_param = param.detach().clone().float()
# Copy tensor model parallel attributes.
core.tensor_parallel.copy_tensor_model_parallel_attributes(main_param,
param)
tensor_parallel.copy_tensor_model_parallel_attributes(main_param,
param)
if hasattr(param, 'shared'):
main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy.
......
......@@ -4,8 +4,8 @@ from functools import reduce
import operator
import torch
from megatron import get_args
from megatron import mpu
from megatron import get_args, core
from megatron.core import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
......@@ -81,10 +81,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel:
if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
tensor_send_next = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
tensor_send_prev = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
if args.use_ring_exchange_p2p:
......@@ -127,18 +127,18 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev = core.tensor_parallel.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev,
requires_grad = True,
keep_graph = False)
tensor_recv_prev = core.utils.make_viewless_tensor(tensor_recv_prev,
requires_grad = True,
keep_graph = False)
if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next = core.tensor_parallel.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_()
tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next,
requires_grad = True,
keep_graph = False)
tensor_recv_next = core.utils.make_viewless_tensor(tensor_recv_next,
requires_grad = True,
keep_graph = False)
return tensor_recv_prev, tensor_recv_next
......
......@@ -8,8 +8,8 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args
from megatron import get_num_microbatches
from megatron import get_timers
from megatron import mpu
from megatron import p2p_communication
from megatron.core import mpu
from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module
......
......@@ -5,7 +5,7 @@
import torch
from megatron import mpu
from megatron.core import mpu
from .communication import broadcast_float_list
from .generation import (
generate_tokens_probs_and_return_on_first_stage,
......
......@@ -5,7 +5,7 @@
import torch
from megatron import mpu
from megatron.core import mpu
......
......@@ -6,9 +6,8 @@ from collections.abc import Iterable
import torch
from megatron import (
get_args,
mpu)
from megatron import get_args
from megatron.core import mpu
from .communication import (
send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_)
......
......@@ -5,7 +5,8 @@
import torch
import torch.nn.functional as F
from megatron import get_args, get_tokenizer, mpu
from megatron import get_args, get_tokenizer
from megatron.core import mpu
from megatron.utils import get_ltor_masks_and_position_ids
from .communication import (
copy_from_last_to_first_pipeline_stage,
......
......@@ -19,8 +19,7 @@ from megatron import get_current_global_batch_size
from megatron import get_num_microbatches
from megatron import is_last_rank
from megatron import update_num_microbatches
from megatron import mpu
from megatron import core
from megatron.core import mpu, tensor_parallel
from megatron import print_rank_0
from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint
......@@ -258,7 +257,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
# are set for all params so the optimizer can use them.
for model_module in model:
for param in model_module.parameters():
core.tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
......
......@@ -10,11 +10,13 @@ from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import get_args
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.model.module import param_is_not_shared
from megatron import (
get_args,
get_adlr_autoresume,
)
from megatron.core import mpu
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.model.module import param_is_not_shared
def unwrap_model(model, module_instances=(torchDDP)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment