# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """Retro Transformer. ** Special note about this file ** Many classes and methods in this file directly parallel those in transformer.py in name and utility. However, due to 1) subtle changes in the code over time (i.e., transposes and contexts), and 2) other code that is soon to be merged, this file will *temporarily* remain as is, until a larger integration is complete. """ import math import numpy as np import torch import torch.nn.functional as F from megatron import get_args, get_retro_args, get_tensorboard_writer from megatron.core import parallel_state from megatron.core import tensor_parallel from megatron.core import utils as core_utils from megatron.core.enums import ModelType from megatron.model.enums import AttnMaskType, LayerType, AttnType from megatron.model import LayerNorm from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, init_method_normal from .module import MegatronModule from .transformer import _get_num_layers, ParallelMLP, NoopTransformerLayer """ We use the following notation throughout this file: h: hidden size n: number of attention heads p: number of model parallel partitions np: n/p hp: h/p hn: h/n b: batch size s: sequence length l: number of layers Transformer takes input of size [s, b, h] and returns a tensor of the same size. We use the following arguments: hyperparameters: transformer hyperparameters """ class DropPath(MegatronModule): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). *Note: differs from transformer.py/DropPath in hidden_state transpose. """ def __init__(self, drop_prob=0.): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, hidden_state): if self.drop_prob == 0. or not self.training: return hidden_state keep_prob = 1 - self.drop_prob # work with diff dim tensors, not just 2D ConvNets shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1) random_tensor = keep_prob + \ torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device) random_tensor.floor_() # binarize output = hidden_state.div(keep_prob) * random_tensor return output class SwitchMLP(MegatronModule): """ Routes input to one of N MLP "experts" """ def __init__(self, init_method, output_layer_init_method): super(SwitchMLP, self).__init__() args = get_args() self.router = torch.nn.Linear(args.hidden_size, args.num_experts) self.experts = torch.nn.ModuleList() for i in range(args.num_experts): self.experts.append(ParallelMLP(init_method, output_layer_init_method)) def forward(self, hidden_states): # hidden_states: [b, s, h] b = hidden_states.size(0) s = hidden_states.size(1) h = hidden_states.size(2) route = self.router(hidden_states) route = torch.nn.functional.softmax(route, dim=2) max_prob, max_ind = torch.max(route, dim=2) max_prob = torch.unsqueeze(max_prob, 2) # [b s 1] # TODO (rprenger) TODO this could be made easier to read # Converting [b, s, h] to [b*s, h]. # Each vector could be routed differently hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h] max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1] max_ind = max_ind.view(-1) # [b*s] output_total = torch.empty_like(hidden_states) output_bias_total = torch.empty_like(hidden_states) #TODO (rprenger) This does each expert in serial, but it could be parallelized for expert_num, expert in enumerate(self.experts): local_indices = (max_ind == expert_num).nonzero() hidden = hidden_states[local_indices,:] output, output_bias = expert(hidden) output_bias = output_bias.expand_as(output) output_total[local_indices,:] = output output_bias_total[local_indices,:] = output_bias output_total = output_total*max_prob output_bias_total = output_bias_total*max_prob output_total = output_total.view(b, s, h) output_bias_total = output_bias_total.view(b, s, h) return output_total, output_bias_total class ParallelAttention(MegatronModule): """Parallel self-attention layer abstract class. Self-attention layer takes input with size [b, s, h] and returns output of the same size. """ def __init__(self, init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=AttnMaskType.padding): super(ParallelAttention, self).__init__() args = get_args() self.fp16 = args.fp16 self.bf16 = args.bf16 self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True self.layer_number = max(1, layer_number) self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.params_dtype = args.params_dtype projection_size = args.kv_channels * args.num_attention_heads # Per attention head and per partition values. world_size = parallel_state.get_tensor_model_parallel_world_size() self.hidden_size_per_partition = core_utils.divide(projection_size, world_size) self.hidden_size_per_attention_head = core_utils.divide( projection_size, args.num_attention_heads) self.num_attention_heads_per_partition = core_utils.divide( args.num_attention_heads, world_size) # Strided linear layer. if attention_type == AttnType.self_attn: self.query_key_value = tensor_parallel.ColumnParallelLinear( args.hidden_size, 3 * projection_size, gather_output=False, init_method=init_method) else: assert attention_type == AttnType.cross_attn self.query = tensor_parallel.ColumnParallelLinear( args.hidden_size, projection_size, gather_output=False, init_method=init_method) self.key_value = tensor_parallel.ColumnParallelLinear( args.hidden_size, 2 * projection_size, gather_output=False, init_method=init_method) coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.scale_mask_softmax = FusedScaleMaskSoftmax( self.fp16, self.bf16, self.attn_mask_type, args.masked_softmax_fusion, attention_mask_func, self.attention_softmax_in_fp32, coeff) # Dropout. Note that for a single iteration, this layer will generate # different outputs on different number of parallel partitions but # on average it should not be partition dependent. self.attention_dropout = torch.nn.Dropout(args.attention_dropout) # Output. self.dense = tensor_parallel.RowParallelLinear( projection_size, args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True) def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( inference_max_sequence_len, batch_size, self.num_attention_heads_per_partition, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device()) def forward(self, hidden_states, attention_mask, encoder_output=None, inference_params=None): # hidden_states: [sq, b, h] # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= if inference_params: if self.layer_number not in inference_params.key_value_memory_dict: inf_max_seq_len = inference_params.max_sequence_len inf_max_batch_size = inference_params.max_batch_size inference_key_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) inference_value_memory = self._allocate_memory( inf_max_seq_len, inf_max_batch_size) inference_params.key_value_memory_dict[self.layer_number] = ( inference_key_memory, inference_value_memory) else: inference_key_memory, inference_value_memory = \ inference_params.key_value_memory_dict[self.layer_number] # ===================== # Query, Key, and Value # ===================== if self.attention_type == AttnType.self_attn: # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] (query_layer, key_layer, value_layer) = tensor_parallel \ .split_tensor_along_last_dim(mixed_x_layer, 3) else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] new_tensor_shape = mixed_kv_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 2 * self.hidden_size_per_attention_head) mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] (key_layer, value_layer) = tensor_parallel \ .split_tensor_along_last_dim(mixed_kv_layer, 2) # Attention head [sq, b, h] --> [sq, b, hp] query_layer, _ = self.query(hidden_states) # [sq, b, hp] --> [sq, b, np, hn] new_tensor_shape = query_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) query_layer = query_layer.view(*new_tensor_shape) # ================================== # Adjust key and value for inference # ================================== if inference_params: batch_start = inference_params.batch_size_offset batch_end = batch_start + key_layer.size(1) assert batch_end <= inference_key_memory.size(1) sequence_start = inference_params.sequence_len_offset sequence_end = sequence_start + key_layer.size(0) assert sequence_end <= inference_key_memory.size(0) # Copy key and values. inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer key_layer = inference_key_memory[ :sequence_end, batch_start:batch_end, ...] value_layer = inference_value_memory[ :sequence_end, batch_start:batch_end, ...] # =================================== # Raw attention scores. [b, np, s, s] # =================================== # [b, np, sq, sk] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # [sq, b, np, hn] -> [sq, b * np, hn] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # preallocting result tensor: [b * np, sq, sk] matmul_result = torch.empty( output_size[0]*output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=torch.cuda.current_device()) # Raw attention scores. [b * np, sq, sk] matmul_result = torch.baddbmm( matmul_result, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0/self.norm_factor)) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) # =========================== # Attention probs and dropout # =========================== # attention scores and attention mask [b, np, sq, sk] attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with tensor_parallel.get_cuda_rng_tracker().fork(): attention_probs = self.attention_dropout(attention_probs) # ========================= # Context layer. [sq, b, hp] # ========================= # value_layer -> context layer. # [sk, b, np, hn] --> [b, np, sq, hn] # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # change view [sk, b * np, hn] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] new_context_layer_shape = context_layer.size()[:-2] + \ (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) # ================= # Output. [sq, b, h] # ================= output, bias = self.dense(context_layer) return output, bias def bias_dropout_add(x, bias, residual, prob, training): # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor out = torch.nn.functional.dropout(x + bias, p=prob, training=training) out = residual + out return out def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) return _bias_dropout_add @torch.jit.script def bias_dropout_add_fused_train(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) @torch.jit.script def bias_dropout_add_fused_inference(x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, False) class ParallelRetroTransformerEncoderLayer(MegatronModule): """A single transformer layer for Retro Decoder with an retriever encoder inside and cross attention. Transformer layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__(self, init_method, output_layer_init_method, layer_number, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, drop_path_rate=0., retriever=None): args = get_args() super(ParallelRetroTransformerEncoderLayer, self).__init__() self.layer_number = layer_number self.layer_type = layer_type self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm self.bf16 = args.bf16 self.fp32_residual_connection = args.fp32_residual_connection # Retro Encoder self.retriever = retriever # Layernorm on the input data. self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) # Self attention. self.self_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=self_attn_mask_type) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \ else None # Layernorm on the attention output self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) self.inter_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, attention_type=AttnType.cross_attn) # Layernorm on the attention output. self.post_inter_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) # MLP if args.num_experts is not None: self.mlp = SwitchMLP(init_method, output_layer_init_method) else: self.mlp = ParallelMLP(init_method, output_layer_init_method) def forward(self, hidden_states, attention_mask, retriever_output, retriever_attn_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None): # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.self_attention( layernorm_output, attention_mask, inference_params=inference_params) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states if self.drop_path is None: # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) # re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) else: out = torch.nn.functional.dropout(attention_output + attention_bias, p=self.hidden_dropout, training=self.training) layernorm_input = residual + self.drop_path(out) # Layer norm post the self attention. # [ns, bs, d] layernorm_output = self.post_attention_layernorm(layernorm_input) """ notations: l: number of chunks m: number of token per chunk bs: batch size d: hidden size k: number of neighbors r: number of tokens per neighbors (neighbors + continuation) """ args = get_args() retro_args = get_retro_args() chunk_length = retro_args.retro_gpt_chunk_length retrieved_length = retro_args.retro_gpt_retrieved_length num_neighbors = args.retro_num_neighbors ns, bs, d = layernorm_output.shape l = int(np.ceil(ns / chunk_length)) first_ns = ns % chunk_length if first_ns > 0: first_chunk, rest_chunk = \ layernorm_output[:first_ns], layernorm_output[first_ns:] first_chunk = torch.nn.functional.pad( first_chunk, (0, 0, 0, 0, 0, chunk_length - first_ns), 'constant', 0) chunked_output = \ torch.cat((first_chunk, rest_chunk), dim=0) # [l * m, bs, d] else: chunked_output = layernorm_output # [l * m, bs, d] chunked_output = chunked_output \ .reshape(l, chunk_length, bs, d) \ .permute(1, 2, 0, 3) \ .reshape(chunk_length, bs * l, d) \ .contiguous() # Get Encoder Output retriever_output = self.retriever( retriever_output, retriever_attn_mask, retriever_output=chunked_output, retriever_attn_mask=retriever_attn_mask, inference_params=inference_params) # [r, k * bs * l , d] retriever_output = retriever_output.reshape( retrieved_length * num_neighbors, bs * l, d) # [r * k, bs * l, d] # Chunked Cross attention with Retriever Encoder pad = (ns - 1) % chunk_length attending_chunks = layernorm_output[pad:] # [ns - m + 1, bs, d] padded_chunks = torch.nn.functional.pad( attending_chunks, (0, 0, 0, 0, 0, chunk_length-1), 'constant', 0) # [ns, bs, d] padded_chunked_output = padded_chunks \ .reshape(l, chunk_length, bs, d) \ .permute(1, 2, 0, 3) padded_chunked_output = padded_chunked_output.reshape( chunk_length, bs * l, d).contiguous() # [m, bs * l, d] # attention_output: [m, bs * l, d] # attention_bias: [d] attention_output, attention_bias = \ self.inter_attention( padded_chunked_output, # Q: main model embedding None, encoder_output=retriever_output) # KV: retriever output embedding # Residual connection if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input # Re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(attention_output), torch.zeros_like(attention_output), self.hidden_dropout) layernorm_input = layernorm_input \ .reshape(chunk_length, bs, l, d) \ .permute(2, 0, 1, 3) # [l, m, bs, d] layernorm_input = layernorm_input.reshape(chunk_length * l, bs, d) layernorm_input = torch.nn.functional.pad( layernorm_input, (0, 0, 0, 0, pad, 0), 'constant', 0)[:ns] # [ns, b, d] layernorm_input = layernorm_input + residual # Layer norm post the decoder attention layernorm_output = self.post_inter_attention_layernorm(layernorm_input) # MLP. mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input if self.drop_path is None: # Re-enable torch grad to enable fused optimization. with torch.enable_grad(): output = bias_dropout_add_func( mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) else: out = torch.nn.functional.dropout(mlp_output + mlp_bias, p=self.hidden_dropout, training=self.training) output = residual + self.drop_path(out) return output, retriever_output class ParallelRetroTransformerLayer(MegatronModule): """A single transformer layer for Retro Decoder with cross attention. Transformer layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__(self, init_method, output_layer_init_method, layer_number, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, drop_path_rate=0.): args = get_args() super(ParallelRetroTransformerLayer, self).__init__() self.layer_number = layer_number self.layer_type = layer_type self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm self.bf16 = args.bf16 self.fp32_residual_connection = args.fp32_residual_connection # Layernorm on the input data. self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) # Self attention. self.self_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=self_attn_mask_type) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \ else None # Layernorm on the attention output self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) self.inter_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, attention_type=AttnType.cross_attn) # Layernorm on the attention output. self.post_inter_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) # MLP if args.num_experts is not None: self.mlp = SwitchMLP(init_method, output_layer_init_method) else: self.mlp = ParallelMLP(init_method, output_layer_init_method) def forward(self, hidden_states, attention_mask, retriever_output, retriever_attn_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None): # hidden_states: [b, s, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.self_attention( layernorm_output, attention_mask, inference_params=inference_params) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states if self.drop_path is None: # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) # re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) else: out = torch.nn.functional.dropout(attention_output + attention_bias, p=self.hidden_dropout, training=self.training) layernorm_input = residual + self.drop_path(out) # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) args = get_args() retro_args = get_retro_args() chunk_length = retro_args.retro_gpt_chunk_length ns, bs, d = layernorm_output.shape l = int(np.ceil(ns / chunk_length)) pad = (ns - 1) % chunk_length attending_chunks = layernorm_output[pad:] padded_chunks = torch.nn.functional.pad( attending_chunks, (0, 0, 0, 0, 0, chunk_length - 1), 'constant', 0) padded_chunked_output = padded_chunks \ .reshape(l, chunk_length, bs, d) \ .permute(1, 2, 0, 3) padded_chunked_output = padded_chunked_output.reshape( chunk_length, bs * l, d).contiguous() # Encoder output. attention_output, attention_bias = \ self.inter_attention(padded_chunked_output, None, encoder_output=retriever_output) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input # Re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(attention_output), torch.zeros_like(attention_output), self.hidden_dropout) layernorm_input = layernorm_input \ .reshape(chunk_length, bs, l, d) \ .permute(2, 0, 1, 3) # [l, m, bs, d] layernorm_input = layernorm_input.reshape(chunk_length * l, bs, d) layernorm_input = torch.nn.functional.pad( layernorm_input, (0, 0, 0, 0, pad, 0), 'constant', 0)[:ns] layernorm_input = layernorm_input + residual # Layer norm post the decoder attention layernorm_output = self.post_inter_attention_layernorm(layernorm_input) # MLP. mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input if self.drop_path is None: # Re-enable torch grad to enable fused optimization. with torch.enable_grad(): output = bias_dropout_add_func( mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) else: out = torch.nn.functional.dropout(mlp_output + mlp_bias, p=self.hidden_dropout, training=self.training) output = residual + self.drop_path(out) return output class ParallelRetroEncoderTransformerCALayer(MegatronModule): """A single transformer layer for Retro Encoder with cross attention. Transformer layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__(self, init_method, output_layer_init_method, layer_number, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, drop_path_rate=0.): args = get_args() super(ParallelRetroEncoderTransformerCALayer, self).__init__() self.layer_number = layer_number self.layer_type = layer_type self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm self.bf16 = args.bf16 self.fp32_residual_connection = args.fp32_residual_connection # Layernorm on the input data. self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) # Self attention. self.self_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=self_attn_mask_type) self.self_attention.attention_dropout = \ torch.nn.Dropout(args.retro_encoder_attention_dropout) self.hidden_dropout = args.retro_encoder_hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \ else None # Layernorm on the attention output self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) self.inter_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, attention_type=AttnType.cross_attn) # Layernorm on the attention output. self.post_inter_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) # MLP if args.num_experts is not None: self.mlp = SwitchMLP(init_method, output_layer_init_method) else: self.mlp = ParallelMLP(init_method, output_layer_init_method) def forward(self, hidden_states, attention_mask, retriever_output, retriever_attn_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None): # hidden_states: [s, b, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.self_attention( layernorm_output, attention_mask, inference_params=inference_params) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states if self.drop_path is None: # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) # re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) else: out = torch.nn.functional.dropout(attention_output + attention_bias, p=self.hidden_dropout, training=self.training) layernorm_input = residual + self.drop_path(out) # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) # Neighbors. args = get_args() retro_args = get_retro_args() retrieved_length = retro_args.retro_gpt_retrieved_length num_neighbors = args.retro_num_neighbors ns, bs, d = layernorm_output.shape # [r, bs * l * k, d] chunked_outputs = layernorm_output.reshape(retrieved_length, -1, num_neighbors, d) chunked_outputs_before_layer_norm = \ layernorm_input.reshape(retrieved_length, -1, num_neighbors, d) # [r, bs * l, k, d] layernorm_inputs = [] layernorm_outputs = [] for k in range(num_neighbors): chunked_output = chunked_outputs[:,:,k].contiguous() attention_output, attention_bias = \ self.inter_attention( chunked_output, # Q (neighbor embedding) None, encoder_output=retriever_output) # K, V (hidden act) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = chunked_output else: residual = chunked_outputs_before_layer_norm[:,:,k] # Re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) layernorm_inputs.append(layernorm_input) # Layer norm post the decoder attention layernorm_output = \ self.post_inter_attention_layernorm(layernorm_input) layernorm_outputs.append(layernorm_output) # layernorm_input : [r, k * bs * l, d] # layernorm_output : [r, k * bs * l, d] layernorm_input = \ torch.stack(layernorm_inputs, dim=1).reshape(ns, bs, d) layernorm_output = \ torch.stack(layernorm_outputs, dim=1).reshape(ns, bs, d) # MLP. mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input if self.drop_path is None: # Re-enable torch grad to enable fused optimization. with torch.enable_grad(): output = bias_dropout_add_func( mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) else: out = torch.nn.functional.dropout(mlp_output + mlp_bias, p=self.hidden_dropout, training=self.training) output = residual + self.drop_path(out) return output class ParallelTransformerLayer(MegatronModule): """A single transformer layer. Transformer layer takes input with size [b, s, h] and returns an output of the same size. """ def __init__(self, init_method, output_layer_init_method, layer_number, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, drop_path_rate=0.): args = get_args() super(ParallelTransformerLayer, self).__init__() self.layer_number = layer_number self.layer_type = layer_type self.apply_residual_connection_post_layernorm \ = args.apply_residual_connection_post_layernorm self.bf16 = args.bf16 self.fp32_residual_connection = args.fp32_residual_connection # Layernorm on the input data. self.input_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) # Self attention. self.self_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, attention_type=AttnType.self_attn, attn_mask_type=self_attn_mask_type) self.hidden_dropout = args.hidden_dropout self.bias_dropout_fusion = args.bias_dropout_fusion self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 \ else None # Layernorm on the attention output self.post_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) if self.layer_type == LayerType.decoder: self.inter_attention = ParallelAttention( init_method, output_layer_init_method, layer_number, attention_type=AttnType.cross_attn) # Layernorm on the attention output. self.post_inter_attention_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) # MLP if args.num_experts is not None: self.mlp = SwitchMLP(init_method, output_layer_init_method) else: self.mlp = ParallelMLP(init_method, output_layer_init_method) def forward(self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None): # hidden_states: [b, s, h] # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) # Self attention. attention_output, attention_bias = \ self.self_attention( layernorm_output, attention_mask, inference_params=inference_params) # Residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = hidden_states if self.drop_path is None: # jit scripting for a nn.module (with dropout) is not # trigerring the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. if self.bias_dropout_fusion: if self.training: bias_dropout_add_func = bias_dropout_add_fused_train else: bias_dropout_add_func = bias_dropout_add_fused_inference else: bias_dropout_add_func = get_bias_dropout_add(self.training) # re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) else: out = torch.nn.functional.dropout(attention_output + attention_bias, p=self.hidden_dropout, training=self.training) layernorm_input = residual + self.drop_path(out) # Layer norm post the self attention. layernorm_output = self.post_attention_layernorm(layernorm_input) if self.layer_type == LayerType.decoder: attention_output, attention_bias = \ self.inter_attention(layernorm_output, enc_dec_attn_mask, encoder_output=encoder_output) # residual connection if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input # re-enable torch grad to enable fused optimization. with torch.enable_grad(): layernorm_input = bias_dropout_add_func( attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout) # Layer norm post the decoder attention layernorm_output = self.post_inter_attention_layernorm(layernorm_input) # MLP. mlp_output, mlp_bias = self.mlp(layernorm_output) # Second residual connection. if self.apply_residual_connection_post_layernorm: residual = layernorm_output else: residual = layernorm_input if self.drop_path is None: # re-enable torch grad to enable fused optimization. with torch.enable_grad(): output = bias_dropout_add_func( mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) else: out = torch.nn.functional.dropout(mlp_output + mlp_bias, p=self.hidden_dropout, training=self.training) output = residual + self.drop_path(out) return output class ParallelRetroEncoder(MegatronModule): """ Retro Transformer class for encoder .""" def __init__(self, init_method, output_layer_init_method, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, pre_process=True, post_process=True, drop_path_rate=0.0): super(ParallelRetroEncoder, self).__init__() args = get_args() self.bf16 = args.bf16 self.fp32_residual_connection = args.fp32_residual_connection self.pre_process = pre_process self.post_process = post_process self.input_tensor = None self.drop_path_rate = drop_path_rate # Store activation checkpoiting flag. self.recompute_granularity = args.recompute_granularity self.recompute_method = args.recompute_method self.recompute_num_layers = args.recompute_num_layers self.distribute_saved_activations = \ args.distribute_saved_activations and not args.sequence_parallel self.sequence_parallel = args.sequence_parallel # Number of layers. self.num_layers = args.retro_encoder_layers self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)] if args.retro_add_retriever: self.P = [1] # Transformer layers. assert args.retro_add_retriever def build_layer(layer_number): if layer_number in self.P: return ParallelRetroEncoderTransformerCALayer( init_method, output_layer_init_method, layer_number, layer_type=layer_type, self_attn_mask_type=self_attn_mask_type, drop_path_rate=self.drop_path_rates[layer_number - 1]) else: layer = ParallelTransformerLayer( init_method, output_layer_init_method, layer_number, layer_type=layer_type, self_attn_mask_type=self_attn_mask_type, drop_path_rate=self.drop_path_rates[layer_number - 1]) layer.self_attention.attention_dropout = \ torch.nn.Dropout(args.retro_encoder_attention_dropout) layer.hidden_dropout = args.retro_encoder_hidden_dropout return layer if args.virtual_pipeline_model_parallel_size is not None: assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ 'num_layers_per_stage must be divisible by ' \ 'virtual_pipeline_model_parallel_size' assert args.model_type != ModelType.encoder_and_decoder # Number of layers in each model chunk is the number of layers in # the stage, divided by the number of model chunks in a stage. self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size # With 8 layers, 2 stages, and 4 model chunks, we want an # assignment of layers to stages like (each list is a model chunk): # Stage 0: [0] [2] [4] [6] # Stage 1: [1] [3] [5] [7] # With 8 layers, 2 stages, and 2 virtual stages, we want an # assignment of layers to stages like (each list is a model chunk): # Stage 0: [0, 1] [4, 5] # Stage 1: [2, 3] [6, 7] offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * ( args.num_layers // args.virtual_pipeline_model_parallel_size) + \ (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers) else: # Each stage gets a contiguous set of layers. if args.model_type == ModelType.encoder_and_decoder and \ parallel_state.get_pipeline_model_parallel_world_size() > 1: pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() if layer_type == LayerType.encoder: offset = pipeline_rank * self.num_layers else: num_ranks_in_enc = args.pipeline_model_parallel_split_rank offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers else: offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers if self.num_layers == 0: # When a standalone embedding stage is used (e.g., # args.standalone_embedding_stage == True), virtual pipeline ranks # on pipeline rank 0 will have zero transformer layers assigned to # them. This results in the model's input and output tensors to be # the same, which will cause failure for certain output tensor # optimizations (e.g., pipeline output deallocation). To remedy # this, we assign a 'no-op' layer on these ranks, which will # disconnect the input tensor from the output tensor. self.num_layers = 1 self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ]) else: self.layers = torch.nn.ModuleList( [build_layer(i + 1 + offset) for i in range(self.num_layers)]) if self.post_process: # Final layer norm before output. self.final_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) def _get_layer(self, layer_number): return self.layers[layer_number] def _checkpointed_forward(self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask): """Forward method with activation checkpointing.""" def custom(start, end): def custom_forward(*inputs): x_ = inputs[0] attention_mask = inputs[1] encoder_output = inputs[2] enc_dec_attn_mask = inputs[3] for index in range(start, end): layer = self._get_layer(index) x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask) return x_ return custom_forward if self.activations_checkpoint_method == 'uniform': # Uniformly divide the total number of Transformer layers and # checkpoint the input activation of each divided chunk. # A method to further reduce memory usage reducing checkpoints. l = 0 while l < self.num_layers: hidden_states = parallel_state.checkpoint( custom(l, l + self.activations_checkpoint_num_layers), self.distribute_checkpointed_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) l += self.activations_checkpoint_num_layers elif self.activations_checkpoint_method == 'block': # Checkpoint the input activation of only a set number of individual # Transformer layers and skip the rest. # A method fully use the device memory removing redundant re-computation. for l in range(self.num_layers): if l < self.activations_checkpoint_num_layers: hidden_states = parallel_state.checkpoint( custom(l, l + 1), self.distribute_checkpointed_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: hidden_states = custom(l, l + 1)( hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: raise ValueError("Invalid activation checkpoint method.") return hidden_states def set_input_tensor(self, input_tensor): """Set input tensor to be used instead of forward()'s input. When doing pipeline parallelism the input from the previous stage comes from communication, not from the input, so the model's forward_step_func won't have it. This function is thus used by internal code to bypass the input provided by the forward_step_func""" self.input_tensor = input_tensor def forward(self, hidden_states, attention_mask, retriever_output, retriever_attn_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None): # Checks. if inference_params: assert self.activations_checkpoint_method is None, \ 'inference does not work with activation checkpointing' if self.pre_process: # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. # If the input flag for fp32 residual connection is set, convert for float. if self.fp32_residual_connection: hidden_states = hidden_states.transpose(0, 1).contiguous().float() # Otherwise, leave it as is. else: hidden_states = hidden_states.transpose(0, 1).contiguous() else: # See set_input_tensor() hidden_states = self.input_tensor # Viewless tensor. # - We only need to create a viewless tensor in the case of micro batch # size (mbs) == 1, since in this case, 'hidden_states.transpose()' # above creates a view tensor, and '.contiguous()' is a pass-through. # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating # the need to make it viewless. # # However, we don't explicitly check mbs == 1 here because # make_viewless_tensor() has negligible overhead when its input # is already viewless. # # - For the 'else' case above, calling make_viewless_tensor() here is # likely redundant, since p2p_communication.py (likely originator) # already creates viewless tensors. That said, make_viewless_tensor() # is called here to be future-proof and corner-case-proof. hidden_states = core_utils.make_viewless_tensor( hidden_states, requires_grad = True, keep_graph = True, ) # Transpose encoder output. if encoder_output is not None: encoder_output = encoder_output.transpose(0, 1).contiguous() args = get_args() assert not args.sequence_parallel, "if SP, need rng context." # Forward pass. if self.recompute_granularity == 'full': hidden_states = self._checkpointed_forward(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: for index in range(self.num_layers): layer = self._get_layer(index) if index + 1 in self.P: hidden_states = layer( hidden_states, attention_mask, retriever_output=retriever_output, retriever_attn_mask=retriever_attn_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=inference_params) else: hidden_states = layer( hidden_states, attention_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=inference_params) # Final layer norm. if self.post_process: # Reverting data format change [s b h] --> [b s h]. hidden_states = hidden_states.transpose(0, 1).contiguous() output = self.final_layernorm(hidden_states) else: output = hidden_states return output class ParallelRetroTransformer(MegatronModule): """Standard GPT Transformer class.""" def __init__(self, init_method, output_layer_init_method, layer_type=LayerType.encoder, self_attn_mask_type=AttnMaskType.padding, pre_process=True, post_process=True, drop_path_rate=0.0, retriever=None): super(ParallelRetroTransformer, self).__init__() args = get_args() assert pre_process and post_process, \ "pipeline parallelism un-supported." self.bf16 = args.bf16 self.fp32_residual_connection = args.fp32_residual_connection self.pre_process = pre_process self.post_process = post_process self.input_tensor = None self.drop_path_rate = drop_path_rate # Store activation checkpoiting flag. self.recompute_granularity = args.recompute_granularity self.recompute_method = args.recompute_method self.recompute_num_layers = args.recompute_num_layers self.distribute_saved_activations = \ args.distribute_saved_activations and not args.sequence_parallel self.sequence_parallel = args.sequence_parallel # Number of layers. self.num_layers = _get_num_layers( args, args.model_type == ModelType.encoder_and_decoder) self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)] if args.retro_add_retriever: if args.num_layers == 12: self.P = [6, 9, 12] elif args.num_layers == 24: self.P = np.arange(9, 25, 3).tolist() elif args.num_layers == 40: self.P = np.arange(9, 41, 3).tolist() self.P.append(40) self.retriever = retriever # Transformer layers. assert args.retro_add_retriever def build_layer(layer_number): if layer_number == min(self.P): return ParallelRetroTransformerEncoderLayer( init_method, output_layer_init_method, layer_number, layer_type=layer_type, self_attn_mask_type=self_attn_mask_type, drop_path_rate=self.drop_path_rates[layer_number - 1], retriever=retriever ) elif layer_number in self.P: return ParallelRetroTransformerLayer( init_method, output_layer_init_method, layer_number, layer_type=layer_type, self_attn_mask_type=self_attn_mask_type, drop_path_rate=self.drop_path_rates[layer_number - 1]) else: return ParallelTransformerLayer( init_method, output_layer_init_method, layer_number, layer_type=layer_type, self_attn_mask_type=self_attn_mask_type, drop_path_rate=self.drop_path_rates[layer_number - 1]) if args.virtual_pipeline_model_parallel_size is not None: assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ 'num_layers_per_stage must be divisible by ' \ 'virtual_pipeline_model_parallel_size' assert args.model_type != ModelType.encoder_and_decoder # Number of layers in each model chunk is the number of layers in the stage, # divided by the number of model chunks in a stage. self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of # layers to stages like (each list is a model chunk): # Stage 0: [0] [2] [4] [6] # Stage 1: [1] [3] [5] [7] # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of # layers to stages like (each list is a model chunk): # Stage 0: [0, 1] [4, 5] # Stage 1: [2, 3] [6, 7] offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * ( args.num_layers // args.virtual_pipeline_model_parallel_size) + \ (parallel_state.get_pipeline_model_parallel_rank() * self.num_layers) else: # Each stage gets a contiguous set of layers. if args.model_type == ModelType.encoder_and_decoder and \ parallel_state.get_pipeline_model_parallel_world_size() > 1: pipeline_rank = parallel_state.get_pipeline_model_parallel_rank() if layer_type == LayerType.encoder: offset = pipeline_rank * self.num_layers else: num_ranks_in_enc = args.pipeline_model_parallel_split_rank offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers else: offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers if self.num_layers == 0: # When a standalone embedding stage is used (e.g., # args.standalone_embedding_stage == True), virtual pipeline ranks # on pipeline rank 0 will have zero transformer layers assigned to # them. This results in the model's input and output tensors to be # the same, which will cause failure for certain output tensor # optimizations (e.g., pipeline output deallocation). To remedy # this, we assign a 'no-op' layer on these ranks, which will # disconnect the input tensor from the output tensor. self.num_layers = 1 self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ]) else: self.layers = torch.nn.ModuleList( [build_layer(i + 1 + offset) for i in range(self.num_layers)]) if self.post_process: # Final layer norm before output. self.final_layernorm = LayerNorm( args.hidden_size, eps=args.layernorm_epsilon, no_persist_layer_norm=args.no_persist_layer_norm) def _get_layer(self, layer_number): return self.layers[layer_number] def _checkpointed_forward(self, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask): """Forward method with activation checkpointing.""" def custom(start, end): def custom_forward(*inputs): x_ = inputs[0] attention_mask = inputs[1] encoder_output = inputs[2] enc_dec_attn_mask = inputs[3] for index in range(start, end): layer = self._get_layer(index) x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask) return x_ return custom_forward if self.activations_checkpoint_method == 'uniform': # Uniformly divide the total number of Transformer layers and checkpoint # the input activation of each divided chunk. # A method to further reduce memory usage reducing checkpoints. l = 0 while l < self.num_layers: hidden_states = parallel_state.checkpoint( custom(l, l + self.activations_checkpoint_num_layers), self.distribute_checkpointed_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) l += self.activations_checkpoint_num_layers elif self.activations_checkpoint_method == 'block': # Checkpoint the input activation of only a set number of individual # Transformer layers and skip the rest. # A method fully use the device memory removing redundant re-computation. for l in range(self.num_layers): if l < self.activations_checkpoint_num_layers: hidden_states = parallel_state.checkpoint( custom(l, l + 1), self.distribute_checkpointed_activations, hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: hidden_states = custom(l, l + 1)( hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: raise ValueError("Invalid activation checkpoint method.") return hidden_states def set_input_tensor(self, input_tensor): """Set input tensor to be used instead of forward()'s input. When doing pipeline parallelism the input from the previous stage comes from communication, not from the input, so the model's forward_step_func won't have it. This function is thus used by internal code to bypass the input provided by the forward_step_func""" self.input_tensor = input_tensor def forward(self, hidden_states, attention_mask, retriever_output=None, retriever_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None, inference_params=None): # Checks. if inference_params: assert self.recompute_granularity is None, \ 'inference does not work with activation checkpointing' args = get_args() # Transpose retriever output, to match hidden_states shape. retriever_output = retriever_output.transpose(0, 1).contiguous() # Viewless tensor. # - We only need to create a viewless tensor in the case of micro batch # size (mbs) == 1, since in this case, 'hidden_states.transpose()' # above creates a view tensor, and '.contiguous()' is a pass-through. # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating # the need to make it viewless. # # However, we don't explicitly check mbs == 1 here because # make_viewless_tensor() has negligible overhead when its input # is already viewless. # # - For the 'else' case above, calling make_viewless_tensor() here is # likely redundant, since p2p_communication.py (likely originator) # already creates viewless tensors. That said, make_viewless_tensor() # is called here to be future-proof and corner-case-proof. hidden_states = core_utils.make_viewless_tensor( hidden_states, requires_grad=True, keep_graph=True, ) # Transpose encoder output. if encoder_output is not None: encoder_output = encoder_output.transpose(0, 1).contiguous() # Forward pass. assert not args.sequence_parallel, "if SP, need rng context." if self.recompute_granularity == 'full': hidden_states = self._checkpointed_forward(hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) else: for index in range(self.num_layers): layer = self._get_layer(index) if args.retro_add_retriever and index + 1 == min(self.P): hidden_states, E = layer( hidden_states, attention_mask, retriever_output=retriever_output, retriever_attn_mask=retriever_attn_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=inference_params) elif args.retro_add_retriever and index + 1 in self.P: hidden_states = layer( hidden_states, attention_mask, retriever_output=E, retriever_attn_mask=retriever_attn_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=inference_params) else: hidden_states = layer( hidden_states, attention_mask, encoder_output=encoder_output, enc_dec_attn_mask=enc_dec_attn_mask, inference_params=inference_params) # Final layer norm. output = self.final_layernorm(hidden_states) return output