Commit 051f58f1 authored by liangjing's avatar liangjing
Browse files

v1

parent 0024a5c6
...@@ -80,7 +80,7 @@ def _set_signal_handler(): ...@@ -80,7 +80,7 @@ def _set_signal_handler():
def set_global_variables(args): def set_global_variables(args, build_tokenizer=True):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
assert args is not None assert args is not None
...@@ -89,7 +89,7 @@ def set_global_variables(args): ...@@ -89,7 +89,7 @@ def set_global_variables(args):
set_args(args) set_args(args)
_build_num_microbatches_calculator(args) _build_num_microbatches_calculator(args)
if args.vocab_file or args.tokenizer_model: if build_tokenizer:
_ = _build_tokenizer(args) _ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
......
...@@ -15,36 +15,40 @@ from megatron import get_adlr_autoresume ...@@ -15,36 +15,40 @@ from megatron import get_adlr_autoresume
from megatron import get_args from megatron import get_args
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.arguments import (parse_args, validate_args) from megatron.arguments import parse_args, validate_args
from megatron.checkpointing import load_args_from_checkpoint from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
from megatron.model.transformer import bias_dropout_add_fused_train from megatron.model.transformer import bias_dropout_add_fused_train
from megatron.model.fused_bias_gelu import bias_gelu from megatron.model.fused_bias_gelu import bias_gelu
def initialize_megatron(extra_args_provider=None, args_defaults={}, def initialize_megatron(
ignore_unknown_args=False, allow_no_cuda=False): extra_args_provider=None,
args_defaults={},
ignore_unknown_args=False,
allow_no_cuda=False,
):
"""Set global variables, initialize distributed, and """Set global variables, initialize distributed, and
set autoresume and random seeds. set autoresume and random seeds.
`allow_no_cuda` should not be set unless using megatron for cpu only `allow_no_cuda` should not be set unless using megatron for cpu only
data processing. In general this arg should not be set unless you know data processing. In general this arg should not be set unless you know
what you are doing. what you are doing.
Returns a function to finalize distributed env initialization Returns a function to finalize distributed env initialization
(optionally, only when args.lazy_mpu_init == True) (optionally, only when args.lazy_mpu_init == True)
""" """
if not allow_no_cuda: if not allow_no_cuda:
# Make sure cuda is available. # Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.' assert torch.cuda.is_available(), "Megatron requires CUDA."
# Parse arguments # Parse arguments
args = parse_args(extra_args_provider, ignore_unknown_args) args = parse_args(extra_args_provider, ignore_unknown_args)
if args.use_checkpoint_args or args_defaults.get('use_checkpoint_args', False): if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False):
assert args.load is not None, '--use-checkpoints-args requires --load argument' assert args.load is not None, "--use-checkpoints-args requires --load argument"
load_args_from_checkpoint(args) load_args_from_checkpoint(args)
validate_args(args, args_defaults) validate_args(args, args_defaults)
# set global args, build tokenizer, and set adlr-autoresume, # set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers. # tensorboard-writer, and timers.
set_global_variables(args) set_global_variables(args)
...@@ -54,16 +58,16 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}, ...@@ -54,16 +58,16 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
args = get_args() args = get_args()
# Pytorch distributed. # Pytorch distributed.
_initialize_distributed() _initialize_distributed()
# Random seeds for reproducibility. # Random seeds for reproducibility.
if args.rank == 0: if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed)) print("> setting random seeds to {} ...".format(args.seed))
_set_random_seed(args.seed, args.data_parallel_random_init) _set_random_seed(args.seed, args.data_parallel_random_init)
args = get_args() args = get_args()
if args.lazy_mpu_init: if args.lazy_mpu_init:
# TODO is this still a necessary option? # TODO is this still a necessary option?
args.use_cpu_initialization=True args.use_cpu_initialization = True
# delayed initialization of DDP-related stuff # delayed initialization of DDP-related stuff
# We only set basic DDP globals # We only set basic DDP globals
mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size) mpu.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
...@@ -95,11 +99,15 @@ def _compile_dependencies(): ...@@ -95,11 +99,15 @@ def _compile_dependencies():
# TODO: move this to ninja # TODO: move this to ninja
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
start_time = time.time() start_time = time.time()
print('> compiling dataset index builder ...') print("> compiling dataset index builder ...")
from megatron.data.dataset_utils import compile_helper from megatron.data.dataset_utils import compile_helper
compile_helper() compile_helper()
print('>>> done with dataset index builder. Compilation time: {:.3f} ' print(
'seconds'.format(time.time() - start_time), flush=True) ">>> done with dataset index builder. Compilation time: {:.3f} "
"seconds".format(time.time() - start_time),
flush=True,
)
# ================== # ==================
# Load fused kernels # Load fused kernels
...@@ -107,41 +115,51 @@ def _compile_dependencies(): ...@@ -107,41 +115,51 @@ def _compile_dependencies():
# Custom kernel constraints check. # Custom kernel constraints check.
seq_len = args.seq_length seq_len = args.seq_length
attn_batch_size = \ attn_batch_size = (
(args.num_attention_heads / args.tensor_model_parallel_size) * \ args.num_attention_heads / args.tensor_model_parallel_size
args.micro_batch_size ) * args.micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based # Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask) # optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \ custom_kernel_constraint = (
seq_len % 4 == 0 and attn_batch_size % 4 == 0 seq_len > 16
and seq_len <= 16384
and seq_len % 4 == 0
and attn_batch_size % 4 == 0
)
# Print a warning. # Print a warning.
if not ((args.fp16 or args.bf16) and if not (
custom_kernel_constraint and (args.fp16 or args.bf16)
args.masked_softmax_fusion): and custom_kernel_constraint
and args.masked_softmax_fusion
):
if args.rank == 0: if args.rank == 0:
print('WARNING: constraints for invoking optimized' print(
' fused softmax kernel are not met. We default' "WARNING: constraints for invoking optimized"
' back to unfused kernel invocations.', flush=True) " fused softmax kernel are not met. We default"
" back to unfused kernel invocations.",
flush=True,
)
# Always build on rank zero first. # Always build on rank zero first.
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
start_time = time.time() start_time = time.time()
print('> compiling and loading fused kernels ...', flush=True) print("> compiling and loading fused kernels ...", flush=True)
fused_kernels.load(args) #fused_kernels.load(args)
torch.distributed.barrier() torch.distributed.barrier()
else: else:
torch.distributed.barrier() torch.distributed.barrier()
fused_kernels.load(args) #fused_kernels.load(args)
# Simple barrier to make sure all ranks have passed the # Simple barrier to make sure all ranks have passed the
# compilation phase successfully before moving on to the # compilation phase successfully before moving on to the
# rest of the program. We think this might ensure that # rest of the program. We think this might ensure that
# the lock is released. # the lock is released.
torch.distributed.barrier() torch.distributed.barrier()
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('>>> done with compiling and loading fused kernels. ' print(
'Compilation time: {:.3f} seconds'.format( ">>> done with compiling and loading fused kernels. "
time.time() - start_time), flush=True) "Compilation time: {:.3f} seconds".format(time.time() - start_time),
flush=True,
)
def _initialize_distributed(): def _initialize_distributed():
...@@ -152,45 +170,58 @@ def _initialize_distributed(): ...@@ -152,45 +170,58 @@ def _initialize_distributed():
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
if args.rank == 0: if args.rank == 0:
print('torch distributed is already initialized, ' print(
'skipping initialization ...', flush=True) "torch distributed is already initialized, "
args.rank = torch.distributed.get_rank() "skipping initialization ...",
args.world_size = torch.distributed.get_world_size() flush=True,
)
#args.rank = torch.distributed.get_rank()
#args.world_size = torch.distributed.get_world_size()
else: else:
if args.rank == 0: if args.rank == 0:
print('> initializing torch distributed ...', flush=True) print("> initializing torch distributed ...", flush=True)
# Manually set the device ids. # Manually set the device ids.
if device_count > 0: if device_count > 0:
device = args.rank % device_count device = args.rank % device_count
if args.local_rank is not None: if args.local_rank is not None:
assert args.local_rank == device, \ assert (
'expected local-rank to be the same as rank % device-count.' args.local_rank == device
), "expected local-rank to be the same as rank % device-count."
else: else:
args.local_rank = device args.local_rank = device
torch.cuda.set_device(device) torch.cuda.set_device(device)
# Call the init process # Call the init process
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=args.distributed_backend, backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank, world_size=args.world_size,
timeout=timedelta(minutes=args.distributed_timeout_minutes)) rank=args.rank,init_method=args.dist_url,
timeout=timedelta(minutes=args.distributed_timeout_minutes),
)
# Set the tensor model-parallel, pipeline model-parallel, and # Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators. # data-parallel communicators.
if device_count > 0: if device_count > 0:
if mpu.model_parallel_is_initialized(): if mpu.model_parallel_is_initialized():
print('model parallel is already initialized') print("model parallel is already initialized")
else: else:
mpu.initialize_model_parallel(args.tensor_model_parallel_size, mpu.initialize_model_parallel(
args.pipeline_model_parallel_size, args.tensor_model_parallel_size,
args.virtual_pipeline_model_parallel_size, args.pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank) args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank,
args.fp8 is not None,
)
if args.rank == 0: if args.rank == 0:
print(f'> initialized tensor model parallel with size ' print(
f'{mpu.get_tensor_model_parallel_world_size()}') f"> initialized tensor model parallel with size "
print(f'> initialized pipeline model parallel with size ' f"{mpu.get_tensor_model_parallel_world_size()}"
f'{mpu.get_pipeline_model_parallel_world_size()}') )
print(
f"> initialized pipeline model parallel with size "
f"{mpu.get_pipeline_model_parallel_world_size()}"
)
def _init_autoresume(): def _init_autoresume():
...@@ -216,7 +247,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False): ...@@ -216,7 +247,7 @@ def _set_random_seed(seed_, data_parallel_random_init=False):
if torch.cuda.device_count() > 0: if torch.cuda.device_count() > 0:
tensor_parallel.model_parallel_cuda_manual_seed(seed) tensor_parallel.model_parallel_cuda_manual_seed(seed)
else: else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed)) raise ValueError("Seed ({}) should be a positive integer.".format(seed))
def write_args_to_tensorboard(): def write_args_to_tensorboard():
...@@ -225,15 +256,14 @@ def write_args_to_tensorboard(): ...@@ -225,15 +256,14 @@ def write_args_to_tensorboard():
writer = get_tensorboard_writer() writer = get_tensorboard_writer()
if writer: if writer:
for arg in vars(args): for arg in vars(args):
writer.add_text(arg, str(getattr(args, arg)), writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration)
global_step=args.iteration)
def set_jit_fusion_options(): def set_jit_fusion_options():
"""Set PyTorch JIT layer fusion options.""" """Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split(".")[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser # nvfuser
torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_executor(True)
...@@ -241,7 +271,7 @@ def set_jit_fusion_options(): ...@@ -241,7 +271,7 @@ def set_jit_fusion_options():
torch._C._jit_override_can_fuse_on_cpu(False) torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False) torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True) torch._C._jit_set_nvfuser_enabled(False)
torch._C._debug_set_autodiff_subgraph_inlining(False) torch._C._debug_set_autodiff_subgraph_inlining(False)
else: else:
# legacy pytorch fuser # legacy pytorch fuser
...@@ -254,7 +284,7 @@ def set_jit_fusion_options(): ...@@ -254,7 +284,7 @@ def set_jit_fusion_options():
def _warmup_jit_function(): def _warmup_jit_function():
""" Compilie JIT functions before the main training steps """ """Compilie JIT functions before the main training steps"""
args = get_args() args = get_args()
if args.bf16: if args.bf16:
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -264,11 +294,20 @@ def _warmup_jit_function(): ...@@ -264,11 +294,20 @@ def _warmup_jit_function():
dtype = torch.float32 dtype = torch.float32
# Warmup fused bias+gelu # Warmup fused bias+gelu
bias = torch.rand(args.ffn_hidden_size // args.tensor_model_parallel_size, bias = torch.rand(
dtype=dtype, device='cuda') args.ffn_hidden_size // args.tensor_model_parallel_size,
input = torch.rand((args.seq_length, args.micro_batch_size, dtype=dtype,
args.ffn_hidden_size // args.tensor_model_parallel_size), device="cuda",
dtype=dtype, device='cuda') )
input = torch.rand(
(
args.seq_length,
args.micro_batch_size,
args.ffn_hidden_size // args.tensor_model_parallel_size,
),
dtype=dtype,
device="cuda",
)
# Warmup JIT fusions with the input grad_enable state of both forward # Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation # prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]): for bias_grad, input_grad in zip([True, True], [False, True]):
...@@ -282,15 +321,25 @@ def _warmup_jit_function(): ...@@ -282,15 +321,25 @@ def _warmup_jit_function():
seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size() seq_length = args.seq_length // mpu.get_tensor_model_parallel_world_size()
else: else:
seq_length = args.seq_length seq_length = args.seq_length
input = torch.rand((seq_length, args.micro_batch_size, args.hidden_size), input = torch.rand(
dtype=dtype, device='cuda') (seq_length, args.micro_batch_size, args.hidden_size),
residual = torch.rand((seq_length, args.micro_batch_size, args.hidden_size), dtype=dtype,
dtype=dtype, device='cuda') device="cuda",
bias = torch.rand((args.hidden_size), dtype=dtype, device='cuda').expand_as(residual) )
residual = torch.rand(
(seq_length, args.micro_batch_size, args.hidden_size),
dtype=dtype,
device="cuda",
)
bias = torch.rand((args.hidden_size), dtype=dtype, device="cuda").expand_as(
residual
)
dropout_rate = 0.1 dropout_rate = 0.1
# Warmup JIT fusions with the input grad_enable state of both forward # Warmup JIT fusions with the input grad_enable state of both forward
# prop and recomputation # prop and recomputation
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): for input_grad, bias_grad, residual_grad in zip(
[False, True], [True, True], [True, True]
):
input.requires_grad = input_grad input.requires_grad = input_grad
bias.requires_grad = bias_grad bias.requires_grad = bias_grad
residual.requires_grad = residual_grad residual.requires_grad = residual_grad
......
...@@ -47,31 +47,27 @@ class BertLMHead(MegatronModule): ...@@ -47,31 +47,27 @@ class BertLMHead(MegatronModule):
"""Masked LM head for Bert """Masked LM head for Bert
Arguments: Arguments:
config: TransformerConfig object
mpu_vocab_size: model parallel size of vocabulary. mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not. parallel_output: whether output logits being distributed or not.
""" """
def __init__(self, mpu_vocab_size, hidden_size, init_method, def __init__(self, mpu_vocab_size, hidden_size, config, parallel_output):
layernorm_epsilon, parallel_output): super().__init__(config=config)
super(BertLMHead, self).__init__()
args = get_args() args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1) tensor_parallel.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, config.init_method)
setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel) setattr(self.dense.weight, 'sequence_parallel', config.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel) setattr(self.dense.bias, 'sequence_parallel', config.sequence_parallel)
self.layernorm = LayerNorm(hidden_size, self.layernorm = LayerNorm(hidden_size,
eps=layernorm_epsilon, eps=config.layernorm_epsilon,
sequence_parallel=args.sequence_parallel) sequence_parallel=config.sequence_parallel)
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
if args.openai_gelu: if args.openai_gelu:
self.gelu = openai_gelu self.gelu = openai_gelu
...@@ -124,12 +120,13 @@ class BertModel(MegatronModule): ...@@ -124,12 +120,13 @@ class BertModel(MegatronModule):
"""Bert Language model.""" """Bert Language model."""
def __init__(self, def __init__(self,
config,
num_tokentypes=2, num_tokentypes=2,
add_binary_head=True, add_binary_head=True,
parallel_output=True, parallel_output=True,
pre_process=True, pre_process=True,
post_process=True): post_process=True):
super(BertModel, self).__init__() super().__init__(config=config)
args = get_args() args = get_args()
# TODO this option is not yet implemented in BERT # TODO this option is not yet implemented in BERT
...@@ -145,29 +142,23 @@ class BertModel(MegatronModule): ...@@ -145,29 +142,23 @@ class BertModel(MegatronModule):
if self.return_embeddings: if self.return_embeddings:
assert self.post_process and self.add_binary_head assert self.post_process and self.add_binary_head
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head, add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process) post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings()
if self.post_process: if self.post_process:
self.lm_head = BertLMHead( self.lm_head = BertLMHead(self.shared_embedding_or_output_weight().size(0), config.hidden_size,
self.word_embeddings_weight().size(0), config, parallel_output)
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
self._lm_head_key = 'lm_head' self._lm_head_key = 'lm_head'
self.binary_head = None self.binary_head = None
if self.add_binary_head: if self.add_binary_head:
self.binary_head = get_linear_layer(args.hidden_size, 2, self.binary_head = get_linear_layer(config.hidden_size, 2,
init_method) config.init_method)
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
...@@ -215,7 +206,7 @@ class BertModel(MegatronModule): ...@@ -215,7 +206,7 @@ class BertModel(MegatronModule):
return post_language_model_processing(lm_output, pooled_output, return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head, self.lm_head, self.binary_head,
lm_labels, lm_labels,
self.word_embeddings_weight(), self.shared_embedding_or_output_weight(),
self.fp16_lm_cross_entropy) self.fp16_lm_cross_entropy)
else: else:
return lm_output return lm_output
......
...@@ -17,25 +17,23 @@ from .module import MegatronModule ...@@ -17,25 +17,23 @@ from .module import MegatronModule
class Classification(MegatronModule): class Classification(MegatronModule):
def __init__(self, def __init__(self,
config,
num_classes, num_classes,
num_tokentypes=2, num_tokentypes=2,
pre_process=True, pre_process=True,
post_process=True): post_process=True):
super(Classification, self).__init__(share_word_embeddings=False) super().__init__(config=config, share_embeddings_and_output_weights=False)
args = get_args() args = get_args()
self.num_classes = num_classes self.num_classes = num_classes
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process) post_process=self.post_process)
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
...@@ -73,7 +73,7 @@ class DistributedDataParallelBase(MegatronModule, ABC): ...@@ -73,7 +73,7 @@ class DistributedDataParallelBase(MegatronModule, ABC):
class DistributedDataParallel(DistributedDataParallelBase): class DistributedDataParallel(DistributedDataParallelBase):
"""DDP with contiguous buffers options to storre and accumulate gradients. """DDP with contiguous buffers options to store and accumulate gradients.
This class: This class:
- has the potential to reduce memory fragmentation. - has the potential to reduce memory fragmentation.
- provides the option to do the gradient accumulation - provides the option to do the gradient accumulation
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import enum import enum
class LayerType(enum.Enum): class LayerType(enum.Enum):
encoder = 1 encoder = 1
decoder = 2 decoder = 2
retro_encoder = 3
retro_decoder = 4
retro_decoder_with_retriever = 5
class AttnType(enum.Enum): class AttnType(enum.Enum):
self_attn = 1 self_attn = 1
......
...@@ -155,12 +155,12 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -155,12 +155,12 @@ class FusedScaleMaskSoftmax(nn.Module):
if ( if (
self.scaled_masked_softmax_fusion # user want to fuse self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16 and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048 and 16 < sk <= 16384 # sk must be 16 ~ 16384
and sq % 4 == 0 # sq must be divisor of 4 and sq % 4 == 0 # sq must be divisor of 4
and sk % 4 == 0 # sk must be divisor of 4 and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
): ):
if 0 <= sk <= 4096: if 0 <= sk <= 16384:
batch_per_block = self.get_batch_per_block(sq, sk, b, np) batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type == AttnMaskType.causal: if self.attn_mask_type == AttnMaskType.causal:
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""GPT-2 model.""" """GPT-2 model."""
...@@ -11,8 +11,6 @@ from .module import MegatronModule ...@@ -11,8 +11,6 @@ from .module import MegatronModule
from .enums import AttnMaskType from .enums import AttnMaskType
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
from .language_model import get_language_model from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def post_language_model_processing(lm_output, labels, logit_weights, def post_language_model_processing(lm_output, labels, logit_weights,
...@@ -46,12 +44,13 @@ class GPTModel(MegatronModule): ...@@ -46,12 +44,13 @@ class GPTModel(MegatronModule):
"""GPT-2 Language model.""" """GPT-2 Language model."""
def __init__(self, def __init__(self,
config,
num_tokentypes=0, num_tokentypes=0,
parallel_output=True, parallel_output=True,
pre_process=True, pre_process=True,
post_process=True): post_process=True):
args = get_args() args = get_args()
super(GPTModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights) super().__init__(config=config, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights)
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.pre_process = pre_process self.pre_process = pre_process
...@@ -60,39 +59,39 @@ class GPTModel(MegatronModule): ...@@ -60,39 +59,39 @@ class GPTModel(MegatronModule):
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal, encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process) post_process=self.post_process)
if not args.untie_embeddings_and_output_weights: if not args.untie_embeddings_and_output_weights:
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings()
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()""" """See megatron.model.transformer.set_input_tensor()"""
self.language_model.set_input_tensor(input_tensor) self.language_model.set_input_tensor(input_tensor)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None, retriever_input_ids=None,
retriever_position_ids=None,
retriever_attn_mask=None,
labels=None, tokentype_ids=None, inference_params=None): labels=None, tokentype_ids=None, inference_params=None):
lm_output = self.language_model( lm_output = self.language_model(
input_ids, input_ids,
position_ids, position_ids,
attention_mask, attention_mask,
ret_input_ids=ret_input_ids, retriever_input_ids=retriever_input_ids,
ret_position_ids=ret_position_ids, retriever_position_ids=retriever_position_ids,
ret_attn_mask=ret_attn_mask, retriever_attn_mask=retriever_attn_mask,
inference_params=inference_params) inference_params=inference_params)
if self.post_process: if self.post_process:
return post_language_model_processing( return post_language_model_processing(
lm_output, labels, lm_output, labels,
self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.word_embeddings_weight(), self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.shared_embedding_or_output_weight(),
self.parallel_output, self.parallel_output,
self.fp16_lm_cross_entropy) self.fp16_lm_cross_entropy)
else: else:
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Transformer based language model.""" """Transformer based language model."""
...@@ -7,11 +7,11 @@ import torch.nn.functional as F ...@@ -7,11 +7,11 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
from megatron.core.models.common.rotary_pos_embedding import RotaryEmbedding
from .enums import LayerType, AttnMaskType from .enums import AttnMaskType, LayerType
from .module import MegatronModule from .module import MegatronModule
from .retro_transformer import ParallelRetroEncoder, ParallelRetroTransformer
from .rotary_pos_embedding import apply_rotary_pos_emb, RotaryEmbedding
from .transformer import ParallelTransformer from .transformer import ParallelTransformer
from .utils import get_linear_layer from .utils import get_linear_layer
from .utils import init_method_normal, scaled_init_method_normal from .utils import init_method_normal, scaled_init_method_normal
...@@ -39,7 +39,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -39,7 +39,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=bias, bias=bias,
gradient_accumulation_fusion=args.gradient_accumulation_fusion, gradient_accumulation_fusion=args.gradient_accumulation_fusion,
async_grad_allreduce=async_grad_allreduce, async_grad_allreduce=async_grad_allreduce,
sequence_parallel_enabled=args.sequence_parallel) sequence_parallel=args.sequence_parallel)
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
...@@ -48,26 +48,24 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -48,26 +48,24 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel) return tensor_parallel.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(num_tokentypes, add_pooler, def get_language_model(config, num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None, encoder_attn_mask_type,
scaled_init_method=None, add_encoder=True, add_encoder=True,
add_decoder=False, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal, decoder_attn_mask_type=AttnMaskType.causal,
pre_process=True, post_process=True): pre_process=True, post_process=True):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args() args = get_args()
if config.init_method is None:
config.init_method = init_method_normal(config.init_method_std)
if init_method is None: if config.output_layer_init_method is None:
init_method = init_method_normal(args.init_method_std) config.output_layer_init_method = scaled_init_method_normal(config.init_method_std,
config.num_layers)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
# Language model. # Language model.
language_model = TransformerLanguageModel( language_model = TransformerLanguageModel(
init_method, config,
scaled_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_encoder=add_encoder, add_encoder=add_encoder,
...@@ -131,6 +129,10 @@ class Embedding(MegatronModule): ...@@ -131,6 +129,10 @@ class Embedding(MegatronModule):
init_method: weight initialization method init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding will ignore this embedding
embedding_weights_in_fp32: casts word embedding weights to
fp32 before sampling. Required to
maintain reproducibility when
training in bf16.
""" """
def __init__(self, def __init__(self,
...@@ -138,28 +140,26 @@ class Embedding(MegatronModule): ...@@ -138,28 +140,26 @@ class Embedding(MegatronModule):
vocab_size, vocab_size,
max_sequence_length, max_sequence_length,
embedding_dropout_prob, embedding_dropout_prob,
init_method, config,
num_tokentypes=0): num_tokentypes=0,
embedding_weights_in_fp32=False):
super(Embedding, self).__init__() super(Embedding, self).__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.init_method = init_method self.init_method = config.init_method
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
args = get_args() args = get_args()
# Word embeddings (parallel). # Word embeddings (parallel).
self.embedding_weights_in_fp32 = embedding_weights_in_fp32
self.params_dtype = args.params_dtype
self.word_embeddings = tensor_parallel.VocabParallelEmbedding( self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
vocab_size, self.hidden_size, vocab_size, self.hidden_size, config=config, init_method=config.init_method)
init_method=self.init_method,
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization
)
self._word_embeddings_key = 'word_embeddings' self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial). # Position embedding (serial).
self.add_position_embedding = args.add_position_embedding self.add_position_embedding = args.position_embedding_type == 'learned_absolute'
if self.add_position_embedding: if self.add_position_embedding:
self.position_embeddings = torch.nn.Embedding( self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size) max_sequence_length, self.hidden_size)
...@@ -182,7 +182,7 @@ class Embedding(MegatronModule): ...@@ -182,7 +182,7 @@ class Embedding(MegatronModule):
else: else:
self.tokentype_embeddings = None self.tokentype_embeddings = None
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
self.sequence_parallel = args.sequence_parallel self.sequence_parallel = args.sequence_parallel
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
...@@ -217,7 +217,12 @@ class Embedding(MegatronModule): ...@@ -217,7 +217,12 @@ class Embedding(MegatronModule):
def forward(self, input_ids, position_ids, tokentype_ids=None): def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings. # Embeddings.
if self.embedding_weights_in_fp32:
self.word_embeddings = self.word_embeddings.to(torch.float32)
words_embeddings = self.word_embeddings(input_ids) words_embeddings = self.word_embeddings(input_ids)
if self.embedding_weights_in_fp32:
words_embeddings = words_embeddings.to(self.params_dtype)
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
if self.add_position_embedding: if self.add_position_embedding:
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings embeddings = words_embeddings + position_embeddings
...@@ -326,8 +331,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -326,8 +331,7 @@ class TransformerLanguageModel(MegatronModule):
""" """
def __init__(self, def __init__(self,
init_method, config,
output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
num_tokentypes=0, num_tokentypes=0,
add_encoder=True, add_encoder=True,
...@@ -337,21 +341,22 @@ class TransformerLanguageModel(MegatronModule): ...@@ -337,21 +341,22 @@ class TransformerLanguageModel(MegatronModule):
pre_process=True, pre_process=True,
post_process=True): post_process=True):
args = get_args() args = get_args()
# TODO: passing share_word_embeddings=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5. # TODO: passing share_embeddings_and_output_weights=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
if args.untie_embeddings_and_output_weights: assert not add_decoder if args.untie_embeddings_and_output_weights: assert not add_decoder
super(TransformerLanguageModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights) super(TransformerLanguageModel, self).__init__(share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights)
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.hidden_size = args.hidden_size self.hidden_size = config.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = init_method self.init_method = config.init_method
self.add_encoder = add_encoder self.add_encoder = add_encoder
self.encoder_attn_mask_type = encoder_attn_mask_type self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler self.add_pooler = add_pooler
self.encoder_hidden_state = None self.encoder_hidden_state = None
self.add_retriever = args.retro_add_retriever
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
# Embeddings. # Embeddings.
...@@ -360,14 +365,15 @@ class TransformerLanguageModel(MegatronModule): ...@@ -360,14 +365,15 @@ class TransformerLanguageModel(MegatronModule):
args.padded_vocab_size, args.padded_vocab_size,
args.max_position_embeddings, args.max_position_embeddings,
args.hidden_dropout, args.hidden_dropout,
self.init_method, config,
self.num_tokentypes) self.num_tokentypes,
args.embedding_weights_in_fp32)
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Rotary positional embeddings # Rotary positional embeddings
self.use_rotary_position_embeddings = \ self.use_rotary_position_embeddings = \
args.use_rotary_position_embeddings args.position_embedding_type == 'rope'
if args.use_rotary_position_embeddings: if self.use_rotary_position_embeddings:
self.seq_length = args.seq_length self.seq_length = args.seq_length
rotary_dim = args.hidden_size // args.num_attention_heads \ rotary_dim = args.hidden_size // args.num_attention_heads \
if args.kv_channels is None else args.kv_channels if args.kv_channels is None else args.kv_channels
...@@ -378,41 +384,22 @@ class TransformerLanguageModel(MegatronModule): ...@@ -378,41 +384,22 @@ class TransformerLanguageModel(MegatronModule):
# partial rotary embeddings, which is better than full rotary # partial rotary embeddings, which is better than full rotary
# Wang and Komatsuzaki et al # Wang and Komatsuzaki et al
# https://github.com/kingoflolz/mesh-transformer-jax/ # https://github.com/kingoflolz/mesh-transformer-jax/
self.rotary_pos_emb = RotaryEmbedding(rotary_dim) self.rotary_pos_emb = RotaryEmbedding(
rotary_dim,
# Retriever (bi-directional transformer with cross attention) seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor
if args.retro_add_retriever:
self.retriever = ParallelRetroEncoder(
self.init_method,
output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding,
pre_process=self.pre_process,
post_process=False,
) )
self._retriever_key = 'retriever'
else:
self.retriever = None
# Encoder (usually set to True, False if part of an encoder-decoder # Encoder (usually set to True, False if part of an encoder-decoder
# architecture and in encoder-only stage). # architecture and in encoder-only stage).
if self.add_encoder: if self.add_encoder:
if args.retro_add_retriever: self.encoder = ParallelTransformer(
self.encoder = ParallelRetroTransformer( config,
self.init_method, model_type=args.model_type if not args.retro_add_retriever \
output_layer_init_method, else ModelType.retro_decoder,
self_attn_mask_type=self.encoder_attn_mask_type, self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process, post_process=self.post_process,
retriever=self.retriever, )
)
else:
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type,
pre_process=self.pre_process,
post_process=self.post_process,
)
self._encoder_key = 'encoder' self._encoder_key = 'encoder'
else: else:
self.encoder = None self.encoder = None
...@@ -421,8 +408,8 @@ class TransformerLanguageModel(MegatronModule): ...@@ -421,8 +408,8 @@ class TransformerLanguageModel(MegatronModule):
# architecture and in decoder-only stage). # architecture and in decoder-only stage).
if self.add_decoder: if self.add_decoder:
self.decoder = ParallelTransformer( self.decoder = ParallelTransformer(
self.init_method, config,
output_layer_init_method, model_type=args.model_type,
layer_type=LayerType.decoder, layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type, self_attn_mask_type=self.decoder_attn_mask_type,
pre_process=self.pre_process, pre_process=self.pre_process,
...@@ -441,8 +428,9 @@ class TransformerLanguageModel(MegatronModule): ...@@ -441,8 +428,9 @@ class TransformerLanguageModel(MegatronModule):
self.output_layer = tensor_parallel.ColumnParallelLinear( self.output_layer = tensor_parallel.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
args.padded_vocab_size, args.padded_vocab_size,
bias=False, # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias. config=config,
init_method=self.init_method) init_method=self.init_method,
bias=False) # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
self._output_layer_key = 'output_layer' self._output_layer_key = 'output_layer'
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
...@@ -475,19 +463,14 @@ class TransformerLanguageModel(MegatronModule): ...@@ -475,19 +463,14 @@ class TransformerLanguageModel(MegatronModule):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
ret_input_ids=None, ret_position_ids=None, ret_attn_mask=None, retriever_input_ids=None,
retriever_position_ids=None,
retriever_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, enc_dec_attn_mask=None, tokentype_ids=None,
inference_params=None, inference_params=None,
pooling_sequence_index=0, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
# Retriever embedding.
if self.retriever and self.pre_process:
retriever_input = self.embedding(ret_input_ids, ret_position_ids,
tokentype_ids=tokentype_ids)
else:
retriever_input = None
# Encoder embedding. # Encoder embedding.
if self.pre_process: if self.pre_process:
encoder_input = self.embedding(enc_input_ids, enc_position_ids, encoder_input = self.embedding(enc_input_ids, enc_position_ids,
...@@ -495,31 +478,33 @@ class TransformerLanguageModel(MegatronModule): ...@@ -495,31 +478,33 @@ class TransformerLanguageModel(MegatronModule):
else: else:
encoder_input = None encoder_input = None
# Retriever embedding.
if self.add_retriever and self.pre_process:
retriever_input = self.embedding(retriever_input_ids,
retriever_position_ids,
tokentype_ids=tokentype_ids)
else:
retriever_input = None
# Rotary positional embeddings # Rotary positional embeddings
rotary_pos_emb = None rotary_pos_emb = None
if self.use_rotary_position_embeddings: if self.use_rotary_position_embeddings:
if inference_params is not None: if inference_params is not None:
rotary_pos_emb = \ rotary_pos_emb = \
self.rotary_pos_emb(inference_params.max_sequence_len) self.rotary_pos_emb(inference_params.max_sequence_length)
else: else:
rotary_pos_emb = self.rotary_pos_emb(self.seq_length) rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
# Run encoder. # Run encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
if self.encoder is not None: if self.encoder is not None:
if self.retriever: encoder_output = self.encoder(
encoder_output = self.encoder( encoder_input,
encoder_input, enc_attn_mask,
enc_attn_mask, retriever_input=retriever_input,
retriever_output=retriever_input, retriever_attn_mask=retriever_attn_mask,
retriever_attn_mask=ret_attn_mask, inference_params=inference_params,
inference_params=inference_params) rotary_pos_emb=rotary_pos_emb)
else:
encoder_output = self.encoder(
encoder_input,
enc_attn_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb)
else: else:
encoder_output = self.encoder_hidden_state encoder_output = self.encoder_hidden_state
else: else:
......
...@@ -25,9 +25,10 @@ class MegatronModule(torch.nn.Module): ...@@ -25,9 +25,10 @@ class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module with support """Megatron specific extensions of torch Module with support
for pipelining.""" for pipelining."""
def __init__(self, share_word_embeddings=True): def __init__(self, config=None, share_embeddings_and_output_weights=True):
super(MegatronModule, self).__init__() super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings self.config = config
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False):
...@@ -36,21 +37,21 @@ class MegatronModule(torch.nn.Module): ...@@ -36,21 +37,21 @@ class MegatronModule(torch.nn.Module):
return self.state_dict(prefix=prefix, keep_vars=keep_vars) return self.state_dict(prefix=prefix, keep_vars=keep_vars)
def word_embeddings_weight(self): def shared_embedding_or_output_weight(self):
if self.pre_process: if self.pre_process:
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
else: else:
if not self.share_word_embeddings: if not self.share_embeddings_and_output_weights:
raise Exception('word_embeddings_weight() called for last ' raise Exception('shared_embedding_or_output_weight() called for last '
'stage, but share_word_embeddings is false') 'stage, but share_embeddings_and_output_weights is false')
return self.word_embeddings.weight return self.word_embeddings.weight
def initialize_word_embeddings(self, init_method_normal): def initialize_word_embeddings(self):
args = get_args() args = get_args()
if not self.share_word_embeddings: if not self.share_embeddings_and_output_weights:
raise Exception('initialize_word_embeddings() was called but ' raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false') 'share_embeddings_and_output_weights is false')
# This function just initializes the word embeddings in the final stage # This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. Nothing to do if we aren't # when we are using pipeline parallelism. Nothing to do if we aren't
...@@ -76,11 +77,8 @@ class MegatronModule(torch.nn.Module): ...@@ -76,11 +77,8 @@ class MegatronModule(torch.nn.Module):
# set word_embeddings weights to 0 here, then copy first # set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below. # stage's weights using all_reduce below.
self.word_embeddings = tensor_parallel.VocabParallelEmbedding( self.word_embeddings = tensor_parallel.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size, args.padded_vocab_size, self.config.hidden_size,
init_method=init_method_normal(args.init_method_std), config=self.config, init_method=self.config.init_method)
params_dtype=args.params_dtype,
use_cpu_initialization=args.use_cpu_initialization,
perform_initialization=args.perform_initialization)
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True self.word_embeddings.weight.shared = True
...@@ -103,7 +101,7 @@ class MegatronModule(torch.nn.Module): ...@@ -103,7 +101,7 @@ class MegatronModule(torch.nn.Module):
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if mpu.is_rank_in_embedding_group(): if mpu.is_rank_in_embedding_group():
torch.distributed.all_reduce(self.word_embeddings_weight().data, torch.distributed.all_reduce(self.shared_embedding_or_output_weight().data,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
# Ensure that encoder(first stage) and decoder(split stage) position # Ensure that encoder(first stage) and decoder(split stage) position
......
...@@ -17,23 +17,21 @@ from .module import MegatronModule ...@@ -17,23 +17,21 @@ from .module import MegatronModule
class MultipleChoice(MegatronModule): class MultipleChoice(MegatronModule):
def __init__(self, def __init__(self,
config,
num_tokentypes=2, num_tokentypes=2,
pre_process=True, pre_process=True,
post_process=True): post_process=True):
super(MultipleChoice, self).__init__(share_word_embeddings=False) super(MultipleChoice, self).__init__(share_embeddings_and_output_weights=False)
args = get_args() args = get_args()
init_method = init_method_normal(args.init_method_std)
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process) post_process=self.post_process)
......
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Retro Transformer.
** Special note about this file **
Many classes and methods in this file directly parallel those in transformer.py
in name and utility. However, due to 1) subtle changes in the code over time
(i.e., transposes and contexts), and 2) other code that is soon to be merged,
this file will *temporarily* remain as is, until a larger integration is
complete.
"""
import math
import numpy as np
import torch
import torch.nn.functional as F
from megatron import get_args, get_retro_args, get_tensorboard_writer
from megatron.core import parallel_state
from megatron.core import tensor_parallel
from megatron.core import utils as core_utils
from megatron.core.enums import ModelType
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, init_method_normal
from .module import MegatronModule
from .transformer import _get_num_layers, ParallelMLP, NoopTransformerLayer
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
"""
class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
*Note: differs from transformer.py/DropPath in hidden_state transpose.
"""
def __init__(self, drop_prob=0.):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, hidden_state):
if self.drop_prob == 0. or not self.training:
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize
output = hidden_state.div(keep_prob) * random_tensor
return output
class SwitchMLP(MegatronModule):
"""
Routes input to one of N MLP "experts"
"""
def __init__(self, init_method, output_layer_init_method):
super(SwitchMLP, self).__init__()
args = get_args()
self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
self.experts = torch.nn.ModuleList()
for i in range(args.num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states):
# hidden_states: [b, s, h]
b = hidden_states.size(0)
s = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) # [b s 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Each vector could be routed differently
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1]
max_ind = max_ind.view(-1) # [b*s]
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:]
output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output)
output_total[local_indices,:] = output
output_bias_total[local_indices,:] = output_bias
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(b, s, h)
output_bias_total = output_bias_total.view(b, s, h)
return output_total, output_bias_total
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, init_method,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = parallel_state.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = core_utils.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = core_utils.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = core_utils.divide(
args.num_attention_heads, world_size)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method)
else:
assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
self.key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
# Output.
self.dense = tensor_parallel.RowParallelLinear(
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty(
inference_max_sequence_len,
batch_size,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
dtype=self.params_dtype,
device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size)
inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory)
else:
inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer,
key_layer,
value_layer) = tensor_parallel \
.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = tensor_parallel \
.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
# ==================================
if inference_params:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = key_layer
inference_value_memory[sequence_start:sequence_end,
batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[
:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...]
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with tensor_parallel.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
output, bias = self.dense(context_layer)
return output, bias
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add
@torch.jit.script
def bias_dropout_add_fused_train(x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
prob: float) -> torch.Tensor:
return bias_dropout_add(x, bias, residual, prob, False)
class ParallelRetroTransformerEncoderLayer(MegatronModule):
"""A single transformer layer for Retro Decoder with an retriever encoder inside and cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0., retriever=None):
args = get_args()
super(ParallelRetroTransformerEncoderLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Retro Encoder
self.retriever = retriever
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \
else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
retriever_output, retriever_attn_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention. # [ns, bs, d]
layernorm_output = self.post_attention_layernorm(layernorm_input)
"""
notations:
l: number of chunks
m: number of token per chunk
bs: batch size
d: hidden size
k: number of neighbors
r: number of tokens per neighbors (neighbors + continuation)
"""
args = get_args()
retro_args = get_retro_args()
chunk_length = retro_args.retro_gpt_chunk_length
retrieved_length = retro_args.retro_gpt_retrieved_length
num_neighbors = args.retro_num_neighbors
ns, bs, d = layernorm_output.shape
l = int(np.ceil(ns / chunk_length))
first_ns = ns % chunk_length
if first_ns > 0:
first_chunk, rest_chunk = \
layernorm_output[:first_ns], layernorm_output[first_ns:]
first_chunk = torch.nn.functional.pad(
first_chunk,
(0, 0, 0, 0, 0, chunk_length - first_ns),
'constant',
0)
chunked_output = \
torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d]
else:
chunked_output = layernorm_output # [l * m, bs, d]
chunked_output = chunked_output \
.reshape(l, chunk_length, bs, d) \
.permute(1, 2, 0, 3) \
.reshape(chunk_length, bs * l, d) \
.contiguous()
# Get Encoder Output
retriever_output = self.retriever(
retriever_output,
retriever_attn_mask,
retriever_output=chunked_output,
retriever_attn_mask=retriever_attn_mask,
inference_params=inference_params) # [r, k * bs * l , d]
retriever_output = retriever_output.reshape(
retrieved_length * num_neighbors, bs * l, d) # [r * k, bs * l, d]
# Chunked Cross attention with Retriever Encoder
pad = (ns - 1) % chunk_length
attending_chunks = layernorm_output[pad:] # [ns - m + 1, bs, d]
padded_chunks = torch.nn.functional.pad(
attending_chunks,
(0, 0, 0, 0, 0, chunk_length-1),
'constant', 0) # [ns, bs, d]
padded_chunked_output = padded_chunks \
.reshape(l, chunk_length, bs, d) \
.permute(1, 2, 0, 3)
padded_chunked_output = padded_chunked_output.reshape(
chunk_length, bs * l, d).contiguous() # [m, bs * l, d]
# attention_output: [m, bs * l, d]
# attention_bias: [d]
attention_output, attention_bias = \
self.inter_attention(
padded_chunked_output, # Q: main model embedding
None,
encoder_output=retriever_output) # KV: retriever output embedding
# Residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(attention_output),
torch.zeros_like(attention_output),
self.hidden_dropout)
layernorm_input = layernorm_input \
.reshape(chunk_length, bs, l, d) \
.permute(2, 0, 1, 3) # [l, m, bs, d]
layernorm_input = layernorm_input.reshape(chunk_length * l, bs, d)
layernorm_input = torch.nn.functional.pad(
layernorm_input,
(0, 0, 0, 0, pad, 0),
'constant', 0)[:ns] # [ns, b, d]
layernorm_input = layernorm_input + residual
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output, retriever_output
class ParallelRetroTransformerLayer(MegatronModule):
"""A single transformer layer for Retro Decoder with cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args()
super(ParallelRetroTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \
else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
retriever_output, retriever_attn_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
args = get_args()
retro_args = get_retro_args()
chunk_length = retro_args.retro_gpt_chunk_length
ns, bs, d = layernorm_output.shape
l = int(np.ceil(ns / chunk_length))
pad = (ns - 1) % chunk_length
attending_chunks = layernorm_output[pad:]
padded_chunks = torch.nn.functional.pad(
attending_chunks,
(0, 0, 0, 0, 0, chunk_length - 1),
'constant', 0)
padded_chunked_output = padded_chunks \
.reshape(l, chunk_length, bs, d) \
.permute(1, 2, 0, 3)
padded_chunked_output = padded_chunked_output.reshape(
chunk_length, bs * l, d).contiguous()
# Encoder output.
attention_output, attention_bias = \
self.inter_attention(padded_chunked_output,
None,
encoder_output=retriever_output)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(attention_output),
torch.zeros_like(attention_output),
self.hidden_dropout)
layernorm_input = layernorm_input \
.reshape(chunk_length, bs, l, d) \
.permute(2, 0, 1, 3) # [l, m, bs, d]
layernorm_input = layernorm_input.reshape(chunk_length * l, bs, d)
layernorm_input = torch.nn.functional.pad(
layernorm_input,
(0, 0, 0, 0, pad, 0),
'constant', 0)[:ns]
layernorm_input = layernorm_input + residual
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output
class ParallelRetroEncoderTransformerCALayer(MegatronModule):
"""A single transformer layer for Retro Encoder with cross attention.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args()
super(ParallelRetroEncoderTransformerCALayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.self_attention.attention_dropout = \
torch.nn.Dropout(args.retro_encoder_attention_dropout)
self.hidden_dropout = args.retro_encoder_hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \
else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
retriever_output, retriever_attn_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# Neighbors.
args = get_args()
retro_args = get_retro_args()
retrieved_length = retro_args.retro_gpt_retrieved_length
num_neighbors = args.retro_num_neighbors
ns, bs, d = layernorm_output.shape # [r, bs * l * k, d]
chunked_outputs = layernorm_output.reshape(retrieved_length, -1,
num_neighbors, d)
chunked_outputs_before_layer_norm = \
layernorm_input.reshape(retrieved_length, -1,
num_neighbors, d) # [r, bs * l, k, d]
layernorm_inputs = []
layernorm_outputs = []
for k in range(num_neighbors):
chunked_output = chunked_outputs[:,:,k].contiguous()
attention_output, attention_bias = \
self.inter_attention(
chunked_output, # Q (neighbor embedding)
None,
encoder_output=retriever_output) # K, V (hidden act)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = chunked_output
else:
residual = chunked_outputs_before_layer_norm[:,:,k]
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
layernorm_inputs.append(layernorm_input)
# Layer norm post the decoder attention
layernorm_output = \
self.post_inter_attention_layernorm(layernorm_input)
layernorm_outputs.append(layernorm_output)
# layernorm_input : [r, k * bs * l, d]
# layernorm_output : [r, k * bs * l, d]
layernorm_input = \
torch.stack(layernorm_inputs, dim=1).reshape(ns, bs, d)
layernorm_output = \
torch.stack(layernorm_outputs, dim=1).reshape(ns, bs, d)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# Self attention.
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \
else None
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.self_attention(
layernorm_output,
attention_mask,
inference_params=inference_params)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(attention_output + attention_bias,
p=self.hidden_dropout,
training=self.training)
layernorm_input = residual + self.drop_path(out)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if self.drop_path is None:
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout,
training=self.training)
output = residual + self.drop_path(out)
return output
class ParallelRetroEncoder(MegatronModule):
""" Retro Transformer class for encoder ."""
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True,
drop_path_rate=0.0):
super(ParallelRetroEncoder, self).__init__()
args = get_args()
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.distribute_saved_activations = \
args.distribute_saved_activations and not args.sequence_parallel
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = args.retro_encoder_layers
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
if args.retro_add_retriever:
self.P = [1]
# Transformer layers.
assert args.retro_add_retriever
def build_layer(layer_number):
if layer_number in self.P:
return ParallelRetroEncoderTransformerCALayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
else:
layer = ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
layer.self_attention.attention_dropout = \
torch.nn.Dropout(args.retro_encoder_attention_dropout)
layer.hidden_dropout = args.retro_encoder_hidden_dropout
return layer
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in
# the stage, divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an
# assignment of layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an
# assignment of layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \
parallel_state.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
else:
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process:
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and
# checkpoint the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = parallel_state.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = parallel_state.checkpoint(
custom(l, l + 1),
self.distribute_checkpointed_activations,
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask,
retriever_output, retriever_attn_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# Checks.
if inference_params:
assert self.activations_checkpoint_method is None, \
'inference does not work with activation checkpointing'
if self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
# See set_input_tensor()
hidden_states = self.input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = core_utils.make_viewless_tensor(
hidden_states,
requires_grad = True,
keep_graph = True,
)
# Transpose encoder output.
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
args = get_args()
assert not args.sequence_parallel, "if SP, need rng context."
# Forward pass.
if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
if index + 1 in self.P:
hidden_states = layer(
hidden_states,
attention_mask,
retriever_output=retriever_output,
retriever_attn_mask=retriever_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
else:
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# Final layer norm.
if self.post_process:
# Reverting data format change [s b h] --> [b s h].
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
return output
class ParallelRetroTransformer(MegatronModule):
"""Standard GPT Transformer class."""
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True, post_process=True,
drop_path_rate=0.0, retriever=None):
super(ParallelRetroTransformer, self).__init__()
args = get_args()
assert pre_process and post_process, \
"pipeline parallelism un-supported."
self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection
self.pre_process = pre_process
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity
self.recompute_method = args.recompute_method
self.recompute_num_layers = args.recompute_num_layers
self.distribute_saved_activations = \
args.distribute_saved_activations and not args.sequence_parallel
self.sequence_parallel = args.sequence_parallel
# Number of layers.
self.num_layers = _get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder)
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
if args.retro_add_retriever:
if args.num_layers == 12:
self.P = [6, 9, 12]
elif args.num_layers == 24:
self.P = np.arange(9, 25, 3).tolist()
elif args.num_layers == 40:
self.P = np.arange(9, 41, 3).tolist()
self.P.append(40)
self.retriever = retriever
# Transformer layers.
assert args.retro_add_retriever
def build_layer(layer_number):
if layer_number == min(self.P):
return ParallelRetroTransformerEncoderLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1],
retriever=retriever
)
elif layer_number in self.P:
return ParallelRetroTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
else:
return ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size'
assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
(parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
else:
# Each stage gets a contiguous set of layers.
if args.model_type == ModelType.encoder_and_decoder and \
parallel_state.get_pipeline_model_parallel_world_size() > 1:
pipeline_rank = parallel_state.get_pipeline_model_parallel_rank()
if layer_type == LayerType.encoder:
offset = pipeline_rank * self.num_layers
else:
num_ranks_in_enc = args.pipeline_model_parallel_split_rank
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers
else:
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
if self.num_layers == 0:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self.num_layers = 1
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ])
else:
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process:
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm)
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = parallel_state.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers
elif self.activations_checkpoint_method == 'block':
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers:
hidden_states = parallel_state.checkpoint(
custom(l, l + 1),
self.distribute_checkpointed_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation checkpoint method.")
return hidden_states
def set_input_tensor(self, input_tensor):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
def forward(self, hidden_states, attention_mask,
retriever_output=None, retriever_attn_mask=None,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# Checks.
if inference_params:
assert self.recompute_granularity is None, \
'inference does not work with activation checkpointing'
args = get_args()
# Transpose retriever output, to match hidden_states shape.
retriever_output = retriever_output.transpose(0, 1).contiguous()
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states = core_utils.make_viewless_tensor(
hidden_states,
requires_grad=True,
keep_graph=True,
)
# Transpose encoder output.
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
# Forward pass.
assert not args.sequence_parallel, "if SP, need rng context."
if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
if args.retro_add_retriever and index + 1 == min(self.P):
hidden_states, E = layer(
hidden_states,
attention_mask,
retriever_output=retriever_output,
retriever_attn_mask=retriever_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
elif args.retro_add_retriever and index + 1 in self.P:
hidden_states = layer(
hidden_states,
attention_mask,
retriever_output=E,
retriever_attn_mask=retriever_attn_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
else:
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# Final layer norm.
output = self.final_layernorm(hidden_states)
return output
...@@ -11,9 +11,7 @@ from megatron.model.language_model import parallel_lm_logits, get_language_model ...@@ -11,9 +11,7 @@ from megatron.model.language_model import parallel_lm_logits, get_language_model
from megatron.model import LayerNorm from megatron.model import LayerNorm
from megatron.model.utils import ( from megatron.model.utils import (
openai_gelu, openai_gelu,
get_linear_layer, get_linear_layer
init_method_normal,
scaled_init_method_normal
) )
from .module import MegatronModule from .module import MegatronModule
...@@ -43,17 +41,12 @@ class T5LMHead(MegatronModule): ...@@ -43,17 +41,12 @@ class T5LMHead(MegatronModule):
Arguments: Arguments:
mpu_vocab_size: model parallel size of vocabulary. mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not. parallel_output: wether output logits being distributed or not.
""" """
def __init__(self, mpu_vocab_size, parallel_output): def __init__(self, mpu_vocab_size, parallel_output):
super(T5LMHead, self).__init__() super(T5LMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True self.bias.model_parallel = True
self.bias.partition_dim = 0 self.bias.partition_dim = 0
...@@ -72,41 +65,38 @@ class T5Model(MegatronModule): ...@@ -72,41 +65,38 @@ class T5Model(MegatronModule):
"""T5 Language model.""" """T5 Language model."""
def __init__(self, def __init__(self,
config,
num_tokentypes=0, num_tokentypes=0,
parallel_output=True, parallel_output=True,
pre_process=True, pre_process=True,
post_process=True, post_process=True,
add_encoder=True, add_encoder=True,
add_decoder=True): add_decoder=True):
super(T5Model, self).__init__() super().__init__(config=config)
args = get_args() args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.parallel_output = parallel_output self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.add_encoder = add_encoder self.add_encoder = add_encoder
self.add_decoder = add_decoder self.add_decoder = add_decoder
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
config=config,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
add_encoder=add_encoder, add_encoder=add_encoder,
add_decoder=add_decoder, add_decoder=add_decoder,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process) post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings()
if self.post_process and self.add_decoder: if self.post_process and self.add_decoder:
self.lm_head = T5LMHead( self.lm_head = T5LMHead(
self.word_embeddings_weight().size(0), self.shared_embedding_or_output_weight().size(0),
parallel_output) parallel_output)
self._lm_head_key = 'lm_head' self._lm_head_key = 'lm_head'
...@@ -139,7 +129,7 @@ class T5Model(MegatronModule): ...@@ -139,7 +129,7 @@ class T5Model(MegatronModule):
decoder_output, encoder_output = lm_output decoder_output, encoder_output = lm_output
# Output. [s, b, h] # Output. [s, b, h]
lm_logits = self.lm_head(decoder_output, lm_logits = self.lm_head(decoder_output,
self.word_embeddings_weight()) self.shared_embedding_or_output_weight())
if lm_labels is None: if lm_labels is None:
# [s b h] => [b s h] # [s b h] => [b s h]
......
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Transformer.""" """Transformer."""
import math
from contextlib import nullcontext from contextlib import nullcontext
import math
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional from typing import Optional
from megatron import get_timers, get_args, core, get_num_microbatches from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches
from .module import MegatronModule from .module import MegatronModule
from megatron.core import mpu, tensor_parallel from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType from megatron.core.enums import ModelType
...@@ -15,7 +16,7 @@ from megatron.model import LayerNorm ...@@ -15,7 +16,7 @@ from megatron.model import LayerNorm
from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb from megatron.core.models.common.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
try: try:
...@@ -26,7 +27,10 @@ except ImportError: ...@@ -26,7 +27,10 @@ except ImportError:
try: try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError: except ImportError:
flash_attn_unpadded_func = None try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func
except ImportError:
flash_attn_unpadded_func = None
""" We use the following notation throughout this file: """ We use the following notation throughout this file:
h: hidden size h: hidden size
...@@ -65,18 +69,6 @@ class DropPath(MegatronModule): ...@@ -65,18 +69,6 @@ class DropPath(MegatronModule):
output = hidden_state.div(keep_prob) * random_tensor output = hidden_state.div(keep_prob) * random_tensor
return output return output
def _args_to_kwargs():
args = get_args()
common_kwargs = {
"params_dtype": args.params_dtype,
"use_cpu_initialization": args.use_cpu_initialization,
"perform_initialization": args.perform_initialization,
"gradient_accumulation_fusion": args.gradient_accumulation_fusion,
"sequence_parallel_enabled": args.sequence_parallel,
}
return common_kwargs
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
"""MLP. """MLP.
...@@ -85,22 +77,26 @@ class ParallelMLP(MegatronModule): ...@@ -85,22 +77,26 @@ class ParallelMLP(MegatronModule):
state back into h hidden dimension. state back into h hidden dimension.
""" """
def __init__(self, init_method, output_layer_init_method): def __init__(self, config):
super(ParallelMLP, self).__init__() super(ParallelMLP, self).__init__()
args = get_args() args = get_args()
self.add_bias = args.add_bias_linear self.add_bias = config.add_bias_linear
ffn_hidden_size = config.ffn_hidden_size
if config.gated_linear_unit:
ffn_hidden_size *= 2
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
args.hidden_size, config.hidden_size,
args.ffn_hidden_size * 2 if args.swiglu else args.ffn_hidden_size, ffn_hidden_size,
config=config,
init_method=config.init_method,
bias=self.add_bias, bias=self.add_bias,
gather_output=False, gather_output=False,
init_method=init_method,
skip_bias_add=True, skip_bias_add=True,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, )
**_args_to_kwargs())
self.bias_gelu_fusion = False self.bias_gelu_fusion = False
self.activation_func = None self.activation_func = None
...@@ -125,13 +121,13 @@ class ParallelMLP(MegatronModule): ...@@ -125,13 +121,13 @@ class ParallelMLP(MegatronModule):
# Project back to h. # Project back to h.
self.dense_4h_to_h = tensor_parallel.RowParallelLinear( self.dense_4h_to_h = tensor_parallel.RowParallelLinear(
args.ffn_hidden_size, config.ffn_hidden_size,
args.hidden_size, config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=self.add_bias, bias=self.add_bias,
input_is_parallel=True, input_is_parallel=True
init_method=output_layer_init_method, )
skip_bias_add=True,
**_args_to_kwargs())
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -155,13 +151,13 @@ class SwitchMLP(MegatronModule): ...@@ -155,13 +151,13 @@ class SwitchMLP(MegatronModule):
""" """
Routes input to one of N MLP "experts" Routes input to one of N MLP "experts"
""" """
def __init__(self, init_method, output_layer_init_method): def __init__(self, config):
super(SwitchMLP, self).__init__() super(SwitchMLP, self).__init__()
args = get_args() args = get_args()
self.router = torch.nn.Linear(args.hidden_size, args.num_experts) self.router = torch.nn.Linear(config.hidden_size, args.num_experts)
self.experts = torch.nn.ModuleList() self.experts = torch.nn.ModuleList()
for i in range(args.num_experts): for i in range(args.num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method)) self.experts.append(ParallelMLP(config))
def forward(self, hidden_states): def forward(self, hidden_states):
# hidden_states: [s, b, h] # hidden_states: [s, b, h]
...@@ -188,45 +184,48 @@ class SwitchMLP(MegatronModule): ...@@ -188,45 +184,48 @@ class SwitchMLP(MegatronModule):
local_indices = (max_ind == expert_num).nonzero() local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:] hidden = hidden_states[local_indices,:]
output, output_bias = expert(hidden) output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output) if output_bias is not None:
output_bias = output_bias.expand_as(output)
output_bias_total[local_indices,:] = output_bias
output_total[local_indices,:] = output output_total[local_indices,:] = output
output_bias_total[local_indices,:] = output_bias
output_total = output_total*max_prob output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(s, b, h) output_total = output_total.view(s, b, h)
output_bias_total = output_bias_total.view(s, b, h) if output_bias is not None:
output_bias_total = output_bias_total*max_prob
output_bias_total = output_bias_total.view(s, b, h)
else:
output_bias_total = None
return output_total, output_bias_total return output_total, output_bias_total
class CoreAttention(MegatronModule): class CoreAttention(MegatronModule):
def __init__(self, layer_number, def __init__(self, layer_number, config,
attn_mask_type=AttnMaskType.padding): attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__() super(CoreAttention, self).__init__()
args = get_args() self.fp16 = config.fp16
self.fp16 = args.fp16 self.bf16 = config.bf16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling: if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.sequence_parallel = args.sequence_parallel self.sequence_parallel = config.sequence_parallel
projection_size = args.kv_channels * args.num_attention_heads projection_size = config.kv_channels * config.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size() world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = core.utils.divide(projection_size, self.hidden_size_per_partition = core.utils.divide(projection_size,
world_size) world_size)
self.hidden_size_per_attention_head = core.utils.divide( self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads) projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide( self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size) config.num_attention_heads, world_size)
coeff = None coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
...@@ -237,7 +236,7 @@ class CoreAttention(MegatronModule): ...@@ -237,7 +236,7 @@ class CoreAttention(MegatronModule):
self.scale_mask_softmax = FusedScaleMaskSoftmax( self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16, self.fp16, self.bf16,
self.attn_mask_type, self.attn_mask_type,
args.masked_softmax_fusion, config.masked_softmax_fusion,
attention_mask_func, attention_mask_func,
self.attention_softmax_in_fp32, self.attention_softmax_in_fp32,
coeff) coeff)
...@@ -245,7 +244,7 @@ class CoreAttention(MegatronModule): ...@@ -245,7 +244,7 @@ class CoreAttention(MegatronModule):
# Dropout. Note that for a single iteration, this layer will generate # Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but # different outputs on different number of parallel partitions but
# on average it should not be partition dependent. # on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout) self.attention_dropout = torch.nn.Dropout(config.attention_dropout)
def forward(self, query_layer, key_layer, def forward(self, query_layer, key_layer,
value_layer, attention_mask): value_layer, attention_mask):
...@@ -261,8 +260,8 @@ class CoreAttention(MegatronModule): ...@@ -261,8 +260,8 @@ class CoreAttention(MegatronModule):
key_layer.size(0)) key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn] # [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], query_layer = query_layer.reshape(output_size[2],
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn] # [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
...@@ -379,17 +378,18 @@ class FlashSelfAttention(torch.nn.Module): ...@@ -379,17 +378,18 @@ class FlashSelfAttention(torch.nn.Module):
is_causal = self.causal is_causal = self.causal
cu_seqlens_k = cu_seqlens_q cu_seqlens_k = cu_seqlens_q
dropout_p = self.dropout_p
else: else:
# turn off FA causal mask after first inference autoregressive iteration # turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen # only on first autoregressive step q,k,v have same seqlen
is_causal = seqlen_q == seqlen_k is_causal = seqlen_q == seqlen_k
cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
device=q.device) device=q.device)
self.dropout_p = 0 dropout_p = 0
output = flash_attn_unpadded_func( output = flash_attn_unpadded_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k,
self.dropout_p, dropout_p,
softmax_scale=self.softmax_scale, causal=is_causal softmax_scale=self.softmax_scale, causal=is_causal
) )
...@@ -404,8 +404,7 @@ class ParallelAttention(MegatronModule): ...@@ -404,8 +404,7 @@ class ParallelAttention(MegatronModule):
and returns output of the same size. and returns output of the same size.
""" """
def __init__(self, init_method, def __init__(self, config, layer_number,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn, attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding): attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__() super(ParallelAttention, self).__init__()
...@@ -413,10 +412,21 @@ class ParallelAttention(MegatronModule): ...@@ -413,10 +412,21 @@ class ParallelAttention(MegatronModule):
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.attention_type = attention_type self.attention_type = attention_type
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
self.params_dtype = args.params_dtype self.params_dtype = config.params_dtype
self.sequence_parallel = args.sequence_parallel self.sequence_parallel = config.sequence_parallel
self.group_query_attention = args.group_query_attention
self.num_query_groups = args.num_query_groups
query_projection_size = config.kv_channels * config.num_attention_heads
if self.group_query_attention:
kv_projection_size = args.kv_channels * args.num_query_groups
else:
kv_projection_size = args.kv_channels * args.num_attention_heads
self.use_flash_attn = args.use_flash_attn self.use_flash_attn = args.use_flash_attn \
and attention_type == AttnType.self_attn \
and self.attn_mask_type == AttnMaskType.causal
if self.use_flash_attn: if self.use_flash_attn:
if flash_attn_unpadded_func is None: if flash_attn_unpadded_func is None:
raise ImportError('FlashAttention is not installed, please install with ' raise ImportError('FlashAttention is not installed, please install with '
...@@ -428,64 +438,72 @@ class ParallelAttention(MegatronModule): ...@@ -428,64 +438,72 @@ class ParallelAttention(MegatronModule):
if rearrange is None: if rearrange is None:
raise ImportError('einops is not installed, please install with pip install einops') raise ImportError('einops is not installed, please install with pip install einops')
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size() world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = core.utils.divide( self.hidden_size_per_attention_head = core.utils.divide(
projection_size, args.num_attention_heads) query_projection_size, config.num_attention_heads)
self.num_attention_heads_per_partition = core.utils.divide( self.num_attention_heads_per_partition = core.utils.divide(
args.num_attention_heads, world_size) config.num_attention_heads, world_size)
if self.group_query_attention:
if args.num_query_groups % world_size != 0:
raise NotImplementedError('Currently the num_query_groups should be '
'a multiple of the tensor parallel size')
self.num_query_groups_per_partition = core.utils.divide(
args.num_query_groups, world_size)
else:
self.num_query_groups_per_partition = self.num_attention_heads_per_partition
# Strided linear layer. # Strided linear layer.
if attention_type == AttnType.self_attn: if attention_type == AttnType.self_attn:
self.query_key_value = tensor_parallel.ColumnParallelLinear( self.query_key_value = tensor_parallel.ColumnParallelLinear(
args.hidden_size, config.hidden_size,
3 * projection_size, query_projection_size + 2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=args.add_bias_linear, bias=args.add_bias_linear,
gather_output=False, gather_output=False)
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
else: else:
assert attention_type == AttnType.cross_attn assert attention_type == AttnType.cross_attn
self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
projection_size,
bias=args.add_bias_linear,
gather_output=False,
init_method=init_method,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce,
**_args_to_kwargs())
if self.group_query_attention:
raise NotImplementedError("Grouped query attention not implemented for cross-attention.")
assert query_projection_size == kv_projection_size
self.key_value = tensor_parallel.ColumnParallelLinear( self.query = tensor_parallel.ColumnParallelLinear(
args.hidden_size, config.hidden_size,
2 * projection_size, query_projection_size,
bias=args.add_bias_linear, config=config,
gather_output=False, init_method=config.init_method,
init_method=init_method, bias=config.add_bias_linear,
async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, gather_output=False)
**_args_to_kwargs())
self.core_attention = CoreAttention(self.layer_number, self.key_value = tensor_parallel.ColumnParallelLinear(
config.hidden_size,
2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
gather_output=False)
self.core_attention = CoreAttention(self.layer_number, config,
self.attn_mask_type) self.attn_mask_type)
self.checkpoint_core_attention = args.recompute_granularity == 'selective' self.checkpoint_core_attention = config.recompute_granularity == 'selective'
if self.use_flash_attn: if self.use_flash_attn:
self.core_attention_flash = FlashSelfAttention( self.core_attention_flash = FlashSelfAttention(
causal=True, attention_dropout=args.attention_dropout causal=True, attention_dropout=config.attention_dropout
) )
# Output. # Output.
self.dense = tensor_parallel.RowParallelLinear( self.dense = tensor_parallel.RowParallelLinear(
projection_size, query_projection_size,
args.hidden_size, config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=args.add_bias_linear, bias=args.add_bias_linear,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, skip_bias_add=True)
skip_bias_add=True,
**_args_to_kwargs())
def _checkpointed_attention_forward(self, query_layer, key_layer, def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask, value_layer, attention_mask,
...@@ -510,11 +528,11 @@ class ParallelAttention(MegatronModule): ...@@ -510,11 +528,11 @@ class ParallelAttention(MegatronModule):
return hidden_states return hidden_states
def _allocate_memory(self, inference_max_sequence_len, batch_size): def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads):
return torch.empty( return torch.empty(
inference_max_sequence_len, inference_max_sequence_len,
batch_size, batch_size,
self.num_attention_heads_per_partition, num_attention_heads,
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
dtype=self.params_dtype, dtype=self.params_dtype,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
...@@ -530,12 +548,15 @@ class ParallelAttention(MegatronModule): ...@@ -530,12 +548,15 @@ class ParallelAttention(MegatronModule):
is_first_step = False is_first_step = False
if inference_params: if inference_params:
if self.layer_number not in inference_params.key_value_memory_dict: if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory( inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size) inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition)
inference_value_memory = self._allocate_memory( inference_value_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size) inf_max_seq_len, inf_max_batch_size,
self.num_query_groups_per_partition)
inference_params.key_value_memory_dict[self.layer_number] = ( inference_params.key_value_memory_dict[self.layer_number] = (
inference_key_memory, inference_value_memory) inference_key_memory, inference_value_memory)
is_first_step = True is_first_step = True
...@@ -546,21 +567,36 @@ class ParallelAttention(MegatronModule): ...@@ -546,21 +567,36 @@ class ParallelAttention(MegatronModule):
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
if self.attention_type == AttnType.self_attn: if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states) mixed_x_layer, _ = self.query_key_value(hidden_states)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \ new_tensor_shape = mixed_x_layer.size()[:-1] + (
(self.num_attention_heads_per_partition, self.num_query_groups_per_partition,
3 * self.hidden_size_per_attention_head) (
(self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
* self.hidden_size_per_attention_head
),
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query_layer, (query_layer,
key_layer, key_layer,
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3) value_layer) = torch.split(
mixed_x_layer,
[
(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head
],
dim=3)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head)
else: else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output) mixed_kv_layer, _ = self.key_value(encoder_output)
...@@ -568,19 +604,19 @@ class ParallelAttention(MegatronModule): ...@@ -568,19 +604,19 @@ class ParallelAttention(MegatronModule):
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \ new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head) 2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, (key_layer,
value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp] # Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states) query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn] # [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \ new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head) self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape) query_layer = query_layer.view(*new_tensor_shape)
# ================================== # ==================================
...@@ -632,11 +668,20 @@ class ParallelAttention(MegatronModule): ...@@ -632,11 +668,20 @@ class ParallelAttention(MegatronModule):
k_pos_emb = k_pos_emb[:sequence_end, :, :, :] k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb) rotary_pos_emb = (q_pos_emb, k_pos_emb)
# ================================== # ==================================
# core attention computation # core attention computation
# ================================== # ==================================
# expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
key_layer = key_layer.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
dim = 2
)
value_layer = value_layer.repeat_interleave(
self.num_attention_heads_per_partition // self.num_query_groups_per_partition,
dim = 2
)
# apply relative positional encoding (rotary embedding) # apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb q_pos_emb, k_pos_emb = rotary_pos_emb
...@@ -711,10 +756,11 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -711,10 +756,11 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size. output of the same size.
""" """
def __init__(self, init_method, output_layer_init_method, def __init__(self, config,
layer_number, layer_type=LayerType.encoder, layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding, self_attn_mask_type=AttnMaskType.padding,
drop_path_rate=0.): drop_path_rate=0.):
# retriever=None):
args = get_args() args = get_args()
super(ParallelTransformerLayer, self).__init__() super(ParallelTransformerLayer, self).__init__()
...@@ -722,57 +768,59 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -722,57 +768,59 @@ class ParallelTransformerLayer(MegatronModule):
self.layer_type = layer_type self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \ self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
self.bf16 = args.bf16 self.bf16 = config.bf16
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = config.fp32_residual_connection
# Layernorm on the input data. # Layernorm on the input data.
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, config.hidden_size,
eps=args.layernorm_epsilon, eps=config.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel, sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p) apply_layernorm_1p=args.apply_layernorm_1p)
# Self attention. # Self attention.
self.self_attention = ParallelAttention( self.self_attention = ParallelAttention(
init_method, config,
output_layer_init_method,
layer_number, layer_number,
attention_type=AttnType.self_attn, attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type) attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout self.hidden_dropout = config.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion self.bias_dropout_fusion = config.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Layernorm on the attention output # Layernorm on the attention output
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
args.hidden_size, config.hidden_size,
eps=args.layernorm_epsilon, eps=config.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=not config.persist_layer_norm,
sequence_parallel=args.sequence_parallel, sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p) apply_layernorm_1p=args.apply_layernorm_1p)
if self.layer_type == LayerType.decoder: # Cross attention.
if self.layer_type in (LayerType.decoder,
LayerType.retro_decoder,
LayerType.retro_decoder_with_retriever,
LayerType.retro_encoder):
self.inter_attention = ParallelAttention( self.inter_attention = ParallelAttention(
init_method, config,
output_layer_init_method,
layer_number, layer_number,
attention_type=AttnType.cross_attn) attention_type=AttnType.cross_attn)
# Layernorm on the attention output. # Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm( self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size, config.hidden_size,
eps=args.layernorm_epsilon, eps=config.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=not config.persist_layer_norm,
sequence_parallel=args.sequence_parallel, sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p) apply_layernorm_1p=args.apply_layernorm_1p)
# MLP # MLP
if args.num_experts is not None: if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method) self.mlp = SwitchMLP(config)
else: else:
self.mlp = ParallelMLP(init_method, output_layer_init_method) self.mlp = ParallelMLP(config)
# Set bias+dropout+add fusion grad_enable execution handler. # Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
...@@ -781,13 +829,245 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -781,13 +829,245 @@ class ParallelTransformerLayer(MegatronModule):
self.bias_dropout_add_exec_handler = \ self.bias_dropout_add_exec_handler = \
nullcontext if use_nvfuser else torch.enable_grad nullcontext if use_nvfuser else torch.enable_grad
if args.retro_add_retriever:
retro_args = get_retro_args()
self.retro_num_neighbors = args.retro_num_neighbors
self.retro_chunk_length = retro_args.retro_gpt_chunk_length
self.retro_retrieved_length = retro_args.retro_gpt_retrieved_length
# Retriever (bi-directional transformer with cross attention)
if layer_type == LayerType.retro_decoder_with_retriever:
self.retriever = ParallelTransformer(
config=config,
model_type=ModelType.retro_encoder,
self_attn_mask_type=AttnMaskType.padding,
pre_process=True,
post_process=False,
)
self._retriever_key = 'retriever'
else:
self.retriever = None
def default_decoder_cross_attention(self,
encoder_output,
enc_dec_attn_mask,
layernorm_input,
layernorm_output,
bias_dropout_add_func):
'''Cross attention for a standard encoder-decoder model.'''
# Attention.
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
if attention_bias is not None:
attention_bias = attention_bias.expand_as(residual)
# Bias-dropout-add.
with self.bias_dropout_add_exec_handler():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias,
residual,
self.hidden_dropout)
# Layer norm.
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
return layernorm_input, layernorm_output
def retro_encoder_cross_attention(self,
retriever_output,
layernorm_input,
layernorm_output,
bias_dropout_add_func):
"""Cross attention for Retro encoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
"""
ns, bs, d = layernorm_output.shape # [r, bs * l * k, d]
# Divide sequence dimension into chunks.
chunked_outputs = layernorm_output.reshape(self.retro_retrieved_length,
-1,
self.retro_num_neighbors,
d)
chunked_outputs_before_layer_norm = \
layernorm_input.reshape(self.retro_retrieved_length, -1,
self.retro_num_neighbors, d) # [r, bs*l, k, d]
# Per-chunk attention.
layernorm_inputs = []
layernorm_outputs = []
for k in range(self.retro_num_neighbors):
# Attention.
chunked_output = chunked_outputs[:,:,k].contiguous()
attention_output, attention_bias = \
self.inter_attention(
chunked_output, # Q (neighbor embedding)
None,
encoder_output=retriever_output) # K, V (hidden act)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = chunked_output
else:
residual = chunked_outputs_before_layer_norm[:,:,k]
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
None if attention_bias is None else attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
layernorm_inputs.append(layernorm_input)
# Layer norm.
layernorm_output = \
self.post_inter_attention_layernorm(layernorm_input)
layernorm_outputs.append(layernorm_output)
# Concatenate layer norms.
# layernorm_input : [r, k * bs * l, d]
# layernorm_output : [r, k * bs * l, d]
layernorm_input = \
torch.stack(layernorm_inputs, dim=1).reshape(ns, bs, d)
layernorm_output = \
torch.stack(layernorm_outputs, dim=1).reshape(ns, bs, d)
return layernorm_input, layernorm_output
def retro_decoder_cross_attention(self,
retriever_input,
retriever_output,
retriever_attn_mask,
layernorm_input,
layernorm_output,
inference_params,
bias_dropout_add_func):
"""Cross attention for Retro decoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
m : Number of tokens per chunk.
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
"""
ns, bs, d = layernorm_output.shape
l = int(np.ceil(ns / self.retro_chunk_length))
# Retrieve neighbors.
if self.layer_type == LayerType.retro_decoder_with_retriever:
first_ns = ns % self.retro_chunk_length
if first_ns > 0:
raise Exception("test this case.")
first_chunk, rest_chunk = \
layernorm_output[:first_ns], layernorm_output[first_ns:]
first_chunk = torch.nn.functional.pad(
first_chunk,
(0, 0, 0, 0, 0, self.retro_chunk_length - first_ns),
'constant',
0)
chunked_output = \
torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d]
else:
chunked_output = layernorm_output # [l * m, bs, d]
chunked_output = chunked_output \
.reshape(l, self.retro_chunk_length, bs, d) \
.permute(1, 2, 0, 3) \
.reshape(self.retro_chunk_length, bs * l, d) \
.contiguous()
# Get Encoder Output
retriever_output = self.retriever(
hidden_states=retriever_input,
attention_mask=retriever_attn_mask,
retriever_output=chunked_output,
retriever_attn_mask=retriever_attn_mask,
inference_params=inference_params) # [r, k * bs * l , d]
retriever_output = retriever_output.reshape(
self.retro_retrieved_length * self.retro_num_neighbors, bs * l, d) # [r * k, bs * l, d]
# Chunks.
pad = (ns - 1) % self.retro_chunk_length
attending_chunks = layernorm_output[pad:]
padded_chunks = torch.nn.functional.pad(
attending_chunks,
(0, 0, 0, 0, 0, self.retro_chunk_length - 1),
'constant', 0)
padded_chunked_output = padded_chunks \
.reshape(l, self.retro_chunk_length, bs, d) \
.permute(1, 2, 0, 3)
padded_chunked_output = padded_chunked_output.reshape(
self.retro_chunk_length, bs * l, d).contiguous()
# Encoder output.
attention_output, attention_bias = \
self.inter_attention(padded_chunked_output,
None,
encoder_output=retriever_output)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# Re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
None if attention_bias is None else attention_bias.expand_as(attention_output),
torch.zeros_like(attention_output),
self.hidden_dropout)
layernorm_input = layernorm_input \
.reshape(self.retro_chunk_length, bs, l, d) \
.permute(2, 0, 1, 3) # [l, m, bs, d]
layernorm_input = layernorm_input.reshape(self.retro_chunk_length * l, bs, d)
layernorm_input = torch.nn.functional.pad(
layernorm_input,
(0, 0, 0, 0, pad, 0),
'constant', 0)[:ns] # [ns, b, d]
layernorm_input = layernorm_input + residual
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
return retriever_output, layernorm_input, layernorm_output
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
inference_params=None, rotary_pos_emb=None): retriever_input=None,
retriever_output=None,
retriever_attn_mask=None,
inference_params=None,
rotary_pos_emb=None):
# hidden_states: [s, b, h] # hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output, attention_bias = \ attention_output, attention_bias = \
self.self_attention( self.self_attention(
...@@ -832,29 +1112,38 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -832,29 +1112,38 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder: # Cross attention.
attention_output, attention_bias = \ if self.layer_type == LayerType.encoder:
self.inter_attention(layernorm_output, pass
enc_dec_attn_mask, elif self.layer_type == LayerType.decoder:
encoder_output=encoder_output) layernorm_input, layernorm_output = \
# residual connection self.default_decoder_cross_attention(
if self.apply_residual_connection_post_layernorm: encoder_output,
residual = layernorm_output enc_dec_attn_mask,
else: layernorm_input,
residual = layernorm_input layernorm_output,
bias_dropout_add_func)
if attention_bias is not None: elif self.layer_type == LayerType.retro_encoder:
attention_bias = attention_bias.expand_as(residual) layernorm_input, layernorm_output = \
self.retro_encoder_cross_attention(
with self.bias_dropout_add_exec_handler(): retriever_output,
layernorm_input = bias_dropout_add_func( layernorm_input,
attention_output, layernorm_output,
attention_bias, bias_dropout_add_func)
residual, elif self.layer_type in (LayerType.retro_decoder,
self.hidden_dropout) LayerType.retro_decoder_with_retriever):
retriever_output, layernorm_input, layernorm_output = \
# Layer norm post the decoder attention self.retro_decoder_cross_attention(
layernorm_output = self.post_inter_attention_layernorm(layernorm_input) retriever_input,
retriever_output,
retriever_attn_mask,
layernorm_input,
layernorm_output,
inference_params,
bias_dropout_add_func)
else:
raise Exception("Unsupported layer type, '%s'." %
self.layer_type.name)
# MLP. # MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output) mlp_output, mlp_bias = self.mlp(layernorm_output)
...@@ -893,7 +1182,10 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -893,7 +1182,10 @@ class ParallelTransformerLayer(MegatronModule):
training=self.training) training=self.training)
output = residual + self.drop_path(out) output = residual + self.drop_path(out)
return output if self.layer_type == LayerType.retro_decoder_with_retriever:
return output, retriever_output
else:
return output
class NoopTransformerLayer(MegatronModule): class NoopTransformerLayer(MegatronModule):
...@@ -922,9 +1214,12 @@ class NoopTransformerLayer(MegatronModule): ...@@ -922,9 +1214,12 @@ class NoopTransformerLayer(MegatronModule):
return hidden_states.clone() return hidden_states.clone()
def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False): def _get_num_layers(args, model_type, is_decoder=False):
"""Compute the number of transformer layers resident on the current rank.""" """Compute the number of transformer layers resident on the current rank."""
if mpu.get_pipeline_model_parallel_world_size() > 1: is_encoder_and_decoder_model = (model_type == ModelType.encoder_and_decoder)
if model_type == ModelType.retro_encoder:
num_layers = args.retro_encoder_layers
elif mpu.get_pipeline_model_parallel_world_size() > 1:
if is_encoder_and_decoder_model: if is_encoder_and_decoder_model:
assert args.pipeline_model_parallel_split_rank is not None assert args.pipeline_model_parallel_split_rank is not None
...@@ -974,51 +1269,91 @@ def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False): ...@@ -974,51 +1269,91 @@ def _get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
return num_layers return num_layers
def _get_layer_type(model_type, default_layer_type, retro_layer_numbers,
layer_number):
args = get_args()
if args.retro_add_retriever and layer_number in retro_layer_numbers:
if model_type == ModelType.retro_decoder:
return LayerType.retro_decoder_with_retriever \
if layer_number == retro_layer_numbers[0] \
else LayerType.retro_decoder
elif model_type == ModelType.retro_encoder:
return LayerType.retro_encoder
else:
raise Exception("Unsupported model type, '%s'." % model_type)
else:
return default_layer_type
class ParallelTransformer(MegatronModule): class ParallelTransformer(MegatronModule):
"""Transformer class.""" """Transformer class."""
def __init__(self, init_method, output_layer_init_method, def __init__(self, config,
layer_type=LayerType.encoder, model_type, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding, self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True, post_layer_norm=True,
pre_process=True, post_process=True, pre_process=True,
post_process=True,
drop_path_rate=0.0): drop_path_rate=0.0):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
self.layer_type = layer_type self.layer_type = layer_type
self.model_type = args.model_type self.model_type = model_type
self.bf16 = args.bf16 self.bf16 = config.bf16
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = config.fp32_residual_connection
self.post_layer_norm = post_layer_norm self.post_layer_norm = post_layer_norm
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.input_tensor = None self.input_tensor = None
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
self.transformer_impl = args.transformer_impl self.transformer_impl = args.transformer_impl
self.retro_add_retriever = args.retro_add_retriever
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity self.recompute_granularity = config.recompute_granularity
self.recompute_method = args.recompute_method self.recompute_method = config.recompute_method
self.recompute_num_layers = args.recompute_num_layers self.recompute_num_layers = config.recompute_num_layers
self.distribute_saved_activations = \ self.distribute_saved_activations = \
args.distribute_saved_activations and not args.sequence_parallel config.distribute_saved_activations and not config.sequence_parallel
self.sequence_parallel = args.sequence_parallel self.sequence_parallel = config.sequence_parallel
# Transformer Engine Init. # Transformer Engine Init.
self.transformer_engine_v_0_10 = False
self.transformer_engine_v_0_11 = False
self.transformer_engine_v_0_8 = False
if self.transformer_impl == 'transformer_engine': if self.transformer_impl == 'transformer_engine':
global transformer_engine global transformer_engine
import transformer_engine import transformer_engine
self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid from importlib.metadata import version
from pkg_resources import packaging
te_version = packaging.version.Version(version("transformer-engine"))
if te_version >= packaging.version.Version("0.8.0"):
self.transformer_engine_v_0_8 = True
if te_version >= packaging.version.Version("0.10.0"):
self.transformer_engine_v_0_10 = True
if te_version >= packaging.version.Version("0.11.0"):
self.transformer_engine_v_0_11 = True
del version, packaging
assert not args.squared_relu, "TransformerEngine does not support squared relu activation."
self.use_fp8 = args.fp8 is not None
self.fp8_recipe = None self.fp8_recipe = None
self.fp8_group = None self.fp8_group = None
if self.use_fp8: if self.use_fp8:
self.fp8_group = mpu.get_data_parallel_group() assert args.transformer_impl == 'transformer_engine', \
if args.fp8_e4m3: 'transformer-engine required for fp8 training and inference'
self.fp8_group = mpu.get_amax_reduction_group()
if args.fp8 == "e4m3":
fp8_format = transformer_engine.common.recipe.Format.E4M3 fp8_format = transformer_engine.common.recipe.Format.E4M3
elif args.fp8_hybrid: elif args.fp8 == "hybrid":
fp8_format = transformer_engine.common.recipe.Format.HYBRID fp8_format = transformer_engine.common.recipe.Format.HYBRID
else:
raise ValueError("The DelayedScaling recipe only supports E4M3 and HYBRID formats.")
self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling( self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=args.fp8_margin, margin=args.fp8_margin,
interval=args.fp8_interval, interval=args.fp8_interval,
...@@ -1030,63 +1365,87 @@ class ParallelTransformer(MegatronModule): ...@@ -1030,63 +1365,87 @@ class ParallelTransformer(MegatronModule):
self.num_microbatches_in_previous_step = -1 self.num_microbatches_in_previous_step = -1
self.microbatch_count = 0 self.microbatch_count = 0
self.checkpoint_core_attention = args.recompute_granularity == 'selective' self.checkpoint_core_attention = config.recompute_granularity == 'selective'
# Number of layers. # Number of layers.
self.num_layers = _get_num_layers( self.num_layers = _get_num_layers(args, model_type,
args, layer_type==LayerType.decoder)
args.model_type == ModelType.encoder_and_decoder,
layer_type == LayerType.decoder) self.drop_path_rates = [
rate.item() for rate in
torch.linspace(0, self.drop_path_rate, config.num_layers)]
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)] self.retro_layer_numbers = None
if model_type == ModelType.retro_decoder:
retro_layer_start = 6 if config.num_layers <= 15 else 9
self.retro_layer_numbers = \
np.arange(retro_layer_start, args.num_layers + 1, 3).tolist()
if model_type == ModelType.retro_encoder:
self.retro_layer_numbers = [1]
# Transformer layers. # Transformer layers.
if args.retro_add_retriever:
assert self.recompute_granularity != 'full', \
"Full recompute not supported for Retro."
assert args.transformer_impl == 'local', \
"Transformer engine does not support Retro layers."
def build_layer(layer_number): def build_layer(layer_number):
if args.transformer_impl == 'local': if args.transformer_impl == 'local':
current_layer_type = _get_layer_type(
model_type, layer_type, self.retro_layer_numbers,
layer_number)
return ParallelTransformerLayer( return ParallelTransformerLayer(
init_method, config,
output_layer_init_method,
layer_number, layer_number,
layer_type=layer_type, layer_type=current_layer_type,
self_attn_mask_type=self_attn_mask_type, self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1]) drop_path_rate=self.drop_path_rates[layer_number - 1])
else: else:
# This argument is only available from TE v0.10 onwards.
extra_transformer_engine_kwargs = {}
if self.transformer_engine_v_0_8:
extra_transformer_engine_kwargs["bias"] = args.add_bias_linear
if self.transformer_engine_v_0_10:
extra_transformer_engine_kwargs["activation"] = "swiglu" if args.swiglu else "gelu"
if self.transformer_engine_v_0_11:
extra_transformer_engine_kwargs["normalization"] = args.normalization
return transformer_engine.pytorch.TransformerLayer( return transformer_engine.pytorch.TransformerLayer(
args.hidden_size, config.hidden_size,
args.ffn_hidden_size, config.ffn_hidden_size,
args.num_attention_heads, config.num_attention_heads,
layernorm_epsilon=args.layernorm_epsilon, layernorm_epsilon=config.layernorm_epsilon,
hidden_dropout=args.hidden_dropout, hidden_dropout=config.hidden_dropout,
attention_dropout=args.attention_dropout, attention_dropout=config.attention_dropout,
init_method=init_method, init_method=config.init_method,
output_layer_init_method=output_layer_init_method, output_layer_init_method=config.output_layer_init_method,
layer_number=layer_number, layer_number=layer_number,
kv_channels=args.kv_channels, kv_channels=config.kv_channels,
self_attn_mask_type=self_attn_mask_type.name, self_attn_mask_type=self_attn_mask_type.name,
tp_group=mpu.get_tensor_model_parallel_group(), tp_group=mpu.get_tensor_model_parallel_group(),
get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker, get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
fuse_wgrad_accumulation=args.gradient_accumulation_fusion, fuse_wgrad_accumulation=config.gradient_accumulation_fusion,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling, apply_query_key_layer_scaling=config.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32, attention_softmax_in_fp32=config.attention_softmax_in_fp32,
seq_length=args.seq_length, seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size, micro_batch_size=args.micro_batch_size,
sequence_parallel=args.sequence_parallel, sequence_parallel=config.sequence_parallel,
params_dtype=args.params_dtype, params_dtype=config.params_dtype,
apply_residual_connection_post_layernorm=args.apply_residual_connection_post_layernorm, apply_residual_connection_post_layernorm=config.apply_residual_connection_post_layernorm,
output_layernorm=False, output_layernorm=False,
layer_type="encoder", layer_type="encoder",
drop_path_rate=self.drop_path_rates[layer_number - 1], drop_path_rate=self.drop_path_rates[layer_number - 1],
set_parallel_mode=True, set_parallel_mode=True,
fuse_qkv_params=True) fuse_qkv_params=True,
**extra_transformer_engine_kwargs)
if args.virtual_pipeline_model_parallel_size is not None: if config.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ assert config.num_layers % config.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \ 'num_layers_per_stage must be divisible by ' \
'virtual_pipeline_model_parallel_size' 'virtual_pipeline_model_parallel_size'
assert args.model_type != ModelType.encoder_and_decoder assert args.model_type != ModelType.encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage, # Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage. # divided by the number of model chunks in a stage.
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size self.num_layers = self.num_layers // config.virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk): # layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6] # Stage 0: [0] [2] [4] [6]
...@@ -1096,7 +1455,7 @@ class ParallelTransformer(MegatronModule): ...@@ -1096,7 +1455,7 @@ class ParallelTransformer(MegatronModule):
# Stage 0: [0, 1] [4, 5] # Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7] # Stage 1: [2, 3] [6, 7]
offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
args.num_layers // args.virtual_pipeline_model_parallel_size) + \ config.num_layers // config.virtual_pipeline_model_parallel_size) + \
(mpu.get_pipeline_model_parallel_rank() * self.num_layers) (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
else: else:
# Each stage gets a contiguous set of layers. # Each stage gets a contiguous set of layers.
...@@ -1126,13 +1485,24 @@ class ParallelTransformer(MegatronModule): ...@@ -1126,13 +1485,24 @@ class ParallelTransformer(MegatronModule):
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
# Update dropout rate for Retro encoder.
if model_type == ModelType.retro_encoder:
for layer in self.layers:
if layer.self_attention.use_flash_attn:
layer.self_attention.core_attention_flash.dropout_p = \
torch.nn.Dropout(args.retro_encoder_attention_dropout)
else:
layer.self_attention.core_attention.attention_dropout.p =\
args.retro_encoder_attention_dropout
layer.hidden_dropout = args.retro_encoder_hidden_dropout
if self.post_process and self.post_layer_norm: if self.post_process and self.post_layer_norm:
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, config.hidden_size,
eps=args.layernorm_epsilon, eps=config.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm, no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel, sequence_parallel=config.sequence_parallel,
apply_layernorm_1p=args.apply_layernorm_1p) apply_layernorm_1p=args.apply_layernorm_1p)
def _get_layer(self, layer_number): def _get_layer(self, layer_number):
...@@ -1142,40 +1512,42 @@ class ParallelTransformer(MegatronModule): ...@@ -1142,40 +1512,42 @@ class ParallelTransformer(MegatronModule):
encoder_output, enc_dec_attn_mask, encoder_output, enc_dec_attn_mask,
rotary_pos_emb, is_first_microbatch): rotary_pos_emb, is_first_microbatch):
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
def custom(start, end, is_transformer_engine=False): def custom(start, end):
def custom_forward(*args, **kwargs): def custom_forward(*args, **kwargs):
x_, *args = args x_, *args = args
for index in range(start, end): for index in range(start, end):
layer = self._get_layer(index) layer = self._get_layer(index)
x_ = layer(x_, *args, **kwargs) x_ = layer(x_, *args, **kwargs)
return x_ return x_
def custom_forward_transformer_engine(*args, **kwargs): return custom_forward
return custom_forward(*args, is_first_microbatch=is_first_microbatch, **kwargs)
if not is_transformer_engine: te_forward_kwargs = {}
return custom_forward if self.transformer_impl == 'transformer_engine':
else: te_forward_kwargs['is_first_microbatch'] = is_first_microbatch
return custom_forward_transformer_engine if self.transformer_engine_v_0_10:
te_forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
if self.recompute_method == 'uniform': if self.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint # Uniformly divide the total number of Transformer layers and
# the input activation of each divided chunk. # checkpoint the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints. # A method to further reduce memory usage reducing checkpoints.
l = 0 l = 0
while l < self.num_layers: while l < self.num_layers:
if self.transformer_impl == 'transformer_engine': if self.transformer_impl == 'transformer_engine':
hidden_states = transformer_engine.pytorch.distributed.checkpoint( hidden_states = transformer_engine.pytorch.checkpoint(
custom(l, l + self.recompute_num_layers, is_transformer_engine=True), custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations, self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker, tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(), mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output, hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb) enc_dec_attn_mask, **te_forward_kwargs)
else: else:
hidden_states = tensor_parallel.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers), custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, hidden_states, attention_mask,
enc_dec_attn_mask, rotary_pos_emb) encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
l += self.recompute_num_layers l += self.recompute_num_layers
...@@ -1186,28 +1558,30 @@ class ParallelTransformer(MegatronModule): ...@@ -1186,28 +1558,30 @@ class ParallelTransformer(MegatronModule):
for l in range(self.num_layers): for l in range(self.num_layers):
if l < self.recompute_num_layers: if l < self.recompute_num_layers:
if self.transformer_impl == 'transformer_engine': if self.transformer_impl == 'transformer_engine':
hidden_states = transformer_engine.pytorch.distributed.checkpoint( hidden_states = transformer_engine.pytorch.checkpoint(
custom(l, l + 1, is_transformer_engine=True), custom(l, l + 1),
self.distribute_saved_activations, self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker, tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(), mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output, hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb) enc_dec_attn_mask, **te_forward_kwargs)
else: else:
hidden_states = tensor_parallel.checkpoint( hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_saved_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, hidden_states, attention_mask,
enc_dec_attn_mask, rotary_pos_emb) encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
else: else:
if self.transformer_impl == 'transformer_engine': if self.transformer_impl == 'transformer_engine':
hidden_states = custom(l, l + 1, is_transformer_engine=True)( hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, hidden_states, attention_mask, encoder_output,
enc_dec_attn_mask, rotary_pos_emb) enc_dec_attn_mask, **te_forward_kwargs)
else: else:
hidden_states = custom(l, l + 1)( hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, hidden_states, attention_mask,
enc_dec_attn_mask, rotary_pos_emb) encoder_output, enc_dec_attn_mask,
None, None, None, None, rotary_pos_emb)
else: else:
raise ValueError("Invalid activation recompute method.") raise ValueError("Invalid activation recompute method.")
...@@ -1225,7 +1599,11 @@ class ParallelTransformer(MegatronModule): ...@@ -1225,7 +1599,11 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
inference_params=None, rotary_pos_emb=None): retriever_input=None,
retriever_output=None,
retriever_attn_mask=None,
inference_params=None,
rotary_pos_emb=None):
# hidden_states: [s, b, h] # hidden_states: [s, b, h]
# Checks. # Checks.
...@@ -1258,11 +1636,13 @@ class ParallelTransformer(MegatronModule): ...@@ -1258,11 +1636,13 @@ class ParallelTransformer(MegatronModule):
keep_graph=True, keep_graph=True,
) )
# RNG context.
if self.sequence_parallel: if self.sequence_parallel:
rng_context = tensor_parallel.get_cuda_rng_tracker().fork() rng_context = tensor_parallel.get_cuda_rng_tracker().fork()
else: else:
rng_context = nullcontext() rng_context = nullcontext()
# Forward layers.
with rng_context: with rng_context:
# The fp8_autocast context manager is a no-op when enabled=True # The fp8_autocast context manager is a no-op when enabled=True
# The if...else serves to short circuit name resolution for fp8_autocast # The if...else serves to short circuit name resolution for fp8_autocast
...@@ -1290,12 +1670,18 @@ class ParallelTransformer(MegatronModule): ...@@ -1290,12 +1670,18 @@ class ParallelTransformer(MegatronModule):
'encoder_output': encoder_output, 'encoder_output': encoder_output,
'enc_dec_attn_mask': enc_dec_attn_mask, 'enc_dec_attn_mask': enc_dec_attn_mask,
'inference_params': inference_params, 'inference_params': inference_params,
'rotary_pos_emb': rotary_pos_emb,
} }
if self.transformer_impl == 'transformer_engine': if self.transformer_impl == 'transformer_engine':
forward_kwargs['is_first_microbatch'] = is_first_microbatch forward_kwargs['is_first_microbatch'] = is_first_microbatch
forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention
if self.transformer_engine_v_0_10:
forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
else:
forward_kwargs['rotary_pos_emb'] = rotary_pos_emb
forward_kwargs['retriever_input'] = retriever_input
forward_kwargs['retriever_output'] = retriever_output
forward_kwargs['retriever_attn_mask'] = retriever_attn_mask
for index in range(self.num_layers): for index in range(self.num_layers):
layer = self._get_layer(index) layer = self._get_layer(index)
...@@ -1305,6 +1691,14 @@ class ParallelTransformer(MegatronModule): ...@@ -1305,6 +1691,14 @@ class ParallelTransformer(MegatronModule):
attention_mask, attention_mask,
**forward_kwargs) **forward_kwargs)
# First Retro decoder layer returns both hidden_states
# and retriever_output. Make retriever_output available
# to subsequence Retro layers.
if isinstance(hidden_states, tuple):
assert len(hidden_states) == 2
hidden_states, retriever_output = hidden_states
forward_kwargs["retriever_output"] = retriever_output
# Skip counter update for eval and activation checkpointing # Skip counter update for eval and activation checkpointing
if torch.is_grad_enabled() and self.training: if torch.is_grad_enabled() and self.training:
self.microbatch_count += 1 self.microbatch_count += 1
......
...@@ -13,7 +13,7 @@ from megatron.model.module import MegatronModule ...@@ -13,7 +13,7 @@ from megatron.model.module import MegatronModule
class VitClassificationModel(MegatronModule): class VitClassificationModel(MegatronModule):
"""Vision Transformer Model.""" """Vision Transformer Model."""
def __init__(self, num_classes, finetune=False, def __init__(self, config, num_classes, finetune=False,
pre_process=True, post_process=True): pre_process=True, post_process=True):
super(VitClassificationModel, self).__init__() super(VitClassificationModel, self).__init__()
args = get_args() args = get_args()
...@@ -24,6 +24,7 @@ class VitClassificationModel(MegatronModule): ...@@ -24,6 +24,7 @@ class VitClassificationModel(MegatronModule):
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.backbone = VitBackbone( self.backbone = VitBackbone(
config=config,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process, post_process=self.post_process,
single_token_output=True single_token_output=True
......
...@@ -173,11 +173,12 @@ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, ...@@ -173,11 +173,12 @@ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
return schedule return schedule
def get_student_backbone_and_num_features(pre_process=True, post_process=True): def get_student_backbone_and_num_features(config, pre_process=True, post_process=True):
args = get_args() args = get_args()
if args.vision_backbone_type == 'vit': if args.vision_backbone_type == 'vit':
student = VitBackbone(pre_process=pre_process, student = VitBackbone(config,
pre_process=pre_process,
post_process=post_process, post_process=post_process,
drop_path_rate=0.1, drop_path_rate=0.1,
single_token_output=True) single_token_output=True)
...@@ -194,11 +195,12 @@ def get_student_backbone_and_num_features(pre_process=True, post_process=True): ...@@ -194,11 +195,12 @@ def get_student_backbone_and_num_features(pre_process=True, post_process=True):
return student, num_features return student, num_features
def get_teacher_backbone_and_num_features(pre_process=True, post_process=True): def get_teacher_backbone_and_num_features(config, pre_process=True, post_process=True):
args = get_args() args = get_args()
if args.vision_backbone_type == 'vit': if args.vision_backbone_type == 'vit':
teacher = VitBackbone(pre_process=pre_process, teacher = VitBackbone(config,
pre_process=pre_process,
post_process=post_process, post_process=post_process,
single_token_output=True) single_token_output=True)
num_features = args.hidden_size num_features = args.hidden_size
...@@ -215,7 +217,7 @@ def get_teacher_backbone_and_num_features(pre_process=True, post_process=True): ...@@ -215,7 +217,7 @@ def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
class DINOPretrainModel(MegatronModule): class DINOPretrainModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True): def __init__(self, config, pre_process=True, post_process=True):
super(DINOPretrainModel, self).__init__() super(DINOPretrainModel, self).__init__()
args = get_args() args = get_args()
self.out_dim = 65536 self.out_dim = 65536
...@@ -234,7 +236,7 @@ class DINOPretrainModel(MegatronModule): ...@@ -234,7 +236,7 @@ class DINOPretrainModel(MegatronModule):
self.momentum_teacher = 0.996 self.momentum_teacher = 0.996
student_backbone, num_features = \ student_backbone, num_features = \
get_student_backbone_and_num_features(pre_process, post_process) get_student_backbone_and_num_features(config, pre_process, post_process)
self.student = MultiCropWrapper( self.student = MultiCropWrapper(
student_backbone, student_backbone,
...@@ -249,7 +251,7 @@ class DINOPretrainModel(MegatronModule): ...@@ -249,7 +251,7 @@ class DINOPretrainModel(MegatronModule):
) )
teacher_backbone, num_features = \ teacher_backbone, num_features = \
get_teacher_backbone_and_num_features(pre_process, post_process) get_teacher_backbone_and_num_features(config, pre_process, post_process)
self.teacher = MultiCropWrapper( self.teacher = MultiCropWrapper(
teacher_backbone, teacher_backbone,
DINOHead(num_features, self.out_dim) DINOHead(num_features, self.out_dim)
......
...@@ -18,14 +18,15 @@ from megatron.model.vision.utils import resize_ ...@@ -18,14 +18,15 @@ from megatron.model.vision.utils import resize_
class VitInpaintingModel(MegatronModule): class VitInpaintingModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True): def __init__(self, config, pre_process=True, post_process=True):
super(VitInpaintingModel, self).__init__() super(VitInpaintingModel, self).__init__()
args = get_args() args = get_args()
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.hidden_size = args.hidden_size self.hidden_size = config.hidden_size
self.backbone = VitBackbone( self.backbone = VitBackbone(
config=config,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process, post_process=self.post_process,
class_token=False, class_token=False,
......
# --------------------------------------------------------------- # Copyright (c) 2023, NVIDIA Corporation. All rights reserved.
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# found in the LICENSE file in the root directory of this
# source tree.
# ---------------------------------------------------------------
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
......
...@@ -130,24 +130,17 @@ class VitBackbone(MegatronModule): ...@@ -130,24 +130,17 @@ class VitBackbone(MegatronModule):
"""Vision Transformer Model.""" """Vision Transformer Model."""
def __init__(self, def __init__(self,
config,
pre_process=True, pre_process=True,
post_process=True, post_process=True,
class_token=True, class_token=True,
single_token_output=False, single_token_output=False,
post_layer_norm=True, post_layer_norm=True,
drop_path_rate=0.0): drop_path_rate=0.0):
super(VitBackbone, self).__init__(share_word_embeddings=False) super(VitBackbone, self).__init__(share_embeddings_and_output_weights=False)
args = get_args() args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
if args.init_method_xavier_uniform:
self.init_method = torch.nn.init.xavier_uniform_
self.scaled_init_method = torch.nn.init.xavier_uniform_
else:
self.init_method = init_method_normal(args.init_method_std)
self.scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers
)
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
...@@ -202,8 +195,7 @@ class VitBackbone(MegatronModule): ...@@ -202,8 +195,7 @@ class VitBackbone(MegatronModule):
# Transformer # Transformer
self.transformer = ParallelTransformer( self.transformer = ParallelTransformer(
self.init_method, config,
self.scaled_init_method,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process, post_process=self.post_process,
post_layer_norm=self.post_layer_norm, post_layer_norm=self.post_layer_norm,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Gradient clipping.""" """Gradient clipping."""
import torch import torch
from torch._six import inf from torch import inf
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment