Commit 8ddf66c6 authored by sunxx1's avatar sunxx1
Browse files

Merge branch 'hepj-test' into 'main'

修改README,增加训练脚本,完善模型转换代码

See merge request dcutoolkit/deeplearing/dlexamples_new!38
parents 0200794c bedf3c0c
## DL params
export BATCHSIZE=27
export LR=4.0e-4
export GRADIENT_STEPS=1
export MAX_STEPS=8103
export WARMUP_PROPORTION=0.0
export PHASE=2
export MAX_SAMPLES_TERMINATION=4500000
export EXTRA_PARAMS="--unpad"
## System run parms
export DGXNNODES=2
export DGXSYSTEM=$(basename $(readlink -f ${BASH_SOURCE[0]}) | sed 's/^config_//' | sed 's/\.sh$//' )
export WALLTIME=02:00:00
## System config params
export DGXNGPU=8
export DGXSOCKETCORES=24
export DGXNSOCKET=2
export DGXHT=2 # HT is on is 2, HT off is 1
## DL params
export BATCHSIZE=20
export LR=1.7e-3
export GRADIENT_STEPS=1
export MAX_STEPS=990
export WARMUP_PROPORTION=0.0
export OPT_LAMB_BETA_1=0.87
export OPT_LAMB_BETA_2=0.97
export PHASE=2
export MAX_SAMPLES_TERMINATION=6500000
export EXTRA_PARAMS=""
## System run parms
export DGXNNODES=32
export DGXSYSTEM=$(basename $(readlink -f ${BASH_SOURCE[0]}) | sed 's/^config_//' | sed 's/\.sh$//' )
export WALLTIME=01:40:00
## System config params
export DGXNGPU=8
export DGXSOCKETCORES=24
export DGXNSOCKET=2
export DGXHT=2 # HT is on is 2, HT off is 1
......@@ -14,9 +14,7 @@
import torch
import argparse
#from modeling import BertForPreTraining, BertConfig
from model import BertForPreTraining, BertConfig
from modeling import BertForPreTraining, BertConfig
def parse_arguments():
parser = argparse.ArgumentParser()
......
# coding=utf-8
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright 2020 MLBenchmark Group. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Create masked LM/next sentence masked_lm TF examples for BERT."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import random
import tokenization
import tensorflow as tf
import h5py
import numpy as np
hdf5_compression_method = None
flags = tf.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("input_file", None,
"Input raw text file (or comma-separated list of files).")
flags.DEFINE_string(
"output_file", None,
"Output TF example file (or comma-separated list of files).")
flags.DEFINE_string("vocab_file", None,
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
flags.DEFINE_integer("max_predictions_per_seq", 20,
"Maximum number of masked LM predictions per sequence.")
flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
flags.DEFINE_integer(
"dupe_factor", 10,
"Number of times to duplicate the input data (with different masks).")
flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
flags.DEFINE_float(
"short_seq_prob", 0.1,
"Probability of creating sequences which are shorter than the "
"maximum length.")
class TrainingInstance(object):
"""A single training instance (sentence pair)."""
def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
is_random_next):
self.tokens = tokens
self.segment_ids = segment_ids
self.is_random_next = is_random_next
self.masked_lm_positions = masked_lm_positions
self.masked_lm_labels = masked_lm_labels
def __str__(self):
s = ""
s += "tokens: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.tokens]))
s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
s += "is_random_next: %s\n" % self.is_random_next
s += "masked_lm_positions: %s\n" % (" ".join(
[str(x) for x in self.masked_lm_positions]))
s += "masked_lm_labels: %s\n" % (" ".join(
[tokenization.printable_text(x) for x in self.masked_lm_labels]))
s += "\n"
return s
def __repr__(self):
return self.__str__()
def write_instance_to_example_files(instances, tokenizer, max_seq_length,
max_predictions_per_seq, output_files):
"""Create TF example files from `TrainingInstance`s."""
writers = []
h5_writers = []
expected_instances_per_file = len(instances) // len(output_files) + 500 # Over-allocation to avoid resizing
for output_file in output_files:
h5_writers.append({
'handle' : h5py.File(output_file + ".hdf5", 'w'),
'input_ids' : np.zeros([expected_instances_per_file, max_seq_length], dtype="int32"),
'input_mask' : np.zeros([expected_instances_per_file, max_seq_length], dtype="int32"),
'segment_ids' : np.zeros([expected_instances_per_file, max_seq_length], dtype="int32"),
'masked_lm_positions' : np.zeros([expected_instances_per_file, max_predictions_per_seq], dtype="int32"),
'masked_lm_ids' : np.zeros([expected_instances_per_file, max_predictions_per_seq], dtype="int32"),
'next_sentence_labels' : np.zeros(expected_instances_per_file, dtype="int32"),
'len' : 0 })
writer_index = 0
total_written = 0
features_h5 = collections.OrderedDict()
for (inst_index, instance) in enumerate(instances):
input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
input_mask = [1] * len(input_ids)
segment_ids = list(instance.segment_ids)
assert len(input_ids) <= max_seq_length
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
masked_lm_positions = list(instance.masked_lm_positions)
masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
masked_lm_weights = [1.0] * len(masked_lm_ids)
while len(masked_lm_positions) < max_predictions_per_seq:
masked_lm_positions.append(0)
masked_lm_ids.append(0)
masked_lm_weights.append(0.0)
next_sentence_label = 1 if instance.is_random_next else 0
h5_writers[writer_index]['input_ids'][inst_index] = input_ids
h5_writers[writer_index]['input_mask'][inst_index] = input_mask
h5_writers[writer_index]['masked_lm_positions'][inst_index] = masked_lm_positions
h5_writers[writer_index]['masked_lm_ids'][inst_index] = masked_lm_ids
h5_writers[writer_index]['next_sentence_labels'][inst_index] = next_sentence_label
h5_writers[writer_index]['len'] += 1
writer_index = (writer_index + 1) % len(h5_writers)
total_written += 1
if inst_index < 20:
tf.logging.info("*** Example ***")
tf.logging.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in instance.tokens]))
print("saving data")
for h5_writer in h5_writers:
my_size = h5_writer['len']
h5_writer['handle'].create_dataset('input_ids', data=h5_writer['input_ids'][:my_size], dtype='i4', compression=hdf5_compression_method)
h5_writer['handle'].create_dataset('input_mask', data=h5_writer['input_mask'][:my_size], dtype='i1', compression=hdf5_compression_method)
h5_writer['handle'].create_dataset('segment_ids', data=h5_writer['segment_ids'][:my_size], dtype='i1', compression=hdf5_compression_method)
h5_writer['handle'].create_dataset('masked_lm_positions', data=h5_writer['masked_lm_positions'][:my_size], dtype='i4', compression=hdf5_compression_method)
h5_writer['handle'].create_dataset('masked_lm_ids', data=h5_writer['masked_lm_ids'][:my_size], dtype='i4', compression=hdf5_compression_method)
h5_writer['handle'].create_dataset('next_sentence_labels', data=h5_writer['next_sentence_labels'][:my_size], dtype='i1', compression=hdf5_compression_method)
h5_writer['handle'].flush()
h5_writer['handle'].close()
tf.logging.info("Wrote %d total instances", total_written)
def create_int_feature(values):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return feature
def create_float_feature(values):
feature = tf.train.Feature(float_list=tf.train.FloatList(value=list(values)))
return feature
def create_training_instances(input_files, tokenizer, max_seq_length,
dupe_factor, short_seq_prob, masked_lm_prob,
max_predictions_per_seq, rng):
"""Create `TrainingInstance`s from raw text."""
all_documents = [[]]
# Input file format:
# (1) One sentence per line. These should ideally be actual sentences, not
# entire paragraphs or arbitrary spans of text. (Because we use the
# sentence boundaries for the "next sentence prediction" task).
# (2) Blank lines between documents. Document boundaries are needed so
# that the "next sentence prediction" task doesn't span between documents.
for input_file in input_files:
with tf.gfile.GFile(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
if not line:
break
line = line.strip()
# Empty lines are used as document delimiters
if not line:
all_documents.append([])
tokens = tokenizer.tokenize(line)
if tokens:
all_documents[-1].append(tokens)
# Remove empty documents
all_documents = [x for x in all_documents if x]
rng.shuffle(all_documents)
vocab_words = list(tokenizer.vocab.keys())
instances = []
for _ in range(dupe_factor):
for document_index in range(len(all_documents)):
instances.extend(
create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
rng.shuffle(instances)
return instances
def create_instances_from_document(
all_documents, document_index, max_seq_length, short_seq_prob,
masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
"""Creates `TrainingInstance`s for a single document."""
document = all_documents[document_index]
# Account for [CLS], [SEP], [SEP]
max_num_tokens = max_seq_length - 3
# We *usually* want to fill up the entire sequence since we are padding
# to `max_seq_length` anyways, so short sequences are generally wasted
# computation. However, we *sometimes*
# (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter
# sequences to minimize the mismatch between pre-training and fine-tuning.
# The `target_seq_length` is just a rough target however, whereas
# `max_seq_length` is a hard limit.
target_seq_length = max_num_tokens
if rng.random() < short_seq_prob:
target_seq_length = rng.randint(2, max_num_tokens)
# We DON'T just concatenate all of the tokens from a document into a long
# sequence and choose an arbitrary split point because this would make the
# next sentence prediction task too easy. Instead, we split the input into
# segments "A" and "B" based on the actual "sentences" provided by the user
# input.
instances = []
current_chunk = []
current_length = 0
i = 0
while i < len(document):
segment = document[i]
current_chunk.append(segment)
current_length += len(segment)
if i == len(document) - 1 or current_length >= target_seq_length:
if current_chunk:
# `a_end` is how many segments from `current_chunk` go into the `A`
# (first) sentence.
a_end = 1
if len(current_chunk) >= 2:
a_end = rng.randint(1, len(current_chunk) - 1)
tokens_a = []
for j in range(a_end):
tokens_a.extend(current_chunk[j])
tokens_b = []
# Random next
is_random_next = False
if len(current_chunk) == 1 or rng.random() < 0.5:
is_random_next = True
target_b_length = target_seq_length - len(tokens_a)
# This should rarely go for more than one iteration for large
# corpora. However, just to be careful, we try to make sure that
# the random document is not the same as the document
# we're processing.
for _ in range(10):
random_document_index = rng.randint(0, len(all_documents) - 1)
if random_document_index != document_index:
break
random_document = all_documents[random_document_index]
random_start = rng.randint(0, len(random_document) - 1)
for j in range(random_start, len(random_document)):
tokens_b.extend(random_document[j])
if len(tokens_b) >= target_b_length:
break
# We didn't actually use these segments so we "put them back" so
# they don't go to waste.
num_unused_segments = len(current_chunk) - a_end
i -= num_unused_segments
# Actual next
else:
is_random_next = False
for j in range(a_end, len(current_chunk)):
tokens_b.extend(current_chunk[j])
truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
assert len(tokens_a) >= 1
assert len(tokens_b) >= 1
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
(tokens, masked_lm_positions,
masked_lm_labels) = create_masked_lm_predictions(
tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
instance = TrainingInstance(
tokens=tokens,
segment_ids=segment_ids,
is_random_next=is_random_next,
masked_lm_positions=masked_lm_positions,
masked_lm_labels=masked_lm_labels)
instances.append(instance)
current_chunk = []
current_length = 0
i += 1
return instances
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
["index", "label"])
def create_masked_lm_predictions(tokens, masked_lm_prob,
max_predictions_per_seq, vocab_words, rng):
"""Creates the predictions for the masked LM objective."""
cand_indexes = []
for (i, token) in enumerate(tokens):
if token == "[CLS]" or token == "[SEP]":
continue
cand_indexes.append(i)
rng.shuffle(cand_indexes)
output_tokens = list(tokens)
num_to_predict = min(max_predictions_per_seq,
max(1, int(round(len(tokens) * masked_lm_prob))))
masked_lms = []
covered_indexes = set()
for index in cand_indexes:
if len(masked_lms) >= num_to_predict:
break
if index in covered_indexes:
continue
covered_indexes.add(index)
masked_token = None
# 80% of the time, replace with [MASK]
if rng.random() < 0.8:
masked_token = "[MASK]"
else:
# 10% of the time, keep original
if rng.random() < 0.5:
masked_token = tokens[index]
# 10% of the time, replace with random word
else:
masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
output_tokens[index] = masked_token
masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
masked_lms = sorted(masked_lms, key=lambda x: x.index)
masked_lm_positions = []
masked_lm_labels = []
for p in masked_lms:
masked_lm_positions.append(p.index)
masked_lm_labels.append(p.label)
return (output_tokens, masked_lm_positions, masked_lm_labels)
def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
"""Truncates a pair of sequences to a maximum sequence length."""
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_num_tokens:
break
trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
assert len(trunc_tokens) >= 1
# We want to sometimes truncate from the front and sometimes from the
# back to add more randomness and avoid biases.
if rng.random() < 0.5:
del trunc_tokens[0]
else:
trunc_tokens.pop()
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
input_files = []
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.gfile.Glob(input_pattern))
tf.logging.info("*** Reading from input files ***")
for input_file in input_files:
tf.logging.info(" %s", input_file)
rng = random.Random(FLAGS.random_seed)
instances = create_training_instances(
input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
rng)
output_files = FLAGS.output_file.split(",")
tf.logging.info("*** Writing to output files ***")
for output_file in output_files:
tf.logging.info(" %s", output_file)
write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
FLAGS.max_predictions_per_seq, output_files)
if __name__ == "__main__":
flags.mark_flag_as_required("input_file")
flags.mark_flag_as_required("output_file")
flags.mark_flag_as_required("vocab_file")
tf.app.run()
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extract pre-computed feature vectors from a PyTorch BERT model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import logging
import json
import re
import torch
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tokenization import BertTokenizer
from modeling import BertModel
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
class InputExample(object):
def __init__(self, unique_id, text_a, text_b):
self.unique_id = unique_id
self.text_a = text_a
self.text_b = text_b
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids):
self.unique_id = unique_id
self.tokens = tokens
self.input_ids = input_ids
self.input_mask = input_mask
self.input_type_ids = input_type_ids
def convert_examples_to_features(examples, seq_length, tokenizer):
"""Loads a data file into a list of `InputBatch`s."""
features = []
for (ex_index, example) in enumerate(examples):
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
_truncate_seq_pair(tokens_a, tokens_b, seq_length - 3)
else:
# Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > seq_length - 2:
tokens_a = tokens_a[0:(seq_length - 2)]
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambigiously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = []
input_type_ids = []
tokens.append("[CLS]")
input_type_ids.append(0)
for token in tokens_a:
tokens.append(token)
input_type_ids.append(0)
tokens.append("[SEP]")
input_type_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
input_type_ids.append(1)
tokens.append("[SEP]")
input_type_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < seq_length:
input_ids.append(0)
input_mask.append(0)
input_type_ids.append(0)
assert len(input_ids) == seq_length
assert len(input_mask) == seq_length
assert len(input_type_ids) == seq_length
if ex_index < 5:
logger.info("*** Example ***")
logger.info("unique_id: %s" % (example.unique_id))
logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info(
"input_type_ids: %s" % " ".join([str(x) for x in input_type_ids]))
features.append(
InputFeatures(
unique_id=example.unique_id,
tokens=tokens,
input_ids=input_ids,
input_mask=input_mask,
input_type_ids=input_type_ids))
return features
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def read_examples(input_file):
"""Read a list of `InputExample`s from an input file."""
examples = []
unique_id = 0
with open(input_file, "r", encoding='utf-8') as reader:
while True:
line = reader.readline()
if not line:
break
line = line.strip()
text_a = None
text_b = None
m = re.match(r"^(.*) \|\|\| (.*)$", line)
if m is None:
text_a = line
else:
text_a = m.group(1)
text_b = m.group(2)
examples.append(
InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b))
unique_id += 1
return examples
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--input_file", default=None, type=str, required=True)
parser.add_argument("--output_file", default=None, type=str, required=True)
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
## Other parameters
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
"than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
parser.add_argument("--local_rank",
type=int,
default=-1,
help = "local_rank for distributed training on gpus")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
args = parser.parse_args()
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
n_gpu = torch.cuda.device_count()
else:
device = torch.device("cuda", args.local_rank)
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1)))
layer_indexes = [int(x) for x in args.layers.split(",")]
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
examples = read_examples(args.input_file)
features = convert_examples_to_features(
examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer)
unique_id_to_feature = {}
for feature in features:
unique_id_to_feature[feature.unique_id] = feature
model = BertModel.from_pretrained(args.bert_model)
model.to(device)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
elif n_gpu > 1:
model = torch.nn.DataParallel(model)
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
eval_sampler = DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size)
model.eval()
with open(args.output_file, "w", encoding='utf-8') as writer:
for input_ids, input_mask, example_indices in eval_dataloader:
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)
all_encoder_layers = all_encoder_layers
for b, example_index in enumerate(example_indices):
feature = features[example_index.item()]
unique_id = int(feature.unique_id)
# feature = unique_id_to_feature[unique_id]
output_json = collections.OrderedDict()
output_json["linex_index"] = unique_id
all_out_features = []
for (i, token) in enumerate(feature.tokens):
all_layers = []
for (j, layer_index) in enumerate(layer_indexes):
layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy()
layer_output = layer_output[b]
layers = collections.OrderedDict()
layers["index"] = layer_index
layers["values"] = [
round(x.item(), 6) for x in layer_output[i]
]
all_layers.append(layers)
out_features = collections.OrderedDict()
out_features["token"] = token
out_features["layers"] = all_layers
all_out_features.append(out_features)
output_json["features"] = all_out_features
writer.write(json.dumps(output_json) + "\n")
if __name__ == "__main__":
main()
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
from __future__ import (absolute_import, division, print_function, unicode_literals)
import json
import logging
import os
import shutil
import tempfile
from functools import wraps
from hashlib import sha256
import sys
from io import open
import boto3
import requests
from botocore.exceptions import ClientError
from tqdm import tqdm
try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert'))
except AttributeError:
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def url_to_filename(url, etag=None):
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode('utf-8')
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']
return url, etag
def cached_path(url_or_filename, cache_dir=None, from_tf=False):
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename)
# if not os.path.exists(url_or_filename):
# raise ValueError("Local cached file does not exist: {}".format(parsed))
if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif from_tf and os.path.exists(url_or_filename + ".meta"):
# TF checkpoint exists
return url_or_filename
elif parsed.scheme == '':
# File, but it doesn't exist.
raise EnvironmentError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url, temp_file):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url, temp_file):
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url, cache_dir=None):
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError("HEAD request failed for url {} with status code {}"
.format(url, response.status_code))
etag = response.headers.get("ETag")
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise ValueError("local cached file {} doesn't exist".format(cache_path))
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
else:
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w', encoding="utf-8") as meta_file:
json.dump(meta, meta_file)
logger.info("removing temp file %s", temp_file.name)
return cache_path
def read_set_from_file(filename):
'''
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
'''
collection = set()
with open(filename, 'r', encoding='utf-8') as file_:
for line in file_:
collection.add(line.rstrip())
return collection
def get_file_extension(path, dot=True, lower=True):
ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:]
return ext.lower() if lower else ext
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERT inference script. Does not depend on dataset. """
from __future__ import absolute_import, division, print_function
import argparse
import collections
import json
import logging
import math
import os
import random
import sys
from io import open
import numpy as np
import torch
from tqdm import tqdm, trange
from types import SimpleNamespace
from file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from tokenization import (BasicTokenizer, BertTokenizer, whitespace_tokenize)
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
import math
import json
import numpy as np
import collections
def preprocess_tokenized_text(doc_tokens, query_tokens, tokenizer,
max_seq_length, max_query_length):
""" converts an example into a feature """
if len(query_tokens) > max_query_length:
query_tokens = query_tokens[0:max_query_length]
tok_to_orig_index = []
orig_to_tok_index = []
all_doc_tokens = []
for (i, token) in enumerate(doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token)
for sub_token in sub_tokens:
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = max_seq_length - len(query_tokens) - 3
# truncate if too long
length = len(all_doc_tokens)
length = min(length, max_tokens_for_doc)
tokens = []
token_to_orig_map = {}
token_is_max_context = {}
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in query_tokens:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
for i in range(length):
token_to_orig_map[len(tokens)] = tok_to_orig_index[i]
token_is_max_context[len(tokens)] = True
tokens.append(all_doc_tokens[i])
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
tensors_for_inference = {
'input_ids': input_ids,
'input_mask': input_mask,
'segment_ids': segment_ids
}
tensors_for_inference = SimpleNamespace(**tensors_for_inference)
tokens_for_postprocessing = {
'tokens': tokens,
'token_to_orig_map': token_to_orig_map,
'token_is_max_context': token_is_max_context
}
tokens_for_postprocessing = SimpleNamespace(**tokens_for_postprocessing)
return tensors_for_inference, tokens_for_postprocessing
RawResult = collections.namedtuple("RawResult", ["start_logits", "end_logits"])
def get_predictions(doc_tokens, tokens_for_postprocessing,
start_logits, end_logits, n_best_size,
max_answer_length, do_lower_case,
can_give_negative_answer, null_score_diff_threshold):
""" Write final predictions to the json file and log-odds of null if needed. """
result = RawResult(start_logits=start_logits, end_logits=end_logits)
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
"PrelimPrediction",
["start_index", "end_index", "start_logit", "end_logit"])
prelim_predictions = []
# keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive
null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score
start_indices = _get_indices_of_largest_logits(result.start_logits)
end_indices = _get_indices_of_largest_logits(result.end_logits)
# if we could have irrelevant answers, get the min score of irrelevant
if can_give_negative_answer:
feature_null_score = result.start_logits[0] + result.end_logits[0]
if feature_null_score < score_null:
score_null = feature_null_score
null_start_logit = result.start_logits[0]
null_end_logit = result.end_logits[0]
for start_index in start_indices:
for end_index in end_indices:
# We could hypothetically create invalid predictions, e.g., predict
# that the start of the span is in the question. We throw out all
# invalid predictions.
if start_index >= len(tokens_for_postprocessing.tokens):
continue
if end_index >= len(tokens_for_postprocessing.tokens):
continue
if start_index not in tokens_for_postprocessing.token_to_orig_map:
continue
if end_index not in tokens_for_postprocessing.token_to_orig_map:
continue
if not tokens_for_postprocessing.token_is_max_context.get(start_index, False):
continue
if end_index < start_index:
continue
length = end_index - start_index + 1
if length > max_answer_length:
continue
prelim_predictions.append(
_PrelimPrediction(
start_index=start_index,
end_index=end_index,
start_logit=result.start_logits[start_index],
end_logit=result.end_logits[end_index]
)
)
if can_give_negative_answer:
prelim_predictions.append(
_PrelimPrediction(
start_index=0,
end_index=0,
start_logit=null_start_logit,
end_logit=null_end_logit
)
)
prelim_predictions = sorted(
prelim_predictions,
key=lambda x: (x.start_logit + x.end_logit),
reverse=True
)
_NbestPrediction = collections.namedtuple("NbestPrediction", ["text", "start_logit", "end_logit"])
seen_predictions = {}
nbest = []
for pred in prelim_predictions:
if len(nbest) >= n_best_size:
break
if pred.start_index > 0: # this is a non-null prediction
tok_tokens = tokens_for_postprocessing.tokens[pred.start_index:(pred.end_index + 1)]
orig_doc_start = tokens_for_postprocessing.token_to_orig_map[pred.start_index]
orig_doc_end = tokens_for_postprocessing.token_to_orig_map[pred.end_index]
orig_tokens = doc_tokens[orig_doc_start:(orig_doc_end + 1)]
tok_text = " ".join(tok_tokens)
# de-tokenize WordPieces that have been split off.
tok_text = tok_text.replace(" ##", "")
tok_text = tok_text.replace("##", "")
# clean whitespace
tok_text = tok_text.strip()
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
# get final text
final_text = get_final_text(tok_text, orig_text, do_lower_case)
if final_text in seen_predictions:
continue
# mark it
seen_predictions[final_text] = True
else: # this is a null prediction
final_text = ""
seen_predictions[final_text] = True
nbest.append(
_NbestPrediction(
text=final_text,
start_logit=pred.start_logit,
end_logit=pred.end_logit
)
)
# if we didn't include the empty option in the n-best, include it
if can_give_negative_answer:
if "" not in seen_predictions:
nbest.append(
_NbestPrediction(
text="",
start_logit=null_start_logit,
end_logit=null_end_logit
)
)
# In very rare edge cases we could only have single null prediction.
# So we just create a nonce prediction in this case to avoid failure.
if len(nbest) == 1:
nbest.insert(0, _NbestPrediction(text="", start_logit=0.0, end_logit=0.0))
# In very rare edge cases we could have no valid predictions. So we
# just create a nonce prediction in this case to avoid failure.
if not nbest:
nbest.append(_NbestPrediction(text="", start_logit=0.0, end_logit=0.0))
assert len(nbest) >= 1
# scoring
total_scores = []
best_non_null_entry = None
for entry in nbest:
total_scores.append(entry.start_logit + entry.end_logit)
if not best_non_null_entry:
if entry.text:
best_non_null_entry = entry
# get probabilities
probs = _compute_softmax(total_scores)
# nbest predictions into json format
nbest_json = []
for (i, entry) in enumerate(nbest):
output = collections.OrderedDict()
output["text"] = entry.text
output["probability"] = probs[i]
output["start_logit"] = entry.start_logit
output["end_logit"] = entry.end_logit
nbest_json.append(output)
assert len(nbest_json) >= 1
if can_give_negative_answer:
# predict "unknown" iff ((score_null - score_of_best_non-null_entry) > threshold)
score = best_non_null_entry.start_logit + best_non_null_entry.end_logit
score_diff = score_null - score
if score_diff > null_score_diff_threshold:
nbest_json[0]['text'] = "unknown"
# best_non_null_entry.text = "unknown"
#
return nbest_json
def get_final_text(pred_text, orig_text, do_lower_case):
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
# now `orig_text` contains the span of our original text corresponding to the
# span that we predicted.
#
# However, `orig_text` may contain extra characters that we don't want in
# our prediction.
#
# For example, let's say:
# pred_text = steve smith
# orig_text = Steve Smith's
#
# We don't want to return `orig_text` because it contains the extra "'s".
#
# We don't want to return `pred_text` because it's already been normalized
# (the SQuAD eval script also does punctuation stripping/lower casing but
# our tokenizer does additional normalization like stripping accent
# characters).
#
# What we really want to return is "Steve Smith".
#
# Therefore, we have to apply a semi-complicated alignment heruistic between
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
# can fail in certain cases in which case we just return `orig_text`.
def _strip_spaces(text):
ns_chars = []
ns_to_s_map = collections.OrderedDict()
for (i, c) in enumerate(text):
if c == " ":
continue
ns_to_s_map[len(ns_chars)] = i
ns_chars.append(c)
ns_text = "".join(ns_chars)
return (ns_text, ns_to_s_map)
# We first tokenize `orig_text`, strip whitespace from the result
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
tok_text = " ".join(tokenizer.tokenize(orig_text))
start_position = tok_text.find(pred_text)
if start_position == -1:
return orig_text
end_position = start_position + len(pred_text) - 1
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
if len(orig_ns_text) != len(tok_ns_text):
return orig_text
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map = {}
for (i, tok_index) in tok_ns_to_s_map.items():
tok_s_to_ns_map[tok_index] = i
orig_start_position = None
if start_position in tok_s_to_ns_map:
ns_start_position = tok_s_to_ns_map[start_position]
if ns_start_position in orig_ns_to_s_map:
orig_start_position = orig_ns_to_s_map[ns_start_position]
if orig_start_position is None:
return orig_text
orig_end_position = None
if end_position in tok_s_to_ns_map:
ns_end_position = tok_s_to_ns_map[end_position]
if ns_end_position in orig_ns_to_s_map:
orig_end_position = orig_ns_to_s_map[ns_end_position]
if orig_end_position is None:
return orig_text
output_text = orig_text[orig_start_position:(orig_end_position + 1)]
return output_text
def _compute_softmax(scores):
"""Compute softmax probability over raw logits."""
if not scores:
return []
max_score = None
for score in scores:
if max_score is None or score > max_score:
max_score = score
exp_scores = []
total_sum = 0.0
for score in scores:
x = math.exp(score - max_score)
exp_scores.append(x)
total_sum += x
probs = []
for score in exp_scores:
probs.append(score / total_sum)
return probs
def _get_indices_of_largest_logits(logits):
""" sort logits and return the indices of the sorted array """
indices_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
indices = map(lambda x: x[0], indices_and_score)
indices = list(indices)
return indices
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
"bert-base-multilingual-cased, bert-base-chinese.")
parser.add_argument("--init_checkpoint",
default=None,
type=str,
required=True,
help="The checkpoint file from pretraining")
## Other parameters
parser.add_argument("--seed", default=1, type=int)
parser.add_argument("--question", default="Most antibiotics target bacteria and don't affect what class of organisms? ",
type=str, help="question")
parser.add_argument("--context", default="Within the genitourinary and gastrointestinal tracts, commensal flora serve as biological barriers by competing with pathogenic bacteria for food and space and, in some cases, by changing the conditions in their environment, such as pH or available iron. This reduces the probability that pathogens will reach sufficient numbers to cause illness. However, since most antibiotics non-specifically target bacteria and do not affect fungi, oral antibiotics can lead to an overgrowth of fungi and cause conditions such as a vaginal candidiasis (a yeast infection). There is good evidence that re-introduction of probiotic flora, such as pure cultures of the lactobacilli normally found in unpasteurized yogurt, helps restore a healthy balance of microbial populations in intestinal infections in children and encouraging preliminary data in studies on bacterial gastroenteritis, inflammatory bowel diseases, urinary tract infection and post-surgical infections. ",
type=str, help="context")
parser.add_argument("--max_seq_length", default=384, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences "
"longer than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--max_query_length", default=64, type=int,
help="The maximum number of tokens for the question. Questions longer than this will "
"be truncated to this length.")
parser.add_argument("--n_best_size", default=1, type=int,
help="The total number of n-best predictions to generate. ")
parser.add_argument("--max_answer_length", default=30, type=int,
help="The maximum length of an answer that can be generated. This is needed because the start "
"and end predictions are not conditioned on one another.")
parser.add_argument("--no_cuda",
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--do_lower_case",
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument('--can_give_negative_answer',
action='store_true',
help='If true, then the model can reply with "unknown". ')
parser.add_argument('--null_score_diff_threshold',
type=float, default=-11.0,
help="If null_score - best_non_null is greater than the threshold predict 'unknown'. ")
parser.add_argument('--vocab_file',
type=str, default=None, required=True,
help="Vocabulary mapping/file BERT was pretrainined on")
parser.add_argument("--config_file",
default=None,
type=str,
required=True,
help="The BERT model config")
parser.add_argument('--fp16',
action='store_true',
help="use mixed-precision")
parser.add_argument("--local_rank", default=-1, help="ordinal of the GPU to use")
args = parser.parse_args()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
tokenizer = BertTokenizer(args.vocab_file, do_lower_case=args.do_lower_case, max_len=512) # for bert large
# Prepare model
config = BertConfig.from_json_file(args.config_file)
# Padding for divisibility by 8
if config.vocab_size % 8 != 0:
config.vocab_size += 8 - (config.vocab_size % 8)
# initialize model
model = BertForQuestionAnswering(config)
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')["model"])
model.to(device)
if args.fp16:
model.half()
model.eval()
print("question: ", args.question)
print("context: ", args.context)
print()
# preprocessing
doc_tokens = args.context.split()
query_tokens = tokenizer.tokenize(args.question)
feature = preprocess_tokenized_text(doc_tokens,
query_tokens,
tokenizer,
max_seq_length=args.max_seq_length,
max_query_length=args.max_query_length)
tensors_for_inference, tokens_for_postprocessing = feature
input_ids = torch.tensor(tensors_for_inference.input_ids, dtype=torch.long).unsqueeze(0)
segment_ids = torch.tensor(tensors_for_inference.segment_ids, dtype=torch.long).unsqueeze(0)
input_mask = torch.tensor(tensors_for_inference.input_mask, dtype=torch.long).unsqueeze(0)
# load tensors to device
input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
# run prediction
with torch.no_grad():
start_logits, end_logits = model(input_ids, segment_ids, input_mask)
# post-processing
start_logits = start_logits[0].detach().cpu().tolist()
end_logits = end_logits[0].detach().cpu().tolist()
answer = get_predictions(doc_tokens, tokens_for_postprocessing,
start_logits, end_logits, args.n_best_size,
args.max_answer_length, args.do_lower_case,
args.can_give_negative_answer,
args.null_score_diff_threshold)
# print result
print(json.dumps(answer, indent=4))
if __name__ == "__main__":
main()
import torch
from .fused_gelu import bias_gelu_impl
__all__ = ["bias_gelu_impl"]
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
# 1/sqrt(2*pi)-> 0.3989423
# 1/sqrt(2) -> 0.70710678
# sqrt(2/pi) -> 0.79788456
# this function is tanh approximation of gelu
# actual gelu is:
# x * 0.5 * (1.0 + torch.erf(x * 0.70710678))
@torch.jit.script
def bias_gelu(bias, y):
x = bias + y
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
# gradient of tanh approximation of gelu
# gradient of actual gelu is:
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
@torch.jit.script
def bias_gelu_back(g, bias, y):
x = bias + y
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff*g
class GeLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, bias):
ctx.save_for_backward(input, bias)
return bias_gelu(bias, input)
@staticmethod
def backward(ctx, grad_output):
input, bias = ctx.saved_tensors
tmp = bias_gelu_back(grad_output, bias, input)
return tmp, tmp
bias_gelu_impl = GeLUFunction.apply
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.contrib.multihead_attn import fast_mask_softmax_dropout_func
from bmm1 import *
from bmm2 import *
from padding import *
from softmax import *
class FastUnpadBertSelfAttention(nn.Module):
def __init__(self, config, enable_stream=True, enable_sync=True, fuse_mask=True, fuse_scale=True, fuse_qkv=True, fuse_dropout=True, apex_softmax=True, pad=True):
super(FastUnpadBertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.hidden_size = config.hidden_size
self.fuse_qkv = fuse_qkv
self.fuse_scale = fuse_scale
self.fuse_mask = fuse_mask
self.fuse_dropout = fuse_dropout
self.apex_softmax = apex_softmax
self.pad = pad
self.enable_stream = enable_stream
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
if self.fuse_qkv:
self.bmm1 = Bmm1Strided(None,None,self.num_attention_heads,self.attention_head_size, scale=self.fuse_scale, stream=enable_stream, sync=enable_sync, timer=False)
self.bmm2 = Bmm2Strided(None,None,self.num_attention_heads,self.attention_head_size, stream=enable_stream, sync=enable_sync, timer=False)
else:
self.bmm1 = Bmm1(None,None,self.num_attention_heads,self.attention_head_size, scale=self.fuse_scale, stream=enable_stream, sync=enable_sync)
self.bmm2 = Bmm2(None,None,self.num_attention_heads,self.attention_head_size, stream=enable_stream, sync=enable_sync)
if self.fuse_dropout == False:
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
if self.fuse_mask == True and self.fuse_dropout == True:
self.softmax = FastMaskSoftmaxDropout(dim=-1, dropout_prob=config.attention_probs_dropout_prob,stream=enable_stream, sync=(not self.pad), timer=False)
elif self.fuse_mask == True:
self.softmax = FastMaskSoftmax(dim=-1, stream=enable_stream, sync=enable_sync, timer=False)
else:
self.softmax = FastSoftmax(dim=-1, stream=enable_stream, sync=enable_sync, timer=False)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = torch.reshape(x, new_x_shape)
return x.permute(0, 2, 1, 3)
def transpose_key_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = torch.reshape(x, new_x_shape)
return x.permute(0, 2, 3, 1)
def pytorch_softmax(self,attention_scores, batch, seqlen, heads):
ntokens2 = 0
for i in range(batch):
ntokens2 += seqlen[i]*seqlen[i]*self.num_attention_heads
attention_probs = torch.zeros(ntokens2, device="cuda", dtype=torch.float16)
ntokens2 = 0
for i in range(batch):
tokens2 = seqlen[i]*seqlen[i]*self.num_attention_heads
attention_probs[ntokens2:ntokens2+tokens2] = F.softmax(attention_scores[ntokens2:ntokens2+tokens2].view(1,self.num_attention_heads,seqlen[i],seqlen[i]), dim=-1).flatten().contiguous()
ntokens2 += tokens2
return attention_probs
def forward(self, hidden_states, attention_mask, seqlen, batch, is_training=True):
self.batch = batch
# QKV
if self.fuse_qkv:
weight = torch.cat([self.query.weight.view(self.num_attention_heads,self.attention_head_size,1,self.hidden_size), self.key.weight.view(self.num_attention_heads,self.attention_head_size,1,self.hidden_size), self.value.weight.view(self.num_attention_heads,self.attention_head_size,1,self.hidden_size)], dim=1).reshape(self.all_head_size*3,self.hidden_size).contiguous()
bias = torch.cat([self.query.bias.view(self.num_attention_heads,1,self.attention_head_size), self.key.bias.view(self.num_attention_heads,1,self.attention_head_size), self.value.bias.view(self.num_attention_heads,1,self.attention_head_size)],dim=1).reshape(3*self.hidden_size).contiguous()
mixed_x_layer = torch.addmm(bias, hidden_states, weight.t())
else:
query_layer = self.query(hidden_states)
key_layer = self.key(hidden_states)
value_layer = self.value(hidden_states)
# BMM1.
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_qkv:
attention_scores, qkv_layer = self.bmm1(mixed_x_layer, self.batch, seqlen)
else:
attention_scores = self.bmm1(query_layer, key_layer, self.batch, seqlen)
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_scale == False:
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Softmax.
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_mask ==True and self.fuse_dropout == True:
attention_probs = self.softmax(attention_scores, attention_mask, self.batch, seqlen, self.num_attention_heads, is_training)
elif self.fuse_mask == True:
attention_probs = self.softmax(attention_scores, attention_mask, self.batch, seqlen, self.num_attention_heads)
else:
attention_scores = attention_scores + attention_mask.view(-1)
if self.apex_softmax == True:
attention_probs = self.softmax(attention_scores, self.batch, seqlen, self.num_attention_heads)
else:
if self.pad == True:
attention_probs = F.softmax(attention_scores.view(batch,self.num_attention_heads,seqlen[0],seqlen[0]), dim=-1).flatten().contiguous()
else:
attention_probs = self.pytorch_softmax(attention_scores, self.batch, seqlen, self.num_attention_heads)
# Dropout.
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_dropout == False:
attention_probs = self.dropout(attention_probs)
# BMM2.
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_qkv:
context_layer = self.bmm2(attention_probs, qkv_layer, self.batch, seqlen)
else:
context_layer = self.bmm2(attention_probs, value_layer, self.batch, seqlen)
if self.enable_stream: torch.cuda.synchronize()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = torch.reshape(context_layer, new_context_layer_shape)
return context_layer
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "/opt/pytorch/apex/apex/contrib/csrc/multihead_attn/softmax.h"
#define nstreams 16
// global variables.
cudaStream_t stream[nstreams];
cublasHandle_t handle;
///////////////////////////////////////////////////////////////////////////////////////////////////
void FastBmm1Fprop_(torch::Tensor &A,
torch::Tensor &B,
torch::Tensor &C,
int batch,
torch::Tensor &seq_len,
int heads,
int embed,
bool scale,
bool strided,
bool enable_stream,
bool sync)
{
float one = 1.0, zero = 0.0, alpha = 1.0 / sqrt(static_cast<float>(embed));
int *seqlen = static_cast<int*>(seq_len.data_ptr());
void *ptrA = static_cast<void*>(static_cast<half*>(A.data_ptr()) + (strided ? embed : 0)); // key
void *ptrB = static_cast<void*>(static_cast<half*>(B.data_ptr())); // query
void *ptrC = static_cast<void*>(static_cast<half*>(C.data_ptr())); // output
for(int i = 0; i < (enable_stream ? batch : 1); i++) {
cublasSetStream(handle, enable_stream ? stream[i%nstreams]: at::cuda::getCurrentCUDAStream());
cublasGemmStridedBatchedEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
seqlen[i],
seqlen[i],
embed,
static_cast<const void*>(scale ? &alpha : &one),
ptrA,
CUDA_R_16F,
(enable_stream ? 1 : batch) * (strided ? heads*3*embed : heads*embed),
strided ? 3*embed : embed,
ptrB,
CUDA_R_16F,
(enable_stream ? 1 : batch) * (strided ? heads*3*embed : heads*embed),
strided ? 3*embed : embed,
static_cast<const void*>(&zero),
ptrC,
CUDA_R_16F,
seqlen[i],
seqlen[i]*seqlen[i],
enable_stream ? heads : batch*heads,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
ptrA = static_cast<void*>(static_cast<half*>(ptrA) + (strided ? seqlen[i]*heads*3*embed : seqlen[i]*heads*embed));
ptrB = static_cast<void*>(static_cast<half*>(ptrB) + (strided ? seqlen[i]*heads*3*embed : seqlen[i]*heads*embed));
ptrC = static_cast<void*>(static_cast<half*>(ptrC) + heads*seqlen[i]*seqlen[i]);
}
for(int i = 0; i < (enable_stream ? nstreams : 0); i++) {
if(sync) cudaStreamSynchronize(stream[i]);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
void FastBmm2Fprop_(torch::Tensor &A,
torch::Tensor &B,
torch::Tensor &C,
int batch,
torch::Tensor &seq_len,
int heads,
int embed,
bool scale,
bool strided,
bool enable_stream,
bool sync)
{
float one = 1.0, zero = 0.0;
int *seqlen = static_cast<int*>(seq_len.data_ptr());
void *ptrA = static_cast<void*>(static_cast<half*>(A.data_ptr()) + (strided ? 2*embed : 0)); // value
void *ptrB = static_cast<void*>(static_cast<half*>(B.data_ptr())); // query*key
void *ptrC = static_cast<void*>(static_cast<half*>(C.data_ptr())); // output
for(int i = 0; i < (enable_stream ? batch : 1); i++) {
cublasSetStream(handle, enable_stream ? stream[i%nstreams]: at::cuda::getCurrentCUDAStream());
cublasGemmStridedBatchedEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed,
seqlen[i],
seqlen[i],
static_cast<const void*>(&one),
ptrA,
CUDA_R_16F,
(enable_stream ? 1 : batch) * (strided ? heads*3*embed : heads*embed),
strided ? 3*embed : embed,
ptrB,
CUDA_R_16F,
seqlen[i],
seqlen[i]*seqlen[i],
static_cast<const void*>(&zero),
ptrC,
CUDA_R_16F,
enable_stream ? heads*embed : batch*heads*embed,
embed,
enable_stream ? heads : batch*heads,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
ptrA = static_cast<void*>(static_cast<half*>(ptrA) + (strided ? seqlen[i]*heads*3*embed : seqlen[i]*heads*embed));
ptrB = static_cast<void*>(static_cast<half*>(ptrB) + heads*seqlen[i]*seqlen[i]);
ptrC = static_cast<void*>(static_cast<half*>(ptrC) + seqlen[i]*heads*embed);
}
for(int i = 0; i < (enable_stream ? nstreams : 0); i++) {
if(sync) cudaStreamSynchronize(stream[i]);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
void FastBmm1Dgrad1_(torch::Tensor &A,
torch::Tensor &B,
torch::Tensor &C,
int batch,
torch::Tensor &seq_len,
int heads,
int embed,
bool scale,
bool strided,
bool enable_stream,
bool sync)
{
float one = 1.0, zero = 0.0, alpha = 1.0 / sqrt(static_cast<float>(embed));
int *seqlen = static_cast<int*>(seq_len.data_ptr());
void *ptrA = static_cast<void*>(static_cast<half*>(A.data_ptr())); // query
void *ptrB = static_cast<void*>(static_cast<half*>(B.data_ptr()));
void *ptrC = static_cast<void*>(static_cast<half*>(C.data_ptr()) + (strided ? embed : 0)); // grad_key
for(int i = 0; i < (enable_stream ? batch : 1); i++) {
cublasSetStream(handle, enable_stream ? stream[i%nstreams] : at::cuda::getCurrentCUDAStream());
cublasGemmStridedBatchedEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed,
seqlen[i],
seqlen[i],
static_cast<const void*>(scale ? &alpha : &one),
ptrA,
CUDA_R_16F,
(enable_stream ? 1 : batch) * (strided ? heads*3*embed : heads*embed),
strided ? 3*embed : embed,
ptrB,
CUDA_R_16F,
seqlen[i],
seqlen[i]*seqlen[i],
static_cast<const void*>(&zero),
ptrC,
CUDA_R_16F,
(enable_stream ? 1 : batch) * (strided ? heads*3*embed : heads*embed),
strided ? 3*embed : embed,
enable_stream ? heads : heads*batch,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
ptrA = static_cast<void*>(static_cast<half*>(ptrA) + (strided ? seqlen[i]*heads*3*embed : seqlen[i]*heads*embed));
ptrB = static_cast<void*>(static_cast<half*>(ptrB) + heads*seqlen[i]*seqlen[i]);
ptrC = static_cast<void*>(static_cast<half*>(ptrC) + (strided ? seqlen[i]*heads*3*embed : seqlen[i]*heads*embed));
}
for(int i = 0; i < (enable_stream ? nstreams : 0); i++) {
if(sync) cudaStreamSynchronize(stream[i]);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
void FastBmm2Dgrad1_(torch::Tensor &A,
torch::Tensor &B,
torch::Tensor &C,
int batch,
torch::Tensor &seq_len,
int heads,
int embed,
bool scale,
bool strided,
bool enable_stream,
bool sync)
{
float one = 1.0, zero = 0.0;
int *seqlen = static_cast<int*>(seq_len.data_ptr());
void *ptrA = static_cast<void*>(static_cast<half*>(A.data_ptr()) + (strided ? 2*embed : 0)); // value
void *ptrB = static_cast<void*>(static_cast<half*>(B.data_ptr()));
void *ptrC = static_cast<void*>(static_cast<half*>(C.data_ptr()));
for(int i = 0; i < (enable_stream ? batch : 1); i++) {
cublasSetStream(handle, enable_stream ? stream[i%nstreams] : at::cuda::getCurrentCUDAStream());
cublasGemmStridedBatchedEx(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
seqlen[i],
seqlen[i],
embed,
static_cast<const void*>(&one),
ptrA,
CUDA_R_16F,
(enable_stream ? 1 : batch) * (strided ? heads*3*embed : heads*embed),
strided ? 3*embed : embed,
ptrB,
CUDA_R_16F,
enable_stream ? heads*embed : batch*heads*embed,
embed,
static_cast<const void*>(&zero),
ptrC,
CUDA_R_16F,
seqlen[i],
seqlen[i]*seqlen[i],
enable_stream ? heads : batch*heads,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
ptrA = static_cast<void*>(static_cast<half*>(ptrA) + (strided ? seqlen[i]*heads*3*embed : seqlen[i]*heads*embed));
ptrB = static_cast<void*>(static_cast<half*>(ptrB) + seqlen[i]*heads*embed);
ptrC = static_cast<void*>(static_cast<half*>(ptrC) + heads*seqlen[i]*seqlen[i]);
}
for(int i = 0; i < (enable_stream ? nstreams : 0); i++) {
if(sync) cudaStreamSynchronize(stream[i]);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
void FastBmm1Dgrad2_(torch::Tensor &A,
torch::Tensor &B,
torch::Tensor &C,
int batch,
torch::Tensor &seq_len,
int heads,
int embed,
bool scale,
bool strided,
bool enable_stream,
bool sync)
{
float one = 1.0, zero = 0.0, alpha = 1.0 / sqrt(static_cast<float>(embed));
int *seqlen = static_cast<int*>(seq_len.data_ptr());
void *ptrA = static_cast<void*>(static_cast<half*>(A.data_ptr()) + (strided ? embed : 0)); // key
void *ptrB = static_cast<void*>(static_cast<half*>(B.data_ptr()));
void *ptrC = static_cast<void*>(static_cast<half*>(C.data_ptr())); // grad query
for(int i = 0; i < (enable_stream ? batch : 1); i++) {
cublasSetStream(handle, enable_stream ? stream[i%nstreams] : at::cuda::getCurrentCUDAStream());
cublasGemmStridedBatchedEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed,
seqlen[i],
seqlen[i],
static_cast<const void*>(scale ? &alpha : &one),
ptrA,
CUDA_R_16F,
(enable_stream ? 1 : batch) * (strided ? heads*3*embed : heads*embed),
strided ? 3*embed : embed,
ptrB,
CUDA_R_16F,
seqlen[i],
seqlen[i]*seqlen[i],
static_cast<const void*>(&zero),
ptrC,
CUDA_R_16F,
(enable_stream ? 1 : batch) * (strided ? heads*3*embed : heads*embed),
strided ? 3*embed : embed,
enable_stream ? heads : batch*heads,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
ptrA = static_cast<void*>(static_cast<half*>(ptrA) + (strided ? seqlen[i]*heads*3*embed : seqlen[i]*heads*embed));
ptrB = static_cast<void*>(static_cast<half*>(ptrB) + heads*seqlen[i]*seqlen[i]);
ptrC = static_cast<void*>(static_cast<half*>(ptrC) + (strided ? seqlen[i]*heads*3*embed : seqlen[i]*heads*embed));
}
for(int i = 0; i < (enable_stream ? nstreams : 0); i++) {
if(sync) cudaStreamSynchronize(stream[i]);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
void FastBmm2Dgrad2_(torch::Tensor &A,
torch::Tensor &B,
torch::Tensor &C,
int batch,
torch::Tensor &seq_len,
int heads,
int embed,
bool scale,
bool strided,
bool enable_stream,
bool sync)
{
float one = 1.0, zero = 0.0;
int *seqlen = static_cast<int*>(seq_len.data_ptr());
void *ptrA = static_cast<void*>(static_cast<half*>(A.data_ptr()));
void *ptrB = static_cast<void*>(static_cast<half*>(B.data_ptr()));
void *ptrC = static_cast<void*>(static_cast<half*>(C.data_ptr()) + (strided ? 2*embed : 0)); // grad-value
for(int i = 0; i < (enable_stream ? batch : 1); i++) {
cublasSetStream(handle, enable_stream ? stream[i%nstreams] : at::cuda::getCurrentCUDAStream());
cublasGemmStridedBatchedEx(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed,
seqlen[i],
seqlen[i],
static_cast<const void*>(&one),
ptrA,
CUDA_R_16F,
enable_stream ? heads*embed : batch*heads*embed,
embed,
ptrB,
CUDA_R_16F,
seqlen[i],
seqlen[i]*seqlen[i],
static_cast<const void*>(&zero),
ptrC,
CUDA_R_16F,
(enable_stream ? 1 : batch) * (strided ? heads*3*embed : heads*embed),
strided ? 3*embed : embed,
enable_stream ? heads : batch*heads,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
ptrA = static_cast<void*>(static_cast<half*>(ptrA) + seqlen[i]*heads*embed);
ptrB = static_cast<void*>(static_cast<half*>(ptrB) + heads*seqlen[i]*seqlen[i]);
ptrC = static_cast<void*>(static_cast<half*>(ptrC) + (strided ? seqlen[i]*heads*3*embed : seqlen[i]*heads*embed));
}
for(int i = 0; i < (enable_stream ? nstreams : 0); i++) {
if(sync) cudaStreamSynchronize(stream[i]);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
void FastSoftmaxFprop_(torch::Tensor &input,
int batch,
torch::Tensor &seq_len,
int heads,
bool enable_stream,
bool sync)
{
int *seqlen = static_cast<int*>(seq_len.data_ptr());
void *ptrIn = static_cast<void*>(input.data_ptr());
for(int i = 0; i < (enable_stream ? batch : 1); i++) {
dispatch_softmax<half, half, float>(
reinterpret_cast<half*>(ptrIn),
reinterpret_cast<const half*>(ptrIn),
seqlen[i],
seqlen[i],
enable_stream ? heads*seqlen[i] : batch*heads*seqlen[i]);
ptrIn = static_cast<void*>(static_cast<half*>(ptrIn) + heads*seqlen[i]*seqlen[i]);
}
for(int i = 0; i < (enable_stream ? nstreams : 0); i++) {
if(sync) cudaStreamSynchronize(stream[i]);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
void FastSoftmaxBprop_(torch::Tensor &input,
torch::Tensor &output,
int batch,
torch::Tensor &seq_len,
int heads,
bool enable_stream,
bool sync)
{
int *seqlen = static_cast<int*>(seq_len.data_ptr());
void *ptrIn = static_cast<void*>(input.data_ptr());
void *ptrOut = static_cast<void*>(output.data_ptr());
for(int i = 0; i < (enable_stream ? batch : 1); i++) {
dispatch_softmax_backward_stream<half, half, float>(
static_cast<half*>(ptrOut),
static_cast<half*>(ptrOut),
reinterpret_cast<half const*>(ptrIn),
seqlen[i],
seqlen[i],
enable_stream ? heads*seqlen[i] : batch*heads*seqlen[i],
enable_stream ? stream[i%nstreams] : at::cuda::getCurrentCUDAStream());
ptrIn = static_cast<void*>(static_cast<half*>(ptrIn) + heads*seqlen[i]*seqlen[i]);
ptrOut = static_cast<void*>(static_cast<half*>(ptrOut) + heads*seqlen[i]*seqlen[i]);
}
for(int i = 0; i < (enable_stream ? nstreams : 0); i++) {
if(sync) cudaStreamSynchronize(stream[i]);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
void FastMaskSoftmaxFprop_(torch::Tensor &input,
torch::Tensor &mask,
int batch,
torch::Tensor &seq_len,
int heads,
bool enable_stream,
bool sync)
{
int *seqlen = static_cast<int*>(seq_len.data_ptr());
void *ptrIn = static_cast<void*>(input.data_ptr());
void *ptrMask = static_cast<void*>(mask.data_ptr());
for(int i = 0; i < (enable_stream ? batch : 1); i++) {
dispatch_additive_masked_softmax_stream<half, half, float>(
reinterpret_cast<half*>(ptrIn),
reinterpret_cast<const half*>(ptrIn),
reinterpret_cast<const half*>(ptrMask),
seqlen[i],
seqlen[i],
enable_stream ? heads*seqlen[i] : batch*heads*seqlen[i],
enable_stream ? heads*seqlen[i] : heads*seqlen[i],
enable_stream ? stream[i%nstreams] : at::cuda::getCurrentCUDAStream());
ptrIn = static_cast<void*>(static_cast<half*>(ptrIn) + heads*seqlen[i]*seqlen[i]);
ptrMask = static_cast<void*>(static_cast<half*>(ptrMask) + seqlen[i]);
}
for(int i = 0; i < (enable_stream ? nstreams : 0); i++) {
if(sync) cudaStreamSynchronize(stream[i]);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
std::vector<torch::Tensor> FastMaskSoftmaxDropoutFprop_(torch::Tensor &input,
torch::Tensor &mask,
int batch,
torch::Tensor &seq_len,
int heads,
float dropout_prob,
bool enable_stream,
bool sync,
bool is_training)
{
int *seqlen = static_cast<int*>(seq_len.data_ptr());
void *ptrIn = static_cast<void*>(input.data_ptr());
void *ptrMask = static_cast<void*>(mask.data_ptr());
for(int i = 0; i < (enable_stream ? batch : 1); i++) {
dispatch_additive_masked_softmax_stream<half, half, float>(
reinterpret_cast<half*>(ptrIn),
reinterpret_cast<const half*>(ptrIn),
reinterpret_cast<const half*>(ptrMask),
seqlen[i],
seqlen[i],
enable_stream ? heads*seqlen[i] : batch*heads*seqlen[i],
enable_stream ? heads*seqlen[i] : heads*seqlen[i],
enable_stream ? stream[i%nstreams] : at::cuda::getCurrentCUDAStream());
ptrIn = static_cast<void*>(static_cast<half*>(ptrIn) + heads*seqlen[i]*seqlen[i]);
ptrMask = static_cast<void*>(static_cast<half*>(ptrMask) + seqlen[i]);
}
for(int i = 0; i < (enable_stream ? nstreams : 0); i++) {
if(sync) cudaStreamSynchronize(stream[i]);
}
int ntokens = seqlen[0];
for(int i = 1; i < (enable_stream ? batch : 2); i++) {
ntokens += seqlen[i];
}
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(torch::kUInt8);
torch::Tensor dropout_results = torch::empty({batch*heads, ntokens}, act_options);
torch::Tensor dropout_mask = torch::empty({batch*heads, ntokens}, mask_options);
//torch::Tensor dropout_results = torch::empty({batch*heads, seqlen[0], seqlen[0]}, act_options);
//torch::Tensor dropout_mask = torch::empty({batch*heads, seqlen[0], seqlen[0]}, mask_options);
if (is_training) {
//use at:: function so that C++ version generates the same random mask as python version
auto dropout_tuple = at::_fused_dropout(input, 1.0f-dropout_prob);
dropout_results = std::get<0>(dropout_tuple);
dropout_mask = std::get<1>(dropout_tuple);
}
return {dropout_results, dropout_mask};
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
void FastMaskSoftmaxDropoutBprop_(torch::Tensor &input,
torch::Tensor &output,
torch::Tensor &dropout_mask,
int batch,
torch::Tensor &seq_len,
int heads,
float dropout_prob,
bool enable_stream,
bool sync)
{
int *seqlen = static_cast<int*>(seq_len.data_ptr());
void *ptrIn = static_cast<void*>(input.data_ptr());
void *ptrOut = static_cast<void*>(output.data_ptr());
void *ptrDropoutMask = static_cast<void*>(dropout_mask.data_ptr());
for(int i = 0; i < (enable_stream ? batch : 1); i++) {
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half*>(ptrOut),
static_cast<half*>(ptrOut),
reinterpret_cast<half const*>(ptrIn),
reinterpret_cast<uint8_t const*>(ptrDropoutMask),
1.0/(1.0-dropout_prob),
seqlen[i],
seqlen[i],
enable_stream ? heads*seqlen[i] : batch*heads*seqlen[i],
enable_stream ? stream[i%nstreams] : at::cuda::getCurrentCUDAStream());
ptrIn = static_cast<void*>(static_cast<half*>(ptrIn) + heads*seqlen[i]*seqlen[i]);
ptrOut = static_cast<void*>(static_cast<half*>(ptrOut) + heads*seqlen[i]*seqlen[i]);
}
for(int i = 0; i < (enable_stream ? nstreams : 0); i++) {
if(sync) cudaStreamSynchronize(stream[i]);
}
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
void init_mha_cuda_extension()
{
// CUDA Stream.
for(int i = 0; i < nstreams; i++) {
cudaStreamCreate(&stream[i]);
}
// CuBlas Handle.
cublasCreate(&handle);
cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH);
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("InitMHACUDAExtension", &init_mha_cuda_extension, "InitMHACUDAExtension");
m.def("FastBmm1Fprop", &FastBmm1Fprop_, "FastBmm1Fprop");
m.def("FastBmm1Dgrad1", &FastBmm1Dgrad1_, "FastBmm1Dgrad1");
m.def("FastBmm1Dgrad2", &FastBmm1Dgrad2_, "FastBmm1Dgrad2");
m.def("FastBmm2Fprop", &FastBmm2Fprop_, "FastBmm2Fprop");
m.def("FastBmm2Dgrad1", &FastBmm2Dgrad1_, "FastBmm2Dgrad1");
m.def("FastBmm2Dgrad2", &FastBmm2Dgrad2_, "FastBmm2Dgrad2");
m.def("FastSoftmaxFprop", &FastSoftmaxFprop_, "FastSoftmaxFprop");
m.def("FastSoftmaxBprop", &FastSoftmaxBprop_, "FastSoftmaxBprop");
m.def("FastMaskSoftmaxFprop", &FastMaskSoftmaxFprop_, "FastMaskSoftmaxFprop");
m.def("FastMaskSoftmaxDropoutFprop", &FastMaskSoftmaxDropoutFprop_, "FastMaskSoftmaxDropoutFprop");
m.def("FastMaskSoftmaxDropoutBprop", &FastMaskSoftmaxDropoutBprop_, "FastMaskSoftmaxDropoutBprop");
}
import torch
import setuptools
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name='mhalib',
ext_modules=[
CUDAExtension(
name='mhalib',
sources=['mha_funcs.cu'],
extra_compile_args={
'cxx': ['-O3',],
'nvcc':['-O3','-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', "--expt-relaxed-constexpr", "-ftemplate-depth=1024", '-gencode arch=compute_70,code=sm_70','-gencode arch=compute_80,code=sm_80','-gencode arch=compute_80,code=compute_80']
}
)
],
cmdclass={
'build_ext': BuildExtension
})
import collections
import os
import subprocess
import torch
from mlperf_logging import mllog
from mlperf_logging.mllog import constants
mllogger = mllog.get_mllogger()
def log_start(*args, **kwargs):
_log_print(mllogger.start, *args, **kwargs)
def log_end(*args, **kwargs):
_log_print(mllogger.end, *args, **kwargs)
def log_event(*args, **kwargs):
_log_print(mllogger.event, *args, **kwargs)
def _log_print(logger, *args, **kwargs):
if kwargs.pop('sync', False):
barrier()
if 'stack_offset' not in kwargs:
kwargs['stack_offset'] = 3
if 'value' not in kwargs:
kwargs['value'] = None
if kwargs.pop('log_all_ranks', False):
log = True
else:
log = (get_rank() == 0)
if log:
logger(*args, **kwargs)
def barrier():
"""
Works as a temporary distributed barrier, currently pytorch
doesn't implement barrier for NCCL backend.
Calls all_reduce on dummy tensor and synchronizes with GPU.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.all_reduce(torch.cuda.FloatTensor(1))
torch.cuda.synchronize()
def get_rank():
"""
Gets distributed rank or returns zero if distributed is not initialized.
"""
if torch.distributed.is_available() and torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
else:
rank = 0
return rank
def mlperf_submission_log(benchmark):
num_nodes = os.environ.get('SLURM_NNODES', 1)
mllog.config(filename=os.path.join(os.path.dirname(os.path.abspath(__file__)), f'{benchmark}.log'))
mllogger = mllog.get_mllogger()
mllogger.logger.propagate = False
log_event(
key=constants.SUBMISSION_BENCHMARK,
value=benchmark,
)
log_event(
key=constants.SUBMISSION_ORG,
value='NVIDIA')
log_event(
key=constants.SUBMISSION_DIVISION,
value='closed')
log_event(
key=constants.SUBMISSION_STATUS,
value='onprem')
log_event(
key=constants.SUBMISSION_PLATFORM,
value=f'{num_nodes}xSUBMISSION_PLATFORM_PLACEHOLDER')
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
from __future__ import absolute_import, division, print_function, unicode_literals
import copy
import json
import logging
import math
import os
import shutil
import tarfile
import tempfile
import sys
from io import open
from operator import mul
from functools import reduce
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.utils import checkpoint
from apex.contrib.multihead_attn import SelfMultiheadAttn
from file_utils import cached_path
from layers.fused_gelu import bias_gelu_impl as bias_gelu
from utils import get_rank
import mhalib
from mha import *
logger = logging.getLogger(__name__)
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
def remap_attn_names_tf(name):
if 'attention' in name:
ind = name.index("attention")
if 'self' in name and 'query' in name and 'kernel' in name:
name = name[:(ind+1)] + ['multi_head_attention', 'q_weight']
if 'self' in name and 'query' in name and 'bias' in name:
name = name[:(ind+1)] + ['multi_head_attention', 'q_bias']
if 'self' in name and 'key' in name and 'kernel' in name:
name = name[:(ind+1)] + ['multi_head_attention', 'k_weight']
if 'self' in name and 'key' in name and 'bias' in name:
name = name[:(ind+1)] + ['multi_head_attention', 'k_bias']
if 'self' in name and 'value' in name and 'kernel' in name:
name = name[:(ind+1)] + ['multi_head_attention', 'v_weight']
if 'self' in name and 'value' in name and 'bias' in name:
name = name[:(ind+1)] + ['multi_head_attention', 'v_bias']
if 'output' in name and 'dense' in name and 'kernel' in name:
name = name[:(ind+1)] + ['multi_head_attention', 'out_proj_weight']
if 'output' in name and 'dense' in name and 'bias' in name:
name = name[:(ind+1)] + ['multi_head_attention', 'out_proj_bias']
if 'output' in name and 'LayerNorm' in name:
name = name[:(ind+1)] + ['layer_norm'] + name[-1:]
return name
def load_tf_weights_in_bert(model, tf_checkpoint_path, use_fast_mha=False):
""" Load tf checkpoints in a pytorch model
"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(tf_checkpoint_path)
if get_rank() == 0:
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
if get_rank() == 0:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
# MHA params need to be treated separately
if use_fast_mha:
mha_params = ['q_weight', 'q_bias', 'k_weight', 'k_bias', 'v_weight', 'v_bias', 'out_proj_weight', 'out_proj_bias']
else:
mha_params = []
for name, array in zip(names, arrays):
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step", "LAMB", "LAMB_1", "beta1_power", "beta2_power"] for n in name):
if get_rank() == 0:
print("Skipping {}".format("/".join(name)))
continue
if use_fast_mha:
name = remap_attn_names_tf(name)
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] in mha_params:
pointer = getattr(pointer, l[0])
elif l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel' or (m_name in mha_params and 'bias' not in m_name):
array = np.ascontiguousarray(np.transpose(array))
try:
assert pointer.shape == array.shape
except AssertionError as e:
# If copying smaller into larger, assume padded and ok
if reduce(mul, pointer.shape) > reduce(mul, array.shape):
if get_rank() == 0:
print("Initialize padded PyTorch weight {}".format(name))
pointer.data.zero_()
def generate_slices():
slices = []
for i in range(array.ndim):
slices.append(slice(0, array.shape[i], 1))
return slices
# pointer.data[generate_slices()] = torch.from_numpy(array)
pointer.data[generate_slices()] = torch.from_numpy(array)
else:
e.args += (pointer.shape, array.shape)
raise
else:
if get_rank() == 0:
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
def swish(x):
return x * torch.sigmoid(x)
def fast_gelu(x):
pi = 3.1415926535897932
cdf = 0.5 * (1.0 + torch.tanh((math.sqrt(2 / pi) * (x + 0.044715 * torch.pow(x, 3)))))
return x*cdf
#torch.nn.functional.gelu(x) # Breaks ONNX export
#ACT2FN = {"gelu": torch.nn.functional.gelu, "bias_gelu": bias_gelu, "relu": torch.nn.functional.relu, "swish": swish}
ACT2FN = {"gelu": fast_gelu, "bias_gelu": bias_gelu, "relu": torch.nn.functional.relu, "swish": swish}
class LinearActivation(torch.nn.Linear):
r"""Fused Linear and activation Module.
"""
__constants__ = ['bias']
def __init__(self, in_features, out_features, act='gelu', bias=True):
super(LinearActivation, self).__init__(in_features, out_features, bias)
self.act_fn = nn.Identity() #
self.biased_act_fn = None #
if isinstance(act, str) or (sys.version_info[0] == 2 and isinstance(act, unicode)): # For TorchScript
if bias and not 'bias' in act: # compatibility
act = 'bias_' + act #
self.biased_act_fn = ACT2FN[act] #
else:
self.act_fn = ACT2FN[act]
else:
self.act_fn = act
def forward(self, input):
if not self.bias is None:
return self.biased_act_fn(self.bias, nn.functional.linear(input, self.weight, None))
else:
return self.act_fn(F.linear(input, self.weight, self.bias))
class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02):
"""Constructs BertConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`BertModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
@classmethod
def from_json_file(cls, json_file):
"""Constructs a `BertConfig` from a json file of parameters."""
with open(json_file, "r", encoding='utf-8') as reader:
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
return output
def to_json_string(self):
"""Serializes this instance to a JSON string."""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
try:
import apex
#apex.amp.register_half_function(apex.normalization.fused_layer_norm, 'FusedLayerNorm')
import apex.normalization
#apex.amp.register_float_function(apex.normalization.FusedLayerNorm, 'forward')
BertLayerNorm = apex.normalization.FusedLayerNorm
except ImportError:
print("Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.")
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module):
def __init__(self, config):
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.softmax = nn.Softmax(dim=-1)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def transpose_key_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 3, 1)
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_key_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer)
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(2)
# Normalize the attention scores to probabilities.
attention_probs = self.softmax(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
# This module uses Apex C++ multihead attention implementation with fusions.
class FastBertAttention(nn.Module):
def __init__(self, config):
super(FastBertAttention, self).__init__()
self.multi_head_attention = SelfMultiheadAttn(config.hidden_size, config.num_attention_heads, dropout = config.attention_probs_dropout_prob, bias=True, include_norm_add=False, impl='fast', separate_qkv_params=True, mask_additive=True)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.layer_norm = BertLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, input_tensor, attention_mask):
residual=input_tensor
multi_head_attention_output,_ = self.multi_head_attention(query = input_tensor, key = input_tensor, value = input_tensor, key_padding_mask=attention_mask, need_weights=True,attn_mask = None, is_training = self.training)
attention_output = self.dropout(multi_head_attention_output)
attention_output = self.layer_norm(attention_output + residual)
return attention_output
class FastUnpadBertAttention(nn.Module):
def __init__(self, config):
super(FastUnpadBertAttention, self).__init__()
self.self = FastUnpadBertSelfAttention(config, enable_stream=config.enable_stream, enable_sync=False, fuse_mask=config.fuse_mask, fuse_scale=config.fuse_scale, fuse_qkv=config.fuse_qkv, fuse_dropout=config.fuse_dropout, apex_softmax=config.apex_softmax, pad=config.pad)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask, seqlen, batch):
self_output = self.self(input_tensor, attention_mask, seqlen, batch, is_training = self.training)
attention_output = self.output(self_output, input_tensor)
return attention_output
class BertAttention(nn.Module):
def __init__(self, config):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BertIntermediate, self).__init__()
self.fused_gelu_bias = config.fused_gelu_bias
if config.fused_gelu_bias:
self.dense = LinearActivation(config.hidden_size, config.intermediate_size, act=config.hidden_act)
else:
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
if not self.fused_gelu_bias:
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
def __init__(self, config):
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.unpad = config.unpad
if config.fused_mha:
self.attention = FastBertAttention(config)
elif config.unpad:
self.attention = FastUnpadBertAttention(config)
else:
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask, seqlen, batch):
if self.unpad:
attention_output = self.attention(hidden_states, attention_mask, seqlen, batch)
else:
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class BertEncoder(nn.Module):
def __init__(self, config):
super(BertEncoder, self).__init__()
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
self.num_attention_heads = config.num_attention_heads
self.fused_mha=config.fused_mha
self.unpad=config.unpad
self.pad = config.pad
self.fuse_mask = config.fuse_mask
self.enable_stream = config.enable_stream
# def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
# all_encoder_layers = []
# for layer_module in self.layer:
# hidden_states = layer_module(hidden_states, attention_mask)
# if output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# if not output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# return all_encoder_layers
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True, checkpoint_activations=False):
# Unpad inputs and mask. It will remove tokens that are padded. Assume ntokens is total number of tokens (padded and non-padded)
# and ntokens_unpad is total number of non-padded tokens. Then unpadding performs the following compression of the inputs:
# hidden_states[ntokens,hidden] -> hidden_states[ntokens_unpad,hidden]
batch = None
seqlen = None
if self.unpad:
batch = hidden_states.shape[0]
maxseqlen = hidden_states.shape[1]
hidden_size = hidden_states.shape[2]
attention_indices, attention_mask, seqlen, ntokens = generate_mask(attention_mask, self.num_attention_heads, pad=self.pad, fuse_mask=self.fuse_mask)
if self.pad == True and self.enable_stream == False:
hidden_states = hidden_states.view(batch,maxseqlen,hidden_size).permute(1,0,2).contiguous().view(batch*maxseqlen,hidden_size).contiguous()
if self.pad == True and self.enable_stream == True:
hidden_states = hidden_states.view(batch*maxseqlen,hidden_size)
if self.pad == False:
hidden_states = UnpadInput.apply(hidden_states.view(batch*maxseqlen, hidden_size).contiguous(), attention_indices, batch, maxseqlen, hidden_size, ntokens)
all_encoder_layers = []
def custom(start, end):
def custom_forward(*inputs):
layers = self.layer[start:end]
x_ = inputs[0]
for layer in layers:
x_ = layer(x_, inputs[1])
return x_
return custom_forward
if checkpoint_activations:
l = 0
num_layers = len(self.layer)
chunk_length = math.ceil(math.sqrt(num_layers))
while l < num_layers:
hidden_states = checkpoint.checkpoint(custom(l, l+chunk_length), hidden_states, attention_mask*1)
l += chunk_length
# decoder layers
else:
if self.fused_mha:
hidden_states = hidden_states.permute(1,0,2).contiguous()
for i,layer_module in enumerate(self.layer):
hidden_states = layer_module(hidden_states, attention_mask, seqlen, batch)
if output_all_encoded_layers:
if self.fused_mha:
all_encoder_layers.append(hidden_states.permute(1,0,2).contiguous())
else:
all_encoder_layers.append(hidden_states)
# Pad inputs and mask. It will insert back zero-padded tokens. Assume ntokens is total number of tokens (padded and non-padded)
# and ntokens_unpad is total number of non-padded tokens. Then padding performs the following de-compression:
# hidden_states[ntokens_unpad,hidden] -> hidden_states[ntokens,hidden]
if self.unpad:
if self.pad == True and self.enable_stream == False:
hidden_states = hidden_states.view(maxseqlen,batch,hidden_size).permute(1,0,2).contiguous().view(batch,maxseqlen,hidden_size).contiguous()
if self.pad == True and self.enable_stream == True:
hidden_states = hidden_states.view(batch,maxseqlen,hidden_size)
if self.pad == False:
hidden_states = PadInput.apply(hidden_states, attention_indices, batch, maxseqlen, hidden_size, ntokens).view(batch, maxseqlen, hidden_size).contiguous()
if not output_all_encoded_layers or checkpoint_activations:
if self.fused_mha:
all_encoder_layers.append(hidden_states.permute(1,0,2).contiguous())
else:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
#class BertEncoder(nn.Module):
# def __init__(self, config):
# super(BertEncoder, self).__init__()
# layer = BertLayer(config)
# self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
#
# def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
# all_encoder_layers = []
# for layer_module in self.layer:
# hidden_states = layer_module(hidden_states, attention_mask)
# if output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# if not output_all_encoded_layers:
# all_encoder_layers.append(hidden_states)
# return all_encoder_layers
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
bert_model_embedding_weights.size(0),
bias=False)
self.decoder.weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertOnlyMLMHead, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertOnlyNSPHead(nn.Module):
def __init__(self, config):
super(BertOnlyNSPHead, self).__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class BertPreTrainingHeads(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertPreTrainingHeads, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
self.dense_seq_output = config.dense_seq_output
def forward(self, sequence_output, pooled_output, masked_lm_labels):
if self.dense_seq_output:
# We are masking out elements that won't contribute to loss because of masked lm labels
sequence_flattened = torch.index_select(sequence_output.view(-1,sequence_output.shape[-1]), 0, torch.nonzero(masked_lm_labels.view(-1) != -1, as_tuple=False).squeeze())
sequence_output = sequence_flattened
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class BertPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(BertPreTrainedModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
# we want to make sure vocab size is padded to % 8 == 0
if self.config.vocab_size % 8 != 0:
self.config.vocab_size += 8 - (self.config.vocab_size % 8)
if get_rank == 0:
print(f'Padded vocab_size to : {self.config.vocab_size}')
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_checkpoint, state_dict=None, cache_dir=None,
from_tf=False, config=None, *inputs, **kwargs):
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
logger.info("loading archive file {}".format(pretrained_checkpoint))
assert config, "BERT configuration file must be provided to from_pretraining()"
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None and not from_tf:
state_dict = torch.load(pretrained_checkpoint, map_location='cpu' if not torch.cuda.is_available() else None)
if from_tf:
# Directly load from a TensorFlow checkpoint
return load_tf_weights_in_bert(model, pretrained_checkpoint, use_fast_mha=config.fused_mha)
# Load from a PyTorch state_dict
old_keys = []
new_keys = []
# print(f'loading keys: {state_dict.keys()}')
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
start_prefix = ''
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
start_prefix = 'bert.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs)))
return model
class BertModel(BertPreTrainedModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
config: a BertConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.BertModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
self.unpad = config.unpad
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True, checkpoint_activations=False):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask#.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
if self.unpad == False:
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers, checkpoint_activations=checkpoint_activations)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output
class BertForPreTraining(BertPreTrainedModel):
"""BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads:
- the masked language modeling head, and
- the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
`next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
Outputs:
if `masked_lm_labels` and `next_sentence_label` are not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `masked_lm_labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
- the next sentence classification logits of shape [batch_size, 2].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForPreTraining(config)
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForPreTraining, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
self.dense_seq_output = config.dense_seq_output
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, checkpoint_activations=False):
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations)
# if dense_seq_output, prediction scores returned by this function is already masked out with masked_lm_labels, and first dimension is flattened
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output, masked_lm_labels)
if self.dense_seq_output:
masked_lm_labels_flat = masked_lm_labels.view(-1)
mlm_labels = masked_lm_labels_flat[masked_lm_labels_flat != -1]
if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
if self.dense_seq_output:
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
else:
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
#print("loss is {} {}".format(masked_lm_loss, next_sentence_loss))
total_loss = masked_lm_loss + next_sentence_loss
# Masked Language Model Accuracy
if not self.dense_seq_output:
prediction_scores_flat = prediction_scores.view(-1, prediction_scores.shape[-1])
masked_lm_labels_flat = masked_lm_labels.view(-1)
mlm_predictions_scores = prediction_scores_flat[masked_lm_labels_flat != -1]
mlm_predictions = mlm_predictions_scores.argmax(dim=-1)
mlm_labels = masked_lm_labels_flat[masked_lm_labels_flat != -1]
else:
mlm_predictions = prediction_scores.argmax(dim=-1)
mlm_acc = (mlm_predictions == mlm_labels).sum(dtype=torch.float)/mlm_labels.numel()
return total_loss, mlm_acc, mlm_labels.numel()
else: #TODO: Handle this path for dense sequence output as well
return prediction_scores, seq_relationship_score
class BertForMaskedLM(BertPreTrainedModel):
"""BERT model with the masked language modeling head.
This module comprises the BERT model followed by the masked language modeling head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
Outputs:
if `masked_lm_labels` is not `None`:
Outputs the masked language modeling loss.
if `masked_lm_labels` is `None`:
Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForMaskedLM(config)
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForMaskedLM, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, checkpoint_activations=False):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
prediction_scores = self.cls(sequence_output)
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return masked_lm_loss
else:
return prediction_scores
class BertForNextSentencePrediction(BertPreTrainedModel):
"""BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
Outputs:
if `next_sentence_label` is not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `next_sentence_label` is `None`:
Outputs the next sentence classification logits of shape [batch_size, 2].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForNextSentencePrediction(config)
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForNextSentencePrediction, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None, checkpoint_activations=False):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
seq_relationship_score = self.cls( pooled_output)
if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
return next_sentence_loss
else:
return seq_relationship_score
class BertForSequenceClassification(BertPreTrainedModel):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_labels].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
model = BertForSequenceClassification(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels):
super(BertForSequenceClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits
class BertForMultipleChoice(BertPreTrainedModel):
"""BERT model for multiple choice tasks.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_choices`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_choices].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_choices = 2
model = BertForMultipleChoice(config, num_choices)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices):
super(BertForMultipleChoice, self).__init__(config)
self.num_choices = num_choices
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
return loss
else:
return reshaped_logits
class BertForTokenClassification(BertPreTrainedModel):
"""BERT model for token-level classification.
This module is composed of the BERT model with a linear layer on top of
the full hidden state of the last layer.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [0, ..., num_labels].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits of shape [batch_size, sequence_length, num_labels].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
num_labels = 2
model = BertForTokenClassification(config, num_labels)
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels):
super(BertForTokenClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits
class BertForQuestionAnswering(BertPreTrainedModel):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with a linear layer on top of
the sequence output that computes start_logits and end_logits
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
Outputs:
if `start_positions` and `end_positions` are not `None`:
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens of shape [batch_size, sequence_length].
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForQuestionAnswering(config)
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForQuestionAnswering, self).__init__(config)
self.bert = BertModel(config)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None, checkpoint_activations=False):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension
if len(start_positions.size()) > 1:
start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss
else:
return start_logits, end_logits
# coding=utf-8
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for BERT model."""
import math
import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
#from fused_adam_local import FusedAdam
from apex.optimizers import FusedAdam
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from utils import is_main_process
import mlperf_logger
multi_tensor_l2norm = amp_C.multi_tensor_l2norm
lamb_compute_update = amp_C.multi_tensor_lamb_stage1_cuda
lamb_apply_update = amp_C.multi_tensor_lamb_stage2_cuda
scale = amp_C.multi_tensor_scale
def warmup_cosine(x, warmup=0.002):
if x < warmup:
return x/warmup
return 0.5 * (1.0 + torch.cos(math.pi * x))
def warmup_constant(x, warmup=0.002):
if x < warmup:
return x/warmup
return 1.0
def warmup_linear(x, warmup=0.002):
if x < warmup:
return x/warmup
return max((x - 1. )/ (warmup - 1.), 0.)
def warmup_poly(x, warmup=0.002, degree=0.5):
if x < warmup:
return x/warmup
return (1.0 - x)**degree
SCHEDULES = {
'warmup_cosine':warmup_cosine,
'warmup_constant':warmup_constant,
'warmup_linear':warmup_linear,
'warmup_poly':warmup_poly,
}
class BertAdam(Optimizer):
"""Implements BERT version of Adam algorithm with weight decay fix.
Params:
lr: learning rate
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
t_total: total number of training steps for the learning
rate schedule, -1 means constant learning rate. Default: -1
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
b1: Adams b1. Default: 0.9
b2: Adams b2. Default: 0.999
e: Adams epsilon. Default: 1e-6
weight_decay: Weight decay. Default: 0.01
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
"""
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01,
max_grad_norm=1.0):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
if not e >= 0.0:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
max_grad_norm=max_grad_norm)
super(BertAdam, self).__init__(params, defaults)
def get_lr(self):
lr = []
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
if len(state) == 0:
return [0]
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
lr.append(lr_scheduled)
return lr
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['next_m'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['next_v'] = torch.zeros_like(p.data)
next_m, next_v = state['next_m'], state['next_v']
beta1, beta2 = group['b1'], group['b2']
# Add grad clipping
if group['max_grad_norm'] > 0:
clip_grad_norm_(p, group['max_grad_norm'])
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
next_m.mul_(beta1).add_(1 - beta1, grad)
next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
update = next_m / (next_v.sqrt() + group['e'])
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if group['weight_decay'] > 0.0:
update += group['weight_decay'] * p.data
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
update_with_lr = lr_scheduled * update
p.data.add_(-update_with_lr)
state['step'] += 1
return loss
import torch
import math
#######################################################################################################################################################################
def unpad_input(out_, in_, indices):
out_[:,:] = in_[indices[:],:]
def pad_input(out_, in_, indices):
out_[indices[:],:] = in_[:,:]
def unpad_mask(out_, in_, indices):
out_[:] = in_.flatten()[indices[:]]
#######################################################################################################################################################################
def generate_mask(attention_mask, heads, pad=False, fuse_mask=True):
seqlen = attention_mask.sum(dim=1).float().cpu()
if pad == False:
seqlen[:] = ((seqlen[:] + 16 - 1) / 16).floor()*16
seqlen[seqlen < 16] = 16
seqlen = seqlen.int()
ntokens = seqlen.sum().item()
else:
batch = attention_mask.shape[0]
maxseqlen = attention_mask.shape[1]
seqlen.fill_(maxseqlen)
seqlen = seqlen.int()
ntokens = batch * maxseqlen
padded_mask = attention_mask.clone()
for i in range(len(seqlen)):
padded_mask[i,:seqlen[i]] = 1
indices = torch.nonzero(padded_mask.flatten(), as_tuple=False).flatten()
if pad==False and fuse_mask == True:
mask = torch.zeros([ntokens], device="cuda", dtype=torch.float16)
unpad_mask(mask, attention_mask, indices)
mask = (1 - mask) * -10000.0
elif pad==False and fuse_mask == False:
padded_mask = (padded_mask.unsqueeze(1) * padded_mask.unsqueeze(2)).unsqueeze(1).half().repeat(1, heads, 1, 1)
indices_mask = torch.nonzero(padded_mask.flatten(), as_tuple=False).flatten()
mask = torch.zeros([len(indices_mask)], device="cuda", dtype=torch.float16)
unpad_mask(mask, padded_mask, indices_mask)
mask = (1 - mask) * -10000.0
elif pad==True and fuse_mask == True:
mask = -10000.0 * (1 - attention_mask).half().view(-1)
elif pad==True and fuse_mask == False:
mask = -10000.0 * (1 - (attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2))).unsqueeze(1).half().repeat(1, heads, 1, 1).view(-1)
return indices, mask, seqlen, ntokens
#######################################################################################################################################################################
class PadInput(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices, batch, maxseqlen, hidden, ntokens):
ctx.save_for_backward(indices)
ctx.hidden = hidden
ctx.ntokens = ntokens
ntokens = batch*maxseqlen
output = torch.zeros([ntokens,hidden], device="cuda", dtype=torch.float16)
pad_input(output, input, indices)
return output[:ntokens]
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
grad_input = torch.zeros([ctx.ntokens,ctx.hidden], device="cuda", dtype=torch.float16)
unpad_input(grad_input, grad_output, indices)
return grad_input[:ctx.ntokens], None, None, None, None, None
#######################################################################################################################################################################
class UnpadInput(torch.autograd.Function):
@staticmethod
def forward(ctx, input, indices, batch, maxseqlen, hidden, ntokens):
ctx.save_for_backward(indices)
ctx.hidden = hidden
ctx.ntokens = batch*maxseqlen
output = torch.zeros([ntokens, hidden], device="cuda", dtype=torch.float16)
unpad_input(output, input, indices)
return output[:ntokens]
@staticmethod
def backward(ctx, grad_output):
indices, = ctx.saved_tensors
grad_input = torch.zeros([ctx.ntokens,ctx.hidden], device="cuda", dtype=torch.float16)
pad_input(grad_input, grad_output, indices)
return grad_input[:ctx.ntokens], None, None, None, None, None
#######################################################################################################################################################################
# progress bars in model download and training scripts
boto3==1.14.0
h5py==2.10.0
html2text==2020.1.16
ipdb==0.13.2
nltk==3.5
onnxruntime==1.3.0
progressbar==2.5
requests==2.23.0
six==1.15.0
tensorflow==1.13.0rc0
#!/bin/bash
#SBATCH --exclusive
#SBATCH --mem=0
#SBATCH --overcommit
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -eux
# The following variables variables need to be set
# Base container to be used
# Location of dataset for phase 1
#readonly DATADIR="<machine_specific_path_here>"
# Location of dataset for phase 2
#readonly DATADIR_PHASE2="<machine_specific_path_here>"
# Path to where trained checkpoints will be saved on the system
#readonly CHECKPOINTDIR="<machine_specific_path_here>"
# Path to pretrained Phase1 checkpoint
#readonly CHECKPOINTDIR_PHASE1="<machine_specific_path_here>"
# Vars without defaults
: "${DGXSYSTEM:?DGXSYSTEM not set}"
: "${CONT:?CONT not set}"
# Vars with defaults
: "${NEXP:=5}"
: "${DATESTAMP:=$(date +'%y%m%d%H%M%S%N')}"
: "${LOGDIR:=./results}"
: "${CLEAR_CACHES:=1}"
# Other vars
readonly _logfile_base="${LOGDIR}/${DATESTAMP}"
readonly _cont_name=language_model
readonly _cont_mounts="${DATADIR}:/workspace/data,${DATADIR_PHASE2}:/workspace/data_phase2,${CHECKPOINTDIR}:/results,${CHECKPOINTDIR_PHASE1}:/workspace/phase1,${EVALDIR}:/workspace/evaldata"
srun --ntasks="${SLURM_JOB_NUM_NODES}" --ntasks-per-node=1 mkdir -p "${CHECKPOINTDIR}"
# If THROUGHPUT_RUN env variable not empty, do a small number of steps to get throughput, otherwise stop based on mlm_accuracy threshold (happens before large number of steps set)
THROUGHPUT_RUN=${THROUGHPUT_RUN:-""}
if [ -z "$THROUGHPUT_RUN" ]
then
MAX_STEPS=${MAX_STEPS:-1536000}
else
MAX_STEPS=4
fi
PHASE1="\
--train_batch_size=$BATCHSIZE \
--learning_rate=${LR:-6e-3} \
--warmup_proportion=${WARMUP_PROPORTION:-0.0} \
--max_steps=7038 \
--num_steps_per_checkpoint=2500 \
--max_seq_length=128 \
--max_predictions_per_seq=20 \
--input_dir=/workspace/data \
"
PHASE2="\
--train_batch_size=$BATCHSIZE \
--learning_rate=${LR:-4e-3} \
--opt_lamb_beta_1=${OPT_LAMB_BETA_1:-0.9} \
--opt_lamb_beta_2=${OPT_LAMB_BETA_2:-0.999} \
--warmup_proportion=${WARMUP_PROPORTION:-0.0} \
--warmup_steps=${WARMUP_STEPS:-0.0} \
--start_warmup_step=${START_WARMUP_STEP:-0.0} \
--max_steps=$MAX_STEPS \
--phase2 \
--max_seq_length=512 \
--max_predictions_per_seq=76 \
--input_dir=/workspace/data_phase2 \
--init_checkpoint=/workspace/phase1/model.ckpt-28252.pt \
"
PHASES=( "$PHASE1" "$PHASE2" )
PHASE=${PHASE:-2}
echo "***** Running Phase $PHASE *****"
echo "***** SLURM_NNODES: $SLURM_NNODES *****"
echo "***** SLURM_NTASKS: $SLURM_NTASKS *****"
cluster=''
if [[ "${DGXSYSTEM}" == DGX2* ]]; then
cluster='circe'
fi
if [[ "${DGXSYSTEM}" == DGXA100* ]]; then
cluster='selene'
fi
MAX_SAMPLES_TERMINATION=${MAX_SAMPLES_TERMINATION:-14000000}
EVAL_ITER_START_SAMPLES=${EVAL_ITER_START_SAMPLES:-3000000}
EVAL_ITER_SAMPLES=${EVAL_ITER_SAMPLES:-500000}
GRADIENT_STEPS=${GRADIENT_STEPS:-2}
USE_DDP=${USE_DDP:-0}
# Run fixed number of training samples
BERT_CMD="\
./bind.sh --cpu=exclusive --ib=single --cluster=${cluster} -- \
python -u /workspace/bert/run_pretraining.py \
$PHASE2 \
--do_train \
--skip_checkpoint \
--train_mlm_accuracy_window_size=0 \
--target_mlm_accuracy=${TARGET_MLM_ACCURACY:-0.712} \
--weight_decay_rate=${WEIGHT_DECAY_RATE:-0.01} \
--max_samples_termination=${MAX_SAMPLES_TERMINATION} \
--eval_iter_start_samples=${EVAL_ITER_START_SAMPLES} --eval_iter_samples=${EVAL_ITER_SAMPLES} \
--eval_batch_size=16 --eval_dir=/workspace/evaldata \
--cache_eval_data \
--output_dir=/results \
--fp16 --fused_gelu_bias --dense_seq_output --fused_mha ${EXTRA_PARAMS} \
--gradient_accumulation_steps=${GRADIENT_STEPS} \
--log_freq=1 \
--local_rank=\${SLURM_LOCALID} \
--bert_config_path=/workspace/phase1/bert_config.json"
if [[ $USE_DDP != 1 || $GRADIENT_STEPS != 1 ]]; then
BERT_CMD="${BERT_CMD} --allreduce_post_accumulation --allreduce_post_accumulation_fp16"
fi
CONTAINER_PRELOAD_LUSTRE=${CONTAINER_PRELOAD_LUSTRE:-0}
if [[ $CONTAINER_PRELOAD_LUSTRE -gt 0 ]]; then
CONT_FILE="/lustre/fsw/containers/${SLURM_JOBID}_$(basename ${CONT}).squashfs"
# Prepull container to LUSTRE
srun --ntasks=1 enroot import --output ${CONT_FILE} docker://${CONT}
else
CONT_FILE=${CONT}
fi
# Setup container
srun --ntasks="$(( SLURM_JOB_NUM_NODES))" --container-image="${CONT_FILE}" --container-name="${_cont_name}" true
# Run experiments
for _experiment_index in $(seq 1 "${NEXP}"); do
(
echo "Beginning trial ${_experiment_index} of ${NEXP}"
# Clear caches
if [ "${CLEAR_CACHES}" -eq 1 ]; then
srun --ntasks="${SLURM_JOB_NUM_NODES}" bash -c "echo -n 'Clearing cache on ' && hostname && sync && sudo /sbin/sysctl vm.drop_caches=3"
srun --ntasks="${SLURM_JOB_NUM_NODES}" --container-name="${_cont_name}" python -c "
import mlperf_logger
mlperf_logger.log_event(key=mlperf_logger.constants.CACHE_CLEAR, value=True)"
fi
# Run experiment
srun -l --mpi=none --ntasks="$(( SLURM_JOB_NUM_NODES * DGXNGPU ))" --ntasks-per-node="${DGXNGPU}" --container-name="${_cont_name}" --container-mounts="${_cont_mounts}" sh -c "/workspace/bert/run_and_time.sh \"${BERT_CMD}\" ${SEED:-$RANDOM} "
) |& tee "${_logfile_base}_${_experiment_index}.log"
done
# Cleanup
if [[ $CONTAINER_PRELOAD_LUSTRE -gt 0 ]]; then
srun --ntasks=1 rm ${CONT_FILE}
fi
#!/bin/bash
SLURM_NTASKS_PER_NODE=${SLURM_NTASKS_PER_NODE:-$DGXNGPU}
SLURM_JOB_ID=${SLURM_JOB_ID:-$RANDOM}
MULTI_NODE=${MULTI_NODE:-''}
echo "Run vars: id $SLURM_JOB_ID gpus $SLURM_NTASKS_PER_NODE mparams $MULTI_NODE"
# Start timing
START=$(date +%s)
START_FMT=$(date +%Y-%m-%d\ %r)
echo "STARTING TIMING RUN AT ${START_FMT}"
BERT_CMD=${1}
SEED=${2}
# Options
set -x
eval "${BERT_CMD} --seed=${SEED}"
# End timing
END=$(date +%s)
END_FMT=$(date +%Y-%m-%d\ %r)
echo "ENDING TIMING RUN AT ${END_FMT}"
# Report result
RESULT=$(( ${END} - ${START} ))
RESULT_NAME="bert"
echo "RESULT,${RESULT_NAME},${SEED},${RESULT},${USER},${START_FMT}"
set +x
# coding=utf-8
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Copyright 2020 MLBenchmark Group. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT Pretraining"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import csv
import h5py
import os
import glob
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from torch.utils.data.distributed import DistributedSampler
import logging
import math
import multiprocessing
import numpy as np
import os
import random
import re
import time
from collections import OrderedDict
from concurrent.futures import ProcessPoolExecutor
from modeling import BertForPreTraining, BertConfig
from apex.optimizers import FusedLAMB
from schedulers import LinearWarmupPolyDecayScheduler
import utils
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from torch.utils.data.distributed import DistributedSampler
import amp_C
import apex_C
from apex import amp
from apex.amp import _amp_state
from apex.optimizers import FusedLAMB
from apex.parallel import DistributedDataParallel as DDP
from apex.parallel.distributed import flat_dist_call
from file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from modeling import BertForPreTraining, BertConfig
from schedulers import LinearWarmUpScheduler, LinearWarmupPolyDecayScheduler
import mlperf_logger
from mhalib import *
# Global variables
skipped_steps = 0
cached_batches = []
class WorkerInitObj(object):
def __init__(self, seed):
self.seed = seed
def __call__(self, id):
np.random.seed(seed=self.seed + id)
random.seed(self.seed + id)
def create_pretraining_dataset(input_file, max_pred_length, shared_list, args, worker_init_fn):
train_data = pretraining_dataset(input_file=input_file, max_pred_length=max_pred_length)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler,
batch_size=args.train_batch_size, num_workers=4, worker_init_fn=worker_init_fn,
pin_memory=True)
return train_dataloader, input_file
def create_eval_dataset(args, worker_init_fn):
eval_data = []
for eval_file in sorted(os.listdir(args.eval_dir)):
eval_file_path = os.path.join(args.eval_dir, eval_file)
if os.path.isfile(eval_file_path) and 'part' in eval_file_path:
eval_data.extend(pretraining_dataset(eval_file_path, max_pred_length=args.max_predictions_per_seq))
if len(eval_data) > args.num_eval_examples:
eval_data = eval_data[:args.num_eval_examples]
break
if torch.distributed.is_initialized():
chunk_size = args.num_eval_examples // torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
remainder = args.num_eval_examples % torch.distributed.get_world_size()
if rank<remainder:
eval_data = eval_data[(chunk_size+1)*rank : (chunk_size+1)*(rank+1)]
else:
eval_data = eval_data[chunk_size*rank+remainder : chunk_size*(rank+1)+remainder]
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size,
num_workers=4, worker_init_fn=worker_init_fn, pin_memory=True)
return eval_dataloader
class pretraining_dataset(Dataset):
def __init__(self, input_file, max_pred_length):
self.input_file = input_file
self.max_pred_length = max_pred_length
f = h5py.File(input_file, "r")
keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions', 'masked_lm_ids',
'next_sentence_labels']
self.inputs = [np.asarray(f[key][:]) for key in keys]
f.close()
def __len__(self):
'Denotes the total number of samples'
return len(self.inputs[0])
def __getitem__(self, index):
[input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [
torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy(
np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs)]
masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -1
index = self.max_pred_length
# store number of masked tokens in index
padded_mask_indices = (masked_lm_positions == 0).nonzero()
if len(padded_mask_indices) != 0:
index = padded_mask_indices[0].item()
masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index]
return [input_ids, segment_ids, input_mask,
masked_lm_labels, next_sentence_labels]
def parse_arguments():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--input_dir",
default=None,
type=str,
required=True,
help="The input data dir. Should contain .hdf5 files for the task.")
parser.add_argument("--bert_model", default="bert-large-uncased", type=str,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
parser.add_argument("--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model checkpoints will be written.")
## Other parameters
parser.add_argument("--eval_dir",
default=None,
type=str,
help="The eval data dir. Should contain .hdf5 files for the task.")
parser.add_argument("--eval_iter_start_samples",
default=3000000,
type=int,
help="Sample to begin performing eval.")
parser.add_argument("--eval_iter_samples",
default=-1,
type=int,
help="If set to -1, disable eval, \
else evaluate every eval_iter_samples during training")
parser.add_argument("--num_eval_examples",
default=10000,
type=int,
help="number of eval examples to run eval on")
parser.add_argument("--cache_eval_data",
default=False,
action='store_true',
help="whether to cache evaluation data on GPU")
parser.add_argument("--init_checkpoint",
default=None,
type=str,
help="The initial checkpoint to start training from.")
parser.add_argument("--init_tf_checkpoint",
default=None,
type=str,
help="The initial TF checkpoint to start training from.")
parser.add_argument("--max_seq_length",
default=512,
type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")
parser.add_argument("--max_predictions_per_seq",
default=76,
type=int,
help="The maximum total of masked tokens in input sequence")
parser.add_argument("--train_batch_size",
default=18,
type=int,
help="Total batch size for training.")
parser.add_argument("--eval_batch_size",
default=128,
type=int,
help="Total batch size for training.")
parser.add_argument("--learning_rate",
default=4e-5,
type=float,
help="The initial learning rate for LAMB.")
parser.add_argument("--weight_decay_rate",
default=0.01,
type=float,
help="weight decay rate for LAMB.")
parser.add_argument("--opt_lamb_beta_1",
default=0.9,
type=float,
help="LAMB beta1.")
parser.add_argument("--opt_lamb_beta_2",
default=0.999,
type=float,
help="LAMB beta2.")
parser.add_argument("--max_steps",
default=1536,
type=float,
help="Total number of training steps to perform.")
parser.add_argument("--max_samples_termination",
default=14000000,
type=float,
help="Total number of training samples to run.")
parser.add_argument("--warmup_proportion",
default=0.01,
type=float,
help="Proportion of optimizer update steps to perform linear learning rate warmup for. "
"Typically 1/8th of steps for Phase2")
parser.add_argument("--warmup_steps",
default=0,
type=float,
help="Number of optimizer update steps to perform linear learning rate warmup for. "
"Typically 1/8th of steps for Phase2")
parser.add_argument("--start_warmup_step",
default=0,
type=float,
help="Starting step for warmup. ")
parser.add_argument("--local_rank",
type=int,
default=-1,
help="local_rank for distributed training on gpus")
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=1,
help="Number of updates steps to accumualte before performing a backward/update pass.")
parser.add_argument('--fp16',
default=False,
action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--loss_scale',
type=float, default=0.0,
help='Loss scaling, positive power of 2 values can improve fp16 convergence.')
parser.add_argument('--log_freq',
type=float, default=1.0,
help='frequency of logging loss.')
parser.add_argument('--checkpoint_activations',
default=False,
action='store_true',
help="Whether to use gradient checkpointing")
parser.add_argument("--resume_from_checkpoint",
default=False,
action='store_true',
help="Whether to resume training from checkpoint. If set, precedes init_checkpoint/init_tf_checkpoint")
parser.add_argument('--keep_n_most_recent_checkpoints',
type=int,
default=20,
help="Number of checkpoints to keep (rolling basis).")
parser.add_argument('--num_samples_per_checkpoint',
type=int,
default=500000,
help="Number of update steps until a model checkpoint is saved to disk.")
parser.add_argument('--min_samples_to_start_checkpoints',
type=int,
default=3000000,
help="Number of update steps until model checkpoints start saving to disk.")
parser.add_argument('--skip_checkpoint',
default=False,
action='store_true',
help="Whether to save checkpoints")
parser.add_argument('--phase2',
default=False,
action='store_true',
help="Only required for checkpoint saving format")
parser.add_argument('--allreduce_post_accumulation',
default=False,
action='store_true',
help="Whether to do allreduces during gradient accumulation steps.")
parser.add_argument('--allreduce_post_accumulation_fp16',
default=False,
action='store_true',
help="Whether to do fp16 allreduce post accumulation.")
parser.add_argument("--do_train",
default=False,
action='store_true',
help="Whether to run training.")
parser.add_argument("--unpad",
default=False,
action='store_true',
help="Whether to run with unpadding.")
parser.add_argument("--pad",
default=False,
action='store_true',
help="Whether to pad tokens.")
parser.add_argument("--enable_fuse_dropout",
default=False,
action='store_true',
help="Whether to disable fusion of attention mask to softmax and dropout.")
parser.add_argument("--disable_fuse_mask",
default=False,
action='store_true',
help="Whether to disable fusion of the attention mask to softmax.")
parser.add_argument("--disable_fuse_scale",
default=False,
action='store_true',
help="Whether to disable fusion of the scaling to BMM1.")
parser.add_argument("--disable_fuse_qkv",
default=False,
action='store_true',
help="Whether to disable fusion of the QKV GEMMs.")
parser.add_argument("--disable_apex_softmax",
default=False,
action='store_true',
help="Whether to disable apex softmax.")
parser.add_argument("--enable_stream",
default=False,
action='store_true',
help="Enable use of streams for pad case.")
parser.add_argument("--fused_mha",
default=False,
action='store_true',
help="Whether to run with optimizations.")
parser.add_argument("--fused_gelu_bias",
default=False,
action='store_true',
help="Whether to run with optimizations.")
parser.add_argument("--dense_seq_output",
default=False,
action='store_true',
help="Whether to run with optimizations.")
parser.add_argument("--use_env",
action='store_true',
help="Whether to read local rank from ENVVAR")
parser.add_argument('--bert_config_path',
type=str,
default="/workspace/phase1",
help="Path bert_config.json is located in")
parser.add_argument('--target_mlm_accuracy',
type=float,
default=0.0,
help="Stop training after reaching this Masked-LM accuracy")
parser.add_argument('--train_mlm_accuracy_window_size',
type=int,
default=0,
help="Average accuracy over this amount of batches before performing a stopping criterion test")
parser.add_argument('--num_epochs_to_generate_seeds_for',
type=int,
default=2,
help="Number of epochs to plan seeds for. Same set across all workers.")
args = parser.parse_args()
# Check we've been given a checkpoint
assert args.init_checkpoint is not None or args.init_tf_checkpoint is not None or found_resume_checkpoint(args), \
"Must specify --init_checkpoint, --init_tf_checkpoint or have ckpt to resume from in --output_dir of the form *.pt"
assert not (args.init_checkpoint is not None and args.init_tf_checkpoint is not None), \
"Can only specify one of --init_checkpoint and --init_tf_checkpoint"
return args
# Returns true only if resuming from a checkpoint found in output_dir.
# init_checkpoint and init_tf_checkpoint are not considered
def found_resume_checkpoint(args):
if args.phase2:
checkpoint_str = "phase2_ckpt*.pt"
else:
checkpoint_str = "phase1_ckpt*.pt"
return args.resume_from_checkpoint and len(glob.glob(os.path.join(args.output_dir, checkpoint_str))) > 0
def setup_training(args):
assert (torch.cuda.is_available())
if args.local_rank == -1:
device = torch.device("cuda")
args.n_gpu = torch.cuda.device_count()
else:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl', init_method='env://')
args.n_gpu = torch.distributed.get_world_size()
print("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
device, args.n_gpu, bool(args.local_rank != -1), args.fp16))
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
args.gradient_accumulation_steps))
if args.train_batch_size % args.gradient_accumulation_steps != 0:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format(
args.gradient_accumulation_steps, args.train_batch_size))
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
if not (args.do_train or (args.eval_dir and args.eval_iter_samples <= 0)):
raise ValueError(" `do_train` or should be in offline eval mode")
if not args.resume_from_checkpoint or not os.path.exists(args.output_dir):
os.makedirs(args.output_dir, exist_ok=True)
return device, args
def remap_attn_parameters(model_dict):
res_dict = OrderedDict()
for k in model_dict:
if 'attention' in k:
if 'self.query.weight' in k:
new_k = k.replace('self.query.weight', 'multi_head_attention.q_weight')
elif 'self.key.weight' in k:
new_k = k.replace('self.key.weight', 'multi_head_attention.k_weight')
elif 'self.value.weight' in k:
new_k = k.replace('self.value.weight', 'multi_head_attention.v_weight')
elif 'self.query.bias' in k:
new_k = k.replace('self.query.bias', 'multi_head_attention.q_bias')
elif 'self.key.bias' in k:
new_k = k.replace('self.key.bias', 'multi_head_attention.k_bias')
elif 'self.value.bias' in k:
new_k = k.replace('self.value.bias', 'multi_head_attention.v_bias')
elif 'output.dense.weight' in k:
new_k = k.replace('output.dense.weight', 'multi_head_attention.out_proj_weight')
elif 'output.dense.bias' in k:
new_k = k.replace('output.dense.bias', 'multi_head_attention.out_proj_bias')
elif 'output.LayerNorm.weight' in k:
new_k = k.replace('output.LayerNorm.weight', 'layer_norm.weight')
elif 'output.LayerNorm.bias' in k:
new_k = k.replace('output.LayerNorm.bias', 'layer_norm.bias')
else:
new_k = k
else:
new_k = k
res_dict[new_k] = model_dict[k]
model_dict.clear()
return res_dict
def prepare_model_and_optimizer(args, device):
global_step = 0
args.resume_step = 0
checkpoint = None
config = BertConfig.from_json_file(args.bert_config_path)
config.fused_mha = args.fused_mha
config.fused_gelu_bias = args.fused_gelu_bias
config.dense_seq_output = args.dense_seq_output
config.unpad = args.unpad
config.pad = args.pad
config.fuse_qkv = not args.disable_fuse_qkv
config.fuse_scale = not args.disable_fuse_scale
config.fuse_mask = not args.disable_fuse_mask
config.fuse_dropout = args.enable_fuse_dropout
config.apex_softmax = not args.disable_apex_softmax
config.enable_stream = args.enable_stream
if config.fuse_mask == True: config.apex_softmax = True
if config.pad == False: config.enable_stream = True
if config.unpad == True: config.fused_mha = False
# Padding for divisibility by 8
if config.vocab_size % 8 != 0:
config.vocab_size += 8 - (config.vocab_size % 8)
# Load from Pyt checkpoint - either given as init_checkpoint, or picked up from output_dir if found
if args.init_checkpoint is not None or found_resume_checkpoint(args):
# Prepare model
model = BertForPreTraining(config)
if args.init_checkpoint is None: # finding checkpoint in output_dir
checkpoint_str = "phase2_ckpt_*.pt" if args.phase2 else "phase1_ckpt_*.pt"
model_names = [f for f in glob.glob(os.path.join(args.output_dir, checkpoint_str))]
global_step = max([int(x.split('.pt')[0].split('_')[-1].strip()) for x in model_names])
args.resume_step = global_step #used for throughput computation
resume_init_checkpoint = os.path.join(args.output_dir, checkpoint_str.replace("*", str(global_step)))
print("Setting init checkpoint to %s - which is the latest in %s" %(resume_init_checkpoint, args.output_dir))
checkpoint=torch.load(resume_init_checkpoint, map_location="cpu")
else:
checkpoint=torch.load(args.init_checkpoint, map_location="cpu")["model"]
# Fused MHA requires a remapping of checkpoint parameters
if config.fused_mha:
checkpoint_remapped = remap_attn_parameters(checkpoint)
model.load_state_dict(checkpoint_remapped, strict=False)
else:
model.load_state_dict(checkpoint, strict=True)
else: #Load from TF Checkpoint
model = BertForPreTraining.from_pretrained(args.init_tf_checkpoint, from_tf=True, config=config)
model.to(device)
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay_rate},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
mlperf_logger.log_event(key=mlperf_logger.constants.OPT_BASE_LR,
value=args.learning_rate, sync=False)
optimizer = FusedLAMB(optimizer_grouped_parameters,
lr=args.learning_rate,
betas=(args.opt_lamb_beta_1, args.opt_lamb_beta_2))
mlperf_logger.log_event(key='opt_epsilon', value=optimizer.defaults['eps'],
sync=False)
b1, b2 = optimizer.defaults['betas']
mlperf_logger.log_event(key='opt_lamb_beta_1', value=b1, sync=False)
mlperf_logger.log_event(key='opt_lamb_beta_2', value=b2, sync=False)
mlperf_logger.log_event(key='opt_lamb_weight_decay_rate',
value=optimizer.defaults['weight_decay'],
sync=False)
if args.warmup_steps == 0:
warmup_steps = int(args.max_steps * args.warmup_proportion)
warmup_start = 0
else:
warmup_steps = args.warmup_steps
warmup_start = args.start_warmup_step
lr_scheduler = LinearWarmupPolyDecayScheduler(optimizer, start_warmup_steps=warmup_start, warmup_steps=warmup_steps,
total_steps=args.max_steps, end_learning_rate=0.0, degree=1.0)
if args.fp16:
if args.loss_scale == 0:
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic")
else:
model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale)
amp._amp_state.loss_scalers[0]._loss_scale = float(os.getenv("INIT_LOSS_SCALE", 2**20))
if found_resume_checkpoint(args):
optimizer.load_state_dict(checkpoint['optimizer']) #restores m,v states (only if resuming checkpoint, not for init_checkpoint and init_tf_checkpoint for now)
# Restore AMP master parameters
if args.fp16:
optimizer._lazy_init_maybe_master_weights()
optimizer._amp_stash.lazy_init_called = True
optimizer.load_state_dict(checkpoint['optimizer'])
for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']):
param.data.copy_(saved_param.data)
if args.local_rank != -1:
if not args.allreduce_post_accumulation:
model = DDP(model, message_size=250000000, gradient_predivide_factor=torch.distributed.get_world_size())
else:
flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )
return model, optimizer, lr_scheduler, checkpoint, global_step
def take_optimizer_step(args, optimizer, model, overflow_buf, global_step):
global skipped_steps
if args.allreduce_post_accumulation:
# manually allreduce gradients after all accumulation steps
# check for Inf/NaN
# 1. allocate an uninitialized buffer for flattened gradient
scaler = _amp_state.loss_scalers[0]
master_grads = [p.grad for p in amp.master_params(optimizer) if p.grad is not None]
flat_grad_size = sum(p.numel() for p in master_grads)
allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else torch.float32
flat_raw = torch.empty(flat_grad_size, device='cuda', dtype=allreduce_dtype)
# 2. combine unflattening and predivision of unscaled 'raw' gradient
allreduced_views = apex_C.unflatten(flat_raw, master_grads)
overflow_buf.zero_()
amp_C.multi_tensor_scale(65536,
overflow_buf,
[master_grads, allreduced_views],
scaler.loss_scale() / (torch.distributed.get_world_size() * args.gradient_accumulation_steps))
# 3. sum gradient across ranks. Because of the predivision, this averages the gradient
torch.distributed.all_reduce(flat_raw)
# 4. combine unscaling and unflattening of allreduced gradient
overflow_buf.zero_()
amp_C.multi_tensor_scale(65536,
overflow_buf,
[allreduced_views, master_grads],
1./scaler.loss_scale())
# 5. update loss scale
scaler = _amp_state.loss_scalers[0]
old_overflow_buf = scaler._overflow_buf
scaler._overflow_buf = overflow_buf
had_overflow = scaler.update_scale()
scaler._overfloat_buf = old_overflow_buf
# 6. call optimizer step function
if had_overflow == 0:
optimizer.step()
global_step += 1
else:
# Overflow detected, print message and clear gradients
skipped_steps += 1
if _amp_state.opt_properties.master_weights:
for param in optimizer._amp_stash.all_fp32_from_fp16_params:
param.grad = None
for param in model.parameters():
param.grad = None
else:
optimizer.step()
for param in model.parameters():
param.grad = None
global_step += 0 if _amp_state.loss_scalers[0]._has_overflow else 1
return global_step
def run_eval(model, eval_dataloader, device, num_eval_examples, first_eval=False, use_cache=False):
model.eval()
total_eval_loss, total_eval_mlm_acc = 0.0, 0.0
total_masked = 0
# on first eval, load and cache data on GPU
if first_eval and use_cache:
for batch in eval_dataloader:
cached_batches.append([t.to(device) for t in batch])
with torch.no_grad():
for batch in cached_batches if use_cache else eval_dataloader:
if not use_cache: batch = [t.to(device) for t in batch]
input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
loss, mlm_acc, num_masked = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask,
masked_lm_labels=masked_lm_labels, next_sentence_label=next_sentence_labels)
total_eval_loss += loss * num_masked
total_eval_mlm_acc += mlm_acc * num_masked
total_masked += num_masked
model.train()
#total_eval_mlm_acc and total_eval_loss are already tensors, total_masked is not
total_masked = torch.tensor(total_masked, device=device, dtype=torch.int64)
if torch.distributed.is_initialized():
#Collect total scores from all ranks
torch.distributed.all_reduce(total_eval_mlm_acc, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(total_eval_loss, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(total_masked, op=torch.distributed.ReduceOp.SUM)
# Average by number of examples
total_eval_mlm_acc /= total_masked
total_eval_loss /= total_masked
return total_eval_loss.item(), total_eval_mlm_acc.item()
def main():
args = parse_arguments()
status = 'aborted' # later set to 'success' if termination criteria met
mlperf_logger.log_start(key=mlperf_logger.constants.INIT_START,
log_all_ranks=True, sync=False)
if args.use_env and 'LOCAL_RANK' in os.environ:
args.local_rank = int(os.environ['LOCAL_RANK'])
device, args = setup_training(args)
mlperf_logger.mlperf_submission_log('bert')
worker_seeds, shuffling_seeds = utils.setup_seeds(args.seed, args.num_epochs_to_generate_seeds_for, device)
worker_seed = worker_seeds[torch.distributed.get_rank()]
random.seed(worker_seed)
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
worker_init = WorkerInitObj(worker_seed)
mlperf_logger.log_event(key=mlperf_logger.constants.SEED, value=args.seed,
sync=False)
mlperf_logger.log_event(key=mlperf_logger.constants.GLOBAL_BATCH_SIZE,
value=global_batch_size(args), sync=False)
mlperf_logger.log_event(key='opt_gradient_accumulation_steps',
value=args.gradient_accumulation_steps, sync=False)
mlperf_logger.log_event(key='max_predictions_per_seq',
value=args.max_predictions_per_seq, sync=False)
mlperf_logger.log_event(key='opt_learning_rate_training_steps',
value=args.max_steps, sync=False)
mlperf_logger.log_event(key='num_warmup_steps',
value=int(args.warmup_proportion*args.max_steps) if args.warmup_steps==0 else args.warmup_steps,
sync=False)
if utils.is_main_process():
print("parsed args:")
print(args)
# Prepare optimizer
model, optimizer, lr_scheduler, checkpoint, global_step = prepare_model_and_optimizer(args, device)
samples_trained = global_step * args.train_batch_size * args.gradient_accumulation_steps * args.n_gpu
if args.unpad:
torch.cuda.synchronize()
InitMHACUDAExtension()
torch.cuda.synchronize()
final_loss = float("inf")
train_time_raw = float("inf")
raw_train_start = time.time()
if args.do_train:
model.train()
most_recent_ckpts_paths = []
average_loss = 0.0 # averaged loss every args.log_freq steps
epoch = 1
training_steps = 0
end_training, converged = False, False
samples_trained_prev = 0
eval_count = 0
pool = ProcessPoolExecutor(1)
if args.target_mlm_accuracy:
if args.train_mlm_accuracy_window_size > 0:
accuracy_scores = []
avg_mlm_accuracy = torch.Tensor([0]).cuda()
first_epoch = True
if found_resume_checkpoint(args):
f_start_id = checkpoint['files'][0]
files = checkpoint['files'][1:]
num_files = len(files)
else:
files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if
os.path.isfile(os.path.join(args.input_dir, f)) and 'part' in f]
files.sort()
num_files = len(files)
random.Random(shuffling_seeds[epoch]).shuffle(files)
f_start_id = 0
mlperf_logger.log_end(key=mlperf_logger.constants.INIT_STOP, sync=False)
mlperf_logger.log_start(key=mlperf_logger.constants.RUN_START, sync=True)
mlperf_logger.barrier()
# Start prefetching eval dataset
if args.eval_dir:
eval_dataset_future = pool.submit(create_eval_dataset, args, worker_init_fn=worker_init)
while global_step < args.max_steps and not end_training:
mlperf_logger.log_start(key=mlperf_logger.constants.EPOCH_START,
metadata={'epoch_num': epoch}, sync=False)
mlperf_logger.log_start(key=mlperf_logger.constants.BLOCK_START,
metadata={'first_epoch_num': epoch,
'epoch_count': 1},
sync=False)
if utils.is_main_process():
print("parsed args:")
print(args)
now_time = time.time()
now_step = global_step
now_skipped = skipped_steps
print("epoch:", epoch)
thread = None
# Reshuffle file list on subsequent epochs
if not first_epoch:
files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if
os.path.isfile(os.path.join(args.input_dir, f)) and 'part' in f]
files.sort()
num_files = len(files)
random.Random(shuffling_seeds[epoch]).shuffle(files)
f_start_id = 0
first_epoch = False
shared_file_list = {}
if torch.distributed.is_initialized() and torch.distributed.get_world_size() > num_files:
remainder = torch.distributed.get_world_size() % num_files
data_file = files[(f_start_id*torch.distributed.get_world_size() + torch.distributed.get_rank() +
remainder * f_start_id) % num_files]
else:
data_file = files[(f_start_id*torch.distributed.get_world_size() + torch.distributed.get_rank()) % num_files]
previous_file = data_file
train_data = pretraining_dataset(data_file, args.max_predictions_per_seq)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler,
batch_size=args.train_batch_size, num_workers=4, worker_init_fn=worker_init, pin_memory=True)
overflow_buf = None
if args.allreduce_post_accumulation:
overflow_buf = torch.cuda.IntTensor([0])
for f_id in range(f_start_id + 1, len(files)):
if torch.distributed.get_world_size() > num_files:
data_file = files[(f_id*torch.distributed.get_world_size() + torch.distributed.get_rank() +
remainder * f_id) % num_files]
else:
data_file = files[(f_id*torch.distributed.get_world_size() + torch.distributed.get_rank())%num_files]
previous_file = data_file
dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init_fn=worker_init)
for step, batch in enumerate(train_dataloader):
training_steps += 1
update_step = training_steps % args.gradient_accumulation_steps == 0
batch = [t.to(device) for t in batch]
input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
loss, mlm_acc, _ = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask,
masked_lm_labels=masked_lm_labels, next_sentence_label=next_sentence_labels,
checkpoint_activations=args.checkpoint_activations)
divisor = args.gradient_accumulation_steps
if args.gradient_accumulation_steps > 1:
if not args.allreduce_post_accumulation:
# this division was merged into predivision
loss = loss / args.gradient_accumulation_steps
divisor = 1.0
if args.fp16:
with amp.scale_loss(loss, optimizer, delay_overflow_check=args.allreduce_post_accumulation) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
average_loss += loss.item()
if update_step:
lr_scheduler.step() # learning rate warmup
global_step = take_optimizer_step(args, optimizer, model, overflow_buf, global_step)
samples_trained = global_step * args.train_batch_size * args.gradient_accumulation_steps * args.n_gpu
if (args.eval_dir and args.eval_iter_samples > 0 and
samples_trained >= args.eval_iter_start_samples + eval_count * args.eval_iter_samples):
# on first eval, get eval_dataloader
if eval_count == 0:
eval_dataloader = eval_dataset_future.result(timeout=None)
samples_trained_prev = samples_trained
eval_avg_loss, eval_avg_mlm_accuracy = run_eval(model, eval_dataloader, device, args.num_eval_examples,
first_eval=(eval_count == 0), use_cache=args.cache_eval_data)
if utils.is_main_process():
mlperf_logger.log_event(key=mlperf_logger.constants.EVAL_ACCURACY, value=eval_avg_mlm_accuracy, metadata={'epoch_num': epoch}, sync=False)
print({"global_steps": global_step, "eval_loss": eval_avg_loss, "eval_mlm_accuracy":eval_avg_mlm_accuracy})
if args.target_mlm_accuracy:
if eval_avg_mlm_accuracy >= args.target_mlm_accuracy:
end_training, converged = True, True
if utils.is_main_process():
print("%f > %f, Target MLM Accuracy reached at %d"%(eval_avg_mlm_accuracy, args.target_mlm_accuracy, global_step))
eval_count += 1
if args.target_mlm_accuracy and args.train_mlm_accuracy_window_size > 0:
accuracy_scores.append(mlm_acc)
if update_step:
accuracy_scores = accuracy_scores[-args.train_mlm_accuracy_window_size * args.gradient_accumulation_steps:]
avg_mlm_accuracy[0] = sum(accuracy_scores) / len(accuracy_scores)
torch.distributed.all_reduce(avg_mlm_accuracy, op=torch.distributed.ReduceOp.SUM)
avg_mlm_accuracy /= torch.distributed.get_world_size()
if training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0:
samples_trained = global_step * args.train_batch_size * args.gradient_accumulation_steps * args.n_gpu
if utils.is_main_process():
time_interval = time.time() - now_time
step_interval = global_step - now_step
skip_interval = skipped_steps - now_skipped
now_time = time.time()
now_step = global_step
now_skipped = skipped_steps
training_perf = args.train_batch_size * args.gradient_accumulation_steps * args.n_gpu \
* (step_interval + skip_interval) / time_interval
if args.train_mlm_accuracy_window_size > 0:
print({"training_steps": training_steps,
"average_loss": average_loss / (args.log_freq * divisor),
"step_loss": loss.item() * args.gradient_accumulation_steps / divisor,
"learning_rate": optimizer.param_groups[0]['lr'],
"seq/s": training_perf,
"global_steps": now_step,
"samples_trained": samples_trained,
"skipped_steps": now_skipped,
"timestamp": now_time,
"mlm_accuracy": avg_mlm_accuracy[0].item()})
else:
print({"training_steps": training_steps,
"average_loss": average_loss / (args.log_freq * divisor),
"step_loss": loss.item() * args.gradient_accumulation_steps / divisor,
"learning_rate": optimizer.param_groups[0]['lr'],
"seq/s": training_perf,
"global_steps": now_step,
"samples_trained": samples_trained,
"skipped_steps": now_skipped,
"timestamp": now_time})
average_loss = 0
if global_step >= args.max_steps or end_training:
status = 'success' if converged else 'aborted'
end_training = True
train_time_raw = time.time() - raw_train_start
last_num_steps = int(training_steps / args.gradient_accumulation_steps) % args.log_freq
last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
average_loss = torch.tensor(average_loss, dtype=torch.float32).cuda()
average_loss = average_loss / (last_num_steps * divisor)
if (torch.distributed.is_initialized()):
average_loss /= torch.distributed.get_world_size()
torch.distributed.all_reduce(average_loss)
final_loss = average_loss.item()
if utils.is_main_process():
if args.train_mlm_accuracy_window_size > 0:
print((epoch, training_steps / args.gradient_accumulation_steps, ), {"final_loss": final_loss,
"final_mlm_accuracy": avg_mlm_accuracy[0].item()})
else:
print((epoch, training_steps / args.gradient_accumulation_steps, ), {"final_loss": final_loss})
if end_training or (samples_trained - samples_trained_prev >= args.num_samples_per_checkpoint and samples_trained >= args.min_samples_to_start_checkpoints):
samples_trained_prev = samples_trained
if utils.is_main_process() and not args.skip_checkpoint:
# Save a trained model
model_to_save = model.module if hasattr(model,
'module') else model # Only save the model it-self
if args.phase2:
output_save_file = os.path.join(args.output_dir, "phase2_ckpt_{}.pt".format(samples_trained))
else:
output_save_file = os.path.join(args.output_dir, "phase1_ckpt_{}.pt".format(samples_trained))
if args.do_train:
torch.save({'model': model_to_save.state_dict(),
'optimizer': optimizer.state_dict(),
'master params': list(amp.master_params(optimizer)),
'files': [f_id] + files}, output_save_file)
most_recent_ckpts_paths.append(output_save_file)
if len(most_recent_ckpts_paths) > args.keep_n_most_recent_checkpoints:
ckpt_to_be_removed = most_recent_ckpts_paths.pop(0)
os.remove(ckpt_to_be_removed)
if samples_trained >= args.max_samples_termination or end_training:
status = 'success' if converged else 'aborted'
end_training = True
break
del train_dataloader
if samples_trained >= args.max_samples_termination or end_training:
status = 'success' if converged else 'aborted'
end_training = True
break
train_dataloader, data_file = dataset_future.result(timeout=None)
mlperf_logger.log_end(key=mlperf_logger.constants.BLOCK_STOP,
metadata={'first_epoch_num': epoch},
sync=False)
mlperf_logger.log_end(key=mlperf_logger.constants.EPOCH_STOP,
metadata={'epoch_num': epoch}, sync=False)
epoch += 1
mlperf_logger.log_event(key=mlperf_logger.constants.TRAIN_SAMPLES,
value=samples_trained,
sync=False)
mlperf_logger.log_event(key=mlperf_logger.constants.EVAL_SAMPLES,
value=args.num_eval_examples,
sync=False)
mlperf_logger.log_end(key=mlperf_logger.constants.RUN_STOP,
metadata={'status': status}, sync=False)
return args, final_loss, train_time_raw
def global_batch_size(args):
return args.train_batch_size * args.gradient_accumulation_steps * args.n_gpu
if __name__ == "__main__":
now = time.time()
args, final_loss, train_time_raw = main()
gpu_count = args.n_gpu
if torch.distributed.is_initialized():
gpu_count = torch.distributed.get_world_size()
if utils.is_main_process():
e2e_time = time.time() - now
training_perf = global_batch_size(args) \
* (args.max_steps - args.resume_step + skipped_steps) / train_time_raw
if args.do_train:
print({"e2e_time": e2e_time, "training_sequences_per_second": training_perf,
"final_loss": final_loss, "raw_train_time": train_time_raw })
else:
print({"e2e_time": e2e_time})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment