Commit 73af1290 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Raul Puri
Browse files

Major refactoring, combining gpt2 and bert

parent ebbe40cd
# ===========
# base images
# ===========
FROM nvcr.io/nvidia/pytorch:19.05-py3
FROM nvcr.io/nvidia/pytorch:19.09-py3
# ===============
......@@ -27,12 +27,3 @@ RUN pip install --upgrade pip && \
COPY requirements.txt /tmp/
RUN pip install --upgrade --ignore-installed -r /tmp/requirements.txt
# ===========
# latest apex
# ===========
RUN pip uninstall -y apex && \
git clone https://github.com/NVIDIA/apex.git ~/apex && \
cd ~/apex && \
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
Note that as of now you need to have PySOL cloned to the directory here before building the container.
......@@ -29,7 +29,6 @@ from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.learning_rates import AnnealingLR
from megatron.model import GPT2Model
from megatron.model import gpt2_get_params_for_weight_decay_optimization
from megatron.model import DistributedDataParallel as DDP
from megatron import mpu
from apex.optimizers import FusedAdam as Adam
......
......@@ -26,9 +26,8 @@ import argparse
import time
from arguments import get_args
from megatron.utils import Timers
from pretrain_gpt2 import initialize_distributed
from pretrain_gpt2 import set_random_seed
from pretrain_gpt2 import get_train_val_test_data
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from pretrain_gpt2 import get_masks_and_position_ids
from megatron.utils import load_checkpoint
from megatron.data_utils import make_tokenizer
......@@ -96,7 +95,8 @@ def get_batch(context_tokens, args):
tokens,
args.eod_token,
args.reset_position_ids,
args.reset_attention_mask)
args.reset_attention_mask,
False)
return tokens, attention_mask, position_ids
......@@ -361,7 +361,7 @@ def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1-boolean)*val1 + boolean*val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None):
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
model.eval()
with torch.no_grad():
context_length = context_lengths.min().item()
......@@ -384,16 +384,21 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
while context_length <= (maxlen):
if args.recompute:
logits = model(tokens, position_ids, attention_mask)
logits = logits[:, context_length - 1, :]
logits = model(tokens, position_ids, attention_mask, tokentype_ids=type_ids)
logits = logits[:, context_length - 1, :]
else:
types2use = None
if counter == 0:
tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length]
if type_ids is not None:
types2use = type_ids[:, :context_length]
else:
tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(batch_size, -1)
logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_present=True)
if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(batch_size, -1)
logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use)
logits = logits[:, -1].view(batch_size,-1).contiguous()
if args.greedy:
......
......@@ -22,6 +22,9 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
from megatron.module import MegatronModule
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
......@@ -56,7 +59,7 @@ def fp16_to_fp32(val):
return val
return conversion_helper(val, float_conversion)
class FP16_Module(nn.Module):
class FP16_Module(MegatronModule):
def __init__(self, module):
super(FP16_Module, self).__init__()
self.add_module('module', module.half())
......@@ -67,6 +70,11 @@ class FP16_Module(nn.Module):
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
......
......@@ -14,7 +14,6 @@
# limitations under the License.
from .distributed import *
from .gpt2_modeling import gpt2_get_params_for_weight_decay_optimization
from .gpt2_modeling import GPT2Model
from .model import BertModel
from .model import get_params_for_weight_decay_optimization
from .bert_model import BertModel
from .gpt2_model import GPT2Model
from .utils import get_params_for_weight_decay_optimization
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT model."""
import torch
from megatron.module import MegatronModule
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .transformer import LayerNorm
from .utils import gelu
from .utils import get_linear_layer
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask
return attention_scores
def bert_extended_attention_mask(attention_mask, dtype):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Since attention_mask is 1.0 for positions we want to attend and 0.0
# for masked positions, this operation will create a tensor which is
# 0.0 for positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output):
super(BertLMHead, self).__init__()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states,
word_embeddings_weight,
self.parallel_output,
bias=self.bias)
return output
class BertModel(MegatronModule):
"""Bert Language model."""
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
add_binary_head=False,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True):
super(BertModel, self).__init__()
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
init_method = init_method_normal(init_method_std)
self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
attention_mask_func=bert_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers),
residual_connection_post_layernorm=True)
self.lm_head = BertLMHead(
self.language_model.embedding.word_embeddings.weight.size(0),
hidden_size, init_method, layernorm_epsilon, parallel_output)
self._lm_head_key = 'lm_head'
if self.add_binary_head:
self.binary_head = get_linear_layer(hidden_size, 2, init_method)
self._binary_head_key = 'binary_head'
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
if self.add_binary_head:
lm_output, pooled_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
else:
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output.
lm_logits = self.lm_head(
lm_output, self.language_model.embedding.word_embeddings.weight)
if self.add_binary_head:
binary_logits = self.binary_head(pooled_output)
return lm_logits, binary_logits
return lm_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(state_dict[self._lm_head_key],
strict=strict)
if self.add_binary_head:
self.binary_head.load_state_dict(state_dict[self._binary_head_key],
strict=strict)
......@@ -20,8 +20,10 @@ from torch.nn.modules import Module
from torch.autograd import Variable
from megatron import mpu
from megatron.module import MegatronModule
class DistributedDataParallel(Module):
class DistributedDataParallel(MegatronModule):
def __init__(self, module):
super(DistributedDataParallel, self).__init__()
......@@ -86,6 +88,11 @@ class DistributedDataParallel(Module):
return sd
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import torch
from megatron.module import MegatronModule
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
return attention_scores
class GPT2Model(MegatronModule):
"""GPT-2 Language model."""
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True):
super(GPT2Model, self).__init__()
self.parallel_output = parallel_output
self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes,
add_pooler=False,
attention_mask_func=gpt2_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method_normal(init_method_std),
scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers),
residual_connection_post_layernorm=False)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
# Language model.
lm_output = self.language_model(input_ids,
position_ids,
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value:
lm_output, presents = lm_output
# Output.
output = parallel_lm_logits(
lm_output,
self.language_model.embedding.word_embeddings.weight,
self.parallel_output)
if get_key_value:
output = [output, presents]
return output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import torch
import torch.nn.functional as F
from megatron import mpu
def init_method_normal(std=0.02):
"""Init method based on normal distribution.
This is only used for embeddings. The transformer has its
own initializer.
"""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
class GPT2Model(torch.nn.Module):
"""GPT-2 Language model.
The output of the forward method are the logits (parallel or
serial depending on the `parallel_output` flag.
"""
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
parallel_output=True):
super(GPT2Model, self).__init__()
self.parallel_output = parallel_output
init_method = init_method_normal(std=0.02)
# Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding(
vocab_size, hidden_size, init_method=init_method)
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length,
hidden_size)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self.tokentype_embeddings = None
self.hidden_size = hidden_size
# Initialize the position embeddings.
init_method(self.position_embeddings.weight)
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
# Transformer
self.transformer = mpu.GPT2ParallelTransformer(num_layers,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
checkpoint_activations,
checkpoint_num_layers)
def add_tokentype_embeddings(self, num_tokentypes):
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes),
flush=True)
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
self.hidden_size)
def forward(self, input_ids, position_ids, attention_mask,
layer_past=None, get_present=False, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
# Transformer.
transformer_output = self.transformer(embeddings, attention_mask,
layer_past=layer_past,
get_present=get_present)
if get_present:
transformer_output, presents = transformer_output
# Parallel logits.
transformer_output_parallel = mpu.copy_to_model_parallel_region(
transformer_output)
logits_parallel = F.linear(transformer_output_parallel,
self.word_embeddings.weight)
if self.parallel_output:
output = logits_parallel
else:
output = mpu.gather_from_model_parallel_region(logits_parallel)
if get_present:
output = [output, presents]
return output
def gpt2_get_params_for_weight_decay_optimization(module):
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer based language model."""
import torch
import torch.nn.functional as F
from megatron import mpu
from megatron.module import MegatronModule
from .transformer import ParallelTransformer
from .transformer import TransformerHyperparameters
from .utils import gelu
from .utils import get_linear_layer
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = mpu.copy_to_model_parallel_region(input_)
# Matrix multiply.
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
else:
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
# Gather if needed.
if parallel_output:
return logits_parallel
else:
return mpu.gather_from_model_parallel_region(logits_parallel)
def get_language_model(num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
num_tokentypes,
attention_mask_func,
add_pooler,
checkpoint_activations,
checkpoint_num_layers,
layernorm_epsilon,
init_method,
scaled_init_method,
residual_connection_post_layernorm):
# Transformer hyperparameters.
transformer_hparams = TransformerHyperparameters(
hidden_size=hidden_size,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
mlp_activation_func=gelu,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method,
output_layer_init_method=scaled_init_method,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
apply_residual_connection_post_layernorm=residual_connection_post_layernorm)
# Language model.
language_model = TransformerLanguageModel(
transformer_hparams=transformer_hparams,
attention_mask_func=attention_mask_func,
vocab_size=vocab_size,
max_sequence_length=max_sequence_length,
embedding_dropout_prob=embedding_dropout_prob,
num_tokentypes=num_tokentypes,
add_pooler=add_pooler)
# key used for checkpoints.
language_model_key = 'language_model'
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
class Embedding(MegatronModule):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0):
super(Embedding, self).__init__()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
# Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method)
self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = 'tokentype_embeddings'
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes),
flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Word embedding.
if self._word_embeddings_key in state_dict:
state_dict_ = state_dict[self._word_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] \
= state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if 'tokentype_embeddings' in key:
state_dict_[key.split('tokentype_embeddings.')[1]] \
= state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_,
strict=strict)
else:
print('***WARNING*** expected tokentype embeddings in the '
'checkpoint but could not find it', flush=True)
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
transformer_hparams,
attention_mask_func,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
num_tokentypes=0,
add_pooler=False):
super(TransformerLanguageModel, self).__init__()
self.hidden_size = transformer_hparams['hidden_size']
self.num_tokentypes = num_tokentypes
self.init_method = transformer_hparams['init_method']
self.add_pooler = add_pooler
# Embeddings
self.embedding = Embedding(self.hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
self.init_method,
self.num_tokentypes)
self._embedding_key = 'embedding'
# Transformer
self.transformer = ParallelTransformer(
transformer_hparams,
attention_mask_func)
self._transformer_key = 'transformer'
# Pooler
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
# Embeddings.
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids)
# Transformer.
transformer_output = self.transformer(embedding_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if self.add_pooler:
pooled_output = self.pooler(transformer_output,
pooling_sequence_index)
return transformer_output, pooled_output
return transformer_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._transformer_key] \
= self.transformer.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Embedding.
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if '_embeddings' in key:
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Transformer.
if self._transformer_key in state_dict:
state_dict_ = state_dict[self._transformer_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler.
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import math
import torch
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
from megatron import mpu
from megatron.module import MegatronModule
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [b, s, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
"""
class TransformerHyperparameters:
"""Hyperparameters used to build and run the transformer.
Arguments:
hidden_size: hidden size (h)
num_layers: number of layers (l)
num_attention_heads: number of attention heads (n)
attention_dropout_prob: dropout probability for the attention
probabiliies
output_dropout_prob: dropout probability for the output
layers (attention output and mlp output)
mlp_activation_func: activation function for the mlp layer
layernorm_epsilon: tolerance parameters used for layer norm
dividions
init_method: init method used for all weights except layer
norm and output weights
output_layer_init_method: init method for output weights (
attention output and mlp output)
checkpoint_activations: flag to use activation checkpointing
checkpoint_num_layers: number of layers use in each chunk of
activation checkpointing
apply_residual_connection_post_layernorm: Take the post layer-norm
values for resudual connecton. BERT: True, GPT-2: False
"""
def __init__(self,
hidden_size=None,
num_layers=None,
num_attention_heads=None,
attention_dropout_prob=None,
output_dropout_prob=None,
mlp_activation_func=None,
layernorm_epsilon=None,
init_method=None,
output_layer_init_method=None,
checkpoint_activations=None,
checkpoint_num_layers=None,
apply_residual_connection_post_layernorm=None):
self.params_dict = {}
self.params_dict['hidden_size'] = hidden_size
self.params_dict['num_layers'] = num_layers
self.params_dict['num_attention_heads'] = num_attention_heads
self.params_dict['attention_dropout_prob'] = attention_dropout_prob
self.params_dict['output_dropout_prob'] = output_dropout_prob
self.params_dict['mlp_activation_func'] = mlp_activation_func
self.params_dict['layernorm_epsilon'] = layernorm_epsilon
self.params_dict['init_method'] = init_method
self.params_dict['output_layer_init_method'] = output_layer_init_method
self.params_dict['checkpoint_activations'] = checkpoint_activations
self.params_dict['checkpoint_num_layers'] = checkpoint_num_layers
self.params_dict['apply_residual_connection_post_layernorm'] \
= apply_residual_connection_post_layernorm
def __getitem__(self, key):
"""Custom retrieval with error checks."""
try:
value = self.params_dict[key]
except KeyError:
raise Exception(
'could not find {} in transformer hyperparameters'.format(key))
except Exception as e:
print('unexpected error in transformer hyperparameters:', e)
raise Exception()
else:
assert value is not None, \
'parameter value for {} is not set in transformer '\
'hyperparameters'.format(key)
return value
raise Exception('should not be here')
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
"""
def __init__(self, hyperparameters):
super(ParallelMLP, self).__init__()
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
hyperparameters['hidden_size'],
4*hyperparameters['hidden_size'],
gather_output=False,
init_method=hyperparameters['init_method'])
self.activation_func = hyperparameters['mlp_activation_func']
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4*hyperparameters['hidden_size'],
hyperparameters['hidden_size'],
input_is_parallel=True,
init_method=hyperparameters['output_layer_init_method'])
self.dropout = torch.nn.Dropout(hyperparameters['output_dropout_prob'])
def forward(self, hidden_states):
# [b, s, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [b, s, h]
output = self.dense_4h_to_h(intermediate_parallel)
output = self.dropout(output)
return output
class ParallelSelfAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, hyperparameters, attention_mask_func):
super(ParallelSelfAttention, self).__init__()
self.attention_mask_func = attention_mask_func
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(
hyperparameters['hidden_size'], world_size)
self.hidden_size_per_attention_head = mpu.divide(
hyperparameters['hidden_size'],
hyperparameters['num_attention_heads'])
self.num_attention_heads_per_partition = mpu.divide(
hyperparameters['num_attention_heads'], world_size)
# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
hyperparameters['hidden_size'],
3*hyperparameters['hidden_size'],
stride=3,
gather_output=False,
init_method=hyperparameters['init_method'])
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(
hyperparameters['attention_dropout_prob'])
# Output.
self.dense = mpu.RowParallelLinear(
hyperparameters['hidden_size'],
hyperparameters['hidden_size'],
input_is_parallel=True,
init_method=hyperparameters['output_layer_init_method'])
self.output_dropout = torch.nn.Dropout(
hyperparameters['output_dropout_prob'])
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def _get_query_key_value(self, hidden_states):
"""Get query, key, and value and transpose to
get size [b, np, s, hn].
"""
# Attention heads. [b, s, hp]
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
return query_layer, key_layer, value_layer
def _get_unmasked_attention_scores(self, query_layer, key_layer):
"""Unmasked attention scores with size [b, np, s, s]."""
norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head))
# Raw attention scores. [b, np, s, s]
return torch.matmul(query_layer/norm_factor,
key_layer.transpose(-1, -2)/norm_factor)
def _get_attention_probs(self, attention_scores):
"""Attention probabilies with dropout. The output has
the size [b, np, s, s].
"""
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
return attention_probs
def _get_attended_context(self, attention_probs, value_layer):
"""Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer.
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
def _get_output(self, context_layer):
"""Output layer with dropout."""
# Output. [b, s, h]
output = self.dense(context_layer)
output = self.output_dropout(output)
return output
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
# Attention heads. [b, np, s, hn]
query_layer, key_layer, value_layer = self._get_query_key_value(
hidden_states)
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=-2)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=-2)
if get_key_value:
present = (key_layer, value_layer)
# Raw attention scores. [b, np, s, s]
attention_scores = self._get_unmasked_attention_scores(
query_layer, key_layer)
# Apply attention mask. [b, np, s, s]
if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3)-1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]
attention_scores = self.attention_mask_func(attention_scores,
attention_mask)
# Attention probabilities. [b, np, s, s]
attention_probs = self._get_attention_probs(attention_scores)
# Context layer. [b, s, hp]
context_layer = self._get_attended_context(attention_probs, value_layer)
# Output. [b, s, h]
output = self._get_output(context_layer)
if get_key_value:
output = [output, present]
return output
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, hyperparameters, attention_mask_func):
super(ParallelTransformerLayer, self).__init__()
self.apply_residual_connection_post_layernorm \
= hyperparameters['apply_residual_connection_post_layernorm']
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
hyperparameters['hidden_size'],
eps=hyperparameters['layernorm_epsilon'])
# Self attention.
self.attention = ParallelSelfAttention(
hyperparameters,
attention_mask_func)
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(
hyperparameters['hidden_size'],
eps=hyperparameters['layernorm_epsilon'])
# MLP
self.mlp = ParallelMLP(hyperparameters)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value:
attention_output, presents = attention_output
# Residual connection.
if self.apply_residual_connection_post_layernorm:
layernorm_input = layernorm_output + attention_output
else:
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
output = layernorm_output + mlp_output
else:
output = layernorm_input + mlp_output
if get_key_value:
output = [output, presents]
return output
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, hyperparameters, attention_mask_func):
super(ParallelTransformer, self).__init__()
# Store activation checkpoiting flag.
self.checkpoint_activations = hyperparameters['checkpoint_activations']
self.checkpoint_num_layers = hyperparameters['checkpoint_num_layers']
def get_layer():
return ParallelTransformerLayer(
hyperparameters,
attention_mask_func)
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer() for _ in range(hyperparameters['num_layers'])])
# Final layer norm before output.
self.final_layernorm = LayerNorm(
hyperparameters['hidden_size'],
eps=hyperparameters['layernorm_epsilon'])
def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_ = inputs[0]
for layer in layers_:
x_ = layer(x_, inputs[1])
return x_
return custom_forward
l = 0
num_layers = len(self.layers)
while l < num_layers:
hidden_states = mpu.checkpoint(
custom(l, l+self.checkpoint_num_layers),
hidden_states, attention_mask)
l += self.checkpoint_num_layers
return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# Checks
if layer_past is not None:
assert get_key_value, \
'for not None values in layer_past, ' \
'expected get_key_value to be set'
if get_key_value:
assert not self.checkpoint_activations, \
'get_key_value does not work with ' \
'activation checkpointing'
if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask)
else:
if get_key_value:
presents = []
for i, layer in enumerate(self.layers):
past = None
if layer_past is not None:
past = layer_past[i]
hidden_states = layer(hidden_states,
attention_mask,
layer_past=past,
get_key_value=get_key_value)
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
# Final layer norm.
output = self.final_layernorm(hidden_states)
if get_key_value:
output = [output, presents]
return output
......@@ -13,21 +13,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for wrapping BertModel."""
"""Utilities for models."""
import math
import torch
from .modeling import BertConfig
from .modeling import BertForPreTraining, BertForMaskedLM
from .modeling import BertLayerNorm
from .transformer import LayerNorm
def get_params_for_weight_decay_optimization(module):
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x)))
def gelu(x):
return gelu_impl(x)
def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, (BertLayerNorm, torch.nn.LayerNorm)):
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
......@@ -40,51 +78,3 @@ def get_params_for_weight_decay_optimization(module):
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
class BertModel(torch.nn.Module):
def __init__(self, args):
super(BertModel, self).__init__()
if args.pretrained_bert:
self.model = BertForPreTraining.from_pretrained(
args.tokenizer_model_type,
cache_dir=args.cache_dir,
fp32_layernorm=args.fp32_layernorm,
fp32_embedding=args.fp32_embedding,
layernorm_epsilon=args.layernorm_epsilon)
else:
if args.intermediate_size is None:
intermediate_size = 4 * args.hidden_size
else:
intermediate_size = args.intermediate_size
self.config = BertConfig(
args.tokenizer_num_tokens,
hidden_size=args.hidden_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_attention_heads,
intermediate_size=intermediate_size,
hidden_dropout_prob=args.hidden_dropout,
attention_probs_dropout_prob=args.attention_dropout,
max_position_embeddings=args.max_position_embeddings,
type_vocab_size=args.tokenizer_num_type_tokens,
fp32_layernorm=args.fp32_layernorm,
fp32_embedding=args.fp32_embedding,
fp32_tokentypes=args.fp32_tokentypes,
layernorm_epsilon=args.layernorm_epsilon,
deep_init=args.deep_init)
self.model = BertForPreTraining(self.config)
def forward(self, input_tokens, token_type_ids=None,
attention_mask=None, checkpoint_activations=False):
return self.model(
input_tokens, token_type_ids, attention_mask,
checkpoint_activations=checkpoint_activations)
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.model.state_dict(destination=destination, prefix=prefix,
keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
return self.model.load_state_dict(state_dict, strict=strict)
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron Module"""
import torch
class MegatronModule(torch.nn.Module):
"""Megatron specific extentions of torch Module."""
def __init__(self):
super(MegatronModule, self).__init__()
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)
......@@ -46,7 +46,5 @@ from .random import checkpoint
from .random import get_cuda_rng_tracker
from .random import model_parallel_cuda_manual_seed
from .transformer import BertParallelSelfAttention
from .transformer import BertParallelTransformerLayer
from .transformer import GPT2ParallelTransformer
from .transformer import LayerNorm
from .utils import divide
from .utils import split_tensor_along_last_dim
This diff is collapsed.
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain utilities"""
from datetime import datetime
import math
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam
from arguments import get_args
from megatron import mpu
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import enable_adlr_autoresume
from megatron.utils import get_tensorboard_writer
from megatron.utils import initialize_distributed
from megatron.utils import load_checkpoint
from megatron.utils import print_args
from megatron.utils import print_rank_0
from megatron.utils import report_memory
from megatron.utils import save_checkpoint
from megatron.utils import set_random_seed
from megatron.utils import Timers
def run(top_level_message, train_val_test_data_provider,
model_provider, forward_step_func):
"""Main training program.
This function will run the followings in the order provided:
1) get input arguments.
2) initialize distributed and seeds.
3) call train_val_test_data_provider to get train/val/test datasets.
4) setup model, optimizer and lr schedule using the model_provider.
5) train the modle using the forward_step_func.
Arguments:
top_level_message: a meesage to print at the top of the run.
train_val_test_data_provider: a function that takes `args` as input
and returns `train, val, test` dataloaders. Note that args are
passed and can be modified in case we need to use some parameters
later. For example, we can set vocab size using
args.vocab_size = ...
and later use this value in `model_provider`.
model_provider: a function that takes `args` and returns a vanilla
version of the model. By vanilla we mean a simple model on cpu
with no fp16 or ddp.
forward_step_func: a function that takes a `data iterator`, `model`,
`args`, and `timers` and returns a `loss` scalar with a dictionary
with key:values being the info we would like to monitor during
training, for example `lm-loss: value`. We also require that this
function add `batch generator` to the timers class.
"""
# Timer.
timers = Timers()
# Arguments.
args = get_args()
# Tensorboard writer
writer = get_tensorboard_writer(args)
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print(top_level_message, flush=True)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
# Data stuff.
train_data, val_data, test_data = train_val_test_data_provider(args)
# Model, optimizer, and learning rate.
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
args)
# Train, validation, and test data.
train_data_iterator, val_data_iterator, \
test_data_iterator = get_train_val_test_data_iterators(train_data,
val_data,
test_data,
args)
iteration = 0
if args.train_iters > 0:
if args.do_train:
iteration, _ = train(forward_step_func, model,
optimizer, lr_scheduler,
train_data_iterator, val_data_iterator,
timers, args, writer)
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model,
args, writer, iteration,
timers, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer,
lr_scheduler, args)
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
args, None, 0, timers, True)
def get_model(model_provider_func, args):
"""Build the model."""
# Build model on cpu.
model = model_provider_func(args)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
# Wrap model for distributed training."""
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
args.DDP_type = torchDDP
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
return model
if args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
return model
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
return model
def get_optimizer(model, args):
"""Set up the optimizer."""
# Build parameter groups (weight decay and non-decay).
while isinstance(model, (args.DDP_type, FP16_Module)):
model = model.module
param_groups = get_params_for_weight_decay_optimization(model)
# Add model parallel attribute if it is not set.
for param_group in param_groups:
for param in param_group['params']:
if not hasattr(param, 'model_parallel'):
param.model_parallel = False
# Use Adam.
optimizer = Adam(param_groups,
lr=args.lr, weight_decay=args.weight_decay)
# Wrap into fp16 optimizer.
if args.fp16:
optimizer = FP16_Optimizer(optimizer,
static_loss_scale=args.loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale,
dynamic_loss_args={
'scale_window': args.loss_scale_window,
'min_scale':args.min_scale,
'delayed_shift': args.hysteresis})
return optimizer
def get_learning_rate_scheduler(optimizer, args):
"""Build the learning rate scheduler."""
# Add linear learning rate scheduler.
if args.lr_decay_iters is not None:
num_iters = args.lr_decay_iters
else:
num_iters = args.train_iters
num_iters = max(1, num_iters)
init_step = -1
warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iter,
num_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step,
min_lr=args.min_lr,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler
def setup_model_and_optimizer(model_provider_func, args):
"""Setup model and optimizer."""
model = get_model(model_provider_func, args)
optimizer = get_optimizer(model, args)
lr_scheduler = get_learning_rate_scheduler(optimizer, args)
if args.load is not None:
args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args)
else:
args.iteration = 0
return model, optimizer, lr_scheduler
def backward_step(optimizer, model, loss, args, timers):
"""Backward step."""
# Backward pass.
optimizer.zero_grad()
if args.fp16:
optimizer.backward(loss, update_master_grads=False)
else:
loss.backward()
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('allreduce').start()
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop()
# Update master gradients.
if args.fp16:
optimizer.update_master_grads()
# Clipping gradients helps prevent the exploding gradient.
if args.clip_grad > 0:
if not args.fp16:
mpu.clip_grad_norm(model.parameters(), args.clip_grad)
else:
optimizer.clip_master_grads(args.clip_grad)
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
args, timers):
"""Single training step."""
# Forward model for one step.
timers('forward').start()
loss, loss_reduced = forward_step_func(data_iterator, model, args, timers)
timers('forward').stop()
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
backward_step(optimizer, model, loss, args, timers)
timers('backward').stop()
# Update parameters.
timers('optimizer').start()
optimizer.step()
timers('optimizer').stop()
# Update learning rate.
skipped_iter = 0
if not (args.fp16 and optimizer.overflow):
lr_scheduler.step()
else:
skipped_iter = 1
return loss_reduced, skipped_iter
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args, writer):
"""Train the model function."""
# Turn on training mode which enables dropout.
model.train()
# Tracking loss.
total_loss_dict = {}
# Iterations.
iteration = args.iteration
skipped_iters = 0
timers('interval time').start()
report_memory_flag = True
while iteration < args.train_iters:
loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
lr_scheduler,
args, timers)
skipped_iters += skipped_iter
iteration += 1
# Update losses.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
# Logging.
if args.DDP_impl == 'torch':
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator']
else:
timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer',
'batch generator']
learning_rate = optimizer.param_groups[0]['lr']
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
for key in total_loss_dict:
writer.add_scalar(key, total_loss_dict[key], iteration)
if args.fp16:
writer.add_scalar('loss_scale', optimizer.loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed()
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
for key in total_loss_dict:
avg = total_loss_dict[key].item() / args.log_interval
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = 0.0
if args.fp16:
log_string += ' loss scale: {:.1f} |'.format(
optimizer.loss_scale)
print_rank_0(log_string)
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
# Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and \
args.adlr_autoresume:
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args)
# Checkpointing
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model, args,
writer, iteration, timers, False)
if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank()
print('rank: {} | time: {} | exiting the program at iteration {}'.
format(rank, time_str, iteration), flush=True)
exit()
return iteration, skipped_iters
def evaluate(forward_step_func, data_iterator, model,
args, timers, verbose=False):
"""Evaluation."""
# Turn on evaluation mode which disables dropout.
model.eval()
total_loss_dict = {}
with torch.no_grad():
iteration = 0
while iteration < args.eval_iters:
iteration += 1
if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters))
# Forward evaluation.
_, loss_dict = forward_step_func(data_iterator, model,
args, timers)
# Reduce across processes.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
loss_dict[key]
# Move model back to the train mode.
model.train()
for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters
return total_loss_dict
def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model,
args, writer, iteration,
timers, verbose=False):
"""Helper function to evaluate and dump results on screen."""
total_loss_dict = evaluate(forward_step_func, data_iterator, model,
args, timers, verbose)
string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
ppl = math.exp(min(20, total_loss_dict[key].item()))
string += '{} PPL: {:.6E} | '.format(key, ppl)
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('{} value'.format(key),
total_loss_dict[key].item(),
iteration)
writer.add_scalar('{} ppl'.format(key), ppl, iteration)
length = len(string) + 1
print_rank_0('-' * length)
print_rank_0(string)
print_rank_0('-' * length)
def get_train_val_test_data_iterators(train_data, val_data, test_data, args):
"""Build train/validation/test iterators"""
# If resume is on, shift the start iterations.
if args.resume_dataloader:
if train_data is not None:
train_data.batch_sampler.start_iter = args.iteration % \
len(train_data)
print_rank_0('setting training data start iteration to {}'.
format(train_data.batch_sampler.start_iter))
if val_data is not None:
start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_iters
val_data.batch_sampler.start_iter = start_iter_val % \
len(val_data)
print_rank_0('setting validation data start iteration to {}'.
format(val_data.batch_sampler.start_iter))
if train_data is not None:
train_data_iterator = iter(train_data)
else:
train_data_iterator = None
if val_data is not None:
val_data_iterator = iter(val_data)
else:
val_data_iterator = None
if test_data is not None:
test_data_iterator = iter(test_data)
else:
test_data_iterator = None
return train_data_iterator, val_data_iterator, test_data_iterator
......@@ -20,12 +20,29 @@ import random
import time
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.fp16 import FP16_Optimizer
from apex.optimizers import FusedAdam as Adam
from megatron import mpu
from megatron import model
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
def get_tensorboard_writer(args):
writer = None
if args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=args.tensorboard_dir)
except ModuleNotFoundError:
print_rank_0('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.')
writer = None
return writer
def print_rank_0(message):
......@@ -39,18 +56,18 @@ def print_rank_0(message):
def enable_adlr_autoresume(args):
print_rank_0('enabling autoresume ...')
import sys
sys.path.append(os.environ.get('SUBMIT_SCRIPTS','.'))
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try:
from userlib.auto_resume import AutoResume
except:
print_rank_0('ADLR autoresume is not available, exiting ...')
exit(0)
exit()
args.AutoResume = AutoResume
args.AutoResume.init()
def check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args):
lr_scheduler, args):
# Add barrier to ensure consistnecy.
torch.distributed.barrier()
if args.AutoResume.termination_requested():
......@@ -74,6 +91,7 @@ def print_args(args, writer=None):
if writer:
writer.add_text(arg, str(getattr(args, arg)))
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
......@@ -220,24 +238,6 @@ def initialize_distributed(args):
mpu.initialize_model_parallel(args.model_parallel_size)
def wrap_model_for_distributed_training(model, args):
"""Wrap model for distributed training."""
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
args.DDP_type = torchDDP
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
return model
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
return model
else:
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
def set_random_seed(seed):
"""Set random seed for reproducability."""
......@@ -284,7 +284,7 @@ def save_checkpoint(iteration, model, optimizer,
sd = {}
sd['iteration'] = iteration
sd['model'] = model.state_dict()
sd['model'] = model.state_dict_for_save_checkpoint()
# Optimizer stuff.
if not args.no_save_optim:
......@@ -378,7 +378,6 @@ def load_checkpoint(model, optimizer, lr_scheduler, args):
print_rank_0('A metadata file exists but Unable to load iteration '
' from checkpoint {}, exiting'.format(checkpoint_name))
exit()
# Model.
try:
model.load_state_dict(sd['model'])
......@@ -410,7 +409,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, args):
torch.cuda.set_rng_state(sd['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}, exiting. '
print_rank_0('Unable to load optimizer from checkpoint {}, exiting.'
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer '
'state.'.format(checkpoint_name))
......@@ -422,6 +421,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, args):
return iteration
def load_weights(src, dst, dst2src=False):
"""
Loads weights from src to dst via in place copy.
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment