Commit 22c0e300 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'staging' into 'master'

April 2020 Release

See merge request ADLR/megatron-lm!69
parents 70174ae3 acf8780b
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -28,7 +28,9 @@ from .wordpiece import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from .tokenization_gpt2 import GPT2Tokenizer
import regex as re
def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type='bpe', pad_token=0, character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs):
def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type='bpe',
pad_token=0, character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs):
"""
Helper function to instantiate a tokenizer given common combinations of options.
"""
......@@ -43,6 +45,7 @@ def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, mod
pad_token=pad_token, character_coverage=character_coverage)
return Tokenizer(text_tokenizer, command_tokens, type_tokens)
class Tokenization(object):
"""
Tokenization object to hold tokenization, (processed text),and original
......@@ -54,7 +57,9 @@ class Tokenization(object):
Several standard array operations are implemented (insert, append, extend).
"""
def __init__(self, tokenization, text=None, original_text=None, command_tokens=None, asIds=True):
def __init__(self, tokenization, text=None, original_text=None,
command_tokens=None, asIds=True):
self.tokenization = tokenization
self.text = text
if self.text is None:
......@@ -91,13 +96,15 @@ class Tokenization(object):
if idx == 0:
self.text = other.token + self.text
self.original_text = other.token + self.original_text
elif idx == len(self.tokenization)-1:
elif idx == len(self.tokenization) - 1:
self.text += other.token
self.original_text += other.token
elif isinstance(other, Tokenization):
self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:]
self.tokenization = self.tokenization[:idx] + \
other.tokenization + self.tokenization[idx:]
else:
self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:]
self.tokenization = self.tokenization[:idx] + \
other.tokenization + self.tokenization[idx:]
def append(self, other):
if isinstance(other, (CommandToken, TypeToken)):
......@@ -129,14 +136,17 @@ class Tokenization(object):
self.tokenization.extend(other)
return self
"""define some default command tokens for the tokenizer to use"""
token_format = "<{0}>"
COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id'))
def prep_command_tokens(tokenlist, token_format=token_format):
return [CommandToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
class CommandToken(object):
def __init__(self, name, token, Id):
self.name = name
......@@ -146,6 +156,7 @@ class CommandToken(object):
def __str__(self):
return str(COMMAND_TUPLE(self.name, self.token, self.Id))
DEFAULT_COMMAND_TOKENS = [
('pad', 0),
('eos', 1),
......@@ -162,9 +173,11 @@ DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS)
TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id'))
def prep_type_tokens(tokenlist, token_format=token_format):
return [TypeToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
class TypeToken(object):
def __init__(self, name, token, Id):
self.name = name
......@@ -174,6 +187,7 @@ class TypeToken(object):
def __str__(self):
return str(TYPE_TUPLE(self.name, self.token, self.Id))
DEFAULT_TYPE_TOKENS = [
('function', 0),
('command', 1),
......@@ -189,6 +203,7 @@ DEFAULT_TYPE_TOKENS = [
]
DEFAULT_TYPE_TOKENS = prep_type_tokens(DEFAULT_TYPE_TOKENS)
class Tokenizer(object):
"""
Tokenizer object that handles text tokenization, command tokens, and type tokens.
......@@ -199,6 +214,7 @@ class Tokenizer(object):
Token types are stored in a separate mapping of size `len(type_tokens)`.
"""
def __init__(self, text_tokenizer, command_tokens=None, type_tokens=None):
# set text tokenizer
self.text_tokenizer = text_tokenizer
......@@ -229,18 +245,20 @@ class Tokenizer(object):
# parse tokens and vocabs from tokenizer
self._tokens = list(self.command_token_map.keys()) + list(self.text_tokenizer.tokens)
self._vocab = {t:Id for Id,t in self.command_id_map.items()}
self._vocab.update({t:Id+self.num_command_tokens for t,Id in self.text_tokenizer.vocab.items()})
self._vocab = {t: Id for Id, t in self.command_id_map.items()}
self._vocab.update({t: Id + self.num_command_tokens for t,
Id in self.text_tokenizer.vocab.items()})
self._text_tokens = list(self.text_tokenizer.tokens)
self._text_token_vocab = {t:Id+self.num_command_tokens for t,Id in self.text_tokenizer.vocab.items()}
self._text_token_vocab = {
t: Id + self.num_command_tokens for t,
Id in self.text_tokenizer.vocab.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()}
self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()}
self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
def __call__(self, text, process_fn=None):
"""run preprocessing and encode text as Ids"""
......@@ -303,7 +321,7 @@ class Tokenizer(object):
encode text using text tokenizer and shift Id values for command tokens
"""
tokenization = self.text_tokenizer.EncodeAsIds(text, process_fn=process_fn)
tokenization.tokenization = [t+self.num_command_tokens for t in tokenization.tokenization]
tokenization.tokenization = [t + self.num_command_tokens for t in tokenization.tokenization]
tokenization.set_command_tokens(self._command_tokens)
return tokenization
......@@ -323,7 +341,7 @@ class Tokenizer(object):
return self.type_id_map[Id].token
if Id < self.num_command_tokens:
return self.command_id_map[Id].token
return self.text_tokenizer.IdToToken(Id-self.num_command_tokens)
return self.text_tokenizer.IdToToken(Id - self.num_command_tokens)
def TokenToId(self, token, type_token=False):
"""convert token to Id accounting for command and type tokens"""
......@@ -333,7 +351,7 @@ class Tokenizer(object):
return self.type_token_map[token].Id
if token in self.command_token_map:
return self.command_token_map[token].Id
return self.text_tokenizer.TokenToId(token)+self.num_command_tokens
return self.text_tokenizer.TokenToId(token) + self.num_command_tokens
def DecodeIds(self, Ids, type_token=False):
"""
......@@ -341,7 +359,8 @@ class Tokenizer(object):
are joined and returned as a string.
"""
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
return ' '.join(Id.token if isinstance(Id, TypeToken)
else self.type_id_map[Id].token for Id in Ids)
rtn_strs = []
current_str = []
if isinstance(Ids, Tokenization):
......@@ -386,10 +405,12 @@ class Tokenizer(object):
rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
return ' '.join(rtn_strs)
class TextTokenizer(object):
"""
Interface for text tokenizer
"""
def __init__(self):
if not hasattr(self, 'num_text_tokens'):
self.num_text_tokens = 0
......@@ -456,11 +477,12 @@ class CharacterLevelTokenizer(TextTokenizer):
"""
Text tokenizer for ASCII-256 Character Level Tokenization.
"""
def __init__(self, **kwargs):
self.num_text_tokens = 256
super(CharacterLevelTokenizer, self).__init__()
self._tokens = [self.IdToToken(Id) for Id in range(self.num_text_tokens)]
self._vocab = {t: i for i,t in enumerate(self._tokens)}
self._vocab = {t: i for i, t in enumerate(self._tokens)}
def __len__(self):
return 256
......@@ -521,6 +543,7 @@ class CharacterLevelTokenizer(TextTokenizer):
MAX_SENTENCEPIECE_SENTENCES = 100000000
def get_corpus_freq(dataset, filepath, filetype='tsv'):
"""
Take corpus, split it into sentences, and extract word frequencies.
......@@ -556,14 +579,13 @@ def get_corpus_freq(dataset, filepath, filetype='tsv'):
print("file path for freq " + str(filepath), flush=True)
freqs_sorted = {}
counter=0
counter = 0
for word, count in sorted(freqs.items(), key=lambda x: x[1], reverse=True):
if counter >= MAX_SENTENCEPIECE_SENTENCES:
break
counter+=1
counter += 1
freqs_sorted[word] = count
print("length of freqs after trancating " + str(len(freqs_sorted)), flush=True)
with open(filepath, 'w') as f:
......@@ -573,9 +595,12 @@ def get_corpus_freq(dataset, filepath, filetype='tsv'):
return total_sentence_count, maxlen
class SentencePieceTokenizer(TextTokenizer):
"""Trains and uses sentencepiece for text tokenization"""
def __init__(self, model_type='bpe', vocab_size=None, corpus=None, model_path=None, character_coverage=1.0, **kwargs):
def __init__(self, model_type='bpe', vocab_size=None, corpus=None,
model_path=None, character_coverage=1.0, **kwargs):
self.character_coverage = character_coverage
self.model_type = model_type.lower()
self.spm_model = model_path
......@@ -608,18 +633,18 @@ class SentencePieceTokenizer(TextTokenizer):
dne = not os.path.exists(model_path)
# check if path.model exists
if dne and not model_path.endswith('.model'):
dne = not os.path.exists(model_path+'.model')
dne = not os.path.exists(model_path + '.model')
return not dne
def load_spm_model(self):
"""load sentencepiece model and parse vocab"""
if not os.path.exists(self.spm_model) and not self.spm_model.endswith('.model'):
self.spm_model = self.spm_model+'.model'
self.spm_model = self.spm_model + '.model'
self.sp = spm.SentencePieceProcessor()
self.sp.Load(self.spm_model)
self.vocab_size = self.num_text_tokens = len(self.sp)
self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)]
self._vocab = {t: i for i,t in enumerate(self._tokens)}
self._vocab = {t: i for i, t in enumerate(self._tokens)}
def Train(self, corpus, num_text_tokens):
"""train sentencepiece model on corpus using word frequencies"""
......@@ -630,7 +655,7 @@ class SentencePieceTokenizer(TextTokenizer):
use_model_path = random_hash
if use_model_path.endswith('.model'):
use_model_path = use_model_path[:use_model_path.rfind('.model')]
input_path = use_model_path+'.tsv.'+random_hash
input_path = use_model_path + '.tsv.' + random_hash
line_count, maxlenline = get_corpus_freq(corpus, input_path)
line_count = min(line_count, MAX_SENTENCEPIECE_SENTENCES)
print('line count used as input_sentence_size ', line_count, flush=True)
......@@ -641,12 +666,12 @@ class SentencePieceTokenizer(TextTokenizer):
+ '--input_format=tsv'
train_string = train_string.format(file_path=input_path, model_prefix=use_model_path, vocab_size=num_text_tokens,
model_type=self.model_type, character_coverage=self.character_coverage,
input_sentence_size=int(line_count)) #, #)#,
print("calling spm.SentencePieceTrainer.Train(%s)"%(train_string), flush=True)
input_sentence_size=int(line_count)) # , #)#,
print("calling spm.SentencePieceTrainer.Train(%s)" % (train_string), flush=True)
spm.SentencePieceTrainer.Train(train_string)
os.remove(input_path)
self.spm_model = use_model_path+'.model'
print('sentencepiece model written to '+self.spm_model, flush=True)
self.spm_model = use_model_path + '.model'
print('sentencepiece model written to ' + self.spm_model, flush=True)
def EncodeAsIds(self, text, process_fn=None):
"""convert text to sentencepiece Ids"""
......@@ -684,19 +709,26 @@ class SentencePieceTokenizer(TextTokenizer):
Tokens = Tokens.tokenization
return self.sp.DecodeTokens(Tokens)
class BertWordPieceTokenizer(Tokenizer):
"""
Loads a pretrained WordPiece tokenizer from `cache_dir` for tokenization
in BERT training. Default to bert-large-uncased tokenizer.
"""
def __init__(self, tokenizer_model_type=None, cache_dir=None, **kwargs):
# default to bert-large-uncased tokenizer
if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP:
tokenizer_model_type = 'bert-large-uncased'
if torch.distributed.get_rank() == 0:
print('loading BertWordPieceTokenizer (', tokenizer_model_type, ') from cache_dir ', cache_dir)
print(
'loading BertWordPieceTokenizer (',
tokenizer_model_type,
') from cache_dir ',
cache_dir)
do_lower_case = not ('-cased' in tokenizer_model_type or 'chinese' in tokenizer_model_type)
self.text_tokenizer = BertTokenizer.from_pretrained(tokenizer_model_type, do_lower_case=do_lower_case, cache_dir=cache_dir)
self.text_tokenizer = BertTokenizer.from_pretrained(
tokenizer_model_type, do_lower_case=do_lower_case, cache_dir=cache_dir)
if torch.distributed.get_rank() == 0:
print('loaded', tokenizer_model_type)
# disable max len warnings by increasing max len
......@@ -705,7 +737,7 @@ class BertWordPieceTokenizer(Tokenizer):
# set command tokens from wordpiece tokenizer values
self.num_command_tokens = 5
self.num_tokens = len(self.text_tokenizer.vocab)
self.num_text_tokens = self.num_tokens-5
self.num_text_tokens = self.num_tokens - 5
self.num_type_tokens = 2
self._command_tokens = [
......@@ -731,16 +763,16 @@ class BertWordPieceTokenizer(Tokenizer):
# parse tokens and vocabs from tokenizer
self._tokens = list(self.text_tokenizer.vocab.keys())
self._vocab = {k:v for k,v in self.text_tokenizer.vocab.items()}
self._vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
self._text_tokens = list(self._tokens)
self._text_token_vocab = {k:v for k,v in self.text_tokenizer.vocab.items()}
self._text_token_vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()}
self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()}
self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
def EncodeAsIds(self, text, process_fn=None):
"""convert text to wordpiece Ids"""
......@@ -778,7 +810,8 @@ class BertWordPieceTokenizer(Tokenizer):
def DecodeIds(self, Ids, type_token=False):
"""converts ids to wordpiece tokens and joins them as a text string"""
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
return ' '.join(Id.token if isinstance(Id, TypeToken)
else self.type_id_map[Id].token for Id in Ids)
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
Tokens = []
......@@ -795,16 +828,17 @@ class BertWordPieceTokenizer(Tokenizer):
Tokens = Tokens.tokenization
return ' '.join(Tokens)
class GPT2BPETokenizer(Tokenizer):
def __init__(self, cache_dir=None, **kwargs):
self.text_tokenizer = GPT2Tokenizer.from_pretrained('gpt2',
cache_dir=cache_dir)
#disable max len warnings by increasing max len
# disable max len warnings by increasing max len
self.text_tokenizer.max_len = int(1e12)
self.num_command_tokens = 2
self.num_tokens = len(self.text_tokenizer.encoder)
self.num_text_tokens = self.num_tokens-1
self.num_text_tokens = self.num_tokens - 1
self.num_type_tokens = 2
self._command_tokens = [
......@@ -824,28 +858,27 @@ class GPT2BPETokenizer(Tokenizer):
self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
self._tokens = list(self.text_tokenizer.encoder.keys())
self._vocab = {k:v for k,v in self.text_tokenizer.encoder.items()}
self._vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
self._text_tokens = list(self._tokens)
self._text_token_vocab = {k:v for k,v in self.text_tokenizer.encoder.items()}
self._text_token_vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()}
self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()}
self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
def EncodeAsIds(self, text, process_fn=None):
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
Ids = self.text_tokenizer.encode(processed_text)
#return Tokenization(Ids, processed_text, text)
# return Tokenization(Ids, processed_text, text)
tokenization = Tokenization(Ids, processed_text, text)
tokenization.set_command_tokens(self._command_tokens)
return tokenization
def EncodeAsTokens(self, text, process_fn=None):
processed_text = text
if process_fn is not None:
......@@ -854,10 +887,10 @@ class GPT2BPETokenizer(Tokenizer):
for token in re.findall(self.text_tokenizer.pat, processed_text):
token = ''.join(self.text_tokenizer.bye_encoder[b] for b in token.encode('utf-8'))
tokens.extend(bpe_token for bpe_token in self.text_tokenizer.bpe(token).split(' '))
tokenization=Tokenization(tokens, processed_text, text, asIds=False)
tokenization = Tokenization(tokens, processed_text, text, asIds=False)
tokenization.set_command_tokens(self._command_tokens)
return tokenization
#return Tokenization(tokens, processed_text, text, asIds=False)
# return Tokenization(tokens, processed_text, text, asIds=False)
def IdToToken(self, Id, type_token=False):
if isinstance(Id, (TypeToken, CommandToken)):
......@@ -875,7 +908,8 @@ class GPT2BPETokenizer(Tokenizer):
def DecodeIds(self, Ids, type_token=False):
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
return ' '.join(Id.token if isinstance(Id, TypeToken)
else self.type_id_map[Id].token for Id in Ids)
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
return self.text_tokenizer.decode(Ids)
......
......@@ -27,7 +27,8 @@ try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
......@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
......@@ -60,17 +62,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
......@@ -83,6 +87,7 @@ def get_pairs(word):
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
......@@ -138,23 +143,31 @@ class GPT2Tokenizer(object):
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
......@@ -172,8 +185,9 @@ class GPT2Tokenizer(object):
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
self.special_tokens = dict((tok, len(self.encoder) + i)
for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
......@@ -186,7 +200,7 @@ class GPT2Tokenizer(object):
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
......@@ -197,12 +211,12 @@ class GPT2Tokenizer(object):
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
except BaseException:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
......@@ -245,7 +259,8 @@ class GPT2Tokenizer(object):
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
" sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
......
......@@ -123,7 +123,8 @@ class BertTokenizer(object):
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
" sequence through BERT will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -22,12 +22,15 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron.module import MegatronModule
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
......@@ -37,6 +40,7 @@ def conversion_helper(val, conversion):
rtn = tuple(rtn)
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def half_conversion(val):
......@@ -48,6 +52,7 @@ def fp32_to_fp16(val):
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float_conversion(val):
......@@ -59,6 +64,7 @@ def fp16_to_fp32(val):
return val
return conversion_helper(val, float_conversion)
class FP16_Module(MegatronModule):
def __init__(self, module):
super(FP16_Module, self).__init__()
......@@ -79,6 +85,8 @@ class FP16_Module(MegatronModule):
self.module.load_state_dict(state_dict, strict=strict)
# TODO: Update overflow check + downscale to use Carl's fused kernel.
class FP16_Optimizer(object):
"""
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
......@@ -305,7 +313,8 @@ class FP16_Optimizer(object):
master_params_to_model_params(fp32_from_fp16_group, fp16_group)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
# that does the overflow check, gradient copy + downscale, and fp32
# allreduce in a different stream.
def _model_grads_to_master_grads(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
......@@ -313,9 +322,12 @@ class FP16_Optimizer(object):
def _downscale_master(self):
if self.loss_scale != 1.0:
for group in self.optimizer.param_groups:
for param in group['params']:
if param.grad is not None:
param.grad.data.mul_(1./self.loss_scale)
grads = [p.grad for p in group['params'] if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grads, grads],
1./self.loss_scale)
def clip_master_grads(self, max_norm, norm_type=2):
"""
......@@ -400,7 +412,8 @@ class FP16_Optimizer(object):
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
for current_group, saved_group in zip(
self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
for current, saved in zip(current_group, saved_group):
current.data.copy_(saved.data)
......@@ -570,7 +583,8 @@ class FP16_Optimizer(object):
"""
if self.dynamic_loss_scale:
self._check_overflow()
if self.overflow: return
if self.overflow:
return
self._model_grads_to_master_grads()
self._downscale_master()
......@@ -607,8 +621,8 @@ class FP16_Optimizer(object):
master_grads_data.append(master_grads_this_group)
return master_grads_data
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,6 +18,9 @@ import torch.nn as nn
from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu
......@@ -102,6 +105,7 @@ class FP16Model(nn.Module):
def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!")
def prep_param_lists(model, flat_master=False):
"""
Creates a list of FP32 master parameters for a given model, as in
......@@ -131,7 +135,7 @@ def prep_param_lists(model, flat_master=False):
# flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
except:
except BaseException:
print("Error in prep_param_lists: model may contain a mixture of parameters "
"of different types. Use flat_master=False, or use F16_Optimizer.")
raise
......@@ -165,9 +169,15 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False):
if model.grad is not None:
if master.grad is None:
master.grad = Variable(master.data.new(*master.data.size()))
master.grad.data.copy_(model.grad.data)
else:
master.grad = None
model_grads = [p.grad for p in model_params if p.grad is not None]
master_grads = [p.grad for p in master_params if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[model_grads, master_grads],
1.0)
def master_params_to_model_params(model_params, master_params, flat_master=False):
......@@ -188,17 +198,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
# Backward compatibility fixes
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
clip_grad_norm = mpu.clip_grad_norm
#elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# clip_grad_norm = torch.nn.utils.clip_grad_norm
#else:
# else:
# clip_grad_norm = torch.nn.utils.clip_grad_norm_
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,15 +14,22 @@
# limitations under the License.
import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
class LossScaler:
"""
Class that manages a static loss scale. This class is intended to interact with
......@@ -54,12 +61,18 @@ class LossScaler:
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
class DynamicLossScaler:
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
......@@ -122,8 +135,8 @@ class DynamicLossScaler:
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
......@@ -158,7 +171,7 @@ class DynamicLossScaler:
if overflow:
# self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
self.cur_scale = max(self.cur_scale/self.scale_factor, self.min_scale)
self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
else:
self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter
......@@ -176,12 +189,18 @@ class DynamicLossScaler:
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
##############################################################
# Example usage below here -- assuming it's in a separate file
##############################################################
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron global variables."""
import os
import sys
import time
import torch
from megatron.tokenizer import build_tokenizer
from .arguments import parse_args
_GLOBAL_ARGS = None
_GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER = None
_GLOBAL_ADLR_AUTORESUME = None
_GLOBAL_TIMERS = None
def get_args():
"""Return arguments."""
_ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
return _GLOBAL_ARGS
def get_tokenizer():
"""Return tokenizer."""
_ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
return _GLOBAL_TOKENIZER
def get_tensorboard_writer():
"""Return tensorboard writer. It can be None so no need
to check if it is initialized."""
return _GLOBAL_TENSORBOARD_WRITER
def get_adlr_autoresume():
"""ADLR autoresume object. It can be None so no need
to check if it is initialized."""
return _GLOBAL_ADLR_AUTORESUME
def get_timers():
"""Return timers."""
_ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
return _GLOBAL_TIMERS
def set_global_variables(extra_args_provider=None, args_defaults={}):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults)
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers()
def _parse_args(extra_args_provider=None, defaults={}):
"""Parse entire arguments."""
global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults)
return _GLOBAL_ARGS
def _build_tokenizer(args):
"""Initialize tokenizer."""
global _GLOBAL_TOKENIZER
_ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
_GLOBAL_TOKENIZER = build_tokenizer(args)
return _GLOBAL_TOKENIZER
def rebuild_tokenizer(args):
global _GLOBAL_TOKENIZER
_GLOBAL_TOKENIZER = None
return _build_tokenizer(args)
def _set_tensorboard_writer(args):
"""Set tensorboard writer."""
global _GLOBAL_TENSORBOARD_WRITER
_ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
'tensorboard writer')
if hasattr(args, 'tensorboard_dir') and \
args.tensorboard_dir and args.rank == 0:
try:
from torch.utils.tensorboard import SummaryWriter
print('> setting tensorboard ...')
_GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
log_dir=args.tensorboard_dir)
except ModuleNotFoundError:
print('WARNING: TensorBoard writing requested but is not '
'available (are you using PyTorch 1.1.0 or later?), '
'no TensorBoard logs will be written.', flush=True)
def _set_adlr_autoresume(args):
"""Initialize ADLR autoresume."""
global _GLOBAL_ADLR_AUTORESUME
_ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume')
if args.adlr_autoresume:
if args.rank == 0:
print('enabling autoresume ...', flush=True)
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try:
from userlib.auto_resume import AutoResume
except BaseException:
print('ADLR autoresume is not available, exiting ...')
sys.exit()
_GLOBAL_ADLR_AUTORESUME = AutoResume
def _set_timers():
"""Initialize timers."""
global _GLOBAL_TIMERS
_ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
_GLOBAL_TIMERS = Timers()
def _ensure_var_is_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is not None, '{} is not initialized.'.format(name)
def _ensure_var_is_not_initialized(var, name):
"""Make sure the input variable is not None."""
assert var is None, '{} is already initialized.'.format(name)
class _Timer:
"""Timer."""
def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
def start(self):
"""Start the timer."""
assert not self.started_, 'timer has already been started'
torch.cuda.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, 'timer is not started'
torch.cuda.synchronize()
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False
def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False
def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_
class Timers:
"""Group of timers."""
def __init__(self):
self.timers = {}
def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]
def write(self, names, writer, iteration, normalizer=1.0, reset=False):
"""Write timers to a tensorboard writer"""
# currently when using add_scalars,
# torch.utils.add_scalars makes each timer its own run, which
# polutes the runs list, so we just add each as a scalar
assert normalizer > 0.0
for name in names:
value = self.timers[name].elapsed(reset=reset) / normalizer
writer.add_scalar(name + '_time', value, iteration)
def log(self, names, normalizer=1.0, reset=True):
"""Log a group of timers."""
assert normalizer > 0.0
string = 'time (ms)'
for name in names:
elapsed_time = self.timers[name].elapsed(
reset=reset) * 1000.0 / normalizer
string += ' | {}: {:.2f}'.format(name, elapsed_time)
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(string, flush=True)
else:
print(string, flush=True)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron initialization."""
import random
import os
import numpy as np
import torch
from megatron import get_adlr_autoresume
from megatron import get_args
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron.global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None, args_defaults={}):
"""Set global variables, initialize distributed, and
set autoresume and random seeds."""
# Make sure cuda is available.
assert torch.cuda.is_available(), 'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
# Pytorch distributed.
_initialize_distributed()
# Autoresume.
_init_autoresume()
# Random seeds for reproducibility.
args = get_args()
if args.rank == 0:
print('> setting random seeds to {} ...'.format(args.seed))
_set_random_seed(args.seed)
# Write arguments to tensorboard.
_write_args_to_tensorboard()
def _initialize_distributed():
"""Initialize torch.distributed and mpu."""
args = get_args()
device_count = torch.cuda.device_count()
if torch.distributed.is_initialized():
if args.rank == 0:
print('torch distributed is already initialized, '
'skipping initialization ...', flush=True)
args.rank = torch.distributed.get_rank()
args.world_size = torch.distributed.get_world_size()
if device_count > 0:
device = torch.cuda.current_device()
local_rank = args.rank % device_count
assert local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
else:
if args.rank == 0:
print('> initializing torch distributed ...', flush=True)
# Manually set the device ids.
if device_count > 0:
device = args.rank % device_count
if args.local_rank is not None:
assert args.local_rank == device, \
'expected local-rank to be the same as rank % device-count.'
else:
args.local_rank = device
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)
# Set the model-parallel / data-parallel communicators.
if device_count > 0:
mpu.initialize_model_parallel(args.model_parallel_size)
def _init_autoresume():
"""Set autoresume start time."""
autoresume = get_adlr_autoresume()
if autoresume:
torch.distributed.barrier()
autoresume.init()
torch.distributed.barrier()
def _set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
mpu.model_parallel_cuda_manual_seed(seed)
else:
raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
def _write_args_to_tensorboard():
"""Write arguments to tensorboard."""
args = get_args()
writer = get_tensorboard_writer()
if writer:
for arg in vars(args):
writer.add_text(arg, str(getattr(args, arg)))
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,59 +12,66 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch DataLoader for TFRecords"""
import torch
from torch.optim.lr_scheduler import _LRScheduler
import math
"""Learning rate decay functions."""
from megatron.utils import print_rank_0
import math
from megatron import print_rank_0
class AnnealingLR(_LRScheduler):
"""Anneals the learning rate"""
DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
class AnnealingLR(object):
"""Anneals the learning rate."""
def __init__(self, optimizer, start_lr, warmup_iter, num_iters,
decay_style=None, last_iter=-1, min_lr=0.0,
def __init__(self, optimizer, start_lr,
warmup_iter, total_iters,
decay_style, last_iter, min_lr=0.0,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):
# Class values.
self.optimizer = optimizer
self.start_lr = start_lr
self.min_lr = min_lr
self.warmup_iter = warmup_iter
self.num_iters = last_iter + 1
self.end_iter = num_iters
self.decay_style = decay_style.lower() if isinstance(decay_style, str) \
else None
self.num_iters = last_iter
self.end_iter = total_iters
assert self.end_iter > 0
self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler
self.use_checkpoint_lr_scheduler = use_checkpoint_lr_scheduler
if self.override_lr_scheduler:
assert not self.use_checkpoint_lr_scheduler, 'both override and '\
'use-checkpoint are set.'
# Set the learning rate
self.step(self.num_iters)
if torch.distributed.get_rank() == 0:
print('learning rate decaying', decay_style)
print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_lr(self):
# https://openreview.net/pdf?id=BJYwwY9ll pg. 4
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter)
# Warmup.
if self.warmup_iter > 0 and self.num_iters <= self.warmup_iter:
return float(self.start_lr) * num_iters_ / self.warmup_iter
else:
if self.decay_style == self.DECAY_STYLES[0]:
lr = self.start_lr * ((self.end_iter - (num_iters_ - self.warmup_iter)) / self.end_iter)
elif self.decay_style == self.DECAY_STYLES[1]:
lr = self.start_lr / 2.0 * (math.cos(math.pi * (num_iters_ - self.warmup_iter) / self.end_iter) + 1)
elif self.decay_style == self.DECAY_STYLES[2]:
num_iters_ = num_iters_ - self.warmup_iter
if self.decay_style == 'linear':
lr = self.start_lr * (self.end_iter - num_iters_) / self.end_iter
elif self.decay_style == 'cosine':
lr = self.start_lr / 2.0 * (math.cos(
math.pi * num_iters_ / self.end_iter) + 1)
elif self.decay_style == 'exponential':
# exp(-0.693) = 1/2
lr = self.start_lr * math.exp(-0.693 * (num_iters_ - self.warmup_iter) / self.end_iter)
lr = self.start_lr * math.exp(-0.693 * num_iters_ / self.end_iter)
else:
lr = self.start_lr
return max(lr, self.min_lr)
def step(self, step_num=None):
"""Set lr for all parameters groups."""
if step_num is None:
step_num = self.num_iters + 1
self.num_iters = step_num
......@@ -73,7 +80,7 @@ class AnnealingLR(_LRScheduler):
group['lr'] = new_lr
def state_dict(self):
sd = {
state_dict = {
'start_lr': self.start_lr,
'warmup_iter': self.warmup_iter,
'num_iters': self.num_iters,
......@@ -81,14 +88,15 @@ class AnnealingLR(_LRScheduler):
'end_iter': self.end_iter,
'min_lr': self.min_lr
}
return sd
return state_dict
def check_and_set_(self, cls_value, sd_value, name):
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if self.override_lr_scheduler:
print_rank_0(' > overriding {} value to {}'.format(name, cls_value))
return cls_value
else:
if not self.use_checkpoint_lr_scheduler:
assert cls_value == sd_value, 'AnnealingLR: class input value' \
'and checkpoint values for {} do not match'.format(name)
......@@ -98,16 +106,16 @@ class AnnealingLR(_LRScheduler):
def load_state_dict(self, sd):
self.start_lr = self.check_and_set_(self.start_lr, sd['start_lr'],
self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'],
'learning rate')
self.min_lr = self.check_and_set_(self.min_lr, sd['min_lr'],
self.min_lr = self._check_and_set(self.min_lr, sd['min_lr'],
'minimum learning rate')
self.warmup_iter = self.check_and_set_(self.warmup_iter,
self.warmup_iter = self._check_and_set(self.warmup_iter,
sd['warmup_iter'],
'warmup iterations')
self.end_iter = self.check_and_set_(self.end_iter, sd['end_iter'],
self.end_iter = self._check_and_set(self.end_iter, sd['end_iter'],
'total number of iterations')
self.decay_style = self.check_and_set_(self.decay_style,
self.decay_style = self._check_and_set(self.decay_style,
sd['decay_style'],
'decay style')
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,16 +17,16 @@
import torch
from megatron import get_args
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .transformer import LayerNorm
from .utils import gelu
from .utils import get_linear_layer
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask
......@@ -65,7 +65,6 @@ def bert_position_ids(token_ids):
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
......@@ -76,11 +75,14 @@ class BertLMHead(MegatronModule):
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output):
super(BertLMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True
self.bias.partition_dim = 0
......@@ -89,11 +91,13 @@ class BertLMHead(MegatronModule):
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states,
word_embeddings_weight,
......@@ -102,69 +106,39 @@ class BertLMHead(MegatronModule):
return output
class BertModel(MegatronModule):
"""Bert Language model."""
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
add_binary_head=False,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
def __init__(self, num_tokentypes=2, add_binary_head=True,
parallel_output=True):
super(BertModel, self).__init__()
args = get_args()
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
init_method = init_method_normal(init_method_std)
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
attention_mask_func=bert_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers),
residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
scaled_init_method=scaled_init_method)
self.lm_head = BertLMHead(
self.language_model.embedding.word_embeddings.weight.size(0),
hidden_size, init_method, layernorm_epsilon, parallel_output)
args.hidden_size, init_method, args.layernorm_epsilon,
parallel_output)
self._lm_head_key = 'lm_head'
if self.add_binary_head:
self.binary_head = get_linear_layer(hidden_size, 2, init_method)
self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
self._binary_head_key = 'binary_head'
def forward(self, input_ids, attention_mask,
tokentype_ids=None):
def forward(self, input_ids, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
......@@ -193,7 +167,6 @@ class BertModel(MegatronModule):
return lm_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -211,7 +184,6 @@ class BertModel(MegatronModule):
= self.binary_head.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classification model."""
import torch
from megatron import get_args
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron import print_rank_0
class Classification(MegatronModule):
def __init__(self, num_classes, num_tokentypes=2):
super(Classification, self).__init__()
args = get_args()
self.num_classes = num_classes
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
# Multi-choice head.
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes,
init_method)
self._classification_head_key = 'classification_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output.
classification_output = self.classification_dropout(pooled_output)
classification_logits = self.classification_head(classification_output)
# Reshape back to separate choices.
classification_logits = classification_logits.view(-1, self.num_classes)
return classification_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._classification_head_key] \
= self.classification_head.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self._classification_head_key in state_dict:
self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict)
else:
print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._classification_head_key))
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -71,8 +71,8 @@ class DistributedDataParallel(MegatronModule):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook)
#self.hooks.append(allreduce_hook)
#self.hook_handles.append(handle)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self.allreduce_params = allreduce_params
def forward(self, *inputs, **kwargs):
......@@ -114,4 +114,3 @@ class DistributedDataParallel(MegatronModule):
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -17,6 +17,7 @@
import torch
from megatron import get_args
from megatron.module import MegatronModule
from .language_model import parallel_lm_logits
......@@ -26,61 +27,30 @@ from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
class GPT2Model(MegatronModule):
"""GPT-2 Language model."""
def __init__(self,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
checkpoint_activations,
checkpoint_num_layers=1,
layernorm_epsilon=1.0e-5,
init_method_std=0.02,
num_tokentypes=0,
parallel_output=True,
apply_query_key_layer_scaling=False,
attention_softmax_in_fp32=False):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2Model, self).__init__()
args = get_args()
self.parallel_output = parallel_output
self.language_model, self._language_model_key = get_language_model(
num_layers=num_layers,
vocab_size=vocab_size,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
embedding_dropout_prob=embedding_dropout_prob,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
max_sequence_length=max_sequence_length,
attention_mask_func=gpt2_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=False,
attention_mask_func=gpt2_attention_mask_func,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method_normal(init_method_std),
scaled_init_method=scaled_init_method_normal(init_method_std,
num_layers),
residual_connection_post_layernorm=False,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
# Language model.
lm_output = self.language_model(input_ids,
......@@ -94,17 +64,19 @@ class GPT2Model(MegatronModule):
lm_output, presents = lm_output
# Output.
parallel_output = self.parallel_output
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
output = parallel_lm_logits(
lm_output,
self.language_model.embedding.word_embeddings.weight,
self.parallel_output)
parallel_output)
if get_key_value:
output = [output, presents]
return output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -114,7 +86,6 @@ class GPT2Model(MegatronModule):
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,13 +18,12 @@
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from .transformer import ParallelTransformer
from .transformer import TransformerHyperparameters
from .utils import gelu
from .utils import get_linear_layer
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
......@@ -40,52 +39,26 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
# Gather if needed.
if parallel_output:
return logits_parallel
else:
return mpu.gather_from_model_parallel_region(logits_parallel)
def get_language_model(num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
max_sequence_length,
num_tokentypes,
attention_mask_func,
add_pooler,
checkpoint_activations,
checkpoint_num_layers,
layernorm_epsilon,
init_method,
scaled_init_method,
residual_connection_post_layernorm,
apply_query_key_layer_scaling,
attention_softmax_in_fp32):
# Transformer hyperparameters.
transformer_hparams = TransformerHyperparameters(
hidden_size=hidden_size,
num_layers=num_layers,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
output_dropout_prob=output_dropout_prob,
mlp_activation_func=gelu,
layernorm_epsilon=layernorm_epsilon,
init_method=init_method,
output_layer_init_method=scaled_init_method,
checkpoint_activations=checkpoint_activations,
checkpoint_num_layers=checkpoint_num_layers,
apply_residual_connection_post_layernorm=residual_connection_post_layernorm,
apply_query_key_layer_scaling=apply_query_key_layer_scaling,
attention_softmax_in_fp32=attention_softmax_in_fp32)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method, scaled_init_method):
"""Build language model and return along with the key to save."""
args = get_args()
# Use torch gelu unless otherwise forced.
gelu = F.gelu
if args.openai_gelu:
gelu = openai_gelu
# Language model.
language_model = TransformerLanguageModel(
transformer_hparams=transformer_hparams,
attention_mask_func=attention_mask_func,
vocab_size=vocab_size,
max_sequence_length=max_sequence_length,
embedding_dropout_prob=embedding_dropout_prob,
mlp_activation_func=gelu,
init_method=init_method,
output_layer_init_method=scaled_init_method,
num_tokentypes=num_tokentypes,
add_pooler=add_pooler)
# key used for checkpoints.
......@@ -94,7 +67,6 @@ def get_language_model(num_layers,
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
......@@ -106,11 +78,11 @@ class Pooler(MegatronModule):
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
......@@ -133,6 +105,7 @@ class Embedding(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
......@@ -174,7 +147,6 @@ class Embedding(MegatronModule):
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
......@@ -191,7 +163,6 @@ class Embedding(MegatronModule):
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
......@@ -208,7 +179,6 @@ class Embedding(MegatronModule):
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
......@@ -226,7 +196,6 @@ class Embedding(MegatronModule):
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......@@ -273,7 +242,6 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True)
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
......@@ -292,34 +260,35 @@ class TransformerLanguageModel(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
transformer_hparams,
attention_mask_func,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
mlp_activation_func,
init_method,
output_layer_init_method,
num_tokentypes=0,
add_pooler=False):
super(TransformerLanguageModel, self).__init__()
args = get_args()
self.hidden_size = transformer_hparams['hidden_size']
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = transformer_hparams['init_method']
self.init_method = init_method
self.add_pooler = add_pooler
# Embeddings
self.embedding = Embedding(self.hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes)
self._embedding_key = 'embedding'
# Transformer
self.transformer = ParallelTransformer(
transformer_hparams,
attention_mask_func)
attention_mask_func, mlp_activation_func,
self.init_method, output_layer_init_method)
self._transformer_key = 'transformer'
# Pooler
......@@ -327,7 +296,6 @@ class TransformerLanguageModel(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
......@@ -349,7 +317,6 @@ class TransformerLanguageModel(MegatronModule):
return transformer_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
......@@ -368,7 +335,6 @@ class TransformerLanguageModel(MegatronModule):
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multiple choice model."""
import torch
from megatron import get_args
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron import print_rank_0
class MultipleChoice(MegatronModule):
def __init__(self, num_tokentypes=2):
super(MultipleChoice, self).__init__()
args = get_args()
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
# Multi-choice head.
self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
self.multichoice_head = get_linear_layer(args.hidden_size, 1,
init_method)
self._multichoice_head_key = 'multichoice_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
# transformer --> [batch, choices] --> softmax
# Ensure the shape is [batch-size, choices, sequence]
assert len(input_ids.shape) == 3
assert len(attention_mask.shape) == 3
assert len(tokentype_ids.shape) == 3
# Reshape and treat choice dimension the same as batch.
num_choices = input_ids.shape[1]
input_ids = input_ids.view(-1, input_ids.size(-1))
attention_mask = attention_mask.view(-1, attention_mask.size(-1))
tokentype_ids = tokentype_ids.view(-1, tokentype_ids.size(-1))
extended_attention_mask = bert_extended_attention_mask(
attention_mask, next(self.language_model.parameters()).dtype)
position_ids = bert_position_ids(input_ids)
_, pooled_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# Output.
multichoice_output = self.multichoice_dropout(pooled_output)
multichoice_logits = self.multichoice_head(multichoice_output)
# Reshape back to separate choices.
multichoice_logits = multichoice_logits.view(-1, num_choices)
return multichoice_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self._multichoice_head_key in state_dict:
self.multichoice_head.load_state_dict(
state_dict[self._multichoice_head_key], strict=strict)
else:
print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._multichoice_head_key))
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -20,6 +20,7 @@ import math
import torch
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
......@@ -46,84 +47,6 @@ from megatron.module import MegatronModule
"""
class TransformerHyperparameters:
"""Hyperparameters used to build and run the transformer.
Arguments:
hidden_size: hidden size (h)
num_layers: number of layers (l)
num_attention_heads: number of attention heads (n)
attention_dropout_prob: dropout probability for the attention
probabiliies
output_dropout_prob: dropout probability for the output
layers (attention output and mlp output)
mlp_activation_func: activation function for the mlp layer
layernorm_epsilon: tolerance parameters used for layer norm
dividions
init_method: init method used for all weights except layer
norm and output weights
output_layer_init_method: init method for output weights (
attention output and mlp output)
checkpoint_activations: flag to use activation checkpointing
checkpoint_num_layers: number of layers use in each chunk of
activation checkpointing
apply_residual_connection_post_layernorm: Take the post layer-norm
values for resudual connecton. BERT: True, GPT-2: False
"""
def __init__(self,
hidden_size=None,
num_layers=None,
num_attention_heads=None,
attention_dropout_prob=None,
output_dropout_prob=None,
mlp_activation_func=None,
layernorm_epsilon=None,
init_method=None,
output_layer_init_method=None,
checkpoint_activations=None,
checkpoint_num_layers=None,
apply_residual_connection_post_layernorm=None,
apply_query_key_layer_scaling=None,
attention_softmax_in_fp32=None):
self.params_dict = {}
self.params_dict['hidden_size'] = hidden_size
self.params_dict['num_layers'] = num_layers
self.params_dict['num_attention_heads'] = num_attention_heads
self.params_dict['attention_dropout_prob'] = attention_dropout_prob
self.params_dict['output_dropout_prob'] = output_dropout_prob
self.params_dict['mlp_activation_func'] = mlp_activation_func
self.params_dict['layernorm_epsilon'] = layernorm_epsilon
self.params_dict['init_method'] = init_method
self.params_dict['output_layer_init_method'] = output_layer_init_method
self.params_dict['checkpoint_activations'] = checkpoint_activations
self.params_dict['checkpoint_num_layers'] = checkpoint_num_layers
self.params_dict['apply_residual_connection_post_layernorm'] \
= apply_residual_connection_post_layernorm
self.params_dict['apply_query_key_layer_scaling'] \
= apply_query_key_layer_scaling
self.params_dict['attention_softmax_in_fp32'] \
= attention_softmax_in_fp32
def __getitem__(self, key):
"""Custom retrieval with error checks."""
try:
value = self.params_dict[key]
except KeyError:
raise Exception(
'could not find {} in transformer hyperparameters'.format(key))
except Exception as e:
print('unexpected error in transformer hyperparameters:', e)
raise Exception()
else:
assert value is not None, \
'parameter value for {} is not set in transformer '\
'hyperparameters'.format(key)
return value
raise Exception('should not be here')
class ParallelMLP(MegatronModule):
"""MLP.
......@@ -133,27 +56,28 @@ class ParallelMLP(MegatronModule):
applied.
"""
def __init__(self, hyperparameters):
def __init__(self, mlp_activation_func, init_method,
output_layer_init_method):
super(ParallelMLP, self).__init__()
args = get_args()
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
hyperparameters['hidden_size'],
4*hyperparameters['hidden_size'],
args.hidden_size,
4 * args.hidden_size,
gather_output=False,
init_method=hyperparameters['init_method'])
init_method=init_method)
self.activation_func = hyperparameters['mlp_activation_func']
self.activation_func = mlp_activation_func
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4*hyperparameters['hidden_size'],
hyperparameters['hidden_size'],
4 * args.hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=hyperparameters['output_layer_init_method'])
self.dropout = torch.nn.Dropout(hyperparameters['output_dropout_prob'])
init_method=output_layer_init_method)
self.dropout = torch.nn.Dropout(args.hidden_dropout)
def forward(self, hidden_states):
......@@ -167,7 +91,6 @@ class ParallelMLP(MegatronModule):
return output
class ParallelSelfAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
......@@ -175,51 +98,48 @@ class ParallelSelfAttention(MegatronModule):
and returns output of the same size.
"""
def __init__(self, hyperparameters, attention_mask_func, layer_number):
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling \
= hyperparameters['apply_query_key_layer_scaling']
self.attention_softmax_in_fp32 \
= hyperparameters['attention_softmax_in_fp32']
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
# Per attention head and per partition values.
world_size = mpu.get_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(
hyperparameters['hidden_size'], world_size)
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
hyperparameters['hidden_size'],
hyperparameters['num_attention_heads'])
args.hidden_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
hyperparameters['num_attention_heads'], world_size)
args.num_attention_heads, world_size)
# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
hyperparameters['hidden_size'],
3*hyperparameters['hidden_size'],
args.hidden_size,
3 * args.hidden_size,
stride=3,
gather_output=False,
init_method=hyperparameters['init_method'])
init_method=init_method)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(
hyperparameters['attention_dropout_prob'])
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
# Output.
self.dense = mpu.RowParallelLinear(
hyperparameters['hidden_size'],
hyperparameters['hidden_size'],
args.hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=hyperparameters['output_layer_init_method'])
self.output_dropout = torch.nn.Dropout(
hyperparameters['output_dropout_prob'])
init_method=output_layer_init_method)
self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
......@@ -231,7 +151,6 @@ class ParallelSelfAttention(MegatronModule):
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def _get_query_key_value(self, hidden_states):
"""Get query, key, and value and transpose to
get size [b, np, s, hn].
......@@ -249,7 +168,6 @@ class ParallelSelfAttention(MegatronModule):
return query_layer, key_layer, value_layer
def _get_unmasked_attention_scores(self, query_layer, key_layer):
"""Unmasked attention scores with size [b, np, s, s]."""
coeff = 1
......@@ -258,9 +176,8 @@ class ParallelSelfAttention(MegatronModule):
norm_factor = math.sqrt(coeff *
math.sqrt(self.hidden_size_per_attention_head))
# Raw attention scores. [b, np, s, s]
return torch.matmul(query_layer/norm_factor,
key_layer.transpose(-1, -2)/norm_factor)
return torch.matmul(query_layer / norm_factor,
key_layer.transpose(-1, -2) / norm_factor)
def _get_attention_probs(self, attention_scores):
"""Attention probabilies with dropout. The output has
......@@ -277,7 +194,6 @@ class ParallelSelfAttention(MegatronModule):
return attention_probs
def _get_attended_context(self, attention_probs, value_layer):
"""Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer.
......@@ -292,7 +208,6 @@ class ParallelSelfAttention(MegatronModule):
return context_layer
def _get_output(self, context_layer):
"""Output layer with dropout."""
# Output. [b, s, h]
......@@ -301,7 +216,6 @@ class ParallelSelfAttention(MegatronModule):
return output
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
......@@ -324,7 +238,7 @@ class ParallelSelfAttention(MegatronModule):
query_layer, key_layer)
# fp32 conversion.
if self.attention_softmax_in_fp32:
if self.fp16 and self.attention_softmax_in_fp32:
attention_scores = attention_scores.float()
# Apply attention mask. [b, np, s, s]
......@@ -333,7 +247,7 @@ class ParallelSelfAttention(MegatronModule):
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3)-1,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
......@@ -347,7 +261,7 @@ class ParallelSelfAttention(MegatronModule):
attention_probs = self._get_attention_probs(attention_scores)
# fp16 conversion
if self.attention_softmax_in_fp32:
if self.fp16 and self.attention_softmax_in_fp32:
attention_probs = attention_probs.half()
# Context layer. [b, s, hp]
......@@ -362,38 +276,41 @@ class ParallelSelfAttention(MegatronModule):
return output
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, hyperparameters, attention_mask_func, layer_number):
def __init__(self, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm \
= hyperparameters['apply_residual_connection_post_layernorm']
= args.apply_residual_connection_post_layernorm
# Layernorm on the input data.
self.input_layernorm = LayerNorm(
hyperparameters['hidden_size'],
eps=hyperparameters['layernorm_epsilon'])
args.hidden_size,
eps=args.layernorm_epsilon)
# Self attention.
self.attention = ParallelSelfAttention(
hyperparameters, attention_mask_func, layer_number)
self.attention = ParallelSelfAttention(attention_mask_func, init_method,
output_layer_init_method,
layer_number)
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(
hyperparameters['hidden_size'],
eps=hyperparameters['layernorm_epsilon'])
args.hidden_size,
eps=args.layernorm_epsilon)
# MLP
self.mlp = ParallelMLP(hyperparameters)
self.mlp = ParallelMLP(mlp_activation_func, init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
......@@ -434,26 +351,28 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, hyperparameters, attention_mask_func):
def __init__(self, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method):
super(ParallelTransformer, self).__init__()
args = get_args()
# Store activation checkpoiting flag.
self.checkpoint_activations = hyperparameters['checkpoint_activations']
self.checkpoint_num_layers = hyperparameters['checkpoint_num_layers']
self.checkpoint_activations = args.checkpoint_activations
self.checkpoint_num_layers = args.checkpoint_num_layers
def get_layer(layer_number):
return ParallelTransformerLayer(
hyperparameters, attention_mask_func, layer_number)
attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number)
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer(i+1) for i in range(hyperparameters['num_layers'])])
[get_layer(i + 1) for i in range(args.num_layers)])
# Final layer norm before output.
self.final_layernorm = LayerNorm(
hyperparameters['hidden_size'],
eps=hyperparameters['layernorm_epsilon'])
args.hidden_size,
eps=args.layernorm_epsilon)
def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing."""
......@@ -470,13 +389,12 @@ class ParallelTransformer(MegatronModule):
num_layers = len(self.layers)
while l < num_layers:
hidden_states = mpu.checkpoint(
custom(l, l+self.checkpoint_num_layers),
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask)
l += self.checkpoint_num_layers
return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -33,6 +33,7 @@ def init_method_normal(sigma):
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
......@@ -53,8 +54,7 @@ def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x)))
def gelu(x):
def openai_gelu(x):
return gelu_impl(x)
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron Module"""
import torch
......@@ -22,11 +21,9 @@ import torch
class MegatronModule(torch.nn.Module):
"""Megatron specific extentions of torch Module."""
def __init__(self):
super(MegatronModule, self).__init__()
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Use this function to override the state dict for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment