Commit 3aca1415 authored by liangjing's avatar liangjing
Browse files

Merge branch 'megatron-lm_dtk24.04' into 'main'

Megatron lm dtk24.04

See merge request !1
parents 0024a5c6 1005e9d3
Pipeline #1806 passed with stage
......@@ -80,7 +80,7 @@ def _set_signal_handler():
def set_global_variables(args):
def set_global_variables(args, build_tokenizer=True):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
assert args is not None
......@@ -89,7 +89,7 @@ def set_global_variables(args):
set_args(args)
_build_num_microbatches_calculator(args)
if args.vocab_file or args.tokenizer_model:
if build_tokenizer:
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
......
......@@ -15,36 +15,40 @@ from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron.core import mpu, tensor_parallel
from megatron.arguments import (parse_args, validate_args)
from megatron.arguments import parse_args, validate_args
from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables
from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu
def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False, allow_no_cuda=False):
def initialize_megatron(
extra_args_provider=None,
args_defaults={},
ignore_unknown_args=False,
allow_no_cuda=False,
):
"""Set global variables, initialize distributed, and
set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
`allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know
what you are doing.
Returns a function to finalize distributed env initialization
Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True)
"""
if not allow_no_cuda:
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
assert torch.cuda.is_available(), "Megatron requires CUDA."
# Parse arguments
args = parse_args(extra_args_provider, ignore_unknown_args)
if args.use_checkpoint_args or args_defaults.get('use_checkpoint_args', False):
assert args.load is not None, '--use-checkpoints-args requires --load argument'
if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
assert args.load is not None, "--use-checkpoints-args requires --load argument"
load_args_from_checkpoint(args)
validate_args(args, args_defaults)
# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(args)
......@@ -54,16 +58,16 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args = get_args()
# Pytorch distributed.
_initialize_distributed()
# Random seeds for reproducibility.
if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed))
print("> setting random seeds to {} ...".format(args.seed))
_set_random_seed(args.seed, args.data_parallel_random_init)
args = get_args()
if args.lazy_mpu_init:
if args.lazy_mpu_init:
# TODO is this still a necessary option?
args.use_cpu_initialization=True
args.use_cpu_initialization = True
# delayed initialization of DDP-related stuff
# We only set basic DDP globals
mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
......@@ -95,11 +99,15 @@ def _compile_dependencies():
# TODO: move this to ninja
if torch.distributed.get_rank() == 0:
start_time = time.time()
print('> compiling dataset index builder ...')
print("> compiling dataset index builder ...")
from megatron.data.dataset_utils import compile_helper
compile_helper()
print('>>> done with dataset index builder. Compilation time: {:.3f} '
'seconds'.format(time.time() - start_time), flush=True)
print(
">>> done with dataset index builder. Compilation time: {:.3f} "
"seconds".format(time.time() - start_time),
flush=True,
)
# ==================
# Load fused kernels
......@@ -107,41 +115,51 @@ def _compile_dependencies():
# Custom kernel constraints check.
seq_len = args.seq_length
attn_batch_size = \
(args.num_attention_heads / args.tensor_model_parallel_size) * \
args.micro_batch_size
attn_batch_size = (
args.num_attention_heads / args.tensor_model_parallel_size
) * args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
custom_kernel_constraint = (
seq_len > 16
and seq_len <= 16384
and seq_len % 4 == 0
and attn_batch_size % 4 == 0
)
# Print a warning.
if not ((args.fp16 or args.bf16) and
custom_kernel_constraint and
args.masked_softmax_fusion):
if not (
(args.fp16 or args.bf16)
and custom_kernel_constraint
and args.masked_softmax_fusion
):
if args.rank == 0:
print('WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default'
' back to unfused kernel invocations.', flush=True)
print(
"WARNING: constraints for invoking optimized"
" fused softmax kernel are not met. We default"
" back to unfused kernel invocations.",
flush=True,
)
# Always build on rank zero first.
if torch.distributed.get_rank() == 0:
start_time = time.time()
print('> compiling and loading fused kernels ...', flush=True)
fused_kernels.load(args)
print("> compiling and loading fused kernels ...", flush=True)
#fused_kernels.load(args)
torch.distributed.barrier()
else:
torch.distributed.barrier()
fused_kernels.load(args)
#fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that
# the lock is released.
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print('>>> done with compiling and loading fused kernels. '
'Compilation time: {:.3f} seconds'.format(
time.time() - start_time), flush=True)
print(
">>> done with compiling and loading fused kernels. "
"Compilation time: {:.3f} seconds".format(time.time() - start_time),
flush=True,
)
def _initialize_distributed():
......@@ -152,45 +170,58 @@ def _initialize_distributed():
if torch.distributed.is_initialized():
if args.rank == 0:
print('torch distributed is already initialized, '
'skipping initialization ...', flush=True)
args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size()
print(
"torch distributed is already initialized, "
"skipping initialization ...",
flush=True,
)
#args.rank = torch.distributed.get_rank()
#args.world_size = torch.distributed.get_world_size()
else:
if args.rank == 0:
print('> initializing torch distributed ...', flush=True)
print("> initializing torch distributed ...", flush=True)
# Manually set the device ids.
if device_count > 0:
device = args.rank % device_count
if args.local_rank is not None:
assert args.local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
assert (
args.local_rank == device
), "expected local-rank to be the same as rank % device-count."
else:
args.local_rank = device
torch.cuda.set_device(device)
# Call the init process
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
timeout=timedelta(minutes=args.distributed_timeout_minutes))
world_size=args.world_size,
rank=args.rank,init_method=args.dist_url,
timeout=timedelta(minutes=args.distributed_timeout_minutes),
)
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
if device_count > 0:
if mpu.model_parallel_is_initialized():
print('model parallel is already initialized')
print("model parallel is already initialized")
else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
mpu.initialize_model_parallel(
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank,
args.fp8 is not None,
)
if args.rank == 0:
print(f'> initialized tensor model parallel with size '
f'{mpu.get_tensor_model_parallel_world_size()}')
print(f'> initialized pipeline model parallel with size '
f'{mpu.get_pipeline_model_parallel_world_size()}')
print(
f"> initialized tensor model parallel with size "
f"{mpu.get_tensor_model_parallel_world_size()}"
)
print(
f"> initialized pipeline model parallel with size "
f"{mpu.get_pipeline_model_parallel_world_size()}"
)
def _init_autoresume():
......@@ -216,7 +247,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
if torch.cuda.device_count() > 0:
tensor_parallel.model_parallel_cuda_manual_seed(seed)
else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
raise ValueError("Seed ({}) should be a positive integer.".format(seed))
def write_args_to_tensorboard():
......@@ -225,15 +256,14 @@ def write_args_to_tensorboard():
writer = get_tensorboard_writer()
if writer:
for arg in vars(args):
writer.add_text(arg, str(getattr(args, arg)),
global_step=args.iteration)
writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration)
def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser
torch._C._jit_set_profiling_executor(True)
......@@ -241,7 +271,7 @@ def set_jit_fusion_options():
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._jit_set_nvfuser_enabled(False)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
......@@ -254,7 +284,7 @@ def set_jit_fusion_options():
def _warmup_jit_function():
""" Compilie JIT functions before the main training steps """
"""Compilie JIT functions before the main training steps"""
args = get_args()
if args.bf16:
dtype = torch.bfloat16
......@@ -264,11 +294,20 @@ def _warmup_jit_function():
dtype = torch.float32
# Warmup fused bias+gelu
bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size,
dtype=dtype, device='cuda')
input = torch.rand((args.seq_length, args.micro_batch_size,
args.ffn_hidden_size // args.tensor_model_parallel_size),
dtype=dtype, device='cuda')
bias = torch.rand(
args.ffn_hidden_size // args.tensor_model_parallel_size,
dtype=dtype,
device="cuda",
)
input = torch.rand(
(
args.seq_length,
args.micro_batch_size,
args.ffn_hidden_size // args.tensor_model_parallel_size,
),
dtype=dtype,
device="cuda",
)
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]):
......@@ -282,15 +321,25 @@ def _warmup_jit_function():
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else:
seq_length = args.seq_length
input = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype, device='cuda')
residual = torch.rand((seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype, device='cuda')
bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual)
input = torch.rand(
(seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype,
device="cuda",
)
residual = torch.rand(
(seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype,
device="cuda",
)
bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(
residual
)
dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
for input_grad, bias_grad, residual_grad in zip(
[False, True], [True, True], [True, True]
):
input.requires_grad = input_grad
bias.requires_grad = bias_grad
residual.requires_grad = residual_grad
......
......@@ -47,31 +47,27 @@ class BertLMHead(MegatronModule):
"""Masked LM head for Bert
Arguments:
config: TransformerConfig object
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output):
super(BertLMHead, self).__init__()
def __init__(self, mpu_vocab_size, hidden_size, config, parallel_output):
super().__init__(config=config)
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel)
self.dense = get_linear_layer(hidden_size, hidden_size, config.init_method)
setattr(self.dense.weight, 'sequence_parallel', config.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel)
self.layernorm = LayerNorm(hidden_size,
eps=layernorm_epsilon,
sequence_parallel=args.sequence_parallel)
eps=config.layernorm_epsilon,
sequence_parallel=config.sequence_parallel)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
......@@ -124,12 +120,13 @@ class BertModel(MegatronModule):
"""Bert Language model."""
def __init__(self,
config,
num_tokentypes=2,
add_binary_head=True,
parallel_output=True,
pre_process=True,
post_process=True):
super(BertModel, self).__init__()
super().__init__(config=config)
args = get_args()
# TODO this option is not yet implemented in BERT
......@@ -145,29 +142,23 @@ class BertModel(MegatronModule):
if self.return_embeddings:
assert self.post_process and self.add_binary_head
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal)
self.initialize_word_embeddings()
if self.post_process:
self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0),
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
self.lm_head = BertLMHead(self.shared_embedding_or_output_weight().size(0), config.hidden_size,
config, parallel_output)
self._lm_head_key = 'lm_head'
self.binary_head = None
if self.add_binary_head:
self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
self.binary_head = get_linear_layer(config.hidden_size, 2,
config.init_method)
self._binary_head_key = 'binary_head'
def set_input_tensor(self, input_tensor):
......@@ -215,7 +206,7 @@ class BertModel(MegatronModule):
return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head,
lm_labels,
self.word_embeddings_weight(),
self.shared_embedding_or_output_weight(),
self.fp16_lm_cross_entropy)
else:
return lm_output
......
......@@ -17,25 +17,23 @@ from .module import MegatronModule
class Classification(MegatronModule):
def __init__(self,
config,
num_classes,
num_tokentypes=2,
pre_process=True,
post_process=True):
super(Classification, self).__init__(share_word_embeddings=False)
super().__init__(config=config, share_embeddings_and_output_weights=False)
args = get_args()
self.num_classes = num_classes
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC
from abc import abstractmethod
......@@ -73,7 +73,7 @@ class DistributedDataParallelBase(MegatronModule, ABC):
class DistributedDataParallel(DistributedDataParallelBase):
"""DDP with contiguous buffers options to storre and accumulate gradients.
"""DDP with contiguous buffers options to store and accumulate gradients.
This class:
- has the potential to reduce memory fragmentation.
- provides the option to do the gradient accumulation
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
retro_encoder = 3
retro_decoder = 4
retro_decoder_with_retriever = 5
class AttnType(enum.Enum):
self_attn = 1
......
......@@ -14,7 +14,7 @@ from megatron.core.utils import make_viewless_tensor
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True
HAVE_PERSIST_LAYER_NORM = False
except:
HAVE_PERSIST_LAYER_NORM = False
......
......@@ -155,12 +155,12 @@ class FusedScaleMaskSoftmax(nn.Module):
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048
and 16 < sk <= 16384 # sk must be 16 ~ 16384
and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
if 0 <= sk <= 16384:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal:
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""GPT-2 model."""
......@@ -11,8 +11,6 @@ from .module import MegatronModule
from .enums import AttnMaskType
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def post_language_model_processing(lm_output, labels, logit_weights,
......@@ -46,12 +44,13 @@ class GPTModel(MegatronModule):
"""GPT-2 Language model."""
def __init__(self,
config,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True):
args = get_args()
super(GPTModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights)
super().__init__(config=config, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights)
self.parallel_output = parallel_output
self.pre_process = pre_process
......@@ -60,39 +59,39 @@ class GPTModel(MegatronModule):
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
if not args.untie_embeddings_and_output_weights:
self.initialize_word_embeddings(init_method_normal)
self.initialize_word_embeddings()
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask,
ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
retriever_input_ids=None,
retriever_position_ids=None,
retriever_attn_mask=None,
labels=None, tokentype_ids=None, inference_params=None):
lm_output = self.language_model(
input_ids,
position_ids,
attention_mask,
ret_input_ids=ret_input_ids,
ret_position_ids=ret_position_ids,
ret_attn_mask=ret_attn_mask,
retriever_input_ids=retriever_input_ids,
retriever_position_ids=retriever_position_ids,
retriever_attn_mask=retriever_attn_mask,
inference_params=inference_params)
if self.post_process:
return post_language_model_processing(
lm_output, labels,
self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.word_embeddings_weight(),
self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(),
self.parallel_output,
self.fp16_lm_cross_entropy)
else:
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Transformer based language model."""
......@@ -7,11 +7,11 @@ import torch.nn.functional as F
from megatron import get_args
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
from megatron.core.models.common.rotary_pos_embedding import RotaryEmbedding
from .enums import LayerType, AttnMaskType
from .enums import AttnMaskType, LayerType
from .module import MegatronModule
from .retro_transformer import ParallelRetroEncoder, ParallelRetroTransformer
from .rotary_pos_embedding import apply_rotary_pos_emb, RotaryEmbedding
from .transformer import ParallelTransformer
from .utils import get_linear_layer
from .utils import init_method_normal, scaled_init_method_normal
......@@ -39,7 +39,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=bias,
gradient_accumulation_fusion=args.gradient_accumulation_fusion,
async_grad_allreduce=async_grad_allreduce,
sequence_parallel_enabled=args.sequence_parallel)
sequence_parallel=args.sequence_parallel)
# Gather if needed.
if parallel_output:
......@@ -48,26 +48,24 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_encoder=True,
def get_language_model(config, num_tokentypes, add_pooler,
encoder_attn_mask_type,
add_encoder=True,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True, post_process=True):
"""Build language model and return along with the key to save."""
args = get_args()
if config.init_method is None:
config.init_method = init_method_normal(config.init_method_std)
if init_method is None:
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
if config.output_layer_init_method is None:
config.output_layer_init_method = scaled_init_method_normal(config.init_method_std,
config.num_layers)
# Language model.
language_model = TransformerLanguageModel(
init_method,
scaled_init_method,
config,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_encoder=add_encoder,
......@@ -131,6 +129,10 @@ class Embedding(MegatronModule):
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
embedding_weights_in_fp32: casts word embedding weights to
fp32 before sampling. Required to
maintain reproducibility when
training in bf16.
"""
def __init__(self,
......@@ -138,28 +140,26 @@ class Embedding(MegatronModule):
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0):
config,
num_tokentypes=0,
embedding_weights_in_fp32=False):
super(Embedding, self).__init__()
self.hidden_size = hidden_size
self.init_method = init_method
self.init_method = config.init_method
self.num_tokentypes = num_tokentypes
args = get_args()
# Word embeddings (parallel).
self.embedding_weights_in_fp32 = embedding_weights_in_fp32
self.params_dtype = args.params_dtype
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size,
init_method=self.init_method,
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization
)
vocab_size, self.hidden_size, config=config, init_method=config.init_method)
self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial).
self.add_position_embedding = args.add_position_embedding
self.add_position_embedding = args.position_embedding_type == 'learned_absolute'
if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
......@@ -182,7 +182,7 @@ class Embedding(MegatronModule):
else:
self.tokentype_embeddings = None
self.fp32_residual_connection = args.fp32_residual_connection
self.fp32_residual_connection = args.fp32_residual_connection
self.sequence_parallel = args.sequence_parallel
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
......@@ -217,7 +217,12 @@ class Embedding(MegatronModule):
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
if self.embedding_weights_in_fp32:
self.word_embeddings = self.word_embeddings.to(torch.float32)
words_embeddings = self.word_embeddings(input_ids)
if self.embedding_weights_in_fp32:
words_embeddings = words_embeddings.to(self.params_dtype)
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
......@@ -326,8 +331,7 @@ class TransformerLanguageModel(MegatronModule):
"""
def __init__(self,
init_method,
output_layer_init_method,
config,
encoder_attn_mask_type,
num_tokentypes=0,
add_encoder=True,
......@@ -337,21 +341,22 @@ class TransformerLanguageModel(MegatronModule):
pre_process=True,
post_process=True):
args = get_args()
# TODO: passing share_word_embeddings=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
# TODO: passing share_embeddings_and_output_weights=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
if args.untie_embeddings_and_output_weights: assert not add_decoder
super(TransformerLanguageModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights)
super(TransformerLanguageModel, self).__init__(share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights)
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.hidden_size = config.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.init_method = config.init_method
self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.encoder_hidden_state = None
self.add_retriever = args.retro_add_retriever
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
# Embeddings.
......@@ -360,14 +365,15 @@ class TransformerLanguageModel(MegatronModule):
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes)
config,
self.num_tokentypes,
args.embedding_weights_in_fp32)
self._embedding_key = 'embedding'
# Rotary positional embeddings
self.use_rotary_position_embeddings = \
args.use_rotary_position_embeddings
if args.use_rotary_position_embeddings:
args.position_embedding_type == 'rope'
if self.use_rotary_position_embeddings:
self.seq_length = args.seq_length
rotary_dim = args.hidden_size // args.num_attention_heads \
if args.kv_channels is None else args.kv_channels
......@@ -378,41 +384,22 @@ class TransformerLanguageModel(MegatronModule):
# partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al
# https://github.com/kingoflolz/mesh-transformer-jax/
self.rotary_pos_emb = RotaryEmbedding(rotary_dim)
# Retriever (bi-directional transformer with cross attention)
if args.retro_add_retriever:
self.retriever = ParallelRetroEncoder(
self.init_method,
output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding,
pre_process=self.pre_process,
post_process=False,
self.rotary_pos_emb = RotaryEmbedding(
rotary_dim,
seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor
)
self._retriever_key = 'retriever'
else:
self.retriever = None
# Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage).
if self.add_encoder:
if args.retro_add_retriever:
self.encoder = ParallelRetroTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
retriever=self.retriever,
)
else:
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self.encoder = ParallelTransformer(
config,
model_type=args.model_type if not args.retro_add_retriever \
else ModelType.retro_decoder,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._encoder_key = 'encoder'
else:
self.encoder = None
......@@ -421,8 +408,8 @@ class TransformerLanguageModel(MegatronModule):
# architecture and in decoder-only stage).
if self.add_decoder:
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
config,
model_type=args.model_type,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process,
......@@ -441,8 +428,9 @@ class TransformerLanguageModel(MegatronModule):
self.output_layer = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.padded_vocab_size,
bias=False, # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
init_method=self.init_method)
config=config,
init_method=self.init_method,
bias=False) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
self._output_layer_key = 'output_layer'
def set_input_tensor(self, input_tensor):
......@@ -475,19 +463,14 @@ class TransformerLanguageModel(MegatronModule):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None,
retriever_input_ids=None,
retriever_position_ids=None,
retriever_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None,
inference_params=None,
pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
# Retriever embedding.
if self.retriever and self.pre_process:
retriever_input = self.embedding(ret_input_ids, ret_position_ids,
tokentype_ids=tokentype_ids)
else:
retriever_input = None
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(enc_input_ids, enc_position_ids,
......@@ -495,31 +478,33 @@ class TransformerLanguageModel(MegatronModule):
else:
encoder_input = None
# Retriever embedding.
if self.add_retriever and self.pre_process:
retriever_input = self.embedding(retriever_input_ids,
retriever_position_ids,
tokentype_ids=tokentype_ids)
else:
retriever_input = None
# Rotary positional embeddings
rotary_pos_emb = None
if self.use_rotary_position_embeddings:
if inference_params is not None:
rotary_pos_emb = \
self.rotary_pos_emb(inference_params.max_sequence_len)
self.rotary_pos_emb(inference_params.max_sequence_length)
else:
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
# Run encoder.
if enc_hidden_states is None:
if self.encoder is not None:
if self.retriever:
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
retriever_output=retriever_input,
retriever_attn_mask=ret_attn_mask,
inference_params=inference_params)
else:
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb)
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
retriever_input=retriever_input,
retriever_attn_mask=retriever_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb)
else:
encoder_output = self.encoder_hidden_state
else:
......
......@@ -25,9 +25,10 @@ class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support
for pipelining."""
def __init__(self, share_word_embeddings=True):
def __init__(self, config=None, share_embeddings_and_output_weights=True):
super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings
self.config = config
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
......@@ -36,21 +37,21 @@ class MegatronModule(torch.nn.Module):
return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def word_embeddings_weight(self):
def shared_embedding_or_output_weight(self):
if self.pre_process:
return self.language_model.embedding.word_embeddings.weight
else:
if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false')
if not self.share_embeddings_and_output_weights:
raise Exception('shared_embedding_or_output_weight() called for last '
'stage, but share_embeddings_and_output_weights is false')
return self.word_embeddings.weight
def initialize_word_embeddings(self, init_method_normal):
def initialize_word_embeddings(self):
args = get_args()
if not self.share_word_embeddings:
if not self.share_embeddings_and_output_weights:
raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false')
'share_embeddings_and_output_weights is false')
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't
......@@ -76,11 +77,8 @@ class MegatronModule(torch.nn.Module):
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std),
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization)
args.padded_vocab_size, self.config.hidden_size,
config=self.config, init_method=self.config.init_method)
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
......@@ -103,7 +101,7 @@ class MegatronModule(torch.nn.Module):
# Ensure that first and last stages have the same initial parameter
# values.
if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
torch.distributed.all_reduce(self.shared_embedding_or_output_weight().data,
group=mpu.get_embedding_group())
# Ensure that encoder(first stage) and decoder(split stage) position
......
......@@ -17,23 +17,21 @@ from .module import MegatronModule
class MultipleChoice(MegatronModule):
def __init__(self,
config,
num_tokentypes=2,
pre_process=True,
post_process=True):
super(MultipleChoice, self).__init__(share_word_embeddings=False)
super(MultipleChoice, self).__init__(share_embeddings_and_output_weights=False)
args = get_args()
init_method = init_method_normal(args.init_method_std)
self.pre_process = pre_process
self.post_process = post_process
self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Retro Transformer.
** Special note about this file **
Many classes and methods in this file directly parallel those in transformer.py
in name and utility. However, due to 1) subtle changes in the code over time
(i.e., transposes and contexts), and 2) other code that is soon to be merged,
this file will *temporarily* remain as is, until a larger integration is
complete.
"""
import math
import numpy as np
import torch
import torch.nn.functional as F
from megatron import get_args, get_retro_args, get_tensorboard_writer
from megatron.core import parallel_state
from megatron.core import tensor_parallel
from megatron.core import utils as core_utils
from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, init_method_normal
from .module import MegatronModule
from .transformer import _get_num_layers, ParallelMLP, NoopTransformerLayer
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
"""
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
*Note: differs from transformer.py/DropPath in hidden_state transpose.
"""
def __init__(self, drop_prob=0.):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, hidden_state):
if self.drop_prob == 0. or not self.training:
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize
output = hidden_state.div(keep_prob) * random_tensor
return output
class SwitchMLP(MegatronModule):
"""
Routes input to one of N MLP "experts"
"""
def __init__(self, init_method, output_layer_init_method):
super(SwitchMLP, self).__init__()
args = get_args()
self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
self.experts = torch.nn.ModuleList()
for i in range(args.num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states):
# hidden_states: [b, s, h]
b = hidden_states.size(0)
s = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) # [b s 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Each vector could be routed differently
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1]
max_ind = max_ind.view(-1) # [b*s]
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:]
output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output)
output_total[local_indices,:] = output
output_bias_total[local_indices,:] = output_bias
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(b, s, h)
output_bias_total = output_bias_total.view(b, s, h)
return output_total, output_bias_total
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = core_utils.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = core_utils.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = core_utils.divide(
args.num_attention_heads, world_size)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method)
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
# Output.
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = tensor_parallel \
.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = tensor_parallel \
.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
inference_value_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[
:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
return output, bias
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelRetroTransformerEncoderLayer(MegatronModule):
"""A single transformer layer for Retro Decoder with an retriever encoder inside and cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0., retriever=None):
args = get_args()
super(ParallelRetroTransformerEncoderLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Retro Encoder
self.retriever = retriever
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \
else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
retriever_output, retriever_attn_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention. # [ns, bs, d]
layernorm_output = self.post_attention_layernorm(layernorm_input)
"""
notations:
l: number of chunks
m: number of token per chunk
bs: batch size
d: hidden size
k: number of neighbors
r: number of tokens per neighbors (neighbors + continuation)
"""
args = get_args()
retro_args = get_retro_args()
chunk_length = retro_args.retro_gpt_chunk_length
retrieved_length = retro_args.retro_gpt_retrieved_length
num_neighbors = args.retro_num_neighbors
ns, bs, d = layernorm_output.shape
l = int(np.ceil(ns / chunk_length))
first_ns = ns % chunk_length
if first_ns > 0:
first_chunk, rest_chunk = \
layernorm_output[:first_ns], layernorm_output[first_ns:]
first_chunk = torch.nn.functional.pad(
first_chunk,
(0, 0, 0, 0, 0, chunk_length - first_ns),
'constant',
0)
chunked_output = \
torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d]
else:
chunked_output = layernorm_output # [l * m, bs, d]
chunked_output = chunked_output \
.reshape(l, chunk_length, bs, d) \
.permute(1, 2, 0, 3) \
.reshape(chunk_length, bs * l, d) \
.contiguous()
# Get Encoder Output
retriever_output = self.retriever(
retriever_output,
retriever_attn_mask,
retriever_output=chunked_output,
retriever_attn_mask=retriever_attn_mask,
inference_params=inference_params) # [r, k * bs * l , d]
retriever_output = retriever_output.reshape(
retrieved_length * num_neighbors, bs * l, d) # [r * k, bs * l, d]
# Chunked Cross attention with Retriever Encoder
pad = (ns - 1) % chunk_length
attending_chunks = layernorm_output[pad:] # [ns - m + 1, bs, d]
padded_chunks = torch.nn.functional.pad(
attending_chunks,
(0, 0, 0, 0, 0, chunk_length-1),
'constant', 0) # [ns, bs, d]
padded_chunked_output = padded_chunks \
.reshape(l, chunk_length, bs, d) \
.permute(1, 2, 0, 3)
padded_chunked_output = padded_chunked_output.reshape(
chunk_length, bs * l, d).contiguous() # [m, bs * l, d]
# attention_output: [m, bs * l, d]
# attention_bias: [d]
attention_output, attention_bias = \
self.inter_attention(
padded_chunked_output, # Q: main model embedding
None,
encoder_output=retriever_output) # KV: retriever output embedding
# Residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(attention_output),
torch.zeros_like(attention_output),
self.hidden_dropout)
layernorm_input = layernorm_input \
.reshape(chunk_length, bs, l, d) \
.permute(2, 0, 1, 3) # [l, m, bs, d]
layernorm_input = layernorm_input.reshape(chunk_length * l, bs, d)
layernorm_input = torch.nn.functional.pad(
layernorm_input,
(0, 0, 0, 0, pad, 0),
'constant', 0)[:ns] # [ns, b, d]
layernorm_input = layernorm_input + residual
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output, retriever_output
class ParallelRetroTransformerLayer(MegatronModule):
"""A single transformer layer for Retro Decoder with cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args()
super(ParallelRetroTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \
else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
retriever_output, retriever_attn_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
args = get_args()
retro_args = get_retro_args()
chunk_length = retro_args.retro_gpt_chunk_length
ns, bs, d = layernorm_output.shape
l = int(np.ceil(ns / chunk_length))
pad = (ns - 1) % chunk_length
attending_chunks = layernorm_output[pad:]
padded_chunks = torch.nn.functional.pad(
attending_chunks,
(0, 0, 0, 0, 0, chunk_length - 1),
'constant', 0)
padded_chunked_output = padded_chunks \
.reshape(l, chunk_length, bs, d) \
.permute(1, 2, 0, 3)
padded_chunked_output = padded_chunked_output.reshape(
chunk_length, bs * l, d).contiguous()
# Encoder output.
attention_output, attention_bias = \
self.inter_attention(padded_chunked_output,
None,
encoder_output=retriever_output)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(attention_output),
torch.zeros_like(attention_output),
self.hidden_dropout)
layernorm_input = layernorm_input \
.reshape(chunk_length, bs, l, d) \
.permute(2, 0, 1, 3) # [l, m, bs, d]
layernorm_input = layernorm_input.reshape(chunk_length * l, bs, d)
layernorm_input = torch.nn.functional.pad(
layernorm_input,
(0, 0, 0, 0, pad, 0),
'constant', 0)[:ns]
layernorm_input = layernorm_input + residual
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output
class ParallelRetroEncoderTransformerCALayer(MegatronModule):
"""A single transformer layer for Retro Encoder with cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args()
super(ParallelRetroEncoderTransformerCALayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.self_attention.attention_dropout = \
torch.nn.Dropout(args.retro_encoder_attention_dropout)
self.hidden_dropout = args.retro_encoder_hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \
else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
retriever_output, retriever_attn_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# Neighbors.
args = get_args()
retro_args = get_retro_args()
retrieved_length = retro_args.retro_gpt_retrieved_length
num_neighbors = args.retro_num_neighbors
ns, bs, d = layernorm_output.shape # [r, bs * l * k, d]
chunked_outputs = layernorm_output.reshape(retrieved_length, -1,
num_neighbors, d)
chunked_outputs_before_layer_norm = \
layernorm_input.reshape(retrieved_length, -1,
num_neighbors, d) # [r, bs * l, k, d]
layernorm_inputs = []
layernorm_outputs = []
for k in range(num_neighbors):
chunked_output = chunked_outputs[:,:,k].contiguous()
attention_output, attention_bias = \
self.inter_attention(
chunked_output, # Q (neighbor embedding)
None,
encoder_output=retriever_output) # K, V (hidden act)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = chunked_output
else:
residual = chunked_outputs_before_layer_norm[:,:,k]
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
layernorm_inputs.append(layernorm_input)
# Layer norm post the decoder attention
layernorm_output = \
self.post_inter_attention_layernorm(layernorm_input)
layernorm_outputs.append(layernorm_output)
# layernorm_input : [r, k * bs * l, d]
# layernorm_output : [r, k * bs * l, d]
layernorm_input = \
torch.stack(layernorm_inputs, dim=1).reshape(ns, bs, d)
layernorm_output = \
torch.stack(layernorm_outputs, dim=1).reshape(ns, bs, d)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \
else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output
class ParallelRetroEncoder(MegatronModule):
""" Retro Transformer class for encoder ."""
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True,
drop_path_rate=0.0):
super(ParallelRetroEncoder, self).__init__()
args = get_args()
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.distribute_saved_activations = \
args.distribute_saved_activations and not args.sequence_parallel
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = args.retro_encoder_layers
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
if args.retro_add_retriever:
self.P = [1]
# Transformer layers.
assert args.retro_add_retriever
def build_layer(layer_number):
if layer_number in self.P:
return ParallelRetroEncoderTransformerCALayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
else:
layer = ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
layer.self_attention.attention_dropout = \
torch.nn.Dropout(args.retro_encoder_attention_dropout)
layer.hidden_dropout = args.retro_encoder_hidden_dropout
return layer
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in
# the stage, divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an
# assignment of layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an
# assignment of layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \
parallel_state.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
else:
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process:
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and
# checkpoint the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = parallel_state.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = parallel_state.checkpoint(
custom(l, l + 1),
self.distribute_checkpointed_activations,
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask,
retriever_output, retriever_attn_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# Checks.
if inference_params:
assert self.activations_checkpoint_method is None, \
'inference does not work with activation checkpointing'
if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = core_utils.make_viewless_tensor(
hidden_states,
requires_grad = True,
keep_graph = True,
)
# Transpose encoder output.
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
args = get_args()
assert not args.sequence_parallel, "if SP, need rng context."
# Forward pass.
if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
if index + 1 in self.P:
hidden_states = layer(
hidden_states,
attention_mask,
retriever_output=retriever_output,
retriever_attn_mask=retriever_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
else:
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# Final layer norm.
if self.post_process:
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
return output
class ParallelRetroTransformer(MegatronModule):
"""Standard GPT Transformer class."""
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True,
drop_path_rate=0.0, retriever=None):
super(ParallelRetroTransformer, self).__init__()
args = get_args()
assert pre_process and post_process, \
"pipeline parallelism un-supported."
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.distribute_saved_activations = \
args.distribute_saved_activations and not args.sequence_parallel
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = _get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
if args.retro_add_retriever:
if args.num_layers == 12:
self.P = [6, 9, 12]
elif args.num_layers == 24:
self.P = np.arange(9, 25, 3).tolist()
elif args.num_layers == 40:
self.P = np.arange(9, 41, 3).tolist()
self.P.append(40)
self.retriever = retriever
# Transformer layers.
assert args.retro_add_retriever
def build_layer(layer_number):
if layer_number == min(self.P):
return ParallelRetroTransformerEncoderLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1],
retriever=retriever
)
elif layer_number in self.P:
return ParallelRetroTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
else:
return ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \
parallel_state.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
else:
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process:
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = parallel_state.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = parallel_state.checkpoint(
custom(l, l + 1),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask,
retriever_output=None, retriever_attn_mask=None,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# Checks.
if inference_params:
assert self.recompute_granularity is None, \
'inference does not work with activation checkpointing'
args = get_args()
# Transpose retriever output, to match hidden_states shape.
retriever_output = retriever_output.transpose(0, 1).contiguous()
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = core_utils.make_viewless_tensor(
hidden_states,
requires_grad=True,
keep_graph=True,
)
# Transpose encoder output.
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
# Forward pass.
assert not args.sequence_parallel, "if SP, need rng context."
if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
if args.retro_add_retriever and index + 1 == min(self.P):
hidden_states, E = layer(
hidden_states,
attention_mask,
retriever_output=retriever_output,
retriever_attn_mask=retriever_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
elif args.retro_add_retriever and index + 1 in self.P:
hidden_states = layer(
hidden_states,
attention_mask,
retriever_output=E,
retriever_attn_mask=retriever_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
else:
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# Final layer norm.
output = self.final_layernorm(hidden_states)
return output
......@@ -11,9 +11,7 @@ from megatron.model.language_model import parallel_lm_logits, get_language_model
from megatron.model import LayerNorm
from megatron.model.utils import (
openai_gelu,
get_linear_layer,
init_method_normal,
scaled_init_method_normal
get_linear_layer
)
from .module import MegatronModule
......@@ -43,17 +41,12 @@ class T5LMHead(MegatronModule):
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, parallel_output):
super(T5LMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True
self.bias.partition_dim = 0
......@@ -72,41 +65,38 @@ class T5Model(MegatronModule):
"""T5 Language model."""
def __init__(self,
config,
num_tokentypes=0,
parallel_output=True,
pre_process=True,
post_process=True,
add_encoder=True,
add_decoder=True):
super(T5Model, self).__init__()
super().__init__(config=config)
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.pre_process = pre_process
self.post_process = post_process
self.add_encoder = add_encoder
self.add_decoder = add_decoder
self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes,
add_pooler=False,
add_encoder=add_encoder,
add_decoder=add_decoder,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal)
self.initialize_word_embeddings()
if self.post_process and self.add_decoder:
self.lm_head = T5LMHead(
self.word_embeddings_weight().size(0),
self.shared_embedding_or_output_weight().size(0),
parallel_output)
self._lm_head_key = 'lm_head'
......@@ -139,7 +129,7 @@ class T5Model(MegatronModule):
decoder_output, encoder_output = lm_output
# Output. [s, b, h]
lm_logits = self.lm_head(decoder_output,
self.word_embeddings_weight())
self.shared_embedding_or_output_weight())
if lm_labels is None:
# [s b h] => [b s h]
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Transformer."""
import math
from contextlib import nullcontext
import math
import numpy as np
import torch
import torch.nn.functional as F
from typing import Optional
from megatron import get_timers, get_args, core, get_num_microbatches
from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches
from .module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
......@@ -15,7 +16,7 @@ from megatron.model import LayerNorm
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.models.common.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
try:
......@@ -26,7 +27,10 @@ except ImportError:
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None
""" We use the following notation throughout this file:
h: hidden size
......@@ -65,18 +69,6 @@ class DropPath(MegatronModule):
output = hidden_state.div(keep_prob) * random_tensor
return output
def _args_to_kwargs():
args = get_args()
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": args.perform_initialization,
"gradient_accumulation_fusion": args.gradient_accumulation_fusion,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs
class ParallelMLP(MegatronModule):
"""MLP.
......@@ -85,22 +77,26 @@ class ParallelMLP(MegatronModule):
state back into h hidden dimension.
"""
def __init__(self, init_method, output_layer_init_method):
def __init__(self, config):
super(ParallelMLP, self).__init__()
args = get_args()
self.add_bias = args.add_bias_linear
self.add_bias = config.add_bias_linear
ffn_hidden_size = config.ffn_hidden_size
if config.gated_linear_unit:
ffn_hidden_size *= 2
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.ffn_hidden_size * 2 if args.swiglu else args.ffn_hidden_size,
config.hidden_size,
ffn_hidden_size,
config=config,
init_method=config.init_method,
bias=self.add_bias,
gather_output=False,
init_method=init_method,
skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
)
self.bias_gelu_fusion = False
self.activation_func = None
......@@ -125,13 +121,13 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
args.ffn_hidden_size,
args.hidden_size,
config.ffn_hidden_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=self.add_bias,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
**_args_to_kwargs())
input_is_parallel=True
)
def forward(self, hidden_states):
......@@ -155,13 +151,13 @@ class SwitchMLP(MegatronModule):
"""
Routes input to one of N MLP "experts"
"""
def __init__(self, init_method, output_layer_init_method):
def __init__(self, config):
super(SwitchMLP, self).__init__()
args = get_args()
self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
self.router = torch.nn.Linear(config.hidden_size, args.num_experts)
self.experts = torch.nn.ModuleList()
for i in range(args.num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method))
self.experts.append(ParallelMLP(config))
def forward(self, hidden_states):
# hidden_states: [s, b, h]
......@@ -188,45 +184,48 @@ class SwitchMLP(MegatronModule):
local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:]
output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output)
if output_bias is not None:
output_bias = output_bias.expand_as(output)
output_bias_total[local_indices,:] = output_bias
output_total[local_indices,:] = output
output_bias_total[local_indices,:] = output_bias
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(s, b, h)
output_bias_total = output_bias_total.view(s, b, h)
if output_bias is not None:
output_bias_total = output_bias_total*max_prob
output_bias_total = output_bias_total.view(s, b, h)
else:
output_bias_total = None
return output_total, output_bias_total
class CoreAttention(MegatronModule):
def __init__(self, layer_number,
def __init__(self, layer_number, config,
attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.fp16 = config.fp16
self.bf16 = config.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.sequence_parallel = args.sequence_parallel
self.sequence_parallel = config.sequence_parallel
projection_size = args.kv_channels * args.num_attention_heads
projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads)
projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size)
config.num_attention_heads, world_size)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
......@@ -237,7 +236,7 @@ class CoreAttention(MegatronModule):
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
config.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
......@@ -245,7 +244,7 @@ class CoreAttention(MegatronModule):
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
def forward(self, query_layer, key_layer,
value_layer, attention_mask):
......@@ -261,8 +260,8 @@ class CoreAttention(MegatronModule):
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
query_layer = query_layer.reshape(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
......@@ -379,17 +378,18 @@ class FlashSelfAttention(torch.nn.Module):
is_causal = self.causal
cu_seqlens_k = cu_seqlens_q
dropout_p = self.dropout_p
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = seqlen_q == seqlen_k
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device)
self.dropout_p = 0
dropout_p = 0
output = flash_attn_unpadded_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
self.dropout_p,
dropout_p,
softmax_scale=self.softmax_scale, causal=is_causal
)
......@@ -404,8 +404,7 @@ class ParallelAttention(MegatronModule):
and returns output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number,
def __init__(self, config, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
......@@ -413,10 +412,21 @@ class ParallelAttention(MegatronModule):
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
self.sequence_parallel = args.sequence_parallel
self.params_dtype = config.params_dtype
self.sequence_parallel = config.sequence_parallel
self.group_query_attention = args.group_query_attention
self.num_query_groups = args.num_query_groups
query_projection_size = config.kv_channels * config.num_attention_heads
if self.group_query_attention:
kv_projection_size = args.kv_channels * args.num_query_groups
else:
kv_projection_size = args.kv_channels * args.num_attention_heads
self.use_flash_attn = args.use_flash_attn
self.use_flash_attn = args.use_flash_attn \
and attention_type == AttnType.self_attn \
and self.attn_mask_type == AttnMaskType.causal
if self.use_flash_attn:
if flash_attn_unpadded_func is None:
raise ImportError('FlashAttention is not installed, please install with '
......@@ -428,64 +438,72 @@ class ParallelAttention(MegatronModule):
if rearrange is None:
raise ImportError('einops is not installed, please install with pip install einops')
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads)
query_projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size)
config.num_attention_heads, world_size)
if self.group_query_attention:
if args.num_query_groups % world_size != 0:
raise NotImplementedError('Currently the num_query_groups should be '
'a multiple of the tensor parallel size')
self.num_query_groups_per_partition = core.utils.divide(
args.num_query_groups, world_size)
else:
self.num_query_groups_per_partition = self.num_attention_heads_per_partition
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
config.hidden_size,
query_projection_size + 2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=args.add_bias_linear,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
gather_output=False)
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
projection_size,
bias=args.add_bias_linear,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
if self.group_query_attention:
raise NotImplementedError("Grouped query attention not implemented for cross-attention.")
assert query_projection_size == kv_projection_size
self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
bias=args.add_bias_linear,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
self.query = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
query_projection_size,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
gather_output=False)
self.core_attention = CoreAttention(self.layer_number,
self.key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
gather_output=False)
self.core_attention = CoreAttention(self.layer_number, config,
self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
self.checkpoint_core_attention = config.recompute_granularity == 'selective'
if self.use_flash_attn:
self.core_attention_flash = FlashSelfAttention(
causal=True, attention_dropout=args.attention_dropout
causal=True, attention_dropout=config.attention_dropout
)
# Output.
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
query_projection_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=args.add_bias_linear,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True,
**_args_to_kwargs())
skip_bias_add=True)
def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask,
......@@ -510,11 +528,11 @@ class ParallelAttention(MegatronModule):
return hidden_states
def _allocate_memory(self, inference_max_sequence_len, batch_size):
def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
num_attention_heads,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())
......@@ -530,12 +548,15 @@ class ParallelAttention(MegatronModule):
is_first_step = False
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
is_first_step = True
......@@ -546,21 +567,36 @@ class ParallelAttention(MegatronModule):
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_query_groups_per_partition,
(
(self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
* self.hidden_size_per_attention_head
),
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query_layer,
key_layer,
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
key_layer,
value_layer) = torch.split(
mixed_x_layer,
[
(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head
],
dim=3)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
......@@ -568,19 +604,19 @@ class ParallelAttention(MegatronModule):
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
......@@ -632,11 +668,20 @@ class ParallelAttention(MegatronModule):
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
# ==================================
# core attention computation
# ==================================
# expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
key_layer = key_layer.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
dim = 2
)
value_layer = value_layer.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
dim = 2
)
# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
......@@ -711,10 +756,11 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
def __init__(self, config,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
# retriever=None):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
......@@ -722,57 +768,59 @@ class ParallelTransformerLayer(MegatronModule):
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
= config.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.bf16 = config.bf16
self.fp32_residual_connection = config.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
config,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.hidden_dropout = config.hidden_dropout
self.bias_dropout_fusion = config.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel,
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=not config.persist_layer_norm,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
if self.layer_type == LayerType.decoder:
# Cross attention.
if self.layer_type in (LayerType.decoder,
LayerType.retro_decoder,
LayerType.retro_decoder_with_retriever,
LayerType.retro_encoder):
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
config,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel,
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=not config.persist_layer_norm,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
self.mlp = SwitchMLP(config)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
self.mlp = ParallelMLP(config)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split('.')[0])
......@@ -781,13 +829,245 @@ class ParallelTransformerLayer(MegatronModule):
self.bias_dropout_add_exec_handler = \
nullcontext if use_nvfuser else torch.enable_grad
if args.retro_add_retriever:
retro_args = get_retro_args()
self.retro_num_neighbors = args.retro_num_neighbors
self.retro_chunk_length = retro_args.retro_gpt_chunk_length
self.retro_retrieved_length = retro_args.retro_gpt_retrieved_length
# Retriever (bi-directional transformer with cross attention)
if layer_type == LayerType.retro_decoder_with_retriever:
self.retriever = ParallelTransformer(
config=config,
model_type=ModelType.retro_encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True,
post_process=False,
)
self._retriever_key = 'retriever'
else:
self.retriever = None
def default_decoder_cross_attention(self,
encoder_output,
enc_dec_attn_mask,
layernorm_input,
layernorm_output,
bias_dropout_add_func):
'''Cross attention for a standard encoder-decoder model.'''
# Attention.
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if attention_bias is not None:
attention_bias = attention_bias.expand_as(residual)
# Bias-dropout-add.
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias,
residual,
self.hidden_dropout)
# Layer norm.
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
return layernorm_input, layernorm_output
def retro_encoder_cross_attention(self,
retriever_output,
layernorm_input,
layernorm_output,
bias_dropout_add_func):
"""Cross attention for Retro encoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
"""
ns, bs, d = layernorm_output.shape # [r, bs * l * k, d]
# Divide sequence dimension into chunks.
chunked_outputs = layernorm_output.reshape(self.retro_retrieved_length,
-1,
self.retro_num_neighbors,
d)
chunked_outputs_before_layer_norm = \
layernorm_input.reshape(self.retro_retrieved_length, -1,
self.retro_num_neighbors, d) # [r, bs*l, k, d]
# Per-chunk attention.
layernorm_inputs = []
layernorm_outputs = []
for k in range(self.retro_num_neighbors):
# Attention.
chunked_output = chunked_outputs[:,:,k].contiguous()
attention_output, attention_bias = \
self.inter_attention(
chunked_output, # Q (neighbor embedding)
None,
encoder_output=retriever_output) # K, V (hidden act)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = chunked_output
else:
residual = chunked_outputs_before_layer_norm[:,:,k]
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
None if attention_bias is None else attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
layernorm_inputs.append(layernorm_input)
# Layer norm.
layernorm_output = \
self.post_inter_attention_layernorm(layernorm_input)
layernorm_outputs.append(layernorm_output)
# Concatenate layer norms.
# layernorm_input : [r, k * bs * l, d]
# layernorm_output : [r, k * bs * l, d]
layernorm_input = \
torch.stack(layernorm_inputs, dim=1).reshape(ns, bs, d)
layernorm_output = \
torch.stack(layernorm_outputs, dim=1).reshape(ns, bs, d)
return layernorm_input, layernorm_output
def retro_decoder_cross_attention(self,
retriever_input,
retriever_output,
retriever_attn_mask,
layernorm_input,
layernorm_output,
inference_params,
bias_dropout_add_func):
"""Cross attention for Retro decoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
m : Number of tokens per chunk.
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
"""
ns, bs, d = layernorm_output.shape
l = int(np.ceil(ns / self.retro_chunk_length))
# Retrieve neighbors.
if self.layer_type == LayerType.retro_decoder_with_retriever:
first_ns = ns % self.retro_chunk_length
if first_ns > 0:
raise Exception("test this case.")
first_chunk, rest_chunk = \
layernorm_output[:first_ns], layernorm_output[first_ns:]
first_chunk = torch.nn.functional.pad(
first_chunk,
(0, 0, 0, 0, 0, self.retro_chunk_length - first_ns),
'constant',
0)
chunked_output = \
torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d]
else:
chunked_output = layernorm_output # [l * m, bs, d]
chunked_output = chunked_output \
.reshape(l, self.retro_chunk_length, bs, d) \
.permute(1, 2, 0, 3) \
.reshape(self.retro_chunk_length, bs * l, d) \
.contiguous()
# Get Encoder Output
retriever_output = self.retriever(
hidden_states=retriever_input,
attention_mask=retriever_attn_mask,
retriever_output=chunked_output,
retriever_attn_mask=retriever_attn_mask,
inference_params=inference_params) # [r, k * bs * l , d]
retriever_output = retriever_output.reshape(
self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d]
# Chunks.
pad = (ns - 1) % self.retro_chunk_length
attending_chunks = layernorm_output[pad:]
padded_chunks = torch.nn.functional.pad(
attending_chunks,
(0, 0, 0, 0, 0, self.retro_chunk_length - 1),
'constant', 0)
padded_chunked_output = padded_chunks \
.reshape(l, self.retro_chunk_length, bs, d) \
.permute(1, 2, 0, 3)
padded_chunked_output = padded_chunked_output.reshape(
self.retro_chunk_length, bs * l, d).contiguous()
# Encoder output.
attention_output, attention_bias = \
self.inter_attention(padded_chunked_output,
None,
encoder_output=retriever_output)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
None if attention_bias is None else attention_bias.expand_as(attention_output),
torch.zeros_like(attention_output),
self.hidden_dropout)
layernorm_input = layernorm_input \
.reshape(self.retro_chunk_length, bs, l, d) \
.permute(2, 0, 1, 3) # [l, m, bs, d]
layernorm_input = layernorm_input.reshape(self.retro_chunk_length * l, bs, d)
layernorm_input = torch.nn.functional.pad(
layernorm_input,
(0, 0, 0, 0, pad, 0),
'constant', 0)[:ns] # [ns, b, d]
layernorm_input = layernorm_input + residual
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
return retriever_output, layernorm_input, layernorm_output
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None, rotary_pos_emb=None):
retriever_input=None,
retriever_output=None,
retriever_attn_mask=None,
inference_params=None,
rotary_pos_emb=None):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
......@@ -832,29 +1112,38 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if attention_bias is not None:
attention_bias = attention_bias.expand_as(residual)
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias,
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# Cross attention.
if self.layer_type == LayerType.encoder:
pass
elif self.layer_type == LayerType.decoder:
layernorm_input, layernorm_output = \
self.default_decoder_cross_attention(
encoder_output,
enc_dec_attn_mask,
layernorm_input,
layernorm_output,
bias_dropout_add_func)
elif self.layer_type == LayerType.retro_encoder:
layernorm_input, layernorm_output = \
self.retro_encoder_cross_attention(
retriever_output,
layernorm_input,
layernorm_output,
bias_dropout_add_func)
elif self.layer_type in (LayerType.retro_decoder,
LayerType.retro_decoder_with_retriever):
retriever_output, layernorm_input, layernorm_output = \
self.retro_decoder_cross_attention(
retriever_input,
retriever_output,
retriever_attn_mask,
layernorm_input,
layernorm_output,
inference_params,
bias_dropout_add_func)
else:
raise Exception("Unsupported layer type, '%s'." %
self.layer_type.name)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
......@@ -893,7 +1182,10 @@ class ParallelTransformerLayer(MegatronModule):
training=self.training)
output = residual + self.drop_path(out)
return output
if self.layer_type == LayerType.retro_decoder_with_retriever:
return output, retriever_output
else:
return output
class NoopTransformerLayer(MegatronModule):
......@@ -922,9 +1214,12 @@ class NoopTransformerLayer(MegatronModule):
return hidden_states.clone()
def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
def _get_num_layers(args, model_type, is_decoder=False):
"""Compute the number of transformer layers resident on the current rank."""
if mpu.get_pipeline_model_parallel_world_size() > 1:
is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder)
if model_type == ModelType.retro_encoder:
num_layers = args.retro_encoder_layers
elif mpu.get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None
......@@ -974,51 +1269,91 @@ def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
return num_layers
def _get_layer_type(model_type, default_layer_type, retro_layer_numbers,
layer_number):
args = get_args()
if args.retro_add_retriever and layer_number in retro_layer_numbers:
if model_type == ModelType.retro_decoder:
return LayerType.retro_decoder_with_retriever \
if layer_number == retro_layer_numbers[0] \
else LayerType.retro_decoder
elif model_type == ModelType.retro_encoder:
return LayerType.retro_encoder
else:
raise Exception("Unsupported model type, '%s'." % model_type)
else:
return default_layer_type
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
def __init__(self, config,
model_type, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True,
pre_process=True, post_process=True,
pre_process=True,
post_process=True,
drop_path_rate=0.0):
super(ParallelTransformer, self).__init__()
args = get_args()
self.layer_type = layer_type
self.model_type = args.model_type
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.model_type = model_type
self.bf16 = config.bf16
self.fp32_residual_connection = config.fp32_residual_connection
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
self.transformer_impl = args.transformer_impl
self.retro_add_retriever = args.retro_add_retriever
# Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.recompute_granularity = config.recompute_granularity
self.recompute_method = config.recompute_method
self.recompute_num_layers = config.recompute_num_layers
self.distribute_saved_activations = \
args.distribute_saved_activations and not args.sequence_parallel
config.distribute_saved_activations and not config.sequence_parallel
self.sequence_parallel = args.sequence_parallel
self.sequence_parallel = config.sequence_parallel
# Transformer Engine Init.
self.transformer_engine_v_0_10 = False
self.transformer_engine_v_0_11 = False
self.transformer_engine_v_0_8 = False
if self.transformer_impl == 'transformer_engine':
global transformer_engine
import transformer_engine
self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid
from importlib.metadata import version
from pkg_resources import packaging
te_version = packaging.version.Version(version("transformer-engine"))
if te_version >= packaging.version.Version("0.8.0"):
self.transformer_engine_v_0_8 = True
if te_version >= packaging.version.Version("0.10.0"):
self.transformer_engine_v_0_10 = True
if te_version >= packaging.version.Version("0.11.0"):
self.transformer_engine_v_0_11 = True
del version, packaging
assert not args.squared_relu, "TransformerEngine does not support squared relu activation."
self.use_fp8 = args.fp8 is not None
self.fp8_recipe = None
self.fp8_group = None
if self.use_fp8:
self.fp8_group = mpu.get_data_parallel_group()
if args.fp8_e4m3:
assert args.transformer_impl == 'transformer_engine', \
'transformer-engine required for fp8 training and inference'
self.fp8_group = mpu.get_amax_reduction_group()
if args.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif args.fp8_hybrid:
elif args.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.")
self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=args.fp8_margin,
interval=args.fp8_interval,
......@@ -1030,63 +1365,87 @@ class ParallelTransformer(MegatronModule):
self.num_microbatches_in_previous_step = -1
self.microbatch_count = 0
self.checkpoint_core_attention = args.recompute_granularity == 'selective'
self.checkpoint_core_attention = config.recompute_granularity == 'selective'
# Number of layers.
self.num_layers = _get_num_layers(
args,
args.model_type == ModelType.encoder_and_decoder,
layer_type == LayerType.decoder)
self.num_layers = _get_num_layers(args, model_type,
layer_type==LayerType.decoder)
self.drop_path_rates = [
rate.item() for rate in
torch.linspace(0, self.drop_path_rate, config.num_layers)]
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
self.retro_layer_numbers = None
if model_type == ModelType.retro_decoder:
retro_layer_start = 6 if config.num_layers <= 15 else 9
self.retro_layer_numbers = \
np.arange(retro_layer_start, args.num_layers + 1, 3).tolist()
if model_type == ModelType.retro_encoder:
self.retro_layer_numbers = [1]
# Transformer layers.
if args.retro_add_retriever:
assert self.recompute_granularity != 'full', \
"Full recompute not supported for Retro."
assert args.transformer_impl == 'local', \
"Transformer engine does not support Retro layers."
def build_layer(layer_number):
if args.transformer_impl == 'local':
current_layer_type = _get_layer_type(
model_type, layer_type, self.retro_layer_numbers,
layer_number)
return ParallelTransformerLayer(
init_method,
output_layer_init_method,
config,
layer_number,
layer_type=layer_type,
layer_type=current_layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
else:
# This argument is only available from TE v0.10 onwards.
extra_transformer_engine_kwargs = {}
if self.transformer_engine_v_0_8:
extra_transformer_engine_kwargs["bias"] = args.add_bias_linear
if self.transformer_engine_v_0_10:
extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu"
if self.transformer_engine_v_0_11:
extra_transformer_engine_kwargs["normalization"] = args.normalization
return transformer_engine.pytorch.TransformerLayer(
args.hidden_size,
args.ffn_hidden_size,
args.num_attention_heads,
layernorm_epsilon=args.layernorm_epsilon,
hidden_dropout=args.hidden_dropout,
attention_dropout=args.attention_dropout,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
config.hidden_size,
config.ffn_hidden_size,
config.num_attention_heads,
layernorm_epsilon=config.layernorm_epsilon,
hidden_dropout=config.hidden_dropout,
attention_dropout=config.attention_dropout,
init_method=config.init_method,
output_layer_init_method=config.output_layer_init_method,
layer_number=layer_number,
kv_channels=args.kv_channels,
kv_channels=config.kv_channels,
self_attn_mask_type=self_attn_mask_type.name,
tp_group=mpu.get_tensor_model_parallel_group(),
get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
fuse_wgrad_accumulation=args.gradient_accumulation_fusion,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32,
fuse_wgrad_accumulation=config.gradient_accumulation_fusion,
apply_query_key_layer_scaling=config.apply_query_key_layer_scaling,
attention_softmax_in_fp32=config.attention_softmax_in_fp32,
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
sequence_parallel=args.sequence_parallel,
params_dtype=args.params_dtype,
apply_residual_connection_post_layernorm=args.apply_residual_connection_post_layernorm,
sequence_parallel=config.sequence_parallel,
params_dtype=config.params_dtype,
apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm,
output_layernorm=False,
layer_type="encoder",
drop_path_rate=self.drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=True)
fuse_qkv_params=True,
**extra_transformer_engine_kwargs)
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
if config.virtual_pipeline_model_parallel_size is not None:
assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
......@@ -1096,7 +1455,7 @@ class ParallelTransformer(MegatronModule):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
config.num_layers // config.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
......@@ -1126,13 +1485,24 @@ class ParallelTransformer(MegatronModule):
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
# Update dropout rate for Retro encoder.
if model_type == ModelType.retro_encoder:
for layer in self.layers:
if layer.self_attention.use_flash_attn:
layer.self_attention.core_attention_flash.dropout_p = \
torch.nn.Dropout(args.retro_encoder_attention_dropout)
else:
layer.self_attention.core_attention.attention_dropout.p =\
args.retro_encoder_attention_dropout
layer.hidden_dropout = args.retro_encoder_hidden_dropout
if self.post_process and self.post_layer_norm:
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
config.hidden_size,
eps=config.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel,
sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p)
def _get_layer(self, layer_number):
......@@ -1142,40 +1512,42 @@ class ParallelTransformer(MegatronModule):
encoder_output, enc_dec_attn_mask,
rotary_pos_emb, is_first_microbatch):
"""Forward method with activation checkpointing."""
def custom(start, end, is_transformer_engine=False):
def custom(start, end):
def custom_forward(*args, **kwargs):
x_, *args = args
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, *args, **kwargs)
return x_
def custom_forward_transformer_engine(*args, **kwargs):
return custom_forward(*args, is_first_microbatch=is_first_microbatch, **kwargs)
if not is_transformer_engine:
return custom_forward
else:
return custom_forward_transformer_engine
return custom_forward
te_forward_kwargs = {}
if self.transformer_impl == 'transformer_engine':
te_forward_kwargs['is_first_microbatch'] = is_first_microbatch
if self.transformer_engine_v_0_10:
te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
if self.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# Uniformly divide the total number of Transformer layers and
# checkpoint the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
if self.transformer_impl == 'transformer_engine':
hidden_states = transformer_engine.pytorch.distributed.checkpoint(
custom(l, l + self.recompute_num_layers, is_transformer_engine=True),
hidden_states = transformer_engine.pytorch.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
enc_dec_attn_mask, **te_forward_kwargs)
else:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
l += self.recompute_num_layers
......@@ -1186,28 +1558,30 @@ class ParallelTransformer(MegatronModule):
for l in range(self.num_layers):
if l < self.recompute_num_layers:
if self.transformer_impl == 'transformer_engine':
hidden_states = transformer_engine.pytorch.distributed.checkpoint(
custom(l, l + 1, is_transformer_engine=True),
hidden_states = transformer_engine.pytorch.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
enc_dec_attn_mask, **te_forward_kwargs)
else:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
else:
if self.transformer_impl == 'transformer_engine':
hidden_states = custom(l, l + 1, is_transformer_engine=True)(
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
enc_dec_attn_mask, **te_forward_kwargs)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb)
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
else:
raise ValueError("Invalid activation recompute method.")
......@@ -1225,7 +1599,11 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None, rotary_pos_emb=None):
retriever_input=None,
retriever_output=None,
retriever_attn_mask=None,
inference_params=None,
rotary_pos_emb=None):
# hidden_states: [s, b, h]
# Checks.
......@@ -1258,11 +1636,13 @@ class ParallelTransformer(MegatronModule):
keep_graph=True,
)
# RNG context.
if self.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else:
rng_context = nullcontext()
# Forward layers.
with rng_context:
# The fp8_autocast context manager is a no-op when enabled=True
# The if...else serves to short circuit name resolution for fp8_autocast
......@@ -1290,12 +1670,18 @@ class ParallelTransformer(MegatronModule):
'encoder_output': encoder_output,
'enc_dec_attn_mask': enc_dec_attn_mask,
'inference_params': inference_params,
'rotary_pos_emb': rotary_pos_emb,
}
if self.transformer_impl == 'transformer_engine':
forward_kwargs['is_first_microbatch'] = is_first_microbatch
forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention
if self.transformer_engine_v_0_10:
forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
else:
forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
forward_kwargs['retriever_input'] = retriever_input
forward_kwargs['retriever_output'] = retriever_output
forward_kwargs['retriever_attn_mask'] = retriever_attn_mask
for index in range(self.num_layers):
layer = self._get_layer(index)
......@@ -1305,6 +1691,14 @@ class ParallelTransformer(MegatronModule):
attention_mask,
**forward_kwargs)
# First Retro decoder layer returns both hidden_states
# and retriever_output. Make retriever_output available
# to subsequence Retro layers.
if isinstance(hidden_states, tuple):
assert len(hidden_states) == 2
hidden_states, retriever_output = hidden_states
forward_kwargs["retriever_output"] = retriever_output
# Skip counter update for eval and activation checkpointing
if torch.is_grad_enabled() and self.training:
self.microbatch_count += 1
......
......@@ -13,7 +13,7 @@ from megatron.model.module import MegatronModule
class VitClassificationModel(MegatronModule):
"""Vision Transformer Model."""
def __init__(self, num_classes, finetune=False,
def __init__(self, config, num_classes, finetune=False,
pre_process=True, post_process=True):
super(VitClassificationModel, self).__init__()
args = get_args()
......@@ -24,6 +24,7 @@ class VitClassificationModel(MegatronModule):
self.pre_process = pre_process
self.post_process = post_process
self.backbone = VitBackbone(
config=config,
pre_process=self.pre_process,
post_process=self.post_process,
single_token_output=True
......
......@@ -173,11 +173,12 @@ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
return schedule
def get_student_backbone_and_num_features(pre_process=True, post_process=True):
def get_student_backbone_and_num_features(config, pre_process=True, post_process=True):
args = get_args()
if args.vision_backbone_type == 'vit':
student = VitBackbone(pre_process=pre_process,
student = VitBackbone(config,
pre_process=pre_process,
post_process=post_process,
drop_path_rate=0.1,
single_token_output=True)
......@@ -194,11 +195,12 @@ def get_student_backbone_and_num_features(pre_process=True, post_process=True):
return student, num_features
def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
def get_teacher_backbone_and_num_features(config, pre_process=True, post_process=True):
args = get_args()
if args.vision_backbone_type == 'vit':
teacher = VitBackbone(pre_process=pre_process,
teacher = VitBackbone(config,
pre_process=pre_process,
post_process=post_process,
single_token_output=True)
num_features = args.hidden_size
......@@ -215,7 +217,7 @@ def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
class DINOPretrainModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True):
def __init__(self, config, pre_process=True, post_process=True):
super(DINOPretrainModel, self).__init__()
args = get_args()
self.out_dim = 65536
......@@ -234,7 +236,7 @@ class DINOPretrainModel(MegatronModule):
self.momentum_teacher = 0.996
student_backbone, num_features = \
get_student_backbone_and_num_features(pre_process, post_process)
get_student_backbone_and_num_features(config, pre_process, post_process)
self.student = MultiCropWrapper(
student_backbone,
......@@ -249,7 +251,7 @@ class DINOPretrainModel(MegatronModule):
)
teacher_backbone, num_features = \
get_teacher_backbone_and_num_features(pre_process, post_process)
get_teacher_backbone_and_num_features(config, pre_process, post_process)
self.teacher = MultiCropWrapper(
teacher_backbone,
DINOHead(num_features, self.out_dim)
......
......@@ -18,14 +18,15 @@ from megatron.model.vision.utils import resize_
class VitInpaintingModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True):
def __init__(self, config, pre_process=True, post_process=True):
super(VitInpaintingModel, self).__init__()
args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.hidden_size = config.hidden_size
self.backbone = VitBackbone(
config=config,
pre_process=self.pre_process,
post_process=self.post_process,
class_token=False,
......
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# found in the LICENSE file in the root directory of this
# source tree.
# ---------------------------------------------------------------
# Copyright (c) 2023, NVIDIA Corporation. All rights reserved.
import math
import torch
import torch.nn as nn
......
......@@ -130,24 +130,17 @@ class VitBackbone(MegatronModule):
"""Vision Transformer Model."""
def __init__(self,
config,
pre_process=True,
post_process=True,
class_token=True,
single_token_output=False,
post_layer_norm=True,
drop_path_rate=0.0):
super(VitBackbone, self).__init__(share_word_embeddings=False)
super(VitBackbone, self).__init__(share_embeddings_and_output_weights=False)
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
if args.init_method_xavier_uniform:
self.init_method = torch.nn.init.xavier_uniform_
self.scaled_init_method = torch.nn.init.xavier_uniform_
else:
self.init_method = init_method_normal(args.init_method_std)
self.scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers
)
self.pre_process = pre_process
self.post_process = post_process
......@@ -202,8 +195,7 @@ class VitBackbone(MegatronModule):
# Transformer
self.transformer = ParallelTransformer(
self.init_method,
self.scaled_init_method,
config,
pre_process=self.pre_process,
post_process=self.post_process,
post_layer_norm=self.post_layer_norm,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment