# rewritten, Copyright (c) 2021, Ming Ding. All rights reserved. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Transformer.""" import copy import math import torch import torch.nn.functional as F from deepspeed.runtime.activation_checkpointing.checkpointing import \ non_reentrant_checkpoint as checkpoint from sat import mpu from sat.model.transformer import BaseTransformerLayer from sat.mpu import (ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding, copy_to_model_parallel_region, gather_from_model_parallel_region, get_model_parallel_world_size) from sat.mpu.utils import (divide, gelu, scaled_init_method, sqrt, unscaled_init_method) from sat.ops.layernorm import LayerNorm from sat.transformer_defaults import (HOOKS_DEFAULT, split_tensor_along_last_dim, standard_attention) # checkpoint class GCBaseTransformer(torch.nn.Module): def __init__(self, num_layers, vocab_size, hidden_size, num_attention_heads, max_sequence_length, embedding_dropout_prob=0, attention_dropout_prob=0, output_dropout_prob=0, drop_path=0, checkpoint_activations=False, checkpoint_num_layers=1, checkpoint_skip_layers=0, layernorm_epsilon=1.0e-5, init_method_std=0.02, inner_hidden_size=None, hidden_size_per_attention_head=None, cross_hidden_size_per_attention_head=None, layernorm_order='pre', parallel_output=False, is_decoder=False, cross_attn_hidden_size=None, use_bias=True, use_qkv_bias=False, num_multi_query_heads=0, cross_num_multi_query_heads=0, row_parallel_linear_final_bias=True, activation_func=gelu, is_gated_mlp=False, is_rotary_emb=False, num_experts=1, layernorm=LayerNorm, init_method=None, use_final_layernorm=True, hooks={}, params_dtype=torch.float, skip_init=False, device=torch.device('cpu')): super().__init__() # recording parameters self.hidden_size = hidden_size self.inner_hidden_size = inner_hidden_size self.hidden_size_per_attention_head = hidden_size_per_attention_head self.cross_hidden_size_per_attention_head = cross_hidden_size_per_attention_head self.is_decoder = is_decoder self.cross_attn_hidden_size = cross_attn_hidden_size self.cross_num_multi_query_heads = cross_num_multi_query_heads if not is_decoder and cross_attn_hidden_size is not None: print( 'warning: cross_attn_hidden_size is set but is_decoder is False' ) self.use_bias = use_bias self.use_qkv_bias = use_qkv_bias self.num_multi_query_heads = num_multi_query_heads self.is_gated_mlp = is_gated_mlp self.is_rotary_emb = is_rotary_emb self.num_experts = num_experts self.use_final_layernorm = use_final_layernorm self.layernorm_epsilon = layernorm_epsilon self.parallel_output = parallel_output self.checkpoint_activations = checkpoint_activations self.checkpoint_num_layers = checkpoint_num_layers self.checkpoint_skip_layers = checkpoint_skip_layers assert checkpoint_skip_layers <= num_layers - checkpoint_num_layers, f'checkpoint_skip_layers too large. Please consider remove checkpoint_activations.' self.max_sequence_length = max_sequence_length self.layernorm_order = layernorm_order self.row_parallel_linear_final_bias = row_parallel_linear_final_bias self.hooks = copy.copy(hooks) # hooks will be updated each forward object.__setattr__( self, 'transformer', self) # to give the default hooks the same api as outer hooks # create embedding parameters self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) if vocab_size < 1000: self.word_embeddings = torch.nn.Embedding(vocab_size, hidden_size, dtype=params_dtype, device=device) torch.nn.init.normal_(self.word_embeddings.weight, mean=0.0, std=init_method_std) else: self.word_embeddings = VocabParallelEmbedding( num_embeddings=vocab_size, embedding_dim=hidden_size, params_dtype=params_dtype, skip_init=skip_init, device=device) if self.is_rotary_emb: from sat.model.position_embedding.triton_rotary_embeddings import \ FastRotaryEmbedding self.position_embeddings = FastRotaryEmbedding(hidden_size // num_attention_heads) else: self.position_embeddings = torch.nn.Embedding( max_sequence_length, hidden_size) torch.nn.init.normal_(self.position_embeddings.weight, mean=0.0, std=init_method_std) # create all layers if init_method is None: self.output_layer_init_method = scaled_init_method( init_method_std, num_layers) self.init_method = unscaled_init_method(init_method_std) else: self.output_layer_init_method = init_method self.init_method = init_method def get_layer(layer_id): return BaseTransformerLayer( hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, layernorm_epsilon, self.init_method, layer_id, inner_hidden_size=inner_hidden_size, hidden_size_per_attention_head=hidden_size_per_attention_head, cross_hidden_size_per_attention_head= cross_hidden_size_per_attention_head, output_layer_init_method=self.output_layer_init_method, is_decoder=self.is_decoder, cross_attn_hidden_size=cross_attn_hidden_size, layernorm_order=layernorm_order, layernorm=layernorm, use_bias=use_bias, use_qkv_bias=use_qkv_bias, num_multi_query_heads=num_multi_query_heads, cross_num_multi_query_heads=cross_num_multi_query_heads, row_parallel_linear_final_bias=row_parallel_linear_final_bias, drop_path=drop_path, activation_func=activation_func, is_gated_mlp=is_gated_mlp, num_experts=num_experts, hooks=self.hooks, transformer_pointer=self, params_dtype=params_dtype, skip_init=skip_init, device=device) self.layers = torch.nn.ModuleList( [get_layer(layer_id) for layer_id in range(num_layers)]) # Final layer norm before output. if use_final_layernorm: self.final_layernorm = layernorm(hidden_size, eps=layernorm_epsilon) def forward(self, input_ids, position_ids, attention_mask, *, output_hidden_states=False, **kw_args): # sanity check assert len(input_ids.shape) >= 2 batch_size, query_length = input_ids.shape[:2] if attention_mask is None: # Definition: None means full attention attention_mask = torch.ones(1, 1, device=input_ids.device) elif isinstance(attention_mask, int) and (attention_mask < 0): # Definition: -1 means lower triangular attention mask attention_mask = torch.ones(query_length, query_length, device=input_ids.device).tril() attention_mask = attention_mask.type_as(next(self.parameters())) assert len(attention_mask.shape) == 2 or \ len(attention_mask.shape) == 4 and attention_mask.shape[1] == 1 # initial output_cross_layer might be generated by word/position_embedding_forward output_cross_layer = {} # embedding part if 'word_embedding_forward' in self.hooks: hidden_states = self.hooks['word_embedding_forward']( input_ids, output_cross_layer=output_cross_layer, **kw_args) else: # default hidden_states = HOOKS_DEFAULT['word_embedding_forward']( self, input_ids, output_cross_layer=output_cross_layer, **kw_args) # handle position embedding if 'position_embedding_forward' in self.hooks: position_embeddings = self.hooks['position_embedding_forward']( position_ids, output_cross_layer=output_cross_layer, **kw_args) else: assert len(position_ids.shape) <= 2 assert position_ids.shape[-1] == hidden_states.shape[1], ( position_ids.shape, hidden_states.shape) position_embeddings = HOOKS_DEFAULT['position_embedding_forward']( self, position_ids, output_cross_layer=output_cross_layer, **kw_args) if position_embeddings is not None: hidden_states = hidden_states + position_embeddings hidden_states = self.embedding_dropout(hidden_states) output_per_layers = [] if self.checkpoint_activations: # define custom_forward for checkpointing def custom(start, end, kw_args_index, cross_layer_index): def custom_forward(*inputs): layers_ = self.layers[start:end] x_, mask = inputs[0], inputs[1] # recover kw_args and output_cross_layer flat_inputs = inputs[2:] kw_args, output_cross_layer = {}, {} for k, idx in kw_args_index.items(): kw_args[k] = flat_inputs[idx] for k, idx in cross_layer_index.items(): output_cross_layer[k] = flat_inputs[idx] # ----------------- output_per_layers_part = [] for i, layer in enumerate(layers_): output_this_layer_obj, output_cross_layer_obj = {}, {} if 'layer_forward' in self.hooks: layer_ret = self.hooks['layer_forward']( x_, mask, layer_id=layer.layer_id, **kw_args, position_ids=position_ids, **output_cross_layer, output_this_layer=output_this_layer_obj, output_cross_layer=output_cross_layer_obj) else: layer_ret = layer( x_, mask, layer_id=layer.layer_id, **kw_args, position_ids=position_ids, **output_cross_layer, output_this_layer=output_this_layer_obj, output_cross_layer=output_cross_layer_obj) if isinstance(layer_ret, tuple): layer_ret = layer_ret[0] # for legacy API x_, output_this_layer, output_cross_layer = layer_ret, output_this_layer_obj, output_cross_layer_obj if output_hidden_states: output_this_layer['hidden_states'] = x_ output_per_layers_part.append(output_this_layer) # flatten for re-aggregate keywords outputs flat_outputs = [] for output_this_layer in output_per_layers_part: for k in output_this_layer: # TODO add warning for depth>=2 grad tensors flat_outputs.append(output_this_layer[k]) output_this_layer[k] = len(flat_outputs) - 1 for k in output_cross_layer: flat_outputs.append(output_cross_layer[k]) output_cross_layer[k] = len(flat_outputs) - 1 # -------------------- return (x_, output_per_layers_part, output_cross_layer, *flat_outputs) return custom_forward # prevent to lose requires_grad in checkpointing. # To save memory when only finetuning the final layers, don't use checkpointing. if self.training: hidden_states.requires_grad_(True) l, num_layers = 0, len(self.layers) chunk_length = self.checkpoint_num_layers output_this_layer = [] while l < num_layers: args = [hidden_states, attention_mask] # flatten kw_args and output_cross_layer flat_inputs, kw_args_index, cross_layer_index = [], {}, {} for k, v in kw_args.items(): flat_inputs.append(v) kw_args_index[k] = len(flat_inputs) - 1 for k, v in output_cross_layer.items(): flat_inputs.append(v) cross_layer_index[k] = len(flat_inputs) - 1 # -------------------- if l + self.checkpoint_skip_layers >= num_layers: # no checkpointing hidden_states, output_per_layers_part, output_cross_layer, *flat_outputs = \ custom(l, l + chunk_length, kw_args_index, cross_layer_index)(*args, *flat_inputs) else: hidden_states, output_per_layers_part, output_cross_layer, *flat_outputs = \ checkpoint(custom(l, l + chunk_length, kw_args_index, cross_layer_index), *args, *flat_inputs) # recover output_per_layers_part, output_cross_layer for output_this_layer in output_per_layers_part: for k in output_this_layer: output_this_layer[k] = flat_outputs[ output_this_layer[k]] for k in output_cross_layer: output_cross_layer[k] = flat_outputs[output_cross_layer[k]] # -------------------- output_per_layers.extend(output_per_layers_part) l += chunk_length else: output_this_layer = [] for i, layer in enumerate(self.layers): args = [hidden_states, attention_mask] output_this_layer_obj, output_cross_layer_obj = {}, {} if 'layer_forward' in self.hooks: # customized layer_forward layer_ret = self.hooks['layer_forward']( *args, layer_id=torch.tensor(i), **kw_args, position_ids=position_ids, **output_cross_layer, output_this_layer=output_this_layer_obj, output_cross_layer=output_cross_layer_obj) else: layer_ret = layer( *args, layer_id=torch.tensor(i), **kw_args, position_ids=position_ids, **output_cross_layer, output_this_layer=output_this_layer_obj, output_cross_layer=output_cross_layer_obj) if isinstance(layer_ret, tuple): layer_ret = layer_ret[0] # for legacy API hidden_states, output_this_layer, output_cross_layer = layer_ret, output_this_layer_obj, output_cross_layer_obj if output_hidden_states: output_this_layer['hidden_states'] = hidden_states output_per_layers.append(output_this_layer) # Final layer norm. if self.use_final_layernorm: logits = self.final_layernorm(hidden_states) else: logits = hidden_states logits = copy_to_model_parallel_region(logits) if 'final_forward' in self.hooks: logits_parallel = self.hooks['final_forward']( logits, **kw_args, parallel_output=self.parallel_output) else: logits_parallel = HOOKS_DEFAULT['final_forward']( self, logits, **kw_args, parallel_output=self.parallel_output) outputs = [logits_parallel] outputs.extend(output_per_layers) return outputs