# coding=utf-8 # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """GPT-2 model.""" import torch import torch.nn.functional as F import mpu def init_method_normal(std=0.02): """Init method based on normal distribution. This is only used for embeddings. The transformer has its own initializer. """ def init_(tensor): return torch.nn.init.normal_(tensor, mean=0.0, std=std) return init_ class GPT2Model(torch.nn.Module): """GPT-2 Language model. The output of the forward method are the logits (parallel or serial depending on the `parallel_output` flag. """ def __init__(self, num_layers, vocab_size, hidden_size, num_attention_heads, embedding_dropout_prob, attention_dropout_prob, output_dropout_prob, max_sequence_length, checkpoint_activations, checkpoint_num_layers=1, parallel_output=True): super(GPT2Model, self).__init__() self.parallel_output = parallel_output init_method = init_method_normal(std=0.02) # Word embeddings (parallel). self.word_embeddings = mpu.VocabParallelEmbedding( vocab_size, hidden_size, init_method=init_method) # Position embedding (serial). self.position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size) # Token type embedding. # Add this as an optional field that can be added through # method call so we can load a pretrain model without # token types and add them as needed. self.tokentype_embeddings = None self.hidden_size = hidden_size # Initialize the position embeddings. init_method(self.position_embeddings.weight) # Embeddings dropout self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) # Transformer self.transformer = mpu.GPT2ParallelTransformer(num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob, checkpoint_activations, checkpoint_num_layers) def add_tokentype_embeddings(self, num_tokentypes): if self.tokentype_embeddings is not None: raise Exception('tokentype embeddings is already initialized') if torch.distributed.get_rank() == 0: print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True) self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size) def forward(self, input_ids, position_ids, attention_mask, layer_past=None, get_present=False, tokentype_ids=None): # Embeddings. words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings if tokentype_ids is not None: assert self.tokentype_embeddings is not None embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) else: assert self.tokentype_embeddings is None # Dropout. embeddings = self.embedding_dropout(embeddings) # Transformer. transformer_output = self.transformer(embeddings, attention_mask, layer_past=layer_past, get_present=get_present) if get_present: transformer_output, presents = transformer_output # Parallel logits. transformer_output_parallel = mpu.copy_to_model_parallel_region( transformer_output) logits_parallel = F.linear(transformer_output_parallel, self.word_embeddings.weight) if self.parallel_output: output = logits_parallel else: output = mpu.gather_from_model_parallel_region(logits_parallel) if get_present: output = [output, presents] return output def gpt2_get_params_for_weight_decay_optimization(module): weight_decay_params = {'params': []} no_weight_decay_params = {'params': [], 'weight_decay': 0.0} for module_ in module.modules(): if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)): no_weight_decay_params['params'].extend( [p for p in list(module_._parameters.values()) if p is not None]) else: weight_decay_params['params'].extend( [p for n, p in list(module_._parameters.items()) if p is not None and n != 'bias']) no_weight_decay_params['params'].extend( [p for n, p in list(module_._parameters.items()) if p is not None and n == 'bias']) return weight_decay_params, no_weight_decay_params