Commit 42c40715 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'core' into 'main'

Move most of mpu functionality into a new "Megatron core"

See merge request ADLR/megatron-lm!462
parents fdc801ed d3a416cd
...@@ -8,7 +8,8 @@ import time ...@@ -8,7 +8,8 @@ import time
import numpy as np import numpy as np
import torch import torch
from megatron import mpu, print_rank_0 from megatron import print_rank_0
from megatron.core import mpu
from megatron.data.blendable_dataset import BlendableDataset from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples from megatron.data.dataset_utils import get_datasets_weights_and_num_samples
from megatron.data.dataset_utils import get_train_valid_test_split_ from megatron.data.dataset_utils import get_train_valid_test_split_
......
...@@ -9,7 +9,8 @@ import random ...@@ -9,7 +9,8 @@ import random
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import print_rank_0, get_args, get_tokenizer, mpu from megatron import print_rank_0, get_args, get_tokenizer
from megatron.core import tensor_parallel
from megatron.data.biencoder_dataset_utils import make_attention_mask from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_open_retrieval_wiki_dataset(): def get_open_retrieval_wiki_dataset():
...@@ -32,7 +33,7 @@ def get_open_retrieval_batch(data_iterator): ...@@ -32,7 +33,7 @@ def get_open_retrieval_batch(data_iterator):
# Broadcast data. # Broadcast data.
data = None if data_iterator is None else next(data_iterator) data = None if data_iterator is None else next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype) data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
row_id = data_b['row_id'].long() row_id = data_b['row_id'].long()
......
...@@ -4,9 +4,10 @@ import time ...@@ -4,9 +4,10 @@ import time
import numpy as np import numpy as np
import torch import torch
from megatron import mpu, print_rank_0 from megatron import print_rank_0
from megatron.core import mpu, tensor_parallel
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron import get_args, get_tokenizer, print_rank_0, mpu from megatron import get_args, get_tokenizer, print_rank_0
def get_one_epoch_dataloader(dataset, micro_batch_size=None): def get_one_epoch_dataloader(dataset, micro_batch_size=None):
...@@ -47,7 +48,7 @@ def get_ict_batch(data_iterator): ...@@ -47,7 +48,7 @@ def get_ict_batch(data_iterator):
data = None data = None
else: else:
data = next(data_iterator) data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype) data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
query_tokens = data_b['query_tokens'].long() query_tokens = data_b['query_tokens'].long()
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.core import mpu
def detach(tensor): def detach(tensor):
...@@ -50,10 +50,10 @@ class OpenRetreivalDataStore(object): ...@@ -50,10 +50,10 @@ class OpenRetreivalDataStore(object):
def load_from_file(self): def load_from_file(self):
"""Populate members from instance saved to file""" """Populate members from instance saved to file"""
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Unpickling BlockData", flush=True) print("\n> Unpickling BlockData", flush=True)
state_dict = pickle.load(open(self.embedding_path, 'rb')) state_dict = pickle.load(open(self.embedding_path, 'rb'))
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Finished unpickling BlockData\n", flush=True) print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data'] self.embed_data = state_dict['embed_data']
...@@ -137,7 +137,7 @@ class FaissMIPSIndex(object): ...@@ -137,7 +137,7 @@ class FaissMIPSIndex(object):
except ImportError: except ImportError:
raise Exception("Error: Please install faiss to use FaissMIPSIndex") raise Exception("Error: Please install faiss to use FaissMIPSIndex")
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print("\n> Building index", flush=True) print("\n> Building index", flush=True)
cpu_index = faiss.IndexFlatIP(self.embed_size) cpu_index = faiss.IndexFlatIP(self.embed_size)
...@@ -149,12 +149,12 @@ class FaissMIPSIndex(object): ...@@ -149,12 +149,12 @@ class FaissMIPSIndex(object):
config.useFloat16 = True config.useFloat16 = True
gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config) gpu_index = faiss.index_cpu_to_all_gpus(cpu_index, co=config)
self.mips_index = faiss.IndexIDMap(gpu_index) self.mips_index = faiss.IndexIDMap(gpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on GPU", flush=True) print(">> Initialized index on GPU", flush=True)
else: else:
# CPU index supports IDs so wrap with IDMap # CPU index supports IDs so wrap with IDMap
self.mips_index = faiss.IndexIDMap(cpu_index) self.mips_index = faiss.IndexIDMap(cpu_index)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">> Initialized index on CPU", flush=True) print(">> Initialized index on CPU", flush=True)
# if we were constructed with a BlockData, then automatically load it # if we were constructed with a BlockData, then automatically load it
...@@ -199,7 +199,7 @@ class FaissMIPSIndex(object): ...@@ -199,7 +199,7 @@ class FaissMIPSIndex(object):
self.mips_index.add_with_ids(embeds_arr, indices_arr) self.mips_index.add_with_ids(embeds_arr, indices_arr)
if mpu.is_unitialized() or mpu.get_data_parallel_rank() == 0: if not mpu.model_parallel_is_initialized() or mpu.get_data_parallel_rank() == 0:
print(">>> Finished adding block data to index", flush=True) print(">>> Finished adding block data to index", flush=True)
def search_mips_index(self, query_embeds, top_k, reconstruct=True): def search_mips_index(self, query_embeds, top_k, reconstruct=True):
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
import os import os
import sys import sys
from functools import reduce
import operator
import torch import torch
from megatron import dist_signal_handler from megatron import dist_signal_handler
...@@ -20,7 +18,6 @@ _GLOBAL_TENSORBOARD_WRITER = None ...@@ -20,7 +18,6 @@ _GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None _GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None _GLOBAL_TIMERS = None
_GLOBAL_SIGNAL_HANDLER = None _GLOBAL_SIGNAL_HANDLER = None
_GLOBAL_MEMORY_BUFFER = None
def get_args(): def get_args():
"""Return arguments.""" """Return arguments."""
...@@ -70,11 +67,6 @@ def get_signal_handler(): ...@@ -70,11 +67,6 @@ def get_signal_handler():
return _GLOBAL_SIGNAL_HANDLER return _GLOBAL_SIGNAL_HANDLER
def get_global_memory_buffer():
_ensure_var_is_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
return _GLOBAL_MEMORY_BUFFER
def _set_signal_handler(): def _set_signal_handler():
global _GLOBAL_SIGNAL_HANDLER global _GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') _ensure_var_is_not_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler')
...@@ -96,7 +88,6 @@ def set_global_variables(args): ...@@ -96,7 +88,6 @@ def set_global_variables(args):
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
_set_timers(args) _set_timers(args)
_set_global_memory_buffer()
if args.exit_signal_handler: if args.exit_signal_handler:
_set_signal_handler() _set_signal_handler()
...@@ -176,13 +167,6 @@ def _set_timers(args): ...@@ -176,13 +167,6 @@ def _set_timers(args):
_GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option) _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option)
def _set_global_memory_buffer():
"""Initialize global buffer"""
global _GLOBAL_MEMORY_BUFFER
_ensure_var_is_not_initialized(_GLOBAL_MEMORY_BUFFER, 'global memory buffer')
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
def _ensure_var_is_initialized(var, name): def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None.""" """Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name) assert var is not None, '{} is not initialized.'.format(name)
...@@ -194,22 +178,3 @@ def _ensure_var_is_not_initialized(var, name): ...@@ -194,22 +178,3 @@ def _ensure_var_is_not_initialized(var, name):
class GlobalMemoryBuffer:
"""Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently."""
def __init__(self):
self.buffer = {}
def get_tensor(self, tensor_shape, dtype, name):
required_len = reduce(operator.mul, tensor_shape, 1)
if self.buffer.get((name, dtype), None) is None or \
self.buffer[(name, dtype)].numel() < required_len:
self.buffer[(name, dtype)] = \
torch.empty(required_len,
dtype=dtype,
device=torch.cuda.current_device(),
requires_grad=False)
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron import mpu from megatron.core import mpu
from megatron.checkpointing import load_biencoder_checkpoint from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
......
...@@ -14,12 +14,10 @@ from megatron import fused_kernels ...@@ -14,12 +14,10 @@ from megatron import fused_kernels
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import get_args from megatron import get_args
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron.core import mpu, tensor_parallel
from megatron.arguments import (parse_args, validate_args) from megatron.arguments import (parse_args, validate_args)
from megatron.checkpointing import load_args_from_checkpoint from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
from megatron.mpu import (set_tensor_model_parallel_rank,
set_tensor_model_parallel_world_size)
from megatron.model.transformer import bias_dropout_add_fused_train from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu from megatron.model.fused_bias_gelu import bias_gelu
...@@ -64,13 +62,14 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -64,13 +62,14 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args = get_args() args = get_args()
if args.lazy_mpu_init: if args.lazy_mpu_init:
# TODO is this still a necessary option?
args.use_cpu_initialization=True args.use_cpu_initialization=True
# delayed initialization of DDP-related stuff # delayed initialization of DDP-related stuff
# We only set basic DDP globals # We only set basic DDP globals
set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
# and return function for external DDP manager # and return function for external DDP manager
# to call when it has DDP initialized # to call when it has DDP initialized
set_tensor_model_parallel_rank(args.rank) mpu.set_tensor_model_parallel_rank(args.rank)
return finish_mpu_init return finish_mpu_init
else: else:
# Megatron's MPU is the master. Complete initialization right away. # Megatron's MPU is the master. Complete initialization right away.
...@@ -146,7 +145,7 @@ def _compile_dependencies(): ...@@ -146,7 +145,7 @@ def _compile_dependencies():
def _initialize_distributed(): def _initialize_distributed():
"""Initialize torch.distributed and mpu.""" """Initialize torch.distributed and core model parallel."""
args = get_args() args = get_args()
device_count = torch.cuda.device_count() device_count = torch.cuda.device_count()
...@@ -184,9 +183,14 @@ def _initialize_distributed(): ...@@ -184,9 +183,14 @@ def _initialize_distributed():
print('model parallel is already initialized') print('model parallel is already initialized')
else: else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size, mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank) args.pipeline_model_parallel_split_rank)
if args.rank == 0:
print(f'> initialized tensor model parallel with size '
f'{mpu.get_tensor_model_parallel_world_size()}')
print(f'> initialized pipeline model parallel with size '
f'{mpu.get_pipeline_model_parallel_world_size()}')
def _init_autoresume(): def _init_autoresume():
...@@ -210,7 +214,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False): ...@@ -210,7 +214,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.device_count() > 0: if torch.cuda.device_count() > 0:
mpu.model_parallel_cuda_manual_seed(seed) tensor_parallel.model_parallel_cuda_manual_seed(seed)
else: else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.core import tensor_parallel
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
...@@ -61,7 +61,7 @@ class BertLMHead(MegatronModule): ...@@ -61,7 +61,7 @@ class BertLMHead(MegatronModule):
args = get_args() args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
...@@ -110,10 +110,10 @@ def post_language_model_processing(lm_output, pooled_output, ...@@ -110,10 +110,10 @@ def post_language_model_processing(lm_output, pooled_output,
# lm_logits : [s, b, h] and lm_labels: [s, b] # lm_logits : [s, b, h] and lm_labels: [s, b]
if fp16_lm_cross_entropy: if fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels) lm_labels)
# [s, b] => [b s] # [s, b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous() lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss, binary_logits return lm_loss, binary_logits
......
...@@ -2,11 +2,11 @@ import os ...@@ -2,11 +2,11 @@ import os
import torch import torch
import sys import sys
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0, get_tokenizer
from megatron.core import mpu
from megatron.checkpointing import fix_query_key_value_ordering from megatron.checkpointing import fix_query_key_value_ordering
from megatron.checkpointing import get_checkpoint_tracker_filename from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import get_checkpoint_name from megatron.checkpointing import get_checkpoint_name
from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_position_ids from megatron.model.bert_model import bert_position_ids
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import torch import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.core import mpu
from .module import MegatronModule from .module import MegatronModule
......
...@@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter ...@@ -10,7 +10,7 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import importlib import importlib
from megatron.mpu import make_viewless_tensor from megatron.core.utils import make_viewless_tensor
try: try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.core import tensor_parallel
from .module import MegatronModule from .module import MegatronModule
from .enums import AttnMaskType from .enums import AttnMaskType
...@@ -33,9 +33,9 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -33,9 +33,9 @@ def post_language_model_processing(lm_output, labels, logit_weights,
labels = labels.transpose(0,1).contiguous() labels = labels.transpose(0,1).contiguous()
if fp16_lm_cross_entropy: if fp16_lm_cross_entropy:
assert output.dtype == torch.half assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels) loss = tensor_parallel.vocab_parallel_cross_entropy(output, labels)
else: else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) loss = tensor_parallel.vocab_parallel_cross_entropy(output.float(), labels)
# [s b] => [b, s] # [s b] => [b, s]
loss = loss.transpose(0,1).contiguous() loss = loss.transpose(0,1).contiguous()
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.core import mpu, tensor_parallel
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
...@@ -26,20 +26,23 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -26,20 +26,23 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.sequence_parallel model_parallel and not args.sequence_parallel
else: else:
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False async_grad_allreduce = False
# Matrix multiply. # Matrix multiply.
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
input_parallel, word_embeddings_weight, bias, input=input_parallel,
args.gradient_accumulation_fusion, weight=word_embeddings_weight,
async_grad_allreduce, args.sequence_parallel) bias=bias,
gradient_accumulation_fusion=args.gradient_accumulation_fusion,
async_grad_allreduce=async_grad_allreduce,
sequence_parallel_enabled=args.sequence_parallel)
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
return mpu.gather_from_tensor_model_parallel_region(logits_parallel) return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
...@@ -103,7 +106,7 @@ class Pooler(MegatronModule): ...@@ -103,7 +106,7 @@ class Pooler(MegatronModule):
# gather data along sequence dimensions # gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes # same pooler is run on all tensor parallel nodes
if self.sequence_parallel: if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region( hidden_states = tensor_parallel.gather_from_sequence_parallel_region(
hidden_states, hidden_states,
tensor_parallel_output_grad=False) tensor_parallel_output_grad=False)
...@@ -143,9 +146,13 @@ class Embedding(MegatronModule): ...@@ -143,9 +146,13 @@ class Embedding(MegatronModule):
args = get_args() args = get_args()
# Word embeddings (parallel). # Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding( self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, vocab_size, self.hidden_size,
init_method=self.init_method) init_method=self.init_method,
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization
)
self._word_embeddings_key = 'word_embeddings' self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial). # Position embedding (serial).
...@@ -222,8 +229,8 @@ class Embedding(MegatronModule): ...@@ -222,8 +229,8 @@ class Embedding(MegatronModule):
# Dropout. # Dropout.
if self.sequence_parallel: if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
with mpu.get_cuda_rng_tracker().fork(): with tensor_parallel.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
else: else:
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
......
...@@ -7,7 +7,7 @@ from torch.autograd import Variable ...@@ -7,7 +7,7 @@ from torch.autograd import Variable
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron.core import mpu, tensor_parallel
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
...@@ -76,9 +76,12 @@ class MegatronModule(torch.nn.Module): ...@@ -76,9 +76,12 @@ class MegatronModule(torch.nn.Module):
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first # set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below. # stage's weights using all_reduce below.
self.word_embeddings = mpu.VocabParallelEmbedding( self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size, args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std)) init_method=init_method_normal(args.init_method_std),
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization)
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True self.word_embeddings.weight.shared = True
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import torch import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
......
...@@ -5,7 +5,7 @@ from megatron import get_args, print_rank_0 ...@@ -5,7 +5,7 @@ from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel from megatron.model import BertModel
from .module import MegatronModule from .module import MegatronModule
from megatron import mpu from megatron.core import mpu
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
......
...@@ -4,10 +4,8 @@ ...@@ -4,10 +4,8 @@
import torch import torch
from megatron import ( from megatron import get_args
get_args, from megatron.core import tensor_parallel
mpu
)
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits, get_language_model from megatron.model.language_model import parallel_lm_logits, get_language_model
from megatron.model.transformer import LayerNorm from megatron.model.transformer import LayerNorm
...@@ -151,10 +149,10 @@ class T5Model(MegatronModule): ...@@ -151,10 +149,10 @@ class T5Model(MegatronModule):
lm_labels = lm_labels.transpose(0,1).contiguous() lm_labels = lm_labels.transpose(0,1).contiguous()
if self.fp16_lm_cross_entropy: if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_loss = tensor_parallel.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels) lm_labels)
# [s b] => [b s] # [s b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous() lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss return lm_loss
......
...@@ -6,9 +6,9 @@ from contextlib import nullcontext ...@@ -6,9 +6,9 @@ from contextlib import nullcontext
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_timers, get_args, get_global_memory_buffer from megatron import get_timers, get_args, core
from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
...@@ -32,7 +32,7 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu ...@@ -32,7 +32,7 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
""" """
class DropPath(MegatronModule): class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample """Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks). (when applied in main path of residual blocks).
""" """
...@@ -52,6 +52,17 @@ class DropPath(MegatronModule): ...@@ -52,6 +52,17 @@ class DropPath(MegatronModule):
output = hidden_state.div(keep_prob) * random_tensor output = hidden_state.div(keep_prob) * random_tensor
return output return output
def _args_to_kwargs():
args = get_args()
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": args.perform_initialization,
"gradient_accumulation_fusion": args.gradient_accumulation_fusion,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
"""MLP. """MLP.
...@@ -65,13 +76,16 @@ class ParallelMLP(MegatronModule): ...@@ -65,13 +76,16 @@ class ParallelMLP(MegatronModule):
super(ParallelMLP, self).__init__() super(ParallelMLP, self).__init__()
args = get_args() args = get_args()
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear( self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
args.ffn_hidden_size, args.ffn_hidden_size,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
skip_bias_add=True) skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.bias_gelu_fusion = args.bias_gelu_fusion self.bias_gelu_fusion = args.bias_gelu_fusion
self.activation_func = F.gelu self.activation_func = F.gelu
...@@ -81,12 +95,13 @@ class ParallelMLP(MegatronModule): ...@@ -81,12 +95,13 @@ class ParallelMLP(MegatronModule):
self.activation_func = erf_gelu self.activation_func = erf_gelu
# Project back to h. # Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear( self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
args.ffn_hidden_size, args.ffn_hidden_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True,
**_args_to_kwargs())
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -136,7 +151,7 @@ class SwitchMLP(MegatronModule): ...@@ -136,7 +151,7 @@ class SwitchMLP(MegatronModule):
output_total = torch.empty_like(hidden_states) output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states) output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized #TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts): for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero() local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:] hidden = hidden_states[local_indices,:]
...@@ -174,11 +189,11 @@ class CoreAttention(MegatronModule): ...@@ -174,11 +189,11 @@ class CoreAttention(MegatronModule):
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size() world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(projection_size, self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size) world_size)
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads) projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size) args.num_attention_heads, world_size)
coeff = None coeff = None
...@@ -221,7 +236,7 @@ class CoreAttention(MegatronModule): ...@@ -221,7 +236,7 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk] # preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = get_global_memory_buffer().get_tensor( matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]), (output_size[0]*output_size[1], output_size[2], output_size[3]),
query_layer.dtype, "mpu") query_layer.dtype, "mpu")
...@@ -247,7 +262,7 @@ class CoreAttention(MegatronModule): ...@@ -247,7 +262,7 @@ class CoreAttention(MegatronModule):
# seem a bit unusual, but is taken from the original Transformer paper. # seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel: if not self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork(): with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
else: else:
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
...@@ -312,43 +327,51 @@ class ParallelAttention(MegatronModule): ...@@ -312,43 +327,51 @@ class ParallelAttention(MegatronModule):
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size() world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads) projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size) args.num_attention_heads, world_size)
# Strided linear layer. # Strided linear layer.
if attention_type == AttnType.self_attn: if attention_type == AttnType.self_attn:
self.query_key_value = mpu.ColumnParallelLinear( self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
3 * projection_size, 3 * projection_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
else: else:
assert attention_type == AttnType.cross_attn assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear( self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
projection_size, projection_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.key_value = mpu.ColumnParallelLinear(
self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
2 * projection_size, 2 * projection_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.core_attention = CoreAttention(self.layer_number, self.core_attention = CoreAttention(self.layer_number,
self.attn_mask_type) self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective' self.checkpoint_core_attention = args.recompute_granularity == 'selective'
# Output. # Output.
self.dense = mpu.RowParallelLinear( self.dense = tensor_parallel.RowParallelLinear(
projection_size, projection_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True,
**_args_to_kwargs())
def _checkpointed_attention_forward(self, query_layer, key_layer, def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask): value_layer, attention_mask):
...@@ -362,7 +385,7 @@ class ParallelAttention(MegatronModule): ...@@ -362,7 +385,7 @@ class ParallelAttention(MegatronModule):
value_layer, attention_mask) value_layer, attention_mask)
return output_ return output_
hidden_states = mpu.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom_forward, custom_forward,
False, query_layer, key_layer, value_layer, attention_mask) False, query_layer, key_layer, value_layer, attention_mask)
...@@ -415,7 +438,7 @@ class ParallelAttention(MegatronModule): ...@@ -415,7 +438,7 @@ class ParallelAttention(MegatronModule):
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, (query_layer,
key_layer, key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
else: else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output) mixed_kv_layer, _ = self.key_value(encoder_output)
...@@ -428,7 +451,7 @@ class ParallelAttention(MegatronModule): ...@@ -428,7 +451,7 @@ class ParallelAttention(MegatronModule):
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, (key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2) value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp] # Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states) query_layer, _ = self.query(hidden_states)
...@@ -674,9 +697,9 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -674,9 +697,9 @@ class ParallelTransformerLayer(MegatronModule):
# won't result in memory savings (like the data loader, or # won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this # p2p_communication), it serves to document the origin of this
# 'view' tensor. # 'view' tensor.
output = mpu.make_viewless_tensor(inp = output, output = core.utils.make_viewless_tensor(inp = output,
requires_grad = output.requires_grad, requires_grad = output.requires_grad,
keep_graph = True) keep_graph = True)
else: else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias, out = torch.nn.functional.dropout(mlp_output + mlp_bias,
...@@ -713,13 +736,65 @@ class NoopTransformerLayer(MegatronModule): ...@@ -713,13 +736,65 @@ class NoopTransformerLayer(MegatronModule):
return hidden_states.clone() return hidden_states.clone()
def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
"""Compute the number of transformer layers resident on the current rank."""
if get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
# When a standalone embedding stage is used, a rank is taken from
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder = (
args.pipeline_model_parallel_split_rank - 1
if args.standalone_embedding_stage else
args.pipeline_model_parallel_split_rank
)
num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
if is_pipeline_stage_before_split():
num_layers = (
0
if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else
args.encoder_num_layers // num_ranks_in_encoder
)
else:
num_layers = args.decoder_num_layers // num_ranks_in_decoder
else:
assert args.num_layers == args.encoder_num_layers
assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
'num_layers must be divisible by transformer_pipeline_model_parallel_size'
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers = (
0
if args.standalone_embedding_stage
and get_pipeline_model_parallel_rank() == 0 else
args.num_layers // args.transformer_pipeline_model_parallel_size
)
else:
if not is_decoder:
num_layers = args.encoder_num_layers
else:
num_layers = args.decoder_num_layers
return num_layers
class ParallelTransformer(MegatronModule): class ParallelTransformer(MegatronModule):
"""Transformer class.""" """Transformer class."""
def __init__(self, init_method, output_layer_init_method, def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding, self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True, post_layer_norm=True,
pre_process=True, post_process=True, pre_process=True, post_process=True,
drop_path_rate=0.0): drop_path_rate=0.0):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
...@@ -745,7 +820,7 @@ class ParallelTransformer(MegatronModule): ...@@ -745,7 +820,7 @@ class ParallelTransformer(MegatronModule):
self.sequence_parallel = args.sequence_parallel self.sequence_parallel = args.sequence_parallel
# Number of layers. # Number of layers.
self.num_layers = mpu.get_num_layers( self.num_layers = _get_num_layers(
args, args,
args.model_type == ModelType.encoder_and_decoder, args.model_type == ModelType.encoder_and_decoder,
layer_type == LayerType.decoder) layer_type == LayerType.decoder)
...@@ -840,7 +915,7 @@ class ParallelTransformer(MegatronModule): ...@@ -840,7 +915,7 @@ class ParallelTransformer(MegatronModule):
# A method to further reduce memory usage reducing checkpoints. # A method to further reduce memory usage reducing checkpoints.
l = 0 l = 0
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers), custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
...@@ -852,7 +927,7 @@ class ParallelTransformer(MegatronModule): ...@@ -852,7 +927,7 @@ class ParallelTransformer(MegatronModule):
# A method fully use the device memory removing redundant re-computation. # A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers): for l in range(self.num_layers):
if l < self.recompute_num_layers: if l < self.recompute_num_layers:
hidden_states = mpu.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
...@@ -898,19 +973,19 @@ class ParallelTransformer(MegatronModule): ...@@ -898,19 +973,19 @@ class ParallelTransformer(MegatronModule):
# However, we don't explicitly check mbs == 1 here because # However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input # make_viewless_tensor() has negligible overhead when its input
# is already viewless. # is already viewless.
# #
# - For the 'else' case above, calling make_viewless_tensor() here is # - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator) # likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor() # already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof. # is called here to be future-proof and corner-case-proof.
hidden_states = mpu.make_viewless_tensor( hidden_states = core.utils.make_viewless_tensor(
hidden_states, hidden_states,
requires_grad=True, requires_grad=True,
keep_graph=True, keep_graph=True,
) )
if self.sequence_parallel: if self.sequence_parallel:
rng_context = mpu.get_cuda_rng_tracker().fork() rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else: else:
rng_context = nullcontext() rng_context = nullcontext()
......
import torch.nn.functional as F import torch.nn.functional as F
import torch import torch
from megatron import print_rank_0, get_args, mpu from megatron import print_rank_0, get_args
from megatron.core import mpu
from megatron.data.vit_dataset import ClassificationTransform from megatron.data.vit_dataset import ClassificationTransform
from megatron.data.image_folder import ImageFolder from megatron.data.image_folder import ImageFolder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment