Commit 73af1290 authored by Mohammad Shoeybi's avatar Mohammad Shoeybi Committed by Raul Puri
Browse files

Major refactoring, combining gpt2 and bert

parent ebbe40cd
# =========== # ===========
# base images # base images
# =========== # ===========
FROM nvcr.io/nvidia/pytorch:19.05-py3 FROM nvcr.io/nvidia/pytorch:19.09-py3
# =============== # ===============
...@@ -27,12 +27,3 @@ RUN pip install --upgrade pip && \ ...@@ -27,12 +27,3 @@ RUN pip install --upgrade pip && \
COPY requirements.txt /tmp/ COPY requirements.txt /tmp/
RUN pip install --upgrade --ignore-installed -r /tmp/requirements.txt RUN pip install --upgrade --ignore-installed -r /tmp/requirements.txt
# ===========
# latest apex
# ===========
RUN pip uninstall -y apex && \
git clone https://github.com/NVIDIA/apex.git ~/apex && \
cd ~/apex && \
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
Note that as of now you need to have PySOL cloned to the directory here before building the container.
...@@ -29,7 +29,6 @@ from megatron.fp16 import FP16_Module ...@@ -29,7 +29,6 @@ from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer from megatron.fp16 import FP16_Optimizer
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import GPT2Model from megatron.model import GPT2Model
from megatron.model import gpt2_get_params_for_weight_decay_optimization
from megatron.model import DistributedDataParallel as DDP from megatron.model import DistributedDataParallel as DDP
from megatron import mpu from megatron import mpu
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
......
...@@ -26,9 +26,8 @@ import argparse ...@@ -26,9 +26,8 @@ import argparse
import time import time
from arguments import get_args from arguments import get_args
from megatron.utils import Timers from megatron.utils import Timers
from pretrain_gpt2 import initialize_distributed from megatron.utils import initialize_distributed
from pretrain_gpt2 import set_random_seed from megatron.utils import set_random_seed
from pretrain_gpt2 import get_train_val_test_data
from pretrain_gpt2 import get_masks_and_position_ids from pretrain_gpt2 import get_masks_and_position_ids
from megatron.utils import load_checkpoint from megatron.utils import load_checkpoint
from megatron.data_utils import make_tokenizer from megatron.data_utils import make_tokenizer
...@@ -96,7 +95,8 @@ def get_batch(context_tokens, args): ...@@ -96,7 +95,8 @@ def get_batch(context_tokens, args):
tokens, tokens,
args.eod_token, args.eod_token,
args.reset_position_ids, args.reset_position_ids,
args.reset_attention_mask) args.reset_attention_mask,
False)
return tokens, attention_mask, position_ids return tokens, attention_mask, position_ids
...@@ -361,7 +361,7 @@ def switch(val1, val2, boolean): ...@@ -361,7 +361,7 @@ def switch(val1, val2, boolean):
boolean = boolean.type_as(val1) boolean = boolean.type_as(val1)
return (1-boolean)*val1 + boolean*val2 return (1-boolean)*val1 + boolean*val2
def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None): def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask, position_ids, tokenizer, args, maxlen=None, type_ids=None):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
context_length = context_lengths.min().item() context_length = context_lengths.min().item()
...@@ -384,16 +384,21 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask ...@@ -384,16 +384,21 @@ def sample_sequence_batch(model, context_tokens, context_lengths, attention_mask
while context_length <= (maxlen): while context_length <= (maxlen):
if args.recompute: if args.recompute:
logits = model(tokens, position_ids, attention_mask) logits = model(tokens, position_ids, attention_mask, tokentype_ids=type_ids)
logits = logits[:, context_length - 1, :] logits = logits[:, context_length - 1, :]
else: else:
types2use = None
if counter == 0: if counter == 0:
tokens2use = tokens[:, :context_length] tokens2use = tokens[:, :context_length]
positions2use = position_ids[:, :context_length] positions2use = position_ids[:, :context_length]
if type_ids is not None:
types2use = type_ids[:, :context_length]
else: else:
tokens2use = tokens[:, context_length - 1].view(batch_size, -1) tokens2use = tokens[:, context_length - 1].view(batch_size, -1)
positions2use = position_ids[:, context_length - 1].view(batch_size, -1) positions2use = position_ids[:, context_length - 1].view(batch_size, -1)
logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_present=True) if type_ids is not None:
types2use = type_ids[:, context_length - 1].view(batch_size, -1)
logits, layer_past = model(tokens2use, positions2use, attention_mask, layer_past=layer_past, get_key_value=True, tokentype_ids=types2use)
logits = logits[:, -1].view(batch_size,-1).contiguous() logits = logits[:, -1].view(batch_size,-1).contiguous()
if args.greedy: if args.greedy:
......
...@@ -22,6 +22,9 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors ...@@ -22,6 +22,9 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
from megatron.module import MegatronModule
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
...@@ -56,7 +59,7 @@ def fp16_to_fp32(val): ...@@ -56,7 +59,7 @@ def fp16_to_fp32(val):
return val return val
return conversion_helper(val, float_conversion) return conversion_helper(val, float_conversion)
class FP16_Module(nn.Module): class FP16_Module(MegatronModule):
def __init__(self, module): def __init__(self, module):
super(FP16_Module, self).__init__() super(FP16_Module, self).__init__()
self.add_module('module', module.half()) self.add_module('module', module.half())
...@@ -67,6 +70,11 @@ class FP16_Module(nn.Module): ...@@ -67,6 +70,11 @@ class FP16_Module(nn.Module):
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars) return self.module.state_dict(destination, prefix, keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict) self.module.load_state_dict(state_dict, strict=strict)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
from .distributed import * from .distributed import *
from .gpt2_modeling import gpt2_get_params_for_weight_decay_optimization from .bert_model import BertModel
from .gpt2_modeling import GPT2Model from .gpt2_model import GPT2Model
from .model import BertModel from .utils import get_params_for_weight_decay_optimization
from .model import get_params_for_weight_decay_optimization
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT model."""
import torch
from megatron.module import MegatronModule
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .transformer import LayerNorm
from .utils import gelu
from .utils import get_linear_layer
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask
return attention_scores
def bert_extended_attention_mask(attention_mask, dtype):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Since attention_mask is 1.0 for positions we want to attend and 0.0
# for masked positions, this operation will create a tensor which is
# 0.0 for positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
Arguments:
mpu_vocab_size: model parallel size of vocabulary.
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output):
super(BertLMHead, self).__init__()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states,
word_embeddings_weight,
self.parallel_output,
bias=self.bias)
return output
class BertModel(MegatronModule):
"""Bert Language model."""
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
add_binary_head=False,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True):
super(BertModel, self).__init__()
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
init_method = init_method_normal(init_method_std)
self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
attention_mask_func=bert_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers),
residual_connection_post_layernorm=True)
self.lm_head = BertLMHead(
self.language_model.embedding.word_embeddings.weight.size(0),
hidden_size, init_method, layernorm_epsilon, parallel_output)
self._lm_head_key = 'lm_head'
if self.add_binary_head:
self.binary_head = get_linear_layer(hidden_size, 2, init_method)
self._binary_head_key = 'binary_head'
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
if self.add_binary_head:
lm_output, pooled_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
else:
lm_output = self.language_model(
input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output.
lm_logits = self.lm_head(
lm_output, self.language_model.embedding.word_embeddings.weight)
if self.add_binary_head:
binary_logits = self.binary_head(pooled_output)
return lm_logits, binary_logits
return lm_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._lm_head_key] \
= self.lm_head.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_binary_head:
state_dict_[self._binary_head_key] \
= self.binary_head.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
self.lm_head.load_state_dict(state_dict[self._lm_head_key],
strict=strict)
if self.add_binary_head:
self.binary_head.load_state_dict(state_dict[self._binary_head_key],
strict=strict)
...@@ -20,8 +20,10 @@ from torch.nn.modules import Module ...@@ -20,8 +20,10 @@ from torch.nn.modules import Module
from torch.autograd import Variable from torch.autograd import Variable
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule
class DistributedDataParallel(Module):
class DistributedDataParallel(MegatronModule):
def __init__(self, module): def __init__(self, module):
super(DistributedDataParallel, self).__init__() super(DistributedDataParallel, self).__init__()
...@@ -86,6 +88,11 @@ class DistributedDataParallel(Module): ...@@ -86,6 +88,11 @@ class DistributedDataParallel(Module):
return sd return sd
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict) self.module.load_state_dict(state_dict, strict=strict)
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import torch
from megatron.module import MegatronModule
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
return attention_scores
class GPT2Model(MegatronModule):
"""GPT-2 Language model."""
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True):
super(GPT2Model, self).__init__()
self.parallel_output = parallel_output
self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
num_tokentypes=num_tokentypes,
add_pooler=False,
attention_mask_func=gpt2_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method_normal(init_method_std),
scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers),
residual_connection_post_layernorm=False)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
# Language model.
lm_output = self.language_model(input_ids,
position_ids,
attention_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value:
lm_output, presents = lm_output
# Output.
output = parallel_lm_logits(
lm_output,
self.language_model.embedding.word_embeddings.weight,
self.parallel_output)
if get_key_value:
output = [output, presents]
return output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT-2 model."""
import torch
import torch.nn.functional as F
from megatron import mpu
def init_method_normal(std=0.02):
"""Init method based on normal distribution.
This is only used for embeddings. The transformer has its
own initializer.
"""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
class GPT2Model(torch.nn.Module):
"""GPT-2 Language model.
The output of the forward method are the logits (parallel or
serial depending on the `parallel_output` flag.
"""
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
parallel_output=True):
super(GPT2Model, self).__init__()
self.parallel_output = parallel_output
init_method = init_method_normal(std=0.02)
# Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding(
vocab_size, hidden_size, init_method=init_method)
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length,
hidden_size)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self.tokentype_embeddings = None
self.hidden_size = hidden_size
# Initialize the position embeddings.
init_method(self.position_embeddings.weight)
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
# Transformer
self.transformer = mpu.GPT2ParallelTransformer(num_layers,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
checkpoint_activations,
checkpoint_num_layers)
def add_tokentype_embeddings(self, num_tokentypes):
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes),
flush=True)
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
self.hidden_size)
def forward(self, input_ids, position_ids, attention_mask,
layer_past=None, get_present=False, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
# Transformer.
transformer_output = self.transformer(embeddings, attention_mask,
layer_past=layer_past,
get_present=get_present)
if get_present:
transformer_output, presents = transformer_output
# Parallel logits.
transformer_output_parallel = mpu.copy_to_model_parallel_region(
transformer_output)
logits_parallel = F.linear(transformer_output_parallel,
self.word_embeddings.weight)
if self.parallel_output:
output = logits_parallel
else:
output = mpu.gather_from_model_parallel_region(logits_parallel)
if get_present:
output = [output, presents]
return output
def gpt2_get_params_for_weight_decay_optimization(module):
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer based language model."""
import torch
import torch.nn.functional as F
from megatron import mpu
from megatron.module import MegatronModule
from .transformer import ParallelTransformer
from .transformer import TransformerHyperparameters
from .utils import gelu
from .utils import get_linear_layer
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
input_parallel = mpu.copy_to_model_parallel_region(input_)
# Matrix multiply.
if bias is None:
logits_parallel = F.linear(input_parallel, word_embeddings_weight)
else:
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
# Gather if needed.
if parallel_output:
return logits_parallel
else:
return mpu.gather_from_model_parallel_region(logits_parallel)
def get_language_model(num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
num_tokentypes,
attention_mask_func,
add_pooler,
checkpoint_activations,
checkpoint_num_layers,
layernorm_epsilon,
init_method,
scaled_init_method,
residual_connection_post_layernorm):
# Transformer hyperparameters.
transformer_hparams = TransformerHyperparameters(
hidden_size=hidden_size,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
mlp_activation_func=gelu,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method,
output_layer_init_method=scaled_init_method,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
apply_residual_connection_post_layernorm=residual_connection_post_layernorm)
# Language model.
language_model = TransformerLanguageModel(
transformer_hparams=transformer_hparams,
attention_mask_func=attention_mask_func,
vocab_size=vocab_size,
max_sequence_length=max_sequence_length,
embedding_dropout_prob=embedding_dropout_prob,
num_tokentypes=num_tokentypes,
add_pooler=add_pooler)
# key used for checkpoints.
language_model_key = 'language_model'
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled
class Embedding(MegatronModule):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
init_method,
num_tokentypes=0):
super(Embedding, self).__init__()
self.hidden_size = hidden_size
self.init_method = init_method
self.num_tokentypes = num_tokentypes
# Word embeddings (parallel).
self.word_embeddings = mpu.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method)
self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = 'tokentype_embeddings'
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes),
flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
self.hidden_size)
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Word embedding.
if self._word_embeddings_key in state_dict:
state_dict_ = state_dict[self._word_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] \
= state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if 'tokentype_embeddings' in key:
state_dict_[key.split('tokentype_embeddings.')[1]] \
= state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_,
strict=strict)
else:
print('***WARNING*** expected tokentype embeddings in the '
'checkpoint but could not find it', flush=True)
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
Arguments:
transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
transformer_hparams,
attention_mask_func,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
num_tokentypes=0,
add_pooler=False):
super(TransformerLanguageModel, self).__init__()
self.hidden_size = transformer_hparams['hidden_size']
self.num_tokentypes = num_tokentypes
self.init_method = transformer_hparams['init_method']
self.add_pooler = add_pooler
# Embeddings
self.embedding = Embedding(self.hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
self.init_method,
self.num_tokentypes)
self._embedding_key = 'embedding'
# Transformer
self.transformer = ParallelTransformer(
transformer_hparams,
attention_mask_func)
self._transformer_key = 'transformer'
# Pooler
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
# Embeddings.
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids)
# Transformer.
transformer_output = self.transformer(embedding_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if self.add_pooler:
pooled_output = self.pooler(transformer_output,
pooling_sequence_index)
return transformer_output, pooled_output
return transformer_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._transformer_key] \
= self.transformer.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Embedding.
if self._embedding_key in state_dict:
state_dict_ = state_dict[self._embedding_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if '_embeddings' in key:
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Transformer.
if self._transformer_key in state_dict:
state_dict_ = state_dict[self._transformer_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler.
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
#from torch.utils.checkpoint import checkpoint
from megatron.data_utils.file_utils import cached_path
from megatron import mpu
def normal_init_method(mean, std):
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=mean, std=std)
return init_
def scaled_init_method(mean, std, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = std / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=mean, std=std)
return init_
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
}
CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_bert(model, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
deep_init=False,
fp32_layernorm=False,
fp32_embedding=False,
fp32_tokentypes=False,
layernorm_epsilon=1e-12):
"""Constructs BertConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.deep_init = deep_init
self.fp32_layernorm = fp32_layernorm
self.fp32_embedding = fp32_embedding
self.layernorm_epsilon = layernorm_epsilon
self.fp32_tokentypes = fp32_tokentypes
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
#self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.word_embeddings = mpu.VocabParallelEmbedding(
config.vocab_size, config.hidden_size,
init_method=normal_init_method(mean=0.0,
std=config.initializer_range))
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.fp32_layernorm = config.fp32_layernorm
self.fp32_embedding = config.fp32_embedding
self.fp32_tokentypes = config.fp32_tokentypes
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
if not self.fp32_tokentypes:
embeddings = words_embeddings + position_embeddings + token_type_embeddings
if self.fp32_embedding and not self.fp32_layernorm:
embeddings = embeddings.half()
previous_type = embeddings.type()
if self.fp32_layernorm:
embeddings = embeddings.float()
embeddings = self.LayerNorm(embeddings)
if self.fp32_layernorm:
if self.fp32_embedding:
embeddings = embeddings.half()
else:
embeddings = embeddings.type(previous_type)
else:
embeddings = words_embeddings.float() + position_embeddings.float() + token_type_embeddings.float()
if self.fp32_tokentypes and not self.fp32_layernorm:
embeddings = embeddings.half()
previous_type = embeddings.type()
if self.fp32_layernorm:
embeddings = embeddings.float()
embeddings = self.LayerNorm(embeddings)
if self.fp32_layernorm:
if self.fp32_tokentypes:
embeddings = embeddings.half()
else:
embeddings = embeddings.type(previous_type)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
previous_type = attention_probs.type()
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
if hasattr(config, 'deep_init') and config.deep_init:
init_method = scaled_init_method(mean=0.0,
std=config.initializer_range,
num_layers=config.num_hidden_layers)
else:
init_method = normal_init_method(mean=0.0,
std=config.initializer_range)
self.dense = mpu.RowParallelLinear(
input_size=config.hidden_size,
output_size=config.hidden_size,
bias=True,
input_is_parallel=True,
stride=1,
init_method=init_method)
self.fp32_layernorm = config.fp32_layernorm
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
ln_input = hidden_states + input_tensor
previous_type = ln_input.type()
if self.fp32_layernorm:
ln_input = ln_input.float()
hidden_states = self.LayerNorm(ln_input)
if self.fp32_layernorm:
hidden_states = hidden_states.type(previous_type)
return hidden_states
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = mpu.BertParallelSelfAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
dropout_prob=config.attention_probs_dropout_prob,
output_parallel=True,
init_method=normal_init_method(mean=0.0,
std=config.initializer_range))
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.dense = mpu.ColumnParallelLinear(
input_size=config.hidden_size,
output_size=config.intermediate_size,
bias=True,
gather_output=False,
stride=1,
init_method=normal_init_method(mean=0.0,
std=config.initializer_range))
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
if hasattr(config, 'deep_init') and config.deep_init:
init_method = scaled_init_method(mean=0.0,
std=config.initializer_range,
num_layers=config.num_hidden_layers)
else:
init_method = normal_init_method(mean=0.0,
std=config.initializer_range)
self.dense = mpu.RowParallelLinear(
input_size=config.intermediate_size,
output_size=config.hidden_size,
bias=True,
input_is_parallel=True,
stride=1,
init_method=init_method)
self.fp32_layernorm = config.fp32_layernorm
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
ln_input = hidden_states + input_tensor
previous_type = ln_input.type()
if self.fp32_layernorm:
ln_input = ln_input.float()
hidden_states = self.LayerNorm(ln_input)
if self.fp32_layernorm:
hidden_states = hidden_states.type(previous_type)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
#layer = BertLayer(config)
#self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
# def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
# all_encoder_layers = []
# for layer_module in self.layer:
# hidden_states = layer_module(hidden_states, attention_mask)
# if output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# if not output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# return all_encoder_layers
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, checkpoint_activations=False):
all_encoder_layers = []
def custom(start, end):
def custom_forward(*inputs):
layers = self.layer[start:end]
x_ = inputs[0]
for layer in layers:
x_ = layer(x_, inputs[1])
return x_
return custom_forward
if checkpoint_activations:
l = 0
num_layers = len(self.layer)
chunk_length = 1 #math.ceil(math.sqrt(num_layers))
while l < num_layers:
hidden_states = mpu.checkpoint(custom(l, l+chunk_length), hidden_states, attention_mask*1)
l += chunk_length
# decoder layers
else:
for i,layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers or checkpoint_activations:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layernorm_epsilon)
self.fp32_layernorm = config.fp32_layernorm
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
previous_type = hidden_states.type()
if self.fp32_layernorm:
hidden_states = hidden_states.float()
hidden_states = self.LayerNorm(hidden_states)
if self.fp32_layernorm:
hidden_states = hidden_states.type(previous_type)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
#self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
# bert_model_embedding_weights.size(0),
# bias=False)
self.decoder_weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
self.bias.model_parallel = True
self.fp32_embedding = config.fp32_embedding
self.fp32_layernorm = config.fp32_layernorm
def convert_to_type(tensor):
if self.fp32_embedding:
return tensor.half()
else:
return tensor
self.type_converter = convert_to_type
self.converted = False
def forward(self, hidden_states):
if not self.converted:
self.converted = True
if self.fp32_embedding:
self.transform.half()
if self.fp32_layernorm:
self.transform.LayerNorm.float()
hidden_states = self.transform(self.type_converter(hidden_states))
# hidden_states = self.decoder(hidden_states) + self.bias
hidden_states = mpu.copy_to_model_parallel_region(hidden_states)
hidden_states = F.linear(self.type_converter(hidden_states),
self.type_converter(self.decoder_weight),
self.type_converter(self.bias))
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertOnlyMLMHead, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertOnlyNSPHead(nn.Module):
def __init__(self, config):
super(BertOnlyNSPHead, self).__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class BertPreTrainingHeads(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertPreTrainingHeads, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
for p in self.seq_relationship.parameters():
if p is None:
continue
pooled_output = pooled_output.type_as(p)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class PreTrainedBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedBertModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None,
fp32_layernorm=False, fp32_embedding=False, layernorm_epsilon=1e-12,
fp32_tokentypes=False, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file))
return None
if resolved_archive_file == archive_file:
logger.info("loading archive file {}".format(archive_file))
else:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = BertConfig.from_json_file(config_file)
config.fp32_layernorm = fp32_layernorm
config.fp32_embedding = fp32_embedding
config.layernorm_epsilon = layernorm_epsilon
config.fp32_tokentypes = fp32_tokentypes
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path)
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return model
class BertModel(PreTrainedBertModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
config: a BertConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.BertModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, checkpoint_activations=False):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.encoder.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers,
checkpoint_activations=checkpoint_activations)
sequence_output = encoded_layers[-1]
for p in self.pooler.parameters():
if p is None:
continue
sequence_output = sequence_output.type_as(p)
break
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers or checkpoint_activations:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output
class BertForPreTraining(PreTrainedBertModel):
"""BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads:
- the masked language modeling head, and
- the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
Outputs:
if `masked_lm_labels` and `next_sentence_label` are not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `masked_lm_labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
- the next sentence classification logits of shape [batch_size, 2].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForPreTraining(config)
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForPreTraining, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, checkpoint_activations=False):
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size).float(), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2).float(), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
else:
return prediction_scores, seq_relationship_score
class BertForMaskedLM(PreTrainedBertModel):
"""BERT model with the masked language modeling head.
This module comprises the BERT model followed by the masked language modeling head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
Outputs:
if `masked_lm_labels` is not `None`:
Outputs the masked language modeling loss.
if `masked_lm_labels` is `None`:
Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForMaskedLM(config)
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForMaskedLM, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, checkpoint_activations=False):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
prediction_scores = self.cls(sequence_output)
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return masked_lm_loss
else:
return prediction_scores
class BertForNextSentencePrediction(PreTrainedBertModel):
"""BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
Outputs:
if `next_sentence_label` is not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `next_sentence_label` is `None`:
Outputs the next sentence classification logits of shape [batch_size, 2].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForNextSentencePrediction(config)
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForNextSentencePrediction, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, checkpoint_activations=False):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
seq_relationship_score = self.cls( pooled_output)
if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
return next_sentence_loss
else:
return seq_relationship_score
class BertForSequenceClassification(PreTrainedBertModel):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_labels].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
model = BertForSequenceClassification(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
super(BertForSequenceClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits
class BertForMultipleChoice(PreTrainedBertModel):
"""BERT model for multiple choice tasks.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2
model = BertForMultipleChoice(config, num_choices)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices=2):
super(BertForMultipleChoice, self).__init__(config)
self.num_choices = num_choices
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
return loss
else:
return reshaped_logits
class BertForTokenClassification(PreTrainedBertModel):
"""BERT model for token-level classification.
This module is composed of the BERT model with a linear layer on top of
the full hidden state of the last layer.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_labels].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
model = BertForTokenClassification(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
super(BertForTokenClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
#self.classifier = nn.Linear(config.hidden_size, num_labels)
self.classifier = mpu.RowParallelLinear(
input_size=config.hidden_size,
output_size=num_labels,
bias=True,
input_is_parallel=True,
stride=1,
init_method=normal_init_method(mean=0.0,
std=config.initializer_range))
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
with mpu.get_cuda_rng_tracker().fork():
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits
class BertForQuestionAnswering(PreTrainedBertModel):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with a linear layer on top of
the sequence output that computes start_logits and end_logits
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
Outputs:
if `start_positions` and `end_positions` are not `None`:
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens of shape [batch_size, sequence_length].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForQuestionAnswering(config)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForQuestionAnswering, self).__init__(config)
self.bert = BertModel(config)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
#self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.qa_outputs = mpu.RowParallelLinear(
input_size=config.hidden_size,
output_size=2,
bias=True,
input_is_parallel=True,
stride=1,
init_method=normal_init_method(mean=0.0,
std=config.initializer_range))
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, checkpoint_activations=False):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss
else:
return start_logits, end_logits
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import math
import torch
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
from megatron import mpu
from megatron.module import MegatronModule
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [b, s, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
"""
class TransformerHyperparameters:
"""Hyperparameters used to build and run the transformer.
Arguments:
hidden_size: hidden size (h)
num_layers: number of layers (l)
num_attention_heads: number of attention heads (n)
attention_dropout_prob: dropout probability for the attention
probabiliies
output_dropout_prob: dropout probability for the output
layers (attention output and mlp output)
mlp_activation_func: activation function for the mlp layer
layernorm_epsilon: tolerance parameters used for layer norm
dividions
init_method: init method used for all weights except layer
norm and output weights
output_layer_init_method: init method for output weights (
attention output and mlp output)
checkpoint_activations: flag to use activation checkpointing
checkpoint_num_layers: number of layers use in each chunk of
activation checkpointing
apply_residual_connection_post_layernorm: Take the post layer-norm
values for resudual connecton. BERT: True, GPT-2: False
"""
def __init__(self,
hidden_size=None,
num_layers=None,
num_attention_heads=None,
attention_dropout_prob=None,
output_dropout_prob=None,
mlp_activation_func=None,
layernorm_epsilon=None,
init_method=None,
output_layer_init_method=None,
checkpoint_activations=None,
checkpoint_num_layers=None,
apply_residual_connection_post_layernorm=None):
self.params_dict = {}
self.params_dict['hidden_size'] = hidden_size
self.params_dict['num_layers'] = num_layers
self.params_dict['num_attention_heads'] = num_attention_heads
self.params_dict['attention_dropout_prob'] = attention_dropout_prob
self.params_dict['output_dropout_prob'] = output_dropout_prob
self.params_dict['mlp_activation_func'] = mlp_activation_func
self.params_dict['layernorm_epsilon'] = layernorm_epsilon
self.params_dict['init_method'] = init_method
self.params_dict['output_layer_init_method'] = output_layer_init_method
self.params_dict['checkpoint_activations'] = checkpoint_activations
self.params_dict['checkpoint_num_layers'] = checkpoint_num_layers
self.params_dict['apply_residual_connection_post_layernorm'] \
= apply_residual_connection_post_layernorm
def __getitem__(self, key):
"""Custom retrieval with error checks."""
try:
value = self.params_dict[key]
except KeyError:
raise Exception(
'could not find {} in transformer hyperparameters'.format(key))
except Exception as e:
print('unexpected error in transformer hyperparameters:', e)
raise Exception()
else:
assert value is not None, \
'parameter value for {} is not set in transformer '\
'hyperparameters'.format(key)
return value
raise Exception('should not be here')
class ParallelMLP(MegatronModule):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
"""
def __init__(self, hyperparameters):
super(ParallelMLP, self).__init__()
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
hyperparameters['hidden_size'],
4*hyperparameters['hidden_size'],
gather_output=False,
init_method=hyperparameters['init_method'])
self.activation_func = hyperparameters['mlp_activation_func']
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4*hyperparameters['hidden_size'],
hyperparameters['hidden_size'],
input_is_parallel=True,
init_method=hyperparameters['output_layer_init_method'])
self.dropout = torch.nn.Dropout(hyperparameters['output_dropout_prob'])
def forward(self, hidden_states):
# [b, s, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = self.activation_func(intermediate_parallel)
# [b, s, h]
output = self.dense_4h_to_h(intermediate_parallel)
output = self.dropout(output)
return output
class ParallelSelfAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, hyperparameters, attention_mask_func):
super(ParallelSelfAttention, self).__init__()
self.attention_mask_func = attention_mask_func
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(
hyperparameters['hidden_size'], world_size)
self.hidden_size_per_attention_head = mpu.divide(
hyperparameters['hidden_size'],
hyperparameters['num_attention_heads'])
self.num_attention_heads_per_partition = mpu.divide(
hyperparameters['num_attention_heads'], world_size)
# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
hyperparameters['hidden_size'],
3*hyperparameters['hidden_size'],
stride=3,
gather_output=False,
init_method=hyperparameters['init_method'])
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(
hyperparameters['attention_dropout_prob'])
# Output.
self.dense = mpu.RowParallelLinear(
hyperparameters['hidden_size'],
hyperparameters['hidden_size'],
input_is_parallel=True,
init_method=hyperparameters['output_layer_init_method'])
self.output_dropout = torch.nn.Dropout(
hyperparameters['output_dropout_prob'])
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def _get_query_key_value(self, hidden_states):
"""Get query, key, and value and transpose to
get size [b, np, s, hn].
"""
# Attention heads. [b, s, hp]
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
return query_layer, key_layer, value_layer
def _get_unmasked_attention_scores(self, query_layer, key_layer):
"""Unmasked attention scores with size [b, np, s, s]."""
norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head))
# Raw attention scores. [b, np, s, s]
return torch.matmul(query_layer/norm_factor,
key_layer.transpose(-1, -2)/norm_factor)
def _get_attention_probs(self, attention_scores):
"""Attention probabilies with dropout. The output has
the size [b, np, s, s].
"""
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
return attention_probs
def _get_attended_context(self, attention_probs, value_layer):
"""Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer.
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
def _get_output(self, context_layer):
"""Output layer with dropout."""
# Output. [b, s, h]
output = self.dense(context_layer)
output = self.output_dropout(output)
return output
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
# Attention heads. [b, np, s, hn]
query_layer, key_layer, value_layer = self._get_query_key_value(
hidden_states)
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer),
key_layer), dim=-2)
value_layer = torch.cat((past_value.type_as(value_layer),
value_layer), dim=-2)
if get_key_value:
present = (key_layer, value_layer)
# Raw attention scores. [b, np, s, s]
attention_scores = self._get_unmasked_attention_scores(
query_layer, key_layer)
# Apply attention mask. [b, np, s, s]
if get_key_value:
with torch.no_grad():
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3)-1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
...,
:attention_scores.size(3),
:attention_scores.size(3)]
attention_scores = self.attention_mask_func(attention_scores,
attention_mask)
# Attention probabilities. [b, np, s, s]
attention_probs = self._get_attention_probs(attention_scores)
# Context layer. [b, s, hp]
context_layer = self._get_attended_context(attention_probs, value_layer)
# Output. [b, s, h]
output = self._get_output(context_layer)
if get_key_value:
output = [output, present]
return output
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, hyperparameters, attention_mask_func):
super(ParallelTransformerLayer, self).__init__()
self.apply_residual_connection_post_layernorm \
= hyperparameters['apply_residual_connection_post_layernorm']
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
hyperparameters['hidden_size'],
eps=hyperparameters['layernorm_epsilon'])
# Self attention.
self.attention = ParallelSelfAttention(
hyperparameters,
attention_mask_func)
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(
hyperparameters['hidden_size'],
eps=hyperparameters['layernorm_epsilon'])
# MLP
self.mlp = ParallelMLP(hyperparameters)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
if get_key_value:
attention_output, presents = attention_output
# Residual connection.
if self.apply_residual_connection_post_layernorm:
layernorm_input = layernorm_output + attention_output
else:
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
output = layernorm_output + mlp_output
else:
output = layernorm_input + mlp_output
if get_key_value:
output = [output, presents]
return output
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, hyperparameters, attention_mask_func):
super(ParallelTransformer, self).__init__()
# Store activation checkpoiting flag.
self.checkpoint_activations = hyperparameters['checkpoint_activations']
self.checkpoint_num_layers = hyperparameters['checkpoint_num_layers']
def get_layer():
return ParallelTransformerLayer(
hyperparameters,
attention_mask_func)
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer() for _ in range(hyperparameters['num_layers'])])
# Final layer norm before output.
self.final_layernorm = LayerNorm(
hyperparameters['hidden_size'],
eps=hyperparameters['layernorm_epsilon'])
def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_ = inputs[0]
for layer in layers_:
x_ = layer(x_, inputs[1])
return x_
return custom_forward
l = 0
num_layers = len(self.layers)
while l < num_layers:
hidden_states = mpu.checkpoint(
custom(l, l+self.checkpoint_num_layers),
hidden_states, attention_mask)
l += self.checkpoint_num_layers
return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# Checks
if layer_past is not None:
assert get_key_value, \
'for not None values in layer_past, ' \
'expected get_key_value to be set'
if get_key_value:
assert not self.checkpoint_activations, \
'get_key_value does not work with ' \
'activation checkpointing'
if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask)
else:
if get_key_value:
presents = []
for i, layer in enumerate(self.layers):
past = None
if layer_past is not None:
past = layer_past[i]
hidden_states = layer(hidden_states,
attention_mask,
layer_past=past,
get_key_value=get_key_value)
if get_key_value:
hidden_states, present = hidden_states
presents.append(present)
# Final layer norm.
output = self.final_layernorm(hidden_states)
if get_key_value:
output = [output, presents]
return output
...@@ -13,21 +13,59 @@ ...@@ -13,21 +13,59 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Utilities for wrapping BertModel.""" """Utilities for models."""
import math
import torch import torch
from .modeling import BertConfig from .transformer import LayerNorm
from .modeling import BertForPreTraining, BertForMaskedLM
from .modeling import BertLayerNorm
def get_params_for_weight_decay_optimization(module): def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
init_method(layer.weight)
with torch.no_grad():
layer.bias.zero_()
return layer
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x)))
def gelu(x):
return gelu_impl(x)
def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
weight_decay_params = {'params': []} weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0} no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules(): for module_ in module.modules():
if isinstance(module_, (BertLayerNorm, torch.nn.LayerNorm)): if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend( no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values()) [p for p in list(module_._parameters.values())
if p is not None]) if p is not None])
...@@ -40,51 +78,3 @@ def get_params_for_weight_decay_optimization(module): ...@@ -40,51 +78,3 @@ def get_params_for_weight_decay_optimization(module):
if p is not None and n == 'bias']) if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params return weight_decay_params, no_weight_decay_params
class BertModel(torch.nn.Module):
def __init__(self, args):
super(BertModel, self).__init__()
if args.pretrained_bert:
self.model = BertForPreTraining.from_pretrained(
args.tokenizer_model_type,
cache_dir=args.cache_dir,
fp32_layernorm=args.fp32_layernorm,
fp32_embedding=args.fp32_embedding,
layernorm_epsilon=args.layernorm_epsilon)
else:
if args.intermediate_size is None:
intermediate_size = 4 * args.hidden_size
else:
intermediate_size = args.intermediate_size
self.config = BertConfig(
args.tokenizer_num_tokens,
hidden_size=args.hidden_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_attention_heads,
intermediate_size=intermediate_size,
hidden_dropout_prob=args.hidden_dropout,
attention_probs_dropout_prob=args.attention_dropout,
max_position_embeddings=args.max_position_embeddings,
type_vocab_size=args.tokenizer_num_type_tokens,
fp32_layernorm=args.fp32_layernorm,
fp32_embedding=args.fp32_embedding,
fp32_tokentypes=args.fp32_tokentypes,
layernorm_epsilon=args.layernorm_epsilon,
deep_init=args.deep_init)
self.model = BertForPreTraining(self.config)
def forward(self, input_tokens, token_type_ids=None,
attention_mask=None, checkpoint_activations=False):
return self.model(
input_tokens, token_type_ids, attention_mask,
checkpoint_activations=checkpoint_activations)
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.model.state_dict(destination=destination, prefix=prefix,
keep_vars=keep_vars)
def load_state_dict(self, state_dict, strict=True):
return self.model.load_state_dict(state_dict, strict=strict)
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron Module"""
import torch
class MegatronModule(torch.nn.Module):
"""Megatron specific extentions of torch Module."""
def __init__(self):
super(MegatronModule, self).__init__()
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)
...@@ -46,7 +46,5 @@ from .random import checkpoint ...@@ -46,7 +46,5 @@ from .random import checkpoint
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .random import model_parallel_cuda_manual_seed from .random import model_parallel_cuda_manual_seed
from .transformer import BertParallelSelfAttention from .utils import divide
from .transformer import BertParallelTransformerLayer from .utils import split_tensor_along_last_dim
from .transformer import GPT2ParallelTransformer
from .transformer import LayerNorm
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer."""
import math
import torch
import torch.nn.init as init
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
from .initialize import get_model_parallel_world_size
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .mappings import gather_from_model_parallel_region
from .random import checkpoint
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import split_tensor_along_last_dim
class GPT2ParallelSelfAttention(torch.nn.Module):
"""Parallel self-attention layer for GPT2.
Self-attention layer takes input with size [b, s, h] where b is
the batch size, s is the sequence lenght, and h is the hidden size
and creates output of the same size.
Arguments:
hidden_size: total hidden size of the layer (h).
num_attention_heads: number of attention heads (n). Note that we
require n to be divisible by number of GPUs
used to parallelize the model. Also, we
require hidden size to be divisible by n.
dropout_prob: dropout probability for the attention scores.
init_method: weight initialization.
output_layer_init_method: output layer initialization. If None, use
`init_method`.
We use the following notation:
h: hidden_size
n: num_attention_heads
p: number of partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
"""
def __init__(self, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob,
init_method, output_layer_init_method=None):
super(GPT2ParallelSelfAttention, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
self.hidden_size_per_partition = divide(hidden_size, world_size)
self.hidden_size_per_attention_head = divide(hidden_size,
num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads,
world_size)
# Strided linear layer.
self.query_key_value = ColumnParallelLinear(hidden_size, 3*hidden_size,
stride=3,
gather_output=False,
init_method=init_method)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
# Output.
self.dense = RowParallelLinear(hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, ltor_mask, layer_past=None, get_present=False):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Attention heads. [b, s, hp]
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
if layer_past is not None:
past_key, past_value = layer_past
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-2)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2)
present = (key_layer, value_layer)
# Raw attention scores. [b, np, s, s]
norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head))
attention_scores = torch.matmul(query_layer/norm_factor,
key_layer.transpose(-1, -2)/norm_factor)
# Apply the left to right attention mask.
if get_present:
with torch.no_grad():
if layer_past is not None:
ltor_mask = ltor_mask[...,attention_scores.size(3)-1, :attention_scores.size(3)].unsqueeze(2)
else:
ltor_mask = ltor_mask[...,:attention_scores.size(3), :attention_scores.size(3)]
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# Context layer.
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
output = self.dense(context_layer)
output = self.output_dropout(output)
if get_present:
output = [output, present]
return output
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x)))
def gelu(x):
return gelu_impl(x)
class GPT2ParallelMLP(torch.nn.Module):
"""MLP for GPT2.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform gelu transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
Arguments:
hidden_size: The hidden size of the self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
init_method: initialization method used for the weights. Note
that all biases are initialized to zero and
layernorm weight are initialized to one.
output_layer_init_method: output layer initialization. If None,
use `init_method`.
"""
def __init__(self, hidden_size, output_dropout_prob, init_method,
output_layer_init_method=None):
super(GPT2ParallelMLP, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
# Project to 4h.
self.dense_h_to_4h = ColumnParallelLinear(hidden_size, 4*hidden_size,
gather_output=False,
init_method=init_method)
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
4*hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
self.dropout = torch.nn.Dropout(output_dropout_prob)
def forward(self, hidden_states):
# [b, s, 4hp]
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = gelu(intermediate_parallel)
# [b, s, h]
output = self.dense_4h_to_h(intermediate_parallel)
output = self.dropout(output)
return output
class GPT2ParallelTransformerLayer(torch.nn.Module):
"""A single layer transformer for GPT2.
We use the following notation:
h: hidden size
n: number of attention heads
b: batch size
s: sequence length
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
Arguments:
hidden_size: The hidden size of the self attention.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
init_method: initialization method used for the weights. Note
that all biases are initialized to zero and
layernorm weight are initialized to one.
output_layer_init_method: output layers (attention output and
mlp output) initialization. If None,
use `init_method`.
"""
def __init__(self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
init_method,
output_layer_init_method=None):
super(GPT2ParallelTransformerLayer, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
output_layer_init_method = init_method
# Layernorm on the input data.
self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
# Self attention.
self.attention = GPT2ParallelSelfAttention(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
init_method,
output_layer_init_method=output_layer_init_method)
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(hidden_size,
eps=layernorm_epsilon)
# MLP
self.mlp = GPT2ParallelMLP(
hidden_size,
output_dropout_prob,
init_method,
output_layer_init_method=output_layer_init_method)
def forward(self, hidden_states, ltor_mask, layer_past=None, get_present=False):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output = self.attention(layernorm_output, ltor_mask, layer_past=layer_past, get_present=get_present)
if get_present:
attention_output, presents = attention_output
# Residual connection.
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
# Second residual connection.
output = layernorm_input + mlp_output
if get_present:
output = [output, presents]
return output
def unscaled_init_method(sigma):
"""Init method based on N(0, sigma)."""
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
class GPT2ParallelTransformer(torch.nn.Module):
"""GPT-2 transformer.
This module takes input from embedding layer and it's output can
be used directly by a logit layer. It consists of L (num-layers)
blocks of:
layer norm
self attention
residual connection
layer norm
mlp
residual connection
followed by a final layer norm.
Arguments:
num_layers: Number of transformer layers.
hidden_size: The hidden size of the self attention.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
checkpoint_activations: if True, checkpoint activations.
checkpoint_num_layers: number of layers to checkpoint. This
is basically the chunk size in checkpoitning.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
init_method_std: standard deviation of the init method which has
the form N(0, std).
use_scaled_init_for_output_weights: If Ture use 1/sqrt(2*num_layers)
scaling for the output weights (
output of self attention and mlp).
"""
def __init__(self,
num_layers,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
use_scaled_init_for_output_weights=True):
super(GPT2ParallelTransformer, self).__init__()
# Store activation checkpoiting flag.
self.checkpoint_activations = checkpoint_activations
self.checkpoint_num_layers = checkpoint_num_layers
output_layer_init_method = None
if use_scaled_init_for_output_weights:
output_layer_init_method = scaled_init_method(init_method_std,
num_layers)
def get_layer():
return GPT2ParallelTransformerLayer(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
unscaled_init_method(init_method_std),
output_layer_init_method=output_layer_init_method)
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer() for _ in range(num_layers)])
# Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def forward(self, hidden_states, attention_mask, layer_past=None, get_present=False):
def custom(start, end):
def custom_forward(*inputs):
layers_ = self.layers[start:end]
x_ = inputs[0]
for layer in layers_:
x_ = layer(x_, inputs[1])
return x_
return custom_forward
if self.checkpoint_activations and not get_present:
l = 0
num_layers = len(self.layers)
chunk_length = self.checkpoint_num_layers
while l < num_layers:
hidden_states = checkpoint(custom(l, l+chunk_length),
hidden_states, attention_mask)
l += chunk_length
else:
presents = []
for i, layer in enumerate(self.layers):
past = None
if layer_past is not None:
past = layer_past[i]
hidden_states = layer(hidden_states, attention_mask, layer_past=past, get_present=get_present)
if get_present:
hidden_states, present = hidden_states
presents.append(present)
# Final layer norm.
output = self.final_layernorm(hidden_states)
if get_present:
output = [output, presents]
return output
class BertParallelSelfAttention(torch.nn.Module):
"""Parallel self-attention layer for BERT.
Self-attention layer takes input with size [b, s, h] where b is
the batch size, s is the sequence lenght, and h is the hidden size
and creates output of the same size.
Arguments:
hidden_size: total hidden size of the layer (h).
num_attention_heads: number of attention heads (n). Note that we
require n to be divisible by number of GPUs
used to parallelize the model. Also, we
require hidden size be divisible by n.
dropout_prob: dropout probability for the attention scores.
output_parallel: If true, no all-gather is done on the output and
the output values will be per partition.
We use the following notation:
h: hidden_size
n: num_attention_heads
p: number of partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
"""
def __init__(self, hidden_size, num_attention_heads,
dropout_prob, output_parallel=False,
init_method=init.xavier_normal_):
super(BertParallelSelfAttention, self).__init__()
# Input configuration.
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.dropout_prob = dropout_prob
self.output_parallel = output_parallel
# Per attention head and per partition values.
world_size = get_model_parallel_world_size()
self.hidden_size_per_partition = divide(hidden_size, world_size)
self.hidden_size_per_attention_head = divide(hidden_size,
num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads,
world_size)
# Strided linear layer.
self.query_key_value = ColumnParallelLinear(hidden_size, 3*hidden_size,
stride=3,
gather_output=False,
init_method=init_method)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.dropout = torch.nn.Dropout(dropout_prob)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
# Attention heads. [b, s, hp]
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
# Reshape and transpose [b, np, s, hn]
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
# Raw attention scores. [b, np, s, s]
norm_factor = math.sqrt(math.sqrt(self.hidden_size_per_attention_head))
attention_scores = torch.matmul(query_layer/norm_factor,
key_layer.transpose(-1, -2)/norm_factor)
# Apply the attention mask.
attention_scores += attention_mask
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with get_cuda_rng_tracker().fork():
attention_probs = self.dropout(attention_probs)
# Context layer.
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
if self.output_parallel:
output = context_layer
else:
output = gather_from_model_parallel_region(context_layer)
return output
class BertParallelTransformerOutput(torch.nn.Module):
"""The output layer used after self attention and intermediate
parts of transformer layer."""
def __init__(self, input_size, output_size, dropout_prob,
layernorm_epsilon=1.0e-12, input_is_parallel=False,
init_method=init.xavier_normal_):
super(BertParallelTransformerOutput, self).__init__()
# Components.
self.dense = RowParallelLinear(input_size,
output_size,
input_is_parallel=input_is_parallel,
init_method=init_method)
self.dropout = torch.nn.Dropout(dropout_prob)
self.layernorm = LayerNorm(output_size, eps=layernorm_epsilon)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
layernorm_input = hidden_states + input_tensor
hidden_states = self.layernorm(layernorm_input)
return hidden_states
class BertParallelTransformerLayer(torch.nn.Module):
"""A single layer transformer for Bert.
We use the following notation:
h: hidden size
n: number of attention heads
b: batch size
s: sequence length
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
Arguments:
hidden_size: The hidden size of the self attention.
intermediate_size: size of the intermediate state after
self attention. In both BERT and GPT
this is set to be 4 times the hidden
size.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
intermediate_activation_fn: activation function for output
of intermediate.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
init_method: initialization method used for the weights. Note
that all biases are initialized to zero and
layernorm weight are initialized to one.
"""
def __init__(self,
hidden_size,
intermediate_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
intermediate_activation_fn,
layernorm_epsilon,
init_method=init.xavier_normal_):
super(BertParallelTransformerLayer, self).__init__()
# Self attention.
self.attention = BertParallelSelfAttention(hidden_size,
num_attention_heads,
attention_dropout_prob,
output_parallel=True,
init_method=init_method)
# Self attention output.
self.self_output = BertParallelTransformerOutput(
hidden_size, hidden_size, output_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
input_is_parallel=True,
init_method=init_method)
# Intermediate.
self.intermediate = ColumnParallelLinear(hidden_size, intermediate_size,
gather_output=False,
init_method=init_method)
self.intermediate_activation_fn = intermediate_activation_fn
# Output.
self.output = BertParallelTransformerOutput(
intermediate_size, hidden_size, output_dropout_prob,
layernorm_epsilon=layernorm_epsilon,
input_is_parallel=True,
init_method=init_method)
def forward(self, hidden_states, attention_mask):
# [b, s, hp]
attention_output_parallel = self.attention(hidden_states,
attention_mask)
# [b, s, h]
attention_self_output = self.self_output(attention_output_parallel,
hidden_states)
# [b, s, ip]
intermediate_output_parallel = self.intermediate(attention_self_output)
intermediate_output_parallel = self.intermediate_activation_fn(
intermediate_output_parallel)
# [b, s, h]
layer_output = self.output(intermediate_output_parallel,
attention_self_output)
return layer_output
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain utilities"""
from datetime import datetime
import math
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam
from arguments import get_args
from megatron import mpu
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import enable_adlr_autoresume
from megatron.utils import get_tensorboard_writer
from megatron.utils import initialize_distributed
from megatron.utils import load_checkpoint
from megatron.utils import print_args
from megatron.utils import print_rank_0
from megatron.utils import report_memory
from megatron.utils import save_checkpoint
from megatron.utils import set_random_seed
from megatron.utils import Timers
def run(top_level_message, train_val_test_data_provider,
model_provider, forward_step_func):
"""Main training program.
This function will run the followings in the order provided:
1) get input arguments.
2) initialize distributed and seeds.
3) call train_val_test_data_provider to get train/val/test datasets.
4) setup model, optimizer and lr schedule using the model_provider.
5) train the modle using the forward_step_func.
Arguments:
top_level_message: a meesage to print at the top of the run.
train_val_test_data_provider: a function that takes `args` as input
and returns `train, val, test` dataloaders. Note that args are
passed and can be modified in case we need to use some parameters
later. For example, we can set vocab size using
args.vocab_size = ...
and later use this value in `model_provider`.
model_provider: a function that takes `args` and returns a vanilla
version of the model. By vanilla we mean a simple model on cpu
with no fp16 or ddp.
forward_step_func: a function that takes a `data iterator`, `model`,
`args`, and `timers` and returns a `loss` scalar with a dictionary
with key:values being the info we would like to monitor during
training, for example `lm-loss: value`. We also require that this
function add `batch generator` to the timers class.
"""
# Timer.
timers = Timers()
# Arguments.
args = get_args()
# Tensorboard writer
writer = get_tensorboard_writer(args)
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print(top_level_message, flush=True)
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
# Data stuff.
train_data, val_data, test_data = train_val_test_data_provider(args)
# Model, optimizer, and learning rate.
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,
args)
# Train, validation, and test data.
train_data_iterator, val_data_iterator, \
test_data_iterator = get_train_val_test_data_iterators(train_data,
val_data,
test_data,
args)
iteration = 0
if args.train_iters > 0:
if args.do_train:
iteration, _ = train(forward_step_func, model,
optimizer, lr_scheduler,
train_data_iterator, val_data_iterator,
timers, args, writer)
if args.do_valid:
prefix = 'the end of training for val data'
evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model,
args, writer, iteration,
timers, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer,
lr_scheduler, args)
if args.do_test:
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, forward_step_func,
test_data_iterator, model,
args, None, 0, timers, True)
def get_model(model_provider_func, args):
"""Build the model."""
# Build model on cpu.
model = model_provider_func(args)
# Print number of parameters.
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
# Wrap model for distributed training."""
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
args.DDP_type = torchDDP
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
return model
if args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
return model
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
return model
def get_optimizer(model, args):
"""Set up the optimizer."""
# Build parameter groups (weight decay and non-decay).
while isinstance(model, (args.DDP_type, FP16_Module)):
model = model.module
param_groups = get_params_for_weight_decay_optimization(model)
# Add model parallel attribute if it is not set.
for param_group in param_groups:
for param in param_group['params']:
if not hasattr(param, 'model_parallel'):
param.model_parallel = False
# Use Adam.
optimizer = Adam(param_groups,
lr=args.lr, weight_decay=args.weight_decay)
# Wrap into fp16 optimizer.
if args.fp16:
optimizer = FP16_Optimizer(optimizer,
static_loss_scale=args.loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale,
dynamic_loss_args={
'scale_window': args.loss_scale_window,
'min_scale':args.min_scale,
'delayed_shift': args.hysteresis})
return optimizer
def get_learning_rate_scheduler(optimizer, args):
"""Build the learning rate scheduler."""
# Add linear learning rate scheduler.
if args.lr_decay_iters is not None:
num_iters = args.lr_decay_iters
else:
num_iters = args.train_iters
num_iters = max(1, num_iters)
init_step = -1
warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR(
optimizer,
start_lr=args.lr,
warmup_iter=warmup_iter,
num_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step,
min_lr=args.min_lr,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler
def setup_model_and_optimizer(model_provider_func, args):
"""Setup model and optimizer."""
model = get_model(model_provider_func, args)
optimizer = get_optimizer(model, args)
lr_scheduler = get_learning_rate_scheduler(optimizer, args)
if args.load is not None:
args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args)
else:
args.iteration = 0
return model, optimizer, lr_scheduler
def backward_step(optimizer, model, loss, args, timers):
"""Backward step."""
# Backward pass.
optimizer.zero_grad()
if args.fp16:
optimizer.backward(loss, update_master_grads=False)
else:
loss.backward()
# All-reduce if needed.
if args.DDP_impl == 'local':
timers('allreduce').start()
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop()
# Update master gradients.
if args.fp16:
optimizer.update_master_grads()
# Clipping gradients helps prevent the exploding gradient.
if args.clip_grad > 0:
if not args.fp16:
mpu.clip_grad_norm(model.parameters(), args.clip_grad)
else:
optimizer.clip_master_grads(args.clip_grad)
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
args, timers):
"""Single training step."""
# Forward model for one step.
timers('forward').start()
loss, loss_reduced = forward_step_func(data_iterator, model, args, timers)
timers('forward').stop()
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
backward_step(optimizer, model, loss, args, timers)
timers('backward').stop()
# Update parameters.
timers('optimizer').start()
optimizer.step()
timers('optimizer').stop()
# Update learning rate.
skipped_iter = 0
if not (args.fp16 and optimizer.overflow):
lr_scheduler.step()
else:
skipped_iter = 1
return loss_reduced, skipped_iter
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args, writer):
"""Train the model function."""
# Turn on training mode which enables dropout.
model.train()
# Tracking loss.
total_loss_dict = {}
# Iterations.
iteration = args.iteration
skipped_iters = 0
timers('interval time').start()
report_memory_flag = True
while iteration < args.train_iters:
loss_dict, skipped_iter = train_step(forward_step_func,
train_data_iterator,
model,
optimizer,
lr_scheduler,
args, timers)
skipped_iters += skipped_iter
iteration += 1
# Update losses.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]
# Logging.
if args.DDP_impl == 'torch':
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator']
else:
timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer',
'batch generator']
learning_rate = optimizer.param_groups[0]['lr']
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
for key in total_loss_dict:
writer.add_scalar(key, total_loss_dict[key], iteration)
if args.fp16:
writer.add_scalar('loss_scale', optimizer.loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0:
elapsed_time = timers('interval time').elapsed()
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate: {:.3E} |'.format(learning_rate)
for key in total_loss_dict:
avg = total_loss_dict[key].item() / args.log_interval
log_string += ' {}: {:.6E} |'.format(key, avg)
total_loss_dict[key] = 0.0
if args.fp16:
log_string += ' loss scale: {:.1f} |'.format(
optimizer.loss_scale)
print_rank_0(log_string)
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
# Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and \
args.adlr_autoresume:
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args)
# Checkpointing
if args.save and args.save_interval and \
iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and \
args.do_valid:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, forward_step_func,
val_data_iterator, model, args,
writer, iteration, timers, False)
if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank()
print('rank: {} | time: {} | exiting the program at iteration {}'.
format(rank, time_str, iteration), flush=True)
exit()
return iteration, skipped_iters
def evaluate(forward_step_func, data_iterator, model,
args, timers, verbose=False):
"""Evaluation."""
# Turn on evaluation mode which disables dropout.
model.eval()
total_loss_dict = {}
with torch.no_grad():
iteration = 0
while iteration < args.eval_iters:
iteration += 1
if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters))
# Forward evaluation.
_, loss_dict = forward_step_func(data_iterator, model,
args, timers)
# Reduce across processes.
for key in loss_dict:
total_loss_dict[key] = total_loss_dict.get(key, 0.) + \
loss_dict[key]
# Move model back to the train mode.
model.train()
for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters
return total_loss_dict
def evaluate_and_print_results(prefix, forward_step_func,
data_iterator, model,
args, writer, iteration,
timers, verbose=False):
"""Helper function to evaluate and dump results on screen."""
total_loss_dict = evaluate(forward_step_func, data_iterator, model,
args, timers, verbose)
string = ' validation loss at {} | '.format(prefix)
for key in total_loss_dict:
string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
ppl = math.exp(min(20, total_loss_dict[key].item()))
string += '{} PPL: {:.6E} | '.format(key, ppl)
if writer and torch.distributed.get_rank() == 0:
writer.add_scalar('{} value'.format(key),
total_loss_dict[key].item(),
iteration)
writer.add_scalar('{} ppl'.format(key), ppl, iteration)
length = len(string) + 1
print_rank_0('-' * length)
print_rank_0(string)
print_rank_0('-' * length)
def get_train_val_test_data_iterators(train_data, val_data, test_data, args):
"""Build train/validation/test iterators"""
# If resume is on, shift the start iterations.
if args.resume_dataloader:
if train_data is not None:
train_data.batch_sampler.start_iter = args.iteration % \
len(train_data)
print_rank_0('setting training data start iteration to {}'.
format(train_data.batch_sampler.start_iter))
if val_data is not None:
start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_iters
val_data.batch_sampler.start_iter = start_iter_val % \
len(val_data)
print_rank_0('setting validation data start iteration to {}'.
format(val_data.batch_sampler.start_iter))
if train_data is not None:
train_data_iterator = iter(train_data)
else:
train_data_iterator = None
if val_data is not None:
val_data_iterator = iter(val_data)
else:
val_data_iterator = None
if test_data is not None:
test_data_iterator = iter(test_data)
else:
test_data_iterator = None
return train_data_iterator, val_data_iterator, test_data_iterator
...@@ -20,12 +20,29 @@ import random ...@@ -20,12 +20,29 @@ import random
import time import time
import numpy as np import numpy as np
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.fp16 import FP16_Optimizer from apex.optimizers import FusedAdam as Adam
from megatron import mpu from megatron import mpu
from megatron import model from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
def get_tensorboard_writer(args):
writer = None
if args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir=args.tensorboard_dir)
except ModuleNotFoundError:
print_rank_0('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.')
writer = None
return writer
def print_rank_0(message): def print_rank_0(message):
...@@ -39,18 +56,18 @@ def print_rank_0(message): ...@@ -39,18 +56,18 @@ def print_rank_0(message):
def enable_adlr_autoresume(args): def enable_adlr_autoresume(args):
print_rank_0('enabling autoresume ...') print_rank_0('enabling autoresume ...')
import sys import sys
sys.path.append(os.environ.get('SUBMIT_SCRIPTS','.')) sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try: try:
from userlib.auto_resume import AutoResume from userlib.auto_resume import AutoResume
except: except:
print_rank_0('ADLR autoresume is not available, exiting ...') print_rank_0('ADLR autoresume is not available, exiting ...')
exit(0) exit()
args.AutoResume = AutoResume args.AutoResume = AutoResume
args.AutoResume.init() args.AutoResume.init()
def check_adlr_autoresume_termination(iteration, model, optimizer, def check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args): lr_scheduler, args):
# Add barrier to ensure consistnecy. # Add barrier to ensure consistnecy.
torch.distributed.barrier() torch.distributed.barrier()
if args.AutoResume.termination_requested(): if args.AutoResume.termination_requested():
...@@ -74,6 +91,7 @@ def print_args(args, writer=None): ...@@ -74,6 +91,7 @@ def print_args(args, writer=None):
if writer: if writer:
writer.add_text(arg, str(getattr(args, arg))) writer.add_text(arg, str(getattr(args, arg)))
def print_params_min_max_norm(optimizer, iteration): def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters.""" """Print min, max, and norm of all parameters."""
index = 0 index = 0
...@@ -220,24 +238,6 @@ def initialize_distributed(args): ...@@ -220,24 +238,6 @@ def initialize_distributed(args):
mpu.initialize_model_parallel(args.model_parallel_size) mpu.initialize_model_parallel(args.model_parallel_size)
def wrap_model_for_distributed_training(model, args):
"""Wrap model for distributed training."""
if args.DDP_impl == 'torch':
i = torch.cuda.current_device()
args.DDP_type = torchDDP
model = args.DDP_type(model, device_ids=[i], output_device=i,
process_group=mpu.get_data_parallel_group())
return model
elif args.DDP_impl == 'local':
args.DDP_type = LocalDDP
model = args.DDP_type(model)
return model
else:
print_rank_0('Unknown DDP implementation specified: {}. '
'Exiting.'.format(args.DDP_impl))
exit()
def set_random_seed(seed): def set_random_seed(seed):
"""Set random seed for reproducability.""" """Set random seed for reproducability."""
...@@ -284,7 +284,7 @@ def save_checkpoint(iteration, model, optimizer, ...@@ -284,7 +284,7 @@ def save_checkpoint(iteration, model, optimizer,
sd = {} sd = {}
sd['iteration'] = iteration sd['iteration'] = iteration
sd['model'] = model.state_dict() sd['model'] = model.state_dict_for_save_checkpoint()
# Optimizer stuff. # Optimizer stuff.
if not args.no_save_optim: if not args.no_save_optim:
...@@ -378,7 +378,6 @@ def load_checkpoint(model, optimizer, lr_scheduler, args): ...@@ -378,7 +378,6 @@ def load_checkpoint(model, optimizer, lr_scheduler, args):
print_rank_0('A metadata file exists but Unable to load iteration ' print_rank_0('A metadata file exists but Unable to load iteration '
' from checkpoint {}, exiting'.format(checkpoint_name)) ' from checkpoint {}, exiting'.format(checkpoint_name))
exit() exit()
# Model. # Model.
try: try:
model.load_state_dict(sd['model']) model.load_state_dict(sd['model'])
...@@ -410,7 +409,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, args): ...@@ -410,7 +409,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, args):
torch.cuda.set_rng_state(sd['cuda_rng_state']) torch.cuda.set_rng_state(sd['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states'])
except KeyError: except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}, exiting. ' print_rank_0('Unable to load optimizer from checkpoint {}, exiting.'
'Specify --no-load-optim or --finetune to prevent ' 'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer ' 'attempting to load the optimizer '
'state.'.format(checkpoint_name)) 'state.'.format(checkpoint_name))
...@@ -422,6 +421,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, args): ...@@ -422,6 +421,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, args):
return iteration return iteration
def load_weights(src, dst, dst2src=False): def load_weights(src, dst, dst2src=False):
""" """
Loads weights from src to dst via in place copy. Loads weights from src to dst via in place copy.
......
...@@ -15,166 +15,43 @@ ...@@ -15,166 +15,43 @@
"""Pretrain BERT""" """Pretrain BERT"""
from datetime import datetime
import os
import random
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from arguments import get_args
from configure_data import configure_data from configure_data import configure_data
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.learning_rates import AnnealingLR
from megatron.model import BertModel
from megatron.model import get_params_for_weight_decay_optimization
from megatron.model import gpt2_get_params_for_weight_decay_optimization
from megatron import mpu from megatron import mpu
from apex.optimizers import FusedAdam as Adam from megatron.model import BertModel
from megatron.utils import Timers
from megatron.utils import save_checkpoint
from megatron.utils import load_checkpoint
from megatron.utils import report_memory
from megatron.utils import print_args
from megatron.utils import print_params_min_max_norm
from megatron.utils import print_rank_0 from megatron.utils import print_rank_0
from megatron.utils import enable_adlr_autoresume
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import initialize_distributed
from megatron.utils import set_random_seed
from megatron.utils import wrap_model_for_distributed_training
from megatron.utils import vocab_size_with_padding from megatron.utils import vocab_size_with_padding
from megatron.training import run
def get_model(args): def model_provider(args):
"""Build the model.""" """Build the model."""
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
model = BertModel(args)
if mpu.get_data_parallel_rank() == 0:
print(' > number of parameters on model parallel rank {}: {}'.format(
mpu.get_model_parallel_rank(),
sum([p.nelement() for p in model.parameters()])), flush=True)
# GPU allocation.
model.cuda(torch.cuda.current_device())
# Fp16 conversion.
if args.fp16:
model = FP16_Module(model)
if args.fp32_embedding:
model.module.model.bert.embeddings.word_embeddings.float()
model.module.model.bert.embeddings.position_embeddings.float()
model.module.model.bert.embeddings.token_type_embeddings.float()
if args.fp32_tokentypes:
model.module.model.bert.embeddings.token_type_embeddings.float()
if args.fp32_layernorm:
for name, _module in model.named_modules():
if 'LayerNorm' in name:
_module.float()
# Wrap model for distributed training.
model = wrap_model_for_distributed_training(model, args)
return model
model = BertModel(
num_layers=args.num_layers,
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_attention_heads=args.num_attention_heads,
embedding_dropout_prob=args.hidden_dropout,
attention_dropout_prob=args.attention_dropout,
output_dropout_prob=args.hidden_dropout,
max_sequence_length=args.max_position_embeddings,
checkpoint_activations=args.checkpoint_activations,
checkpoint_num_layers=args.checkpoint_num_layers,
add_binary_head=True,
layernorm_epsilon=args.layernorm_epsilon,
num_tokentypes=args.tokentype_size,
parallel_output=True)
def get_optimizer(model, args): return model
"""Set up the optimizer."""
# Build parameter groups (weight decay and non-decay).
while isinstance(model, (args.DDP_type, FP16_Module)):
model = model.module
layers = model.model.bert.encoder.layer
pooler = model.model.bert.pooler
lmheads = model.model.cls.predictions
nspheads = model.model.cls.seq_relationship
embeddings = model.model.bert.embeddings
param_groups = []
param_groups += list(get_params_for_weight_decay_optimization(layers))
param_groups += list(get_params_for_weight_decay_optimization(pooler))
param_groups += list(get_params_for_weight_decay_optimization(nspheads))
param_groups += list(get_params_for_weight_decay_optimization(embeddings))
param_groups += list(get_params_for_weight_decay_optimization(
lmheads.transform))
param_groups[1]['params'].append(lmheads.bias)
# Add model parallel attribute if it is not set.
for param_group in param_groups:
for param in param_group['params']:
if not hasattr(param, 'model_parallel'):
param.model_parallel = False
# Use Adam.
betas = (0.9, 0.999)
optimizer = Adam(param_groups, betas=betas,
lr=args.lr, weight_decay=args.weight_decay)
# Wrap into fp16 optimizer.
if args.fp16:
optimizer = FP16_Optimizer(optimizer,
static_loss_scale=args.loss_scale,
dynamic_loss_scale=args.dynamic_loss_scale,
dynamic_loss_args={
'scale_window': args.loss_scale_window,
'min_scale':args.min_scale,
'delayed_shift': args.hysteresis})
return optimizer
def get_learning_rate_scheduler(optimizer, args):
"""Build the learning rate scheduler."""
# Add linear learning rate scheduler.
if args.lr_decay_iters is not None:
num_iters = args.lr_decay_iters
else:
num_iters = args.train_iters
init_step = -1
warmup_iter = args.warmup * num_iters
lr_scheduler = AnnealingLR(optimizer,
start_lr=args.lr,
warmup_iter=warmup_iter,
num_iters=num_iters,
decay_style=args.lr_decay_style,
last_iter=init_step,
min_lr=args.min_lr,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler)
return lr_scheduler
def setup_model_and_optimizer(args):
"""Setup model and optimizer."""
model = get_model(args)
optimizer = get_optimizer(model, args)
lr_scheduler = get_learning_rate_scheduler(optimizer, args)
if args.load is not None:
args.iteration = load_checkpoint(model, optimizer, lr_scheduler, args)
else:
args.iteration = 0
return model, optimizer, lr_scheduler
def get_batch(data_iterator, timers): def get_batch(data_iterator, timers):
''' get_batch subdivides the source data into chunks of
length args.seq_length. If source is equal to the example
output of the data loading example, with a seq_length limit
of 2, we'd get the following two Variables for i = 0:
┌ a g m s ┐ ┌ b h n t ┐
└ b h n t ┘ └ c i o u ┘
Note that despite the name of the function, the subdivison of data is not
done along the batch dimension (i.e. dimension 1), since that was handled
by the data loader. The chunks are along dimension 0, corresponding
to the seq_len dimension in the LSTM. A Variable representing an appropriate
shard reset mask of the same dimensions is also returned.
'''
# Items and their type. # Items and their type.
keys = ['text', 'types', 'is_random', 'mask', 'mask_labels', 'pad_mask'] keys = ['text', 'types', 'is_random', 'mask', 'mask_labels', 'pad_mask']
datatype = torch.int64 datatype = torch.int64
...@@ -204,266 +81,32 @@ def forward_step(data_iterator, model, args, timers): ...@@ -204,266 +81,32 @@ def forward_step(data_iterator, model, args, timers):
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
tokens, types, next_sentence, loss_mask, lm_labels, \ tokens, types, next_sentence, loss_mask, lm_labels, padding_mask \
padding_mask = get_batch(data_iterator, timers) = get_batch(data_iterator, timers)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
output, nsp = model(tokens, types, 1-padding_mask, lm_logits, nsp_logits = model(tokens, 1-padding_mask, tokentype_ids=types)
checkpoint_activations=args.checkpoint_activations)
nsp_loss = F.cross_entropy(nsp.view(-1, 2).contiguous().float(), nsp_loss = F.cross_entropy(nsp_logits.view(-1, 2).contiguous().float(),
next_sentence.view(-1).contiguous(), next_sentence.view(-1).contiguous(),
ignore_index=-1) ignore_index=-1)
losses = mpu.vocab_parallel_cross_entropy( lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
output.contiguous().float(), lm_labels.contiguous()) lm_labels.contiguous())
loss_mask = loss_mask.contiguous()
lm_loss = torch.sum( lm_loss = torch.sum(
losses.view(-1) * loss_mask.view(-1).float()) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
return lm_loss, nsp_loss
def backward_step(optimizer, model, lm_loss, nsp_loss, args, timers):
"""Backward step."""
# Total loss.
loss = lm_loss + nsp_loss loss = lm_loss + nsp_loss
# Backward pass. reduced_losses = torch.cat((lm_loss.clone().detach().view(1),
optimizer.zero_grad() nsp_loss.clone().detach().view(1)))
if args.fp16: torch.distributed.all_reduce(reduced_losses)
optimizer.backward(loss, update_master_grads=False) reduced_losses = reduced_losses / torch.distributed.get_world_size()
else:
loss.backward()
# Reduce across processes.
lm_loss_reduced = lm_loss
nsp_loss_reduced = nsp_loss
reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data / args.world_size
if args.DDP_impl == 'local':
timers('allreduce').start()
model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop()
lm_loss_reduced = reduced_losses[0] lm_loss_reduced = reduced_losses[0]
nsp_loss_reduced = reduced_losses[1] nsp_loss_reduced = reduced_losses[1]
# Update master gradients. return loss, {'lm loss': lm_loss_reduced, 'nsp loss': nsp_loss_reduced}
if args.fp16:
optimizer.update_master_grads()
# Clipping gradients helps prevent the exploding gradient.
if args.clip_grad > 0:
if not args.fp16:
mpu.clip_grad_norm(model.parameters(), args.clip_grad)
else:
optimizer.clip_master_grads(args.clip_grad)
return lm_loss_reduced, nsp_loss_reduced
def train_step(data_iterator, model, optimizer, lr_scheduler,
args, timers):
"""Single training step."""
# Forward model for one step.
timers('forward').start()
lm_loss, nsp_loss = forward_step(data_iterator, model,
args, timers)
timers('forward').stop()
# Calculate gradients, reduce across processes, and clip.
timers('backward').start()
lm_loss_reduced, nsp_loss_reduced = backward_step(optimizer, model, lm_loss,
nsp_loss, args, timers)
timers('backward').stop()
# Update parameters.
timers('optimizer').start()
optimizer.step()
timers('optimizer').stop()
# Update learning rate.
skipped_iter = 0
if not (args.fp16 and optimizer.overflow):
lr_scheduler.step()
else:
skipped_iter = 1
return lm_loss_reduced, nsp_loss_reduced, skipped_iter
def train(model, optimizer, lr_scheduler,
train_data_iterator, val_data_iterator, timers, args, writer):
"""Train the model."""
# Turn on training mode which enables dropout.
model.train()
# Tracking loss.
total_lm_loss = 0.0
total_nsp_loss = 0.0
# Iterations.
iteration = args.iteration
skipped_iters = 0
timers('interval time').start()
report_memory_flag = True
while iteration < args.train_iters:
lm_loss, nsp_loss, skipped_iter = train_step(train_data_iterator,
model,
optimizer,
lr_scheduler,
args, timers)
skipped_iters += skipped_iter
iteration += 1
# Update losses.
current_lm_loss = lm_loss.data.detach().float()
current_nsp_loss = nsp_loss.data.detach().float()
total_lm_loss += current_lm_loss
total_nsp_loss += current_nsp_loss
# Logging.
if args.DDP_impl == 'torch':
timers_to_log = ['forward', 'backward', 'optimizer',
'batch generator', 'data loader']
else:
timers_to_log = ['forward', 'backward', 'allreduce', 'optimizer',
'batch generator', 'data loader']
learning_rate = optimizer.param_groups[0]['lr']
if writer and args.rank == 0:
writer.add_scalar('learning_rate', learning_rate, iteration)
writer.add_scalar('lm_loss', current_lm_loss, iteration)
writer.add_scalar('nsp_loss', current_nsp_loss, iteration)
if args.fp16:
writer.add_scalar('loss_scale', optimizer.loss_scale, iteration)
normalizer = iteration % args.log_interval
if normalizer == 0:
normalizer = args.log_interval
timers.write(timers_to_log, writer, iteration,
normalizer=normalizer)
if iteration % args.log_interval == 0:
avg_nsp_loss = total_nsp_loss.item() / args.log_interval
avg_lm_loss = total_lm_loss.item() / args.log_interval
elapsed_time = timers('interval time').elapsed()
if writer and args.rank == 0:
writer.add_scalar('iteration_time',
elapsed_time / args.log_interval, iteration)
log_string = ' iteration {:8d}/{:8d} |'.format(iteration,
args.train_iters)
log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
elapsed_time * 1000.0 / args.log_interval)
log_string += ' learning rate {:.3E} |'.format(learning_rate)
log_string += ' lm loss {:.6E} |'.format(avg_lm_loss)
log_string += ' nsp loss {:.6E} |'.format(avg_nsp_loss)
if args.fp16:
log_string += ' loss scale {:.1f} |'.format(
optimizer.loss_scale)
print_rank_0(log_string)
total_nsp_loss = 0.0
total_lm_loss = 0.0
if report_memory_flag:
report_memory('after {} iterations'.format(iteration))
report_memory_flag = False
timers.log(timers_to_log, normalizer=args.log_interval)
# Autoresume
if (iteration % args.adlr_autoresume_interval == 0) and args.adlr_autoresume:
check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args)
# Checkpointing
if args.save and args.save_interval and iteration % args.save_interval == 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
# Evaluation
if args.eval_interval and iteration % args.eval_interval == 0 and args.do_valid:
prefix = 'iteration {}'.format(iteration)
evaluate_and_print_results(prefix, val_data_iterator, model, args,
writer, iteration, timers, False)
if args.exit_interval and iteration % args.exit_interval == 0:
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank()
print('rank: {} | time: {} | exiting the program at iteration {}'.
format(rank, time_str, iteration), flush=True)
exit()
return iteration, skipped_iters
def evaluate(data_iterator, model, args, timers, verbose = False):
"""Evaluation."""
# Turn on evaluation mode which disables dropout.
model.eval()
total_lm_loss = 0
total_nsp_loss = 0
with torch.no_grad():
iteration = 0
while iteration < args.eval_iters:
iteration += 1
if verbose and iteration % args.log_interval == 0:
print_rank_0('Evaluating iter {}/{}'.format(iteration, args.eval_iters))
# Forward evaluation.
lm_loss, nsp_loss = forward_step(data_iterator, model,
args, timers)
# Reduce across processes.
if isinstance(model, args.DDP_type):
reduced_losses = torch.cat((lm_loss.view(1), nsp_loss.view(1)))
torch.distributed.all_reduce(reduced_losses.data)
reduced_losses.data = reduced_losses.data/args.world_size
lm_loss = reduced_losses[0]
nsp_loss = reduced_losses[1]
total_lm_loss += lm_loss.data.detach().float().item()
total_nsp_loss += nsp_loss.data.detach().float().item()
# Move model back to the train mode.
model.train()
total_lm_loss /= args.eval_iters
total_nsp_loss /= args.eval_iters
return total_lm_loss, total_nsp_loss
def evaluate_and_print_results(prefix, data_iterator, model,
args, writer, iteration,
timers, verbose=False):
"""Helper function to evaluate and dump results on screen."""
lm_loss, nsp_loss = evaluate(data_iterator, model,
args, timers, verbose)
val_loss = lm_loss + nsp_loss
print_rank_0('-' * 100)
string = ' validation loss at {} | '.format(prefix)
string += 'LM loss: {:.6E} | '.format(lm_loss)
string += 'NSP loss: {:.6E} | '.format(nsp_loss)
string += 'total loss: {:.6E}'.format(val_loss)
length = len(string) + 1
print_rank_0('-' * length)
print_rank_0(string)
print_rank_0('-' * length)
if writer and args.rank == 0:
writer.add_scalar('val_lm_loss', lm_loss, iteration)
writer.add_scalar('val_nsp_loss', nsp_loss, iteration)
writer.add_scalar('val_total_loss', val_loss, iteration)
return val_loss
def get_train_val_test_data(args): def get_train_val_test_data(args):
...@@ -497,104 +140,13 @@ def get_train_val_test_data(args): ...@@ -497,104 +140,13 @@ def get_train_val_test_data(args):
args.do_valid = token_counts[3].item() args.do_valid = token_counts[3].item()
args.do_test = token_counts[4].item() args.do_test = token_counts[4].item()
return train_data, val_data, test_data, num_tokens, num_type_tokens args.vocab_size = num_tokens
args.tokentype_size = num_type_tokens
def main():
"""Main training program."""
# Disable CuDNN.
torch.backends.cudnn.enabled = False
# Timer.
timers = Timers()
# Arguments.
args = get_args()
writer = None
if args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir = args.tensorboard_dir)
except ModuleNotFoundError:
print_rank_0('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.')
writer = None
# Pytorch distributed.
initialize_distributed(args)
if torch.distributed.get_rank() == 0:
print('Pretrain BERT model')
print_args(args, writer)
# Autoresume.
torch.distributed.barrier()
if args.adlr_autoresume:
enable_adlr_autoresume(args)
# Random seeds for reproducability.
set_random_seed(args.seed)
# Data stuff.
train_data, val_data, test_data, args.tokenizer_num_tokens, \
args.tokenizer_num_type_tokens = get_train_val_test_data(args)
# Model, optimizer, and learning rate.
model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
if args.resume_dataloader:
if train_data is not None:
train_data.batch_sampler.start_iter = args.iteration % \
len(train_data)
print_rank_0('setting training data start iteration to {}'.
format(train_data.batch_sampler.start_iter))
if val_data is not None:
start_iter_val = (args.iteration // args.eval_interval) * \
args.eval_iters
val_data.batch_sampler.start_iter = start_iter_val % \
len(val_data)
print_rank_0('setting validation data start iteration to {}'.
format(val_data.batch_sampler.start_iter))
if train_data is not None:
train_data_iterator = iter(train_data)
else:
train_data_iterator = None
if val_data is not None:
val_data_iterator = iter(val_data)
else:
val_data_iterator = None
iteration = 0
if args.train_iters > 0:
if args.do_train:
iteration, skipped = train(model, optimizer,
lr_scheduler,
train_data_iterator,
val_data_iterator,
timers, args, writer)
if args.do_valid:
prefix = 'the end of training for val data'
val_loss = evaluate_and_print_results(prefix, val_data_iterator,
model, args, writer, iteration,
timers, False)
if args.save and iteration != 0:
save_checkpoint(iteration, model, optimizer, lr_scheduler, args)
if test_data is not None:
test_data_iterator = iter(test_data)
else:
test_data_iterator = None
if args.do_test: return train_data, val_data, test_data
# Run on test data.
prefix = 'the end of training for test data'
evaluate_and_print_results(prefix, test_data_iterator,
model, args, None, 0, timers, True)
if __name__ == "__main__": if __name__ == "__main__":
main()
run('Pretrain BERT model', get_train_val_test_data,
model_provider, forward_step)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment