Commit 5942af97 authored by Jared Casper's avatar Jared Casper
Browse files

Alias core.parallel_state as mpu and use it throughout code. RIP mpu.

parent c2ea914f
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.core import mpu
from .module import MegatronModule from .module import MegatronModule
......
...@@ -5,8 +5,7 @@ ...@@ -5,8 +5,7 @@
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.core import tensor_parallel
from megatron import core
from .module import MegatronModule from .module import MegatronModule
from .enums import AttnMaskType from .enums import AttnMaskType
...@@ -34,9 +33,9 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -34,9 +33,9 @@ def post_language_model_processing(lm_output, labels, logit_weights,
labels = labels.transpose(0,1).contiguous() labels = labels.transpose(0,1).contiguous()
if fp16_lm_cross_entropy: if fp16_lm_cross_entropy:
assert output.dtype == torch.half assert output.dtype == torch.half
loss = core.tensor_parallel.vocab_parallel_cross_entropy(output, labels) loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else: else:
loss = core.tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels) loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
# [s b] => [b, s] # [s b] => [b, s]
loss = loss.transpose(0,1).contiguous() loss = loss.transpose(0,1).contiguous()
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import core from megatron.core import mpu, tensor_parallel
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
...@@ -22,15 +22,15 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -22,15 +22,15 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if args.async_tensor_model_parallel_allreduce or\ if args.async_tensor_model_parallel_allreduce or\
args.sequence_parallel: args.sequence_parallel:
input_parallel = input_ input_parallel = input_
model_parallel = core.get_tensor_model_parallel_world_size() > 1 model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.sequence_parallel model_parallel and not args.sequence_parallel
else: else:
input_parallel = core.tensor_parallel.copy_to_tensor_model_parallel_region(input_) input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False async_grad_allreduce = False
# Matrix multiply. # Matrix multiply.
logits_parallel = core.tensor_parallel.linear_with_grad_accumulation_and_async_allreduce( logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
input=input_parallel, input=input_parallel,
weight=word_embeddings_weight, weight=word_embeddings_weight,
bias=bias, bias=bias,
...@@ -42,7 +42,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -42,7 +42,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
return core.tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
...@@ -106,7 +106,7 @@ class Pooler(MegatronModule): ...@@ -106,7 +106,7 @@ class Pooler(MegatronModule):
# gather data along sequence dimensions # gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes # same pooler is run on all tensor parallel nodes
if self.sequence_parallel: if self.sequence_parallel:
hidden_states = core.tensor_parallel.gather_from_sequence_parallel_region( hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
hidden_states, hidden_states,
tensor_parallel_output_grad=False) tensor_parallel_output_grad=False)
...@@ -146,7 +146,7 @@ class Embedding(MegatronModule): ...@@ -146,7 +146,7 @@ class Embedding(MegatronModule):
args = get_args() args = get_args()
# Word embeddings (parallel). # Word embeddings (parallel).
self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding( self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, vocab_size, self.hidden_size,
init_method=self.init_method, init_method=self.init_method,
params_dtype=args.params_dtype, params_dtype=args.params_dtype,
...@@ -229,8 +229,8 @@ class Embedding(MegatronModule): ...@@ -229,8 +229,8 @@ class Embedding(MegatronModule):
# Dropout. # Dropout.
if self.sequence_parallel: if self.sequence_parallel:
embeddings = core.tensor_parallel.scatter_to_sequence_parallel_region(embeddings) embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
with core.tensor_parallel.get_cuda_rng_tracker().fork(): with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
else: else:
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
......
...@@ -7,8 +7,7 @@ from torch.autograd import Variable ...@@ -7,8 +7,7 @@ from torch.autograd import Variable
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.core import mpu, tensor_parallel
from megatron import core
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
...@@ -77,7 +76,7 @@ class MegatronModule(torch.nn.Module): ...@@ -77,7 +76,7 @@ class MegatronModule(torch.nn.Module):
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first # set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below. # stage's weights using all_reduce below.
self.word_embeddings = core.tensor_parallel.VocabParallelEmbedding( self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size, args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std), init_method=init_method_normal(args.init_method_std),
params_dtype=args.params_dtype, params_dtype=args.params_dtype,
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import torch import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
......
...@@ -5,7 +5,7 @@ from megatron import get_args, print_rank_0 ...@@ -5,7 +5,7 @@ from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel from megatron.model import BertModel
from .module import MegatronModule from .module import MegatronModule
from megatron import mpu from megatron.core import mpu
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
......
...@@ -4,10 +4,8 @@ ...@@ -4,10 +4,8 @@
import torch import torch
from megatron import ( from megatron import get_args
get_args, from megatron.core import tensor_parallel
mpu
)
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits, get_language_model from megatron.model.language_model import parallel_lm_logits, get_language_model
from megatron.model.transformer import LayerNorm from megatron.model.transformer import LayerNorm
...@@ -151,10 +149,10 @@ class T5Model(MegatronModule): ...@@ -151,10 +149,10 @@ class T5Model(MegatronModule):
lm_labels = lm_labels.transpose(0,1).contiguous() lm_labels = lm_labels.transpose(0,1).contiguous()
if self.fp16_lm_cross_entropy: if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels) lm_labels)
# [s b] => [b s] # [s b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous() lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss return lm_loss
......
...@@ -6,10 +6,9 @@ from contextlib import nullcontext ...@@ -6,10 +6,9 @@ from contextlib import nullcontext
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_timers, get_args from megatron import get_timers, get_args, core
from megatron.core import get_global_memory_buffer
from megatron import core
from .module import MegatronModule from .module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
...@@ -79,7 +78,7 @@ class ParallelMLP(MegatronModule): ...@@ -79,7 +78,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = core.tensor_parallel.ColumnParallelLinear( self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
args.ffn_hidden_size, args.ffn_hidden_size,
gather_output=False, gather_output=False,
...@@ -96,7 +95,7 @@ class ParallelMLP(MegatronModule): ...@@ -96,7 +95,7 @@ class ParallelMLP(MegatronModule):
self.activation_func = erf_gelu self.activation_func = erf_gelu
# Project back to h. # Project back to h.
self.dense_4h_to_h = core.tensor_parallel.RowParallelLinear( self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
args.ffn_hidden_size, args.ffn_hidden_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
...@@ -189,7 +188,7 @@ class CoreAttention(MegatronModule): ...@@ -189,7 +188,7 @@ class CoreAttention(MegatronModule):
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = core.get_tensor_model_parallel_world_size() world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = core.utils.divide(projection_size, self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size) world_size)
self.hidden_size_per_attention_head = core.utils.divide( self.hidden_size_per_attention_head = core.utils.divide(
...@@ -237,7 +236,7 @@ class CoreAttention(MegatronModule): ...@@ -237,7 +236,7 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk] # preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = get_global_memory_buffer().get_tensor( matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]), (output_size[0]*output_size[1], output_size[2], output_size[3]),
query_layer.dtype, "mpu") query_layer.dtype, "mpu")
...@@ -263,7 +262,7 @@ class CoreAttention(MegatronModule): ...@@ -263,7 +262,7 @@ class CoreAttention(MegatronModule):
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel: if not self.sequence_parallel:
with core.tensor_parallel.get_cuda_rng_tracker().fork(): with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
else: else:
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
...@@ -327,7 +326,7 @@ class ParallelAttention(MegatronModule): ...@@ -327,7 +326,7 @@ class ParallelAttention(MegatronModule):
projection_size = args.kv_channels * args.num_attention_heads projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = core.get_tensor_model_parallel_world_size() world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = core.utils.divide( self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads) projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide( self.num_attention_heads_per_partition = core.utils.divide(
...@@ -335,7 +334,7 @@ class ParallelAttention(MegatronModule): ...@@ -335,7 +334,7 @@ class ParallelAttention(MegatronModule):
# Strided linear layer. # Strided linear layer.
if attention_type == AttnType.self_attn: if attention_type == AttnType.self_attn:
self.query_key_value = core.tensor_parallel.ColumnParallelLinear( self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
3 * projection_size, 3 * projection_size,
gather_output=False, gather_output=False,
...@@ -344,7 +343,7 @@ class ParallelAttention(MegatronModule): ...@@ -344,7 +343,7 @@ class ParallelAttention(MegatronModule):
**_args_to_kwargs()) **_args_to_kwargs())
else: else:
assert attention_type == AttnType.cross_attn assert attention_type == AttnType.cross_attn
self.query = core.tensor_parallel.ColumnParallelLinear( self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
projection_size, projection_size,
gather_output=False, gather_output=False,
...@@ -353,7 +352,7 @@ class ParallelAttention(MegatronModule): ...@@ -353,7 +352,7 @@ class ParallelAttention(MegatronModule):
**_args_to_kwargs()) **_args_to_kwargs())
self.key_value = core.tensor_parallel.ColumnParallelLinear( self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
2 * projection_size, 2 * projection_size,
gather_output=False, gather_output=False,
...@@ -366,7 +365,7 @@ class ParallelAttention(MegatronModule): ...@@ -366,7 +365,7 @@ class ParallelAttention(MegatronModule):
self.checkpoint_core_attention = args.recompute_granularity == 'selective' self.checkpoint_core_attention = args.recompute_granularity == 'selective'
# Output. # Output.
self.dense = core.tensor_parallel.RowParallelLinear( self.dense = tensor_parallel.RowParallelLinear(
projection_size, projection_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
...@@ -386,7 +385,7 @@ class ParallelAttention(MegatronModule): ...@@ -386,7 +385,7 @@ class ParallelAttention(MegatronModule):
value_layer, attention_mask) value_layer, attention_mask)
return output_ return output_
hidden_states = core.tensor_parallel.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom_forward, custom_forward,
False, query_layer, key_layer, value_layer, attention_mask) False, query_layer, key_layer, value_layer, attention_mask)
...@@ -439,7 +438,7 @@ class ParallelAttention(MegatronModule): ...@@ -439,7 +438,7 @@ class ParallelAttention(MegatronModule):
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, (query_layer,
key_layer, key_layer,
value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3) value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
else: else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output) mixed_kv_layer, _ = self.key_value(encoder_output)
...@@ -452,7 +451,7 @@ class ParallelAttention(MegatronModule): ...@@ -452,7 +451,7 @@ class ParallelAttention(MegatronModule):
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, (key_layer,
value_layer) = core.tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp] # Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states) query_layer, _ = self.query(hidden_states)
...@@ -769,7 +768,7 @@ class ParallelTransformer(MegatronModule): ...@@ -769,7 +768,7 @@ class ParallelTransformer(MegatronModule):
self.sequence_parallel = args.sequence_parallel self.sequence_parallel = args.sequence_parallel
# Number of layers. # Number of layers.
self.num_layers = core.get_num_layers( self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder) args, args.model_type == ModelType.encoder_and_decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)] self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
...@@ -799,21 +798,21 @@ class ParallelTransformer(MegatronModule): ...@@ -799,21 +798,21 @@ class ParallelTransformer(MegatronModule):
# layers to stages like (each list is a model chunk): # layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5] # Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7] # Stage 1: [2, 3] [6, 7]
offset = core.get_virtual_pipeline_model_parallel_rank() * ( offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \ args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(core.get_pipeline_model_parallel_rank() * self.num_layers) (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else: else:
# Each stage gets a contiguous set of layers. # Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \ if args.model_type == ModelType.encoder_and_decoder and \
core.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = core.get_pipeline_model_parallel_rank() pipeline_rank = mpu.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder: if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers offset = pipeline_rank * self.num_layers
else: else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else: else:
offset = core.get_pipeline_model_parallel_rank() * self.num_layers offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
if self.num_layers == 0: if self.num_layers == 0:
# When a standalone embedding stage is used (e.g., # When a standalone embedding stage is used (e.g.,
...@@ -862,7 +861,7 @@ class ParallelTransformer(MegatronModule): ...@@ -862,7 +861,7 @@ class ParallelTransformer(MegatronModule):
# A method to further reduce memory usage reducing checkpoints. # A method to further reduce memory usage reducing checkpoints.
l = 0 l = 0
while l < self.num_layers: while l < self.num_layers:
hidden_states = core.tensor_parallel.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers), custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
...@@ -874,7 +873,7 @@ class ParallelTransformer(MegatronModule): ...@@ -874,7 +873,7 @@ class ParallelTransformer(MegatronModule):
# A method fully use the device memory removing redundant re-computation. # A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers): for l in range(self.num_layers):
if l < self.recompute_num_layers: if l < self.recompute_num_layers:
hidden_states = core.tensor_parallel.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
...@@ -932,7 +931,7 @@ class ParallelTransformer(MegatronModule): ...@@ -932,7 +931,7 @@ class ParallelTransformer(MegatronModule):
) )
if self.sequence_parallel: if self.sequence_parallel:
rng_context = core.tensor_parallel.get_cuda_rng_tracker().fork() rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else: else:
rng_context = nullcontext() rng_context = nullcontext()
......
import torch.nn.functional as F import torch.nn.functional as F
import torch import torch
from megatron import print_rank_0, get_args, mpu from megatron import print_rank_0, get_args
from megatron.core import mpu
from megatron.data.vit_dataset import ClassificationTransform from megatron.data.vit_dataset import ClassificationTransform
from megatron.data.image_folder import ImageFolder from megatron.data.image_folder import ImageFolder
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model parallel utility interface."""
from .initialize import is_unitialized
from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group
from .initialize import get_data_parallel_rank
from .initialize import get_data_parallel_world_size
from .initialize import get_embedding_group
from .initialize import get_position_embedding_group
from .initialize import get_model_parallel_group
from .initialize import get_tensor_model_parallel_group
from .initialize import get_pipeline_model_parallel_group
from .initialize import get_tensor_model_parallel_rank, set_tensor_model_parallel_rank
from .initialize import get_pipeline_model_parallel_rank, set_pipeline_model_parallel_rank
from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import is_rank_in_embedding_group
from .initialize import is_rank_in_position_embedding_group
from .initialize import is_pipeline_stage_before_split, is_pipeline_stage_after_split
from .initialize import is_pipeline_stage_at_split
from .initialize import get_num_layers
from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_data_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank
from .initialize import get_pipeline_model_parallel_next_rank
from .initialize import get_pipeline_model_parallel_prev_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pipeline_model_parallel_rank
from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized
from .utils import divide
from .utils import split_tensor_along_last_dim
...@@ -8,10 +8,9 @@ import torch ...@@ -8,10 +8,9 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
...@@ -290,9 +289,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -290,9 +289,9 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_model_param = model_param.detach().view(-1) \ shard_model_param = model_param.detach().view(-1) \
[param_range.start:param_range.end] [param_range.start:param_range.end]
shard_main_param = shard_model_param.clone().float() shard_main_param = shard_model_param.clone().float()
mpu.copy_tensor_model_parallel_attributes( tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param) shard_model_param, model_param)
mpu.copy_tensor_model_parallel_attributes( tensor_parallel.copy_tensor_model_parallel_attributes(
shard_main_param, model_param) shard_main_param, model_param)
if hasattr(model_param, 'shared'): if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared shard_model_param.shared = model_param.shared
...@@ -309,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -309,7 +308,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
[param_range.start:param_range.end] [param_range.start:param_range.end]
model_fp32_params_this_group.append(model_param) model_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param) shard_fp32_params_this_group.append(shard_model_param)
mpu.copy_tensor_model_parallel_attributes( tensor_parallel.copy_tensor_model_parallel_attributes(
shard_model_param, model_param) shard_model_param, model_param)
if hasattr(model_param, 'shared'): if hasattr(model_param, 'shared'):
shard_model_param.shared = model_param.shared shard_model_param.shared = model_param.shared
......
...@@ -11,13 +11,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP ...@@ -11,13 +11,11 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_timers from megatron import get_timers
from megatron import mpu
from megatron import core
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32 from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
...@@ -103,7 +101,7 @@ class MegatronOptimizer(ABC): ...@@ -103,7 +101,7 @@ class MegatronOptimizer(ABC):
grad = param.grad grad = param.grad
grad_not_none = grad is not None grad_not_none = grad is not None
is_not_shared = param_is_not_shared(param) is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param)
if grad_not_none and is_not_shared and is_not_tp_duplicate: if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad) grads_for_norm.append(grad)
...@@ -528,8 +526,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer): ...@@ -528,8 +526,8 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# Create a copy # Create a copy
main_param = param.detach().clone().float() main_param = param.detach().clone().float()
# Copy tensor model parallel attributes. # Copy tensor model parallel attributes.
core.tensor_parallel.copy_tensor_model_parallel_attributes(main_param, tensor_parallel.copy_tensor_model_parallel_attributes(main_param,
param) param)
if hasattr(param, 'shared'): if hasattr(param, 'shared'):
main_param.shared = param.shared main_param.shared = param.shared
# Replace the optimizer params with the new fp32 copy. # Replace the optimizer params with the new fp32 copy.
......
...@@ -4,8 +4,8 @@ from functools import reduce ...@@ -4,8 +4,8 @@ from functools import reduce
import operator import operator
import torch import torch
from megatron import get_args from megatron import get_args, core
from megatron import mpu from megatron.core import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
...@@ -81,10 +81,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -81,10 +81,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
args.scatter_gather_tensors_in_pipeline and \ args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel: not args.sequence_parallel:
if tensor_send_next is not None: if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) tensor_send_next = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None: if tensor_send_prev is not None:
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) tensor_send_prev = core.tensor_parallel.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate. # Send tensors in both the forward and backward directions as appropriate.
if args.use_ring_exchange_p2p: if args.use_ring_exchange_p2p:
...@@ -127,18 +127,18 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -127,18 +127,18 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
args.scatter_gather_tensors_in_pipeline and \ args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel: not args.sequence_parallel:
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = core.tensor_parallel.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
tensor_recv_prev = mpu.make_viewless_tensor(tensor_recv_prev, tensor_recv_prev = core.utils.make_viewless_tensor(tensor_recv_prev,
requires_grad = True, requires_grad = True,
keep_graph = False) keep_graph = False)
if recv_next: if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor( tensor_recv_next = core.tensor_parallel.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next).view(tensor_shape).requires_grad_()
tensor_recv_next = mpu.make_viewless_tensor(tensor_recv_next, tensor_recv_next = core.utils.make_viewless_tensor(tensor_recv_next,
requires_grad = True, requires_grad = True,
keep_graph = False) keep_graph = False)
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
......
...@@ -8,8 +8,8 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP ...@@ -8,8 +8,8 @@ from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import get_args from megatron import get_args
from megatron import get_num_microbatches from megatron import get_num_microbatches
from megatron import get_timers from megatron import get_timers
from megatron import mpu
from megatron import p2p_communication from megatron import p2p_communication
from megatron.core import mpu
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module from megatron.model import Float16Module
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
from megatron import mpu from megatron.core import mpu
from .communication import broadcast_float_list from .communication import broadcast_float_list
from .generation import ( from .generation import (
generate_tokens_probs_and_return_on_first_stage, generate_tokens_probs_and_return_on_first_stage,
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
from megatron import mpu from megatron.core import mpu
......
...@@ -6,9 +6,8 @@ from collections.abc import Iterable ...@@ -6,9 +6,8 @@ from collections.abc import Iterable
import torch import torch
from megatron import ( from megatron import get_args
get_args, from megatron.core import mpu
mpu)
from .communication import ( from .communication import (
send_to_next_pipeline_rank, send_to_next_pipeline_rank,
recv_from_prev_pipeline_rank_) recv_from_prev_pipeline_rank_)
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args, get_tokenizer, mpu from megatron import get_args, get_tokenizer
from megatron.core import mpu
from megatron.utils import get_ltor_masks_and_position_ids from megatron.utils import get_ltor_masks_and_position_ids
from .communication import ( from .communication import (
copy_from_last_to_first_pipeline_stage, copy_from_last_to_first_pipeline_stage,
......
...@@ -19,8 +19,7 @@ from megatron import get_current_global_batch_size ...@@ -19,8 +19,7 @@ from megatron import get_current_global_batch_size
from megatron import get_num_microbatches from megatron import get_num_microbatches
from megatron import is_last_rank from megatron import is_last_rank
from megatron import update_num_microbatches from megatron import update_num_microbatches
from megatron import mpu from megatron.core import mpu, tensor_parallel
from megatron import core
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import print_rank_last from megatron import print_rank_last
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
...@@ -258,7 +257,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap ...@@ -258,7 +257,7 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
# are set for all params so the optimizer can use them. # are set for all params so the optimizer can use them.
for model_module in model: for model_module in model:
for param in model_module.parameters(): for param in model_module.parameters():
core.tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param) tensor_parallel.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
# Print number of parameters. # Print number of parameters.
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
......
...@@ -10,11 +10,13 @@ from torch.nn.parallel import DistributedDataParallel as torchDDP ...@@ -10,11 +10,13 @@ from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
from megatron import get_args from megatron import (
from megatron import get_adlr_autoresume get_args,
from megatron import mpu get_adlr_autoresume,
from megatron.model.module import param_is_not_shared )
from megatron.core import mpu
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.model.module import param_is_not_shared
def unwrap_model(model, module_instances=(torchDDP)): def unwrap_model(model, module_instances=(torchDDP)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment