Commit 5e56e563 authored by Neel Kant's avatar Neel Kant
Browse files

Merge master into realm-mlm

parents 6c0a5bd8 569b3dab
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -28,7 +28,9 @@ from .wordpiece import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from .tokenization_gpt2 import GPT2Tokenizer
import regex as re
def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type='bpe', pad_token=0, character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs):
def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type='bpe',
pad_token=0, character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs):
"""
Helper function to instantiate a tokenizer given common combinations of options.
"""
......@@ -43,6 +45,7 @@ def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, mod
pad_token=pad_token, character_coverage=character_coverage)
return Tokenizer(text_tokenizer, command_tokens, type_tokens)
class Tokenization(object):
"""
Tokenization object to hold tokenization, (processed text),and original
......@@ -54,7 +57,9 @@ class Tokenization(object):
Several standard array operations are implemented (insert, append, extend).
"""
def __init__(self, tokenization, text=None, original_text=None, command_tokens=None, asIds=True):
def __init__(self, tokenization, text=None, original_text=None,
command_tokens=None, asIds=True):
self.tokenization = tokenization
self.text = text
if self.text is None:
......@@ -91,13 +96,15 @@ class Tokenization(object):
if idx == 0:
self.text = other.token + self.text
self.original_text = other.token + self.original_text
elif idx == len(self.tokenization)-1:
elif idx == len(self.tokenization) - 1:
self.text += other.token
self.original_text += other.token
elif isinstance(other, Tokenization):
self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:]
self.tokenization = self.tokenization[:idx] + \
other.tokenization + self.tokenization[idx:]
else:
self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:]
self.tokenization = self.tokenization[:idx] + \
other.tokenization + self.tokenization[idx:]
def append(self, other):
if isinstance(other, (CommandToken, TypeToken)):
......@@ -129,14 +136,17 @@ class Tokenization(object):
self.tokenization.extend(other)
return self
"""define some default command tokens for the tokenizer to use"""
token_format = "<{0}>"
COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id'))
def prep_command_tokens(tokenlist, token_format=token_format):
return [CommandToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
class CommandToken(object):
def __init__(self, name, token, Id):
self.name = name
......@@ -146,6 +156,7 @@ class CommandToken(object):
def __str__(self):
return str(COMMAND_TUPLE(self.name, self.token, self.Id))
DEFAULT_COMMAND_TOKENS = [
('pad', 0),
('eos', 1),
......@@ -162,9 +173,11 @@ DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS)
TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id'))
def prep_type_tokens(tokenlist, token_format=token_format):
return [TypeToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
class TypeToken(object):
def __init__(self, name, token, Id):
self.name = name
......@@ -174,6 +187,7 @@ class TypeToken(object):
def __str__(self):
return str(TYPE_TUPLE(self.name, self.token, self.Id))
DEFAULT_TYPE_TOKENS = [
('function', 0),
('command', 1),
......@@ -189,6 +203,7 @@ DEFAULT_TYPE_TOKENS = [
]
DEFAULT_TYPE_TOKENS = prep_type_tokens(DEFAULT_TYPE_TOKENS)
class Tokenizer(object):
"""
Tokenizer object that handles text tokenization, command tokens, and type tokens.
......@@ -199,6 +214,7 @@ class Tokenizer(object):
Token types are stored in a separate mapping of size `len(type_tokens)`.
"""
def __init__(self, text_tokenizer, command_tokens=None, type_tokens=None):
# set text tokenizer
self.text_tokenizer = text_tokenizer
......@@ -229,18 +245,20 @@ class Tokenizer(object):
# parse tokens and vocabs from tokenizer
self._tokens = list(self.command_token_map.keys()) + list(self.text_tokenizer.tokens)
self._vocab = {t:Id for Id,t in self.command_id_map.items()}
self._vocab.update({t:Id+self.num_command_tokens for t,Id in self.text_tokenizer.vocab.items()})
self._vocab = {t: Id for Id, t in self.command_id_map.items()}
self._vocab.update({t: Id + self.num_command_tokens for t,
Id in self.text_tokenizer.vocab.items()})
self._text_tokens = list(self.text_tokenizer.tokens)
self._text_token_vocab = {t:Id+self.num_command_tokens for t,Id in self.text_tokenizer.vocab.items()}
self._text_token_vocab = {
t: Id + self.num_command_tokens for t,
Id in self.text_tokenizer.vocab.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()}
self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()}
self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
def __call__(self, text, process_fn=None):
"""run preprocessing and encode text as Ids"""
......@@ -303,7 +321,7 @@ class Tokenizer(object):
encode text using text tokenizer and shift Id values for command tokens
"""
tokenization = self.text_tokenizer.EncodeAsIds(text, process_fn=process_fn)
tokenization.tokenization = [t+self.num_command_tokens for t in tokenization.tokenization]
tokenization.tokenization = [t + self.num_command_tokens for t in tokenization.tokenization]
tokenization.set_command_tokens(self._command_tokens)
return tokenization
......@@ -323,7 +341,7 @@ class Tokenizer(object):
return self.type_id_map[Id].token
if Id < self.num_command_tokens:
return self.command_id_map[Id].token
return self.text_tokenizer.IdToToken(Id-self.num_command_tokens)
return self.text_tokenizer.IdToToken(Id - self.num_command_tokens)
def TokenToId(self, token, type_token=False):
"""convert token to Id accounting for command and type tokens"""
......@@ -333,7 +351,7 @@ class Tokenizer(object):
return self.type_token_map[token].Id
if token in self.command_token_map:
return self.command_token_map[token].Id
return self.text_tokenizer.TokenToId(token)+self.num_command_tokens
return self.text_tokenizer.TokenToId(token) + self.num_command_tokens
def DecodeIds(self, Ids, type_token=False):
"""
......@@ -341,7 +359,8 @@ class Tokenizer(object):
are joined and returned as a string.
"""
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
return ' '.join(Id.token if isinstance(Id, TypeToken)
else self.type_id_map[Id].token for Id in Ids)
rtn_strs = []
current_str = []
if isinstance(Ids, Tokenization):
......@@ -386,10 +405,12 @@ class Tokenizer(object):
rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
return ' '.join(rtn_strs)
class TextTokenizer(object):
"""
Interface for text tokenizer
"""
def __init__(self):
if not hasattr(self, 'num_text_tokens'):
self.num_text_tokens = 0
......@@ -456,11 +477,12 @@ class CharacterLevelTokenizer(TextTokenizer):
"""
Text tokenizer for ASCII-256 Character Level Tokenization.
"""
def __init__(self, **kwargs):
self.num_text_tokens = 256
super(CharacterLevelTokenizer, self).__init__()
self._tokens = [self.IdToToken(Id) for Id in range(self.num_text_tokens)]
self._vocab = {t: i for i,t in enumerate(self._tokens)}
self._vocab = {t: i for i, t in enumerate(self._tokens)}
def __len__(self):
return 256
......@@ -521,6 +543,7 @@ class CharacterLevelTokenizer(TextTokenizer):
MAX_SENTENCEPIECE_SENTENCES = 100000000
def get_corpus_freq(dataset, filepath, filetype='tsv'):
"""
Take corpus, split it into sentences, and extract word frequencies.
......@@ -556,14 +579,13 @@ def get_corpus_freq(dataset, filepath, filetype='tsv'):
print("file path for freq " + str(filepath), flush=True)
freqs_sorted = {}
counter=0
counter = 0
for word, count in sorted(freqs.items(), key=lambda x: x[1], reverse=True):
if counter >= MAX_SENTENCEPIECE_SENTENCES:
break
counter+=1
counter += 1
freqs_sorted[word] = count
print("length of freqs after trancating " + str(len(freqs_sorted)), flush=True)
with open(filepath, 'w') as f:
......@@ -573,9 +595,12 @@ def get_corpus_freq(dataset, filepath, filetype='tsv'):
return total_sentence_count, maxlen
class SentencePieceTokenizer(TextTokenizer):
"""Trains and uses sentencepiece for text tokenization"""
def __init__(self, model_type='bpe', vocab_size=None, corpus=None, model_path=None, character_coverage=1.0, **kwargs):
def __init__(self, model_type='bpe', vocab_size=None, corpus=None,
model_path=None, character_coverage=1.0, **kwargs):
self.character_coverage = character_coverage
self.model_type = model_type.lower()
self.spm_model = model_path
......@@ -608,18 +633,18 @@ class SentencePieceTokenizer(TextTokenizer):
dne = not os.path.exists(model_path)
# check if path.model exists
if dne and not model_path.endswith('.model'):
dne = not os.path.exists(model_path+'.model')
dne = not os.path.exists(model_path + '.model')
return not dne
def load_spm_model(self):
"""load sentencepiece model and parse vocab"""
if not os.path.exists(self.spm_model) and not self.spm_model.endswith('.model'):
self.spm_model = self.spm_model+'.model'
self.spm_model = self.spm_model + '.model'
self.sp = spm.SentencePieceProcessor()
self.sp.Load(self.spm_model)
self.vocab_size = self.num_text_tokens = len(self.sp)
self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)]
self._vocab = {t: i for i,t in enumerate(self._tokens)}
self._vocab = {t: i for i, t in enumerate(self._tokens)}
def Train(self, corpus, num_text_tokens):
"""train sentencepiece model on corpus using word frequencies"""
......@@ -630,7 +655,7 @@ class SentencePieceTokenizer(TextTokenizer):
use_model_path = random_hash
if use_model_path.endswith('.model'):
use_model_path = use_model_path[:use_model_path.rfind('.model')]
input_path = use_model_path+'.tsv.'+random_hash
input_path = use_model_path + '.tsv.' + random_hash
line_count, maxlenline = get_corpus_freq(corpus, input_path)
line_count = min(line_count, MAX_SENTENCEPIECE_SENTENCES)
print('line count used as input_sentence_size ', line_count, flush=True)
......@@ -641,12 +666,12 @@ class SentencePieceTokenizer(TextTokenizer):
+ '--input_format=tsv'
train_string = train_string.format(file_path=input_path, model_prefix=use_model_path, vocab_size=num_text_tokens,
model_type=self.model_type, character_coverage=self.character_coverage,
input_sentence_size=int(line_count)) #, #)#,
print("calling spm.SentencePieceTrainer.Train(%s)"%(train_string), flush=True)
input_sentence_size=int(line_count)) # , #)#,
print("calling spm.SentencePieceTrainer.Train(%s)" % (train_string), flush=True)
spm.SentencePieceTrainer.Train(train_string)
os.remove(input_path)
self.spm_model = use_model_path+'.model'
print('sentencepiece model written to '+self.spm_model, flush=True)
self.spm_model = use_model_path + '.model'
print('sentencepiece model written to ' + self.spm_model, flush=True)
def EncodeAsIds(self, text, process_fn=None):
"""convert text to sentencepiece Ids"""
......@@ -684,19 +709,26 @@ class SentencePieceTokenizer(TextTokenizer):
Tokens = Tokens.tokenization
return self.sp.DecodeTokens(Tokens)
class BertWordPieceTokenizer(Tokenizer):
"""
Loads a pretrained WordPiece tokenizer from `cache_dir` for tokenization
in BERT training. Default to bert-large-uncased tokenizer.
"""
def __init__(self, tokenizer_model_type=None, cache_dir=None, **kwargs):
# default to bert-large-uncased tokenizer
if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP:
tokenizer_model_type = 'bert-large-uncased'
if torch.distributed.get_rank() == 0:
print('loading BertWordPieceTokenizer (', tokenizer_model_type, ') from cache_dir ', cache_dir)
print(
'loading BertWordPieceTokenizer (',
tokenizer_model_type,
') from cache_dir ',
cache_dir)
do_lower_case = not ('-cased' in tokenizer_model_type or 'chinese' in tokenizer_model_type)
self.text_tokenizer = BertTokenizer.from_pretrained(tokenizer_model_type, do_lower_case=do_lower_case, cache_dir=cache_dir)
self.text_tokenizer = BertTokenizer.from_pretrained(
tokenizer_model_type, do_lower_case=do_lower_case, cache_dir=cache_dir)
if torch.distributed.get_rank() == 0:
print('loaded', tokenizer_model_type)
# disable max len warnings by increasing max len
......@@ -705,7 +737,7 @@ class BertWordPieceTokenizer(Tokenizer):
# set command tokens from wordpiece tokenizer values
self.num_command_tokens = 5
self.num_tokens = len(self.text_tokenizer.vocab)
self.num_text_tokens = self.num_tokens-5
self.num_text_tokens = self.num_tokens - 5
self.num_type_tokens = 2
self._command_tokens = [
......@@ -731,16 +763,16 @@ class BertWordPieceTokenizer(Tokenizer):
# parse tokens and vocabs from tokenizer
self._tokens = list(self.text_tokenizer.vocab.keys())
self._vocab = {k:v for k,v in self.text_tokenizer.vocab.items()}
self._vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
self._text_tokens = list(self._tokens)
self._text_token_vocab = {k:v for k,v in self.text_tokenizer.vocab.items()}
self._text_token_vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()}
self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()}
self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
def EncodeAsIds(self, text, process_fn=None):
"""convert text to wordpiece Ids"""
......@@ -778,7 +810,8 @@ class BertWordPieceTokenizer(Tokenizer):
def DecodeIds(self, Ids, type_token=False):
"""converts ids to wordpiece tokens and joins them as a text string"""
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
return ' '.join(Id.token if isinstance(Id, TypeToken)
else self.type_id_map[Id].token for Id in Ids)
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
Tokens = []
......@@ -795,16 +828,17 @@ class BertWordPieceTokenizer(Tokenizer):
Tokens = Tokens.tokenization
return ' '.join(Tokens)
class GPT2BPETokenizer(Tokenizer):
def __init__(self, cache_dir=None, **kwargs):
self.text_tokenizer = GPT2Tokenizer.from_pretrained('gpt2',
cache_dir=cache_dir)
#disable max len warnings by increasing max len
# disable max len warnings by increasing max len
self.text_tokenizer.max_len = int(1e12)
self.num_command_tokens = 2
self.num_tokens = len(self.text_tokenizer.encoder)
self.num_text_tokens = self.num_tokens-1
self.num_text_tokens = self.num_tokens - 1
self.num_type_tokens = 2
self._command_tokens = [
......@@ -824,28 +858,27 @@ class GPT2BPETokenizer(Tokenizer):
self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
self._tokens = list(self.text_tokenizer.encoder.keys())
self._vocab = {k:v for k,v in self.text_tokenizer.encoder.items()}
self._vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
self._text_tokens = list(self._tokens)
self._text_token_vocab = {k:v for k,v in self.text_tokenizer.encoder.items()}
self._text_token_vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()}
self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()}
self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
def EncodeAsIds(self, text, process_fn=None):
processed_text = text
if process_fn is not None:
processed_text = process_fn(processed_text)
Ids = self.text_tokenizer.encode(processed_text)
#return Tokenization(Ids, processed_text, text)
# return Tokenization(Ids, processed_text, text)
tokenization = Tokenization(Ids, processed_text, text)
tokenization.set_command_tokens(self._command_tokens)
return tokenization
def EncodeAsTokens(self, text, process_fn=None):
processed_text = text
if process_fn is not None:
......@@ -854,10 +887,10 @@ class GPT2BPETokenizer(Tokenizer):
for token in re.findall(self.text_tokenizer.pat, processed_text):
token = ''.join(self.text_tokenizer.bye_encoder[b] for b in token.encode('utf-8'))
tokens.extend(bpe_token for bpe_token in self.text_tokenizer.bpe(token).split(' '))
tokenization=Tokenization(tokens, processed_text, text, asIds=False)
tokenization = Tokenization(tokens, processed_text, text, asIds=False)
tokenization.set_command_tokens(self._command_tokens)
return tokenization
#return Tokenization(tokens, processed_text, text, asIds=False)
# return Tokenization(tokens, processed_text, text, asIds=False)
def IdToToken(self, Id, type_token=False):
if isinstance(Id, (TypeToken, CommandToken)):
......@@ -875,7 +908,8 @@ class GPT2BPETokenizer(Tokenizer):
def DecodeIds(self, Ids, type_token=False):
if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids)
return ' '.join(Id.token if isinstance(Id, TypeToken)
else self.type_id_map[Id].token for Id in Ids)
if isinstance(Ids, Tokenization):
Ids = Ids.tokenization
return self.text_tokenizer.decode(Ids)
......
......@@ -27,7 +27,8 @@ try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
......@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
......@@ -60,17 +62,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
......@@ -83,6 +87,7 @@ def get_pairs(word):
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
......@@ -138,23 +143,31 @@ class GPT2Tokenizer(object):
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
......@@ -172,8 +185,9 @@ class GPT2Tokenizer(object):
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
self.special_tokens = dict((tok, len(self.encoder) + i)
for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
......@@ -186,7 +200,7 @@ class GPT2Tokenizer(object):
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
......@@ -197,12 +211,12 @@ class GPT2Tokenizer(object):
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
except BaseException:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
......@@ -245,7 +259,8 @@ class GPT2Tokenizer(object):
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
" sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
......
......@@ -123,7 +123,8 @@ class BertTokenizer(object):
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
" sequence through BERT will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -22,12 +22,15 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron.module import MegatronModule
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
......@@ -37,6 +40,7 @@ def conversion_helper(val, conversion):
rtn = tuple(rtn)
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def half_conversion(val):
......@@ -48,6 +52,7 @@ def fp32_to_fp16(val):
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float_conversion(val):
......@@ -59,6 +64,7 @@ def fp16_to_fp32(val):
return val
return conversion_helper(val, float_conversion)
class FP16_Module(MegatronModule):
def __init__(self, module):
super(FP16_Module, self).__init__()
......@@ -79,6 +85,8 @@ class FP16_Module(MegatronModule):
self.module.load_state_dict(state_dict, strict=strict)
# TODO: Update overflow check + downscale to use Carl's fused kernel.
class FP16_Optimizer(object):
"""
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
......@@ -305,7 +313,8 @@ class FP16_Optimizer(object):
master_params_to_model_params(fp32_from_fp16_group, fp16_group)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
# that does the overflow check, gradient copy + downscale, and fp32
# allreduce in a different stream.
def _model_grads_to_master_grads(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
......@@ -313,9 +322,12 @@ class FP16_Optimizer(object):
def _downscale_master(self):
if self.loss_scale != 1.0:
for group in self.optimizer.param_groups:
for param in group['params']:
if param.grad is not None:
param.grad.data.mul_(1./self.loss_scale)
grads = [p.grad for p in group['params'] if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grads, grads],
1./self.loss_scale)
def clip_master_grads(self, max_norm, norm_type=2):
"""
......@@ -400,7 +412,8 @@ class FP16_Optimizer(object):
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
for current_group, saved_group in zip(
self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
for current, saved in zip(current_group, saved_group):
current.data.copy_(saved.data)
......@@ -570,7 +583,8 @@ class FP16_Optimizer(object):
"""
if self.dynamic_loss_scale:
self._check_overflow()
if self.overflow: return
if self.overflow:
return
self._model_grads_to_master_grads()
self._downscale_master()
......@@ -607,8 +621,8 @@ class FP16_Optimizer(object):
master_grads_data.append(master_grads_this_group)
return master_grads_data
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,6 +18,9 @@ import torch.nn as nn
from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu
......@@ -102,6 +105,7 @@ class FP16Model(nn.Module):
def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!")
def prep_param_lists(model, flat_master=False):
"""
Creates a list of FP32 master parameters for a given model, as in
......@@ -131,7 +135,7 @@ def prep_param_lists(model, flat_master=False):
# flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
except:
except BaseException:
print("Error in prep_param_lists: model may contain a mixture of parameters "
"of different types. Use flat_master=False, or use F16_Optimizer.")
raise
......@@ -165,9 +169,15 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False):
if model.grad is not None:
if master.grad is None:
master.grad = Variable(master.data.new(*master.data.size()))
master.grad.data.copy_(model.grad.data)
else:
master.grad = None
model_grads = [p.grad for p in model_params if p.grad is not None]
master_grads = [p.grad for p in master_params if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[model_grads, master_grads],
1.0)
def master_params_to_model_params(model_params, master_params, flat_master=False):
......@@ -188,17 +198,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
# Backward compatibility fixes
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
clip_grad_norm = mpu.clip_grad_norm
#elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# clip_grad_norm = torch.nn.utils.clip_grad_norm
#else:
# else:
# clip_grad_norm = torch.nn.utils.clip_grad_norm_
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,15 +14,22 @@
# limitations under the License.
import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
class LossScaler:
"""
Class that manages a static loss scale. This class is intended to interact with
......@@ -54,12 +61,18 @@ class LossScaler:
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
class DynamicLossScaler:
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
......@@ -122,8 +135,8 @@ class DynamicLossScaler:
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
......@@ -158,7 +171,7 @@ class DynamicLossScaler:
if overflow:
# self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
self.cur_scale = max(self.cur_scale/self.scale_factor, self.min_scale)
self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
else:
self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter
......@@ -176,12 +189,18 @@ class DynamicLossScaler:
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
##############################################################
# Example usage below here -- assuming it's in a separate file
##############################################################
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -61,22 +61,26 @@ def get_timers():
return _GLOBAL_TIMERS
def set_global_variables(extra_args_provider=None, args_defaults={}):
def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults)
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers()
def _parse_args(extra_args_provider=None, defaults={}):
def _parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse entire arguments."""
global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults)
defaults=defaults,
ignore_unknown_args=ignore_unknown_args)
return _GLOBAL_ARGS
......@@ -124,7 +128,7 @@ def _set_adlr_autoresume(args):
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try:
from userlib.auto_resume import AutoResume
except:
except BaseException:
print('ADLR autoresume is not available, exiting ...')
sys.exit()
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -28,7 +28,8 @@ from megatron import mpu
from megatron.global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None, args_defaults={}):
def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set global variables, initialize distributed, and
set autoresume and random seeds."""
# Make sure cuda is available.
......@@ -37,7 +38,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}):
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
# Pytorch distributed.
_initialize_distributed()
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -48,7 +48,6 @@ class AnnealingLR(object):
print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_lr(self):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
......@@ -71,7 +70,6 @@ class AnnealingLR(object):
lr = self.start_lr
return max(lr, self.min_lr)
def step(self, step_num=None):
"""Set lr for all parameters groups."""
if step_num is None:
......@@ -81,7 +79,6 @@ class AnnealingLR(object):
for group in self.optimizer.param_groups:
group['lr'] = new_lr
def state_dict(self):
state_dict = {
'start_lr': self.start_lr,
......@@ -93,7 +90,6 @@ class AnnealingLR(object):
}
return state_dict
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
......@@ -108,7 +104,6 @@ class AnnealingLR(object):
name))
return sd_value
def load_state_dict(self, sd):
self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'],
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -22,16 +22,15 @@ import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .transformer import LayerNorm
from .utils import gelu
from .utils import get_linear_layer
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask
......@@ -70,7 +69,6 @@ def bert_position_ids(token_ids):
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
......@@ -81,11 +79,14 @@ class BertLMHead(MegatronModule):
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output):
super(BertLMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True
self.bias.partition_dim = 0
......@@ -94,11 +95,13 @@ class BertLMHead(MegatronModule):
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states,
word_embeddings_weight,
......@@ -107,7 +110,6 @@ class BertLMHead(MegatronModule):
return output
class BertModel(MegatronModule):
"""Bert Language model."""
......@@ -184,7 +186,6 @@ class BertModel(MegatronModule):
return lm_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -206,7 +207,6 @@ class BertModel(MegatronModule):
= self.ict_head.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......@@ -224,8 +224,6 @@ class BertModel(MegatronModule):
class REALMBertModel(MegatronModule):
# TODO: load BertModel checkpoint
def __init__(self, retriever):
super(REALMBertModel, self).__init__()
bert_args = dict(
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -53,7 +53,6 @@ class Classification(MegatronModule):
init_method)
self._classification_head_key = 'classification_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
extended_attention_mask = bert_extended_attention_mask(
......@@ -74,7 +73,6 @@ class Classification(MegatronModule):
return classification_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -89,7 +87,6 @@ class Classification(MegatronModule):
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -31,10 +31,6 @@ class DistributedDataParallel(MegatronModule):
self.module = module
self.data_parallel_group = mpu.get_data_parallel_group()
src_rank = mpu.get_model_parallel_rank()
for p in self.module.parameters():
if torch.is_tensor(p):
dist.broadcast(p, src_rank, group=self.data_parallel_group)
def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False):
if(self.needs_reduction):
......@@ -71,8 +67,8 @@ class DistributedDataParallel(MegatronModule):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook)
#self.hooks.append(allreduce_hook)
#self.hook_handles.append(handle)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self.allreduce_params = allreduce_params
def forward(self, *inputs, **kwargs):
......@@ -114,4 +110,3 @@ class DistributedDataParallel(MegatronModule):
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
......@@ -49,7 +48,6 @@ class GPT2Model(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
......@@ -79,7 +77,6 @@ class GPT2Model(MegatronModule):
return output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -89,7 +86,6 @@ class GPT2Model(MegatronModule):
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,9 +21,8 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import gelu
from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
......@@ -47,6 +46,12 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method, scaled_init_method, max_pos_embeds=None):
"""Build language model and return along with the key to save."""
args = get_args()
# Use torch gelu unless otherwise forced.
gelu = F.gelu
if args.openai_gelu:
gelu = openai_gelu
# Language model.
language_model = TransformerLanguageModel(
......@@ -63,7 +68,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
......@@ -75,11 +79,11 @@ class Pooler(MegatronModule):
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
......@@ -102,6 +106,7 @@ class Embedding(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
......@@ -143,7 +148,6 @@ class Embedding(MegatronModule):
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
......@@ -160,7 +164,6 @@ class Embedding(MegatronModule):
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
......@@ -177,7 +180,6 @@ class Embedding(MegatronModule):
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
......@@ -195,7 +197,6 @@ class Embedding(MegatronModule):
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......@@ -242,7 +243,6 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True)
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
......@@ -261,6 +261,7 @@ class TransformerLanguageModel(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
attention_mask_func,
mlp_activation_func,
......@@ -298,7 +299,6 @@ class TransformerLanguageModel(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
......@@ -320,7 +320,6 @@ class TransformerLanguageModel(MegatronModule):
return transformer_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
......@@ -339,7 +338,6 @@ class TransformerLanguageModel(MegatronModule):
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule):
init_method)
self._multichoice_head_key = 'multichoice_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
......@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule):
return multichoice_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule):
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -46,6 +46,7 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask)
"""
class ParallelMLP(MegatronModule):
"""MLP.
......@@ -63,7 +64,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size,
4*args.hidden_size,
4 * args.hidden_size,
gather_output=False,
init_method=init_method)
......@@ -71,14 +72,13 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4*args.hidden_size,
4 * args.hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
self.dropout = torch.nn.Dropout(args.hidden_dropout)
def forward(self, hidden_states):
# [b, s, 4hp]
......@@ -91,13 +91,13 @@ class ParallelMLP(MegatronModule):
return output
class ParallelSelfAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__()
......@@ -123,7 +123,7 @@ class ParallelSelfAttention(MegatronModule):
# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size,
3*args.hidden_size,
3 * args.hidden_size,
stride=3,
gather_output=False,
init_method=init_method)
......@@ -141,7 +141,6 @@ class ParallelSelfAttention(MegatronModule):
init_method=output_layer_init_method)
self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
......@@ -152,7 +151,6 @@ class ParallelSelfAttention(MegatronModule):
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def _get_query_key_value(self, hidden_states):
"""Get query, key, and value and transpose to
get size [b, np, s, hn].
......@@ -170,7 +168,6 @@ class ParallelSelfAttention(MegatronModule):
return query_layer, key_layer, value_layer
def _get_unmasked_attention_scores(self, query_layer, key_layer):
"""Unmasked attention scores with size [b, np, s, s]."""
coeff = 1
......@@ -179,9 +176,8 @@ class ParallelSelfAttention(MegatronModule):
norm_factor = math.sqrt(coeff *
math.sqrt(self.hidden_size_per_attention_head))
# Raw attention scores. [b, np, s, s]
return torch.matmul(query_layer/norm_factor,
key_layer.transpose(-1, -2)/norm_factor)
return torch.matmul(query_layer / norm_factor,
key_layer.transpose(-1, -2) / norm_factor)
def _get_attention_probs(self, attention_scores):
"""Attention probabilies with dropout. The output has
......@@ -198,7 +194,6 @@ class ParallelSelfAttention(MegatronModule):
return attention_probs
def _get_attended_context(self, attention_probs, value_layer):
"""Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer.
......@@ -213,7 +208,6 @@ class ParallelSelfAttention(MegatronModule):
return context_layer
def _get_output(self, context_layer):
"""Output layer with dropout."""
# Output. [b, s, h]
......@@ -222,7 +216,6 @@ class ParallelSelfAttention(MegatronModule):
return output
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
......@@ -254,7 +247,7 @@ class ParallelSelfAttention(MegatronModule):
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3)-1,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
......@@ -283,13 +276,13 @@ class ParallelSelfAttention(MegatronModule):
return output
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number):
args = get_args()
......@@ -319,7 +312,6 @@ class ParallelTransformerLayer(MegatronModule):
self.mlp = ParallelMLP(mlp_activation_func, init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
......@@ -375,14 +367,13 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer(i+1) for i in range(args.num_layers)])
[get_layer(i + 1) for i in range(args.num_layers)])
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
......@@ -398,13 +389,12 @@ class ParallelTransformer(MegatronModule):
num_layers = len(self.layers)
while l < num_layers:
hidden_states = mpu.checkpoint(
custom(l, l+self.checkpoint_num_layers),
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask)
l += self.checkpoint_num_layers
return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -33,6 +33,7 @@ def init_method_normal(sigma):
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
......@@ -53,8 +54,7 @@ def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x)))
def gelu(x):
def openai_gelu(x):
return gelu_impl(x)
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,11 +21,9 @@ import torch
class MegatronModule(torch.nn.Module):
"""Megatron specific extentions of torch Module."""
def __init__(self):
super(MegatronModule, self).__init__()
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Use this function to override the state dict for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment