"python/vscode:/vscode.git/clone" did not exist on "d17d19e5b84ec459e8fcce238232781a731ca488"
Commit 5e56e563 authored by Neel Kant's avatar Neel Kant
Browse files

Merge master into realm-mlm

parents 6c0a5bd8 569b3dab
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -28,7 +28,9 @@ from .wordpiece import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP ...@@ -28,7 +28,9 @@ from .wordpiece import BertTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_gpt2 import GPT2Tokenizer
import regex as re import regex as re
def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type='bpe', pad_token=0, character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs):
def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, model_type='bpe',
pad_token=0, character_coverage=1.0, command_tokens=None, type_tokens=None, **kwargs):
""" """
Helper function to instantiate a tokenizer given common combinations of options. Helper function to instantiate a tokenizer given common combinations of options.
""" """
...@@ -39,10 +41,11 @@ def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, mod ...@@ -39,10 +41,11 @@ def make_tokenizer(tokenizer_type, corpus, model_path=None, vocab_size=None, mod
return BertWordPieceTokenizer(model_type, **kwargs) return BertWordPieceTokenizer(model_type, **kwargs)
elif tokenizer_class is GPT2BPETokenizer: elif tokenizer_class is GPT2BPETokenizer:
return GPT2BPETokenizer(**kwargs) return GPT2BPETokenizer(**kwargs)
text_tokenizer = tokenizer_class(corpus=corpus, vocab_size=vocab_size, model_path=model_path, model_type=model_type, text_tokenizer = tokenizer_class(corpus=corpus, vocab_size=vocab_size, model_path=model_path, model_type=model_type,
pad_token=pad_token, character_coverage=character_coverage) pad_token=pad_token, character_coverage=character_coverage)
return Tokenizer(text_tokenizer, command_tokens, type_tokens) return Tokenizer(text_tokenizer, command_tokens, type_tokens)
class Tokenization(object): class Tokenization(object):
""" """
Tokenization object to hold tokenization, (processed text),and original Tokenization object to hold tokenization, (processed text),and original
...@@ -54,7 +57,9 @@ class Tokenization(object): ...@@ -54,7 +57,9 @@ class Tokenization(object):
Several standard array operations are implemented (insert, append, extend). Several standard array operations are implemented (insert, append, extend).
""" """
def __init__(self, tokenization, text=None, original_text=None, command_tokens=None, asIds=True):
def __init__(self, tokenization, text=None, original_text=None,
command_tokens=None, asIds=True):
self.tokenization = tokenization self.tokenization = tokenization
self.text = text self.text = text
if self.text is None: if self.text is None:
...@@ -91,13 +96,15 @@ class Tokenization(object): ...@@ -91,13 +96,15 @@ class Tokenization(object):
if idx == 0: if idx == 0:
self.text = other.token + self.text self.text = other.token + self.text
self.original_text = other.token + self.original_text self.original_text = other.token + self.original_text
elif idx == len(self.tokenization)-1: elif idx == len(self.tokenization) - 1:
self.text += other.token self.text += other.token
self.original_text += other.token self.original_text += other.token
elif isinstance(other, Tokenization): elif isinstance(other, Tokenization):
self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:] self.tokenization = self.tokenization[:idx] + \
other.tokenization + self.tokenization[idx:]
else: else:
self.tokenization = self.tokenization[:idx] + other.tokenization + self.tokenization[idx:] self.tokenization = self.tokenization[:idx] + \
other.tokenization + self.tokenization[idx:]
def append(self, other): def append(self, other):
if isinstance(other, (CommandToken, TypeToken)): if isinstance(other, (CommandToken, TypeToken)):
...@@ -129,14 +136,17 @@ class Tokenization(object): ...@@ -129,14 +136,17 @@ class Tokenization(object):
self.tokenization.extend(other) self.tokenization.extend(other)
return self return self
"""define some default command tokens for the tokenizer to use""" """define some default command tokens for the tokenizer to use"""
token_format = "<{0}>" token_format = "<{0}>"
COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id')) COMMAND_TUPLE = namedtuple('CommandToken', ('name', 'token', 'Id'))
def prep_command_tokens(tokenlist, token_format=token_format): def prep_command_tokens(tokenlist, token_format=token_format):
return [CommandToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist] return [CommandToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
class CommandToken(object): class CommandToken(object):
def __init__(self, name, token, Id): def __init__(self, name, token, Id):
self.name = name self.name = name
...@@ -146,15 +156,16 @@ class CommandToken(object): ...@@ -146,15 +156,16 @@ class CommandToken(object):
def __str__(self): def __str__(self):
return str(COMMAND_TUPLE(self.name, self.token, self.Id)) return str(COMMAND_TUPLE(self.name, self.token, self.Id))
DEFAULT_COMMAND_TOKENS = [ DEFAULT_COMMAND_TOKENS = [
('pad', 0), ('pad', 0),
('eos', 1), ('eos', 1),
('bos', 2), ('bos', 2),
('unk', 3), ('unk', 3),
('sep', 4), ('sep', 4),
('L2R', 5), ('L2R', 5),
('ENC', 6), ('ENC', 6),
('MASK', 7), ('MASK', 7),
] ]
DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS) DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS)
...@@ -162,9 +173,11 @@ DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS) ...@@ -162,9 +173,11 @@ DEFAULT_COMMAND_TOKENS = prep_command_tokens(DEFAULT_COMMAND_TOKENS)
TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id')) TYPE_TUPLE = namedtuple('TypeToken', ('name', 'token', 'Id'))
def prep_type_tokens(tokenlist, token_format=token_format): def prep_type_tokens(tokenlist, token_format=token_format):
return [TypeToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist] return [TypeToken(tok[0], token_format.format(tok[0]), tok[1]) for tok in tokenlist]
class TypeToken(object): class TypeToken(object):
def __init__(self, name, token, Id): def __init__(self, name, token, Id):
self.name = name self.name = name
...@@ -174,21 +187,23 @@ class TypeToken(object): ...@@ -174,21 +187,23 @@ class TypeToken(object):
def __str__(self): def __str__(self):
return str(TYPE_TUPLE(self.name, self.token, self.Id)) return str(TYPE_TUPLE(self.name, self.token, self.Id))
DEFAULT_TYPE_TOKENS = [ DEFAULT_TYPE_TOKENS = [
('function', 0), ('function', 0),
('command', 1), ('command', 1),
('str0', 2), ('str0', 2),
('str1', 3), ('str1', 3),
('str2', 4), ('str2', 4),
('embedding0', 5), ('embedding0', 5),
('embedding1', 6), ('embedding1', 6),
('embedding2', 7), ('embedding2', 7),
('arg0', 8), ('arg0', 8),
('arg1', 9), ('arg1', 9),
('arg2', 10), ('arg2', 10),
] ]
DEFAULT_TYPE_TOKENS = prep_type_tokens(DEFAULT_TYPE_TOKENS) DEFAULT_TYPE_TOKENS = prep_type_tokens(DEFAULT_TYPE_TOKENS)
class Tokenizer(object): class Tokenizer(object):
""" """
Tokenizer object that handles text tokenization, command tokens, and type tokens. Tokenizer object that handles text tokenization, command tokens, and type tokens.
...@@ -199,6 +214,7 @@ class Tokenizer(object): ...@@ -199,6 +214,7 @@ class Tokenizer(object):
Token types are stored in a separate mapping of size `len(type_tokens)`. Token types are stored in a separate mapping of size `len(type_tokens)`.
""" """
def __init__(self, text_tokenizer, command_tokens=None, type_tokens=None): def __init__(self, text_tokenizer, command_tokens=None, type_tokens=None):
# set text tokenizer # set text tokenizer
self.text_tokenizer = text_tokenizer self.text_tokenizer = text_tokenizer
...@@ -229,18 +245,20 @@ class Tokenizer(object): ...@@ -229,18 +245,20 @@ class Tokenizer(object):
# parse tokens and vocabs from tokenizer # parse tokens and vocabs from tokenizer
self._tokens = list(self.command_token_map.keys()) + list(self.text_tokenizer.tokens) self._tokens = list(self.command_token_map.keys()) + list(self.text_tokenizer.tokens)
self._vocab = {t:Id for Id,t in self.command_id_map.items()} self._vocab = {t: Id for Id, t in self.command_id_map.items()}
self._vocab.update({t:Id+self.num_command_tokens for t,Id in self.text_tokenizer.vocab.items()}) self._vocab.update({t: Id + self.num_command_tokens for t,
Id in self.text_tokenizer.vocab.items()})
self._text_tokens = list(self.text_tokenizer.tokens) self._text_tokens = list(self.text_tokenizer.tokens)
self._text_token_vocab = {t:Id+self.num_command_tokens for t,Id in self.text_tokenizer.vocab.items()} self._text_token_vocab = {
t: Id + self.num_command_tokens for t,
Id in self.text_tokenizer.vocab.items()}
self._command_token_tokens = list(self.command_token_map.keys()) self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()} self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys()) self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()} self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
def __call__(self, text, process_fn=None): def __call__(self, text, process_fn=None):
"""run preprocessing and encode text as Ids""" """run preprocessing and encode text as Ids"""
...@@ -303,7 +321,7 @@ class Tokenizer(object): ...@@ -303,7 +321,7 @@ class Tokenizer(object):
encode text using text tokenizer and shift Id values for command tokens encode text using text tokenizer and shift Id values for command tokens
""" """
tokenization = self.text_tokenizer.EncodeAsIds(text, process_fn=process_fn) tokenization = self.text_tokenizer.EncodeAsIds(text, process_fn=process_fn)
tokenization.tokenization = [t+self.num_command_tokens for t in tokenization.tokenization] tokenization.tokenization = [t + self.num_command_tokens for t in tokenization.tokenization]
tokenization.set_command_tokens(self._command_tokens) tokenization.set_command_tokens(self._command_tokens)
return tokenization return tokenization
...@@ -323,7 +341,7 @@ class Tokenizer(object): ...@@ -323,7 +341,7 @@ class Tokenizer(object):
return self.type_id_map[Id].token return self.type_id_map[Id].token
if Id < self.num_command_tokens: if Id < self.num_command_tokens:
return self.command_id_map[Id].token return self.command_id_map[Id].token
return self.text_tokenizer.IdToToken(Id-self.num_command_tokens) return self.text_tokenizer.IdToToken(Id - self.num_command_tokens)
def TokenToId(self, token, type_token=False): def TokenToId(self, token, type_token=False):
"""convert token to Id accounting for command and type tokens""" """convert token to Id accounting for command and type tokens"""
...@@ -333,7 +351,7 @@ class Tokenizer(object): ...@@ -333,7 +351,7 @@ class Tokenizer(object):
return self.type_token_map[token].Id return self.type_token_map[token].Id
if token in self.command_token_map: if token in self.command_token_map:
return self.command_token_map[token].Id return self.command_token_map[token].Id
return self.text_tokenizer.TokenToId(token)+self.num_command_tokens return self.text_tokenizer.TokenToId(token) + self.num_command_tokens
def DecodeIds(self, Ids, type_token=False): def DecodeIds(self, Ids, type_token=False):
""" """
...@@ -341,7 +359,8 @@ class Tokenizer(object): ...@@ -341,7 +359,8 @@ class Tokenizer(object):
are joined and returned as a string. are joined and returned as a string.
""" """
if type_token: if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids) return ' '.join(Id.token if isinstance(Id, TypeToken)
else self.type_id_map[Id].token for Id in Ids)
rtn_strs = [] rtn_strs = []
current_str = [] current_str = []
if isinstance(Ids, Tokenization): if isinstance(Ids, Tokenization):
...@@ -386,10 +405,12 @@ class Tokenizer(object): ...@@ -386,10 +405,12 @@ class Tokenizer(object):
rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str)) rtn_strs.append(self.text_tokenizer.DecodeTokens(current_str))
return ' '.join(rtn_strs) return ' '.join(rtn_strs)
class TextTokenizer(object): class TextTokenizer(object):
""" """
Interface for text tokenizer Interface for text tokenizer
""" """
def __init__(self): def __init__(self):
if not hasattr(self, 'num_text_tokens'): if not hasattr(self, 'num_text_tokens'):
self.num_text_tokens = 0 self.num_text_tokens = 0
...@@ -450,17 +471,18 @@ class TextTokenizer(object): ...@@ -450,17 +471,18 @@ class TextTokenizer(object):
def DecodeTokens(self, Tokens): def DecodeTokens(self, Tokens):
"""Convert a list or tokenization object of tokens to a text string""" """Convert a list or tokenization object of tokens to a text string"""
raise NotImplementedError('TextTokenizer DecodeTokens not implemented') raise NotImplementedError('TextTokenizer DecodeTokens not implemented')
class CharacterLevelTokenizer(TextTokenizer): class CharacterLevelTokenizer(TextTokenizer):
""" """
Text tokenizer for ASCII-256 Character Level Tokenization. Text tokenizer for ASCII-256 Character Level Tokenization.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.num_text_tokens = 256 self.num_text_tokens = 256
super(CharacterLevelTokenizer, self).__init__() super(CharacterLevelTokenizer, self).__init__()
self._tokens = [self.IdToToken(Id) for Id in range(self.num_text_tokens)] self._tokens = [self.IdToToken(Id) for Id in range(self.num_text_tokens)]
self._vocab = {t: i for i,t in enumerate(self._tokens)} self._vocab = {t: i for i, t in enumerate(self._tokens)}
def __len__(self): def __len__(self):
return 256 return 256
...@@ -521,6 +543,7 @@ class CharacterLevelTokenizer(TextTokenizer): ...@@ -521,6 +543,7 @@ class CharacterLevelTokenizer(TextTokenizer):
MAX_SENTENCEPIECE_SENTENCES = 100000000 MAX_SENTENCEPIECE_SENTENCES = 100000000
def get_corpus_freq(dataset, filepath, filetype='tsv'): def get_corpus_freq(dataset, filepath, filetype='tsv'):
""" """
Take corpus, split it into sentences, and extract word frequencies. Take corpus, split it into sentences, and extract word frequencies.
...@@ -556,14 +579,13 @@ def get_corpus_freq(dataset, filepath, filetype='tsv'): ...@@ -556,14 +579,13 @@ def get_corpus_freq(dataset, filepath, filetype='tsv'):
print("file path for freq " + str(filepath), flush=True) print("file path for freq " + str(filepath), flush=True)
freqs_sorted = {} freqs_sorted = {}
counter=0 counter = 0
for word, count in sorted(freqs.items(), key=lambda x: x[1], reverse=True): for word, count in sorted(freqs.items(), key=lambda x: x[1], reverse=True):
if counter >= MAX_SENTENCEPIECE_SENTENCES: if counter >= MAX_SENTENCEPIECE_SENTENCES:
break break
counter+=1 counter += 1
freqs_sorted[word] = count freqs_sorted[word] = count
print("length of freqs after trancating " + str(len(freqs_sorted)), flush=True) print("length of freqs after trancating " + str(len(freqs_sorted)), flush=True)
with open(filepath, 'w') as f: with open(filepath, 'w') as f:
...@@ -573,9 +595,12 @@ def get_corpus_freq(dataset, filepath, filetype='tsv'): ...@@ -573,9 +595,12 @@ def get_corpus_freq(dataset, filepath, filetype='tsv'):
return total_sentence_count, maxlen return total_sentence_count, maxlen
class SentencePieceTokenizer(TextTokenizer): class SentencePieceTokenizer(TextTokenizer):
"""Trains and uses sentencepiece for text tokenization""" """Trains and uses sentencepiece for text tokenization"""
def __init__(self, model_type='bpe', vocab_size=None, corpus=None, model_path=None, character_coverage=1.0, **kwargs):
def __init__(self, model_type='bpe', vocab_size=None, corpus=None,
model_path=None, character_coverage=1.0, **kwargs):
self.character_coverage = character_coverage self.character_coverage = character_coverage
self.model_type = model_type.lower() self.model_type = model_type.lower()
self.spm_model = model_path self.spm_model = model_path
...@@ -608,18 +633,18 @@ class SentencePieceTokenizer(TextTokenizer): ...@@ -608,18 +633,18 @@ class SentencePieceTokenizer(TextTokenizer):
dne = not os.path.exists(model_path) dne = not os.path.exists(model_path)
# check if path.model exists # check if path.model exists
if dne and not model_path.endswith('.model'): if dne and not model_path.endswith('.model'):
dne = not os.path.exists(model_path+'.model') dne = not os.path.exists(model_path + '.model')
return not dne return not dne
def load_spm_model(self): def load_spm_model(self):
"""load sentencepiece model and parse vocab""" """load sentencepiece model and parse vocab"""
if not os.path.exists(self.spm_model) and not self.spm_model.endswith('.model'): if not os.path.exists(self.spm_model) and not self.spm_model.endswith('.model'):
self.spm_model = self.spm_model+'.model' self.spm_model = self.spm_model + '.model'
self.sp = spm.SentencePieceProcessor() self.sp = spm.SentencePieceProcessor()
self.sp.Load(self.spm_model) self.sp.Load(self.spm_model)
self.vocab_size = self.num_text_tokens = len(self.sp) self.vocab_size = self.num_text_tokens = len(self.sp)
self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)] self._tokens = [self.IdToToken(t) for t in range(self.vocab_size)]
self._vocab = {t: i for i,t in enumerate(self._tokens)} self._vocab = {t: i for i, t in enumerate(self._tokens)}
def Train(self, corpus, num_text_tokens): def Train(self, corpus, num_text_tokens):
"""train sentencepiece model on corpus using word frequencies""" """train sentencepiece model on corpus using word frequencies"""
...@@ -630,7 +655,7 @@ class SentencePieceTokenizer(TextTokenizer): ...@@ -630,7 +655,7 @@ class SentencePieceTokenizer(TextTokenizer):
use_model_path = random_hash use_model_path = random_hash
if use_model_path.endswith('.model'): if use_model_path.endswith('.model'):
use_model_path = use_model_path[:use_model_path.rfind('.model')] use_model_path = use_model_path[:use_model_path.rfind('.model')]
input_path = use_model_path+'.tsv.'+random_hash input_path = use_model_path + '.tsv.' + random_hash
line_count, maxlenline = get_corpus_freq(corpus, input_path) line_count, maxlenline = get_corpus_freq(corpus, input_path)
line_count = min(line_count, MAX_SENTENCEPIECE_SENTENCES) line_count = min(line_count, MAX_SENTENCEPIECE_SENTENCES)
print('line count used as input_sentence_size ', line_count, flush=True) print('line count used as input_sentence_size ', line_count, flush=True)
...@@ -640,13 +665,13 @@ class SentencePieceTokenizer(TextTokenizer): ...@@ -640,13 +665,13 @@ class SentencePieceTokenizer(TextTokenizer):
+ '--input_sentence_size={input_sentence_size} ' \ + '--input_sentence_size={input_sentence_size} ' \
+ '--input_format=tsv' + '--input_format=tsv'
train_string = train_string.format(file_path=input_path, model_prefix=use_model_path, vocab_size=num_text_tokens, train_string = train_string.format(file_path=input_path, model_prefix=use_model_path, vocab_size=num_text_tokens,
model_type=self.model_type, character_coverage=self.character_coverage, model_type=self.model_type, character_coverage=self.character_coverage,
input_sentence_size=int(line_count)) #, #)#, input_sentence_size=int(line_count)) # , #)#,
print("calling spm.SentencePieceTrainer.Train(%s)"%(train_string), flush=True) print("calling spm.SentencePieceTrainer.Train(%s)" % (train_string), flush=True)
spm.SentencePieceTrainer.Train(train_string) spm.SentencePieceTrainer.Train(train_string)
os.remove(input_path) os.remove(input_path)
self.spm_model = use_model_path+'.model' self.spm_model = use_model_path + '.model'
print('sentencepiece model written to '+self.spm_model, flush=True) print('sentencepiece model written to ' + self.spm_model, flush=True)
def EncodeAsIds(self, text, process_fn=None): def EncodeAsIds(self, text, process_fn=None):
"""convert text to sentencepiece Ids""" """convert text to sentencepiece Ids"""
...@@ -684,19 +709,26 @@ class SentencePieceTokenizer(TextTokenizer): ...@@ -684,19 +709,26 @@ class SentencePieceTokenizer(TextTokenizer):
Tokens = Tokens.tokenization Tokens = Tokens.tokenization
return self.sp.DecodeTokens(Tokens) return self.sp.DecodeTokens(Tokens)
class BertWordPieceTokenizer(Tokenizer): class BertWordPieceTokenizer(Tokenizer):
""" """
Loads a pretrained WordPiece tokenizer from `cache_dir` for tokenization Loads a pretrained WordPiece tokenizer from `cache_dir` for tokenization
in BERT training. Default to bert-large-uncased tokenizer. in BERT training. Default to bert-large-uncased tokenizer.
""" """
def __init__(self, tokenizer_model_type=None, cache_dir=None, **kwargs): def __init__(self, tokenizer_model_type=None, cache_dir=None, **kwargs):
# default to bert-large-uncased tokenizer # default to bert-large-uncased tokenizer
if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP: if tokenizer_model_type not in PRETRAINED_VOCAB_ARCHIVE_MAP:
tokenizer_model_type = 'bert-large-uncased' tokenizer_model_type = 'bert-large-uncased'
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('loading BertWordPieceTokenizer (', tokenizer_model_type, ') from cache_dir ', cache_dir) print(
'loading BertWordPieceTokenizer (',
tokenizer_model_type,
') from cache_dir ',
cache_dir)
do_lower_case = not ('-cased' in tokenizer_model_type or 'chinese' in tokenizer_model_type) do_lower_case = not ('-cased' in tokenizer_model_type or 'chinese' in tokenizer_model_type)
self.text_tokenizer = BertTokenizer.from_pretrained(tokenizer_model_type, do_lower_case=do_lower_case, cache_dir=cache_dir) self.text_tokenizer = BertTokenizer.from_pretrained(
tokenizer_model_type, do_lower_case=do_lower_case, cache_dir=cache_dir)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print('loaded', tokenizer_model_type) print('loaded', tokenizer_model_type)
# disable max len warnings by increasing max len # disable max len warnings by increasing max len
...@@ -705,7 +737,7 @@ class BertWordPieceTokenizer(Tokenizer): ...@@ -705,7 +737,7 @@ class BertWordPieceTokenizer(Tokenizer):
# set command tokens from wordpiece tokenizer values # set command tokens from wordpiece tokenizer values
self.num_command_tokens = 5 self.num_command_tokens = 5
self.num_tokens = len(self.text_tokenizer.vocab) self.num_tokens = len(self.text_tokenizer.vocab)
self.num_text_tokens = self.num_tokens-5 self.num_text_tokens = self.num_tokens - 5
self.num_type_tokens = 2 self.num_type_tokens = 2
self._command_tokens = [ self._command_tokens = [
...@@ -731,16 +763,16 @@ class BertWordPieceTokenizer(Tokenizer): ...@@ -731,16 +763,16 @@ class BertWordPieceTokenizer(Tokenizer):
# parse tokens and vocabs from tokenizer # parse tokens and vocabs from tokenizer
self._tokens = list(self.text_tokenizer.vocab.keys()) self._tokens = list(self.text_tokenizer.vocab.keys())
self._vocab = {k:v for k,v in self.text_tokenizer.vocab.items()} self._vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
self._text_tokens = list(self._tokens) self._text_tokens = list(self._tokens)
self._text_token_vocab = {k:v for k,v in self.text_tokenizer.vocab.items()} self._text_token_vocab = {k: v for k, v in self.text_tokenizer.vocab.items()}
self._command_token_tokens = list(self.command_token_map.keys()) self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()} self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys()) self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()} self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
def EncodeAsIds(self, text, process_fn=None): def EncodeAsIds(self, text, process_fn=None):
"""convert text to wordpiece Ids""" """convert text to wordpiece Ids"""
...@@ -778,7 +810,8 @@ class BertWordPieceTokenizer(Tokenizer): ...@@ -778,7 +810,8 @@ class BertWordPieceTokenizer(Tokenizer):
def DecodeIds(self, Ids, type_token=False): def DecodeIds(self, Ids, type_token=False):
"""converts ids to wordpiece tokens and joins them as a text string""" """converts ids to wordpiece tokens and joins them as a text string"""
if type_token: if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids) return ' '.join(Id.token if isinstance(Id, TypeToken)
else self.type_id_map[Id].token for Id in Ids)
if isinstance(Ids, Tokenization): if isinstance(Ids, Tokenization):
Ids = Ids.tokenization Ids = Ids.tokenization
Tokens = [] Tokens = []
...@@ -795,16 +828,17 @@ class BertWordPieceTokenizer(Tokenizer): ...@@ -795,16 +828,17 @@ class BertWordPieceTokenizer(Tokenizer):
Tokens = Tokens.tokenization Tokens = Tokens.tokenization
return ' '.join(Tokens) return ' '.join(Tokens)
class GPT2BPETokenizer(Tokenizer): class GPT2BPETokenizer(Tokenizer):
def __init__(self, cache_dir=None, **kwargs): def __init__(self, cache_dir=None, **kwargs):
self.text_tokenizer = GPT2Tokenizer.from_pretrained('gpt2', self.text_tokenizer = GPT2Tokenizer.from_pretrained('gpt2',
cache_dir=cache_dir) cache_dir=cache_dir)
#disable max len warnings by increasing max len # disable max len warnings by increasing max len
self.text_tokenizer.max_len = int(1e12) self.text_tokenizer.max_len = int(1e12)
self.num_command_tokens = 2 self.num_command_tokens = 2
self.num_tokens = len(self.text_tokenizer.encoder) self.num_tokens = len(self.text_tokenizer.encoder)
self.num_text_tokens = self.num_tokens-1 self.num_text_tokens = self.num_tokens - 1
self.num_type_tokens = 2 self.num_type_tokens = 2
self._command_tokens = [ self._command_tokens = [
...@@ -824,28 +858,27 @@ class GPT2BPETokenizer(Tokenizer): ...@@ -824,28 +858,27 @@ class GPT2BPETokenizer(Tokenizer):
self.type_id_map = {tok.Id: tok for tok in self.type_tokens} self.type_id_map = {tok.Id: tok for tok in self.type_tokens}
self._tokens = list(self.text_tokenizer.encoder.keys()) self._tokens = list(self.text_tokenizer.encoder.keys())
self._vocab = {k:v for k,v in self.text_tokenizer.encoder.items()} self._vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
self._text_tokens = list(self._tokens) self._text_tokens = list(self._tokens)
self._text_token_vocab = {k:v for k,v in self.text_tokenizer.encoder.items()} self._text_token_vocab = {k: v for k, v in self.text_tokenizer.encoder.items()}
self._command_token_tokens = list(self.command_token_map.keys()) self._command_token_tokens = list(self.command_token_map.keys())
self._command_token_vocab = {t:Id for Id,t in self.command_id_map.items()} self._command_token_vocab = {t: Id for Id, t in self.command_id_map.items()}
self._token_types = list(self.type_token_map.keys()) self._token_types = list(self.type_token_map.keys())
self._token_type_vocab = {t:Id for Id, t in self.type_id_map.items()} self._token_type_vocab = {t: Id for Id, t in self.type_id_map.items()}
def EncodeAsIds(self, text, process_fn=None): def EncodeAsIds(self, text, process_fn=None):
processed_text = text processed_text = text
if process_fn is not None: if process_fn is not None:
processed_text = process_fn(processed_text) processed_text = process_fn(processed_text)
Ids = self.text_tokenizer.encode(processed_text) Ids = self.text_tokenizer.encode(processed_text)
#return Tokenization(Ids, processed_text, text) # return Tokenization(Ids, processed_text, text)
tokenization = Tokenization(Ids, processed_text, text) tokenization = Tokenization(Ids, processed_text, text)
tokenization.set_command_tokens(self._command_tokens) tokenization.set_command_tokens(self._command_tokens)
return tokenization return tokenization
def EncodeAsTokens(self, text, process_fn=None): def EncodeAsTokens(self, text, process_fn=None):
processed_text = text processed_text = text
if process_fn is not None: if process_fn is not None:
...@@ -854,10 +887,10 @@ class GPT2BPETokenizer(Tokenizer): ...@@ -854,10 +887,10 @@ class GPT2BPETokenizer(Tokenizer):
for token in re.findall(self.text_tokenizer.pat, processed_text): for token in re.findall(self.text_tokenizer.pat, processed_text):
token = ''.join(self.text_tokenizer.bye_encoder[b] for b in token.encode('utf-8')) token = ''.join(self.text_tokenizer.bye_encoder[b] for b in token.encode('utf-8'))
tokens.extend(bpe_token for bpe_token in self.text_tokenizer.bpe(token).split(' ')) tokens.extend(bpe_token for bpe_token in self.text_tokenizer.bpe(token).split(' '))
tokenization=Tokenization(tokens, processed_text, text, asIds=False) tokenization = Tokenization(tokens, processed_text, text, asIds=False)
tokenization.set_command_tokens(self._command_tokens) tokenization.set_command_tokens(self._command_tokens)
return tokenization return tokenization
#return Tokenization(tokens, processed_text, text, asIds=False) # return Tokenization(tokens, processed_text, text, asIds=False)
def IdToToken(self, Id, type_token=False): def IdToToken(self, Id, type_token=False):
if isinstance(Id, (TypeToken, CommandToken)): if isinstance(Id, (TypeToken, CommandToken)):
...@@ -875,7 +908,8 @@ class GPT2BPETokenizer(Tokenizer): ...@@ -875,7 +908,8 @@ class GPT2BPETokenizer(Tokenizer):
def DecodeIds(self, Ids, type_token=False): def DecodeIds(self, Ids, type_token=False):
if type_token: if type_token:
return ' '.join(Id.token if isinstance(Id, TypeToken) else self.type_id_map[Id].token for Id in Ids) return ' '.join(Id.token if isinstance(Id, TypeToken)
else self.type_id_map[Id].token for Id in Ids)
if isinstance(Ids, Tokenization): if isinstance(Ids, Tokenization):
Ids = Ids.tokenization Ids = Ids.tokenization
return self.text_tokenizer.decode(Ids) return self.text_tokenizer.decode(Ids)
......
...@@ -27,7 +27,8 @@ try: ...@@ -27,7 +27,8 @@ try:
from functools import lru_cache from functools import lru_cache
except ImportError: except ImportError:
# Just a dummy decorator to get the checks to run on python2 # Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. # because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache(): def lru_cache():
return lambda func: func return lambda func: func
...@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json' ...@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt' MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt' SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache() @lru_cache()
def bytes_to_unicode(): def bytes_to_unicode():
""" """
...@@ -60,17 +62,19 @@ def bytes_to_unicode(): ...@@ -60,17 +62,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on. And avoids mapping to whitespace/control characters the bpe code barfs on.
""" """
_chr = unichr if sys.version_info[0] == 2 else chr _chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:] cs = bs[:]
n = 0 n = 0
for b in range(2**8): for b in range(2**8):
if b not in bs: if b not in bs:
bs.append(b) bs.append(b)
cs.append(2**8+n) cs.append(2**8 + n)
n += 1 n += 1
cs = [_chr(n) for n in cs] cs = [_chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
def get_pairs(word): def get_pairs(word):
"""Return set of symbol pairs in a word. """Return set of symbol pairs in a word.
...@@ -83,6 +87,7 @@ def get_pairs(word): ...@@ -83,6 +87,7 @@ def get_pairs(word):
prev_char = char prev_char = char
return pairs return pairs
class GPT2Tokenizer(object): class GPT2Tokenizer(object):
""" """
GPT-2 BPE tokenizer. Peculiarities: GPT-2 BPE tokenizer. Peculiarities:
...@@ -138,23 +143,31 @@ class GPT2Tokenizer(object): ...@@ -138,23 +143,31 @@ class GPT2Tokenizer(object):
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else: else:
special_tokens = kwargs.pop('special_tokens', []) special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode() self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data] bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {} self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions # Should haved added re.IGNORECASE so BPE merges can happen for
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") # capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {} self.special_tokens = {}
self.special_tokens_decoder = {} self.special_tokens_decoder = {}
...@@ -172,8 +185,9 @@ class GPT2Tokenizer(object): ...@@ -172,8 +185,9 @@ class GPT2Tokenizer(object):
self.special_tokens = {} self.special_tokens = {}
self.special_tokens_decoder = {} self.special_tokens_decoder = {}
return return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) self.special_tokens = dict((tok, len(self.encoder) + i)
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens)) logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
...@@ -186,7 +200,7 @@ class GPT2Tokenizer(object): ...@@ -186,7 +200,7 @@ class GPT2Tokenizer(object):
return token return token
while True: while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
...@@ -197,12 +211,12 @@ class GPT2Tokenizer(object): ...@@ -197,12 +211,12 @@ class GPT2Tokenizer(object):
j = word.index(first, i) j = word.index(first, i)
new_word.extend(word[i:j]) new_word.extend(word[i:j])
i = j i = j
except: except BaseException:
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first+second) new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
...@@ -245,7 +259,8 @@ class GPT2Tokenizer(object): ...@@ -245,7 +259,8 @@ class GPT2Tokenizer(object):
logger.warning( logger.warning(
"Token indices sequence length is longer than the specified maximum " "Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this" " sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len) " sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
) )
return ids return ids
......
...@@ -99,19 +99,19 @@ class BertTokenizer(object): ...@@ -99,19 +99,19 @@ class BertTokenizer(object):
[(ids, tok) for tok, ids in self.vocab.items()]) [(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize: if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split) never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
def tokenize(self, text): def tokenize(self, text):
if self.do_basic_tokenize: if self.do_basic_tokenize:
split_tokens = [] split_tokens = []
for token in self.basic_tokenizer.tokenize(text): for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token): for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token) split_tokens.append(sub_token)
else: else:
split_tokens = self.wordpiece_tokenizer.tokenize(text) split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
...@@ -123,7 +123,8 @@ class BertTokenizer(object): ...@@ -123,7 +123,8 @@ class BertTokenizer(object):
logger.warning( logger.warning(
"Token indices sequence length is longer than the specified maximum " "Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this" " sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len) " sequence through BERT will result in indexing errors".format(
len(ids), self.max_len)
) )
return ids return ids
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,21 +22,25 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors ...@@ -22,21 +22,25 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron.module import MegatronModule from megatron.module import MegatronModule
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def conversion_helper(val, conversion): def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)): if not isinstance(val, (tuple, list)):
return conversion(val) return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val] rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple): if isinstance(val, tuple):
rtn = tuple(rtn) rtn = tuple(rtn)
return rtn return rtn
def fp32_to_fp16(val): def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16""" """Convert fp32 `val` to fp16"""
def half_conversion(val): def half_conversion(val):
...@@ -48,6 +52,7 @@ def fp32_to_fp16(val): ...@@ -48,6 +52,7 @@ def fp32_to_fp16(val):
return val return val
return conversion_helper(val, half_conversion) return conversion_helper(val, half_conversion)
def fp16_to_fp32(val): def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32""" """Convert fp16 `val` to fp32"""
def float_conversion(val): def float_conversion(val):
...@@ -59,6 +64,7 @@ def fp16_to_fp32(val): ...@@ -59,6 +64,7 @@ def fp16_to_fp32(val):
return val return val
return conversion_helper(val, float_conversion) return conversion_helper(val, float_conversion)
class FP16_Module(MegatronModule): class FP16_Module(MegatronModule):
def __init__(self, module): def __init__(self, module):
super(FP16_Module, self).__init__() super(FP16_Module, self).__init__()
...@@ -79,9 +85,11 @@ class FP16_Module(MegatronModule): ...@@ -79,9 +85,11 @@ class FP16_Module(MegatronModule):
self.module.load_state_dict(state_dict, strict=strict) self.module.load_state_dict(state_dict, strict=strict)
# TODO: Update overflow check + downscale to use Carl's fused kernel. # TODO: Update overflow check + downscale to use Carl's fused kernel.
class FP16_Optimizer(object): class FP16_Optimizer(object):
""" """
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
and manage static or dynamic loss scaling and master weights in a manner transparent to the user. and manage static or dynamic loss scaling and master weights in a manner transparent to the user.
For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance,
and changing the call to ``backward``. and changing the call to ``backward``.
...@@ -104,45 +112,45 @@ class FP16_Optimizer(object): ...@@ -104,45 +112,45 @@ class FP16_Optimizer(object):
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
# optional arg to control dynamic loss scaling behavior # optional arg to control dynamic loss scaling behavior
# dynamic_loss_args={'scale_window' : 500}) # dynamic_loss_args={'scale_window' : 500})
# Usually, dynamic_loss_args is not necessary. # Usually, dynamic_loss_args is not necessary.
Args: Args:
init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`.
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate.
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option.
dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used.
verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling.
``init_optimizer`` is expected to have been constructed in the ordinary way. ``init_optimizer`` is expected to have been constructed in the ordinary way.
It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be
named to replace ``init_optimizer``, for two reasons: named to replace ``init_optimizer``, for two reasons:
First, it means that references to the same name First, it means that references to the same name
later in the file will not have to change. later in the file will not have to change.
Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to
modify ``init_optimizer``. If you do choose a unique name for the new modify ``init_optimizer``. If you do choose a unique name for the new
:class:`FP16_Optimizer` instance, you should only work with this new instance, :class:`FP16_Optimizer` instance, you should only work with this new instance,
because the preexisting optimizer might no longer behave as expected. because the preexisting optimizer might no longer behave as expected.
``init_optimizer`` may be any Pytorch optimizer. ``init_optimizer`` may be any Pytorch optimizer.
It may contain a mixture of fp16 and fp32 parameters organized into any number of It may contain a mixture of fp16 and fp32 parameters organized into any number of
``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will
ingest these ``param_groups`` and remember them. ingest these ``param_groups`` and remember them.
Calls to :: Calls to ::
loss.backward() loss.backward()
must be replaced with :: must be replaced with ::
optimizer.backward(loss) optimizer.backward(loss)
because :class:`FP16_Optimizer` requires ownership of the backward pass to implement because :class:`FP16_Optimizer` requires ownership of the backward pass to implement
loss scaling and copies to master gradients. loss scaling and copies to master gradients.
.. note:: .. note::
Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients
are downscaled before being applied. This means that adjusting the loss scale, or using are downscaled before being applied. This means that adjusting the loss scale, or using
dynamic loss scaling, should not require retuning the learning rate or any other dynamic loss scaling, should not require retuning the learning rate or any other
hyperparameters. hyperparameters.
...@@ -152,7 +160,7 @@ class FP16_Optimizer(object): ...@@ -152,7 +160,7 @@ class FP16_Optimizer(object):
See docstring for :attr:`step`. See docstring for :attr:`step`.
**Gradient clipping**: Use :attr:`clip_master_grads`. **Gradient clipping**: Use :attr:`clip_master_grads`.
**Multiple losses**: If your model accumulates gradients from multiple losses, **Multiple losses**: If your model accumulates gradients from multiple losses,
this can be made more efficient by supplying ``update_master_grads=False`` this can be made more efficient by supplying ``update_master_grads=False``
to :attr:`backward`. See docstring for :attr:`backward`. to :attr:`backward`. See docstring for :attr:`backward`.
...@@ -163,19 +171,19 @@ class FP16_Optimizer(object): ...@@ -163,19 +171,19 @@ class FP16_Optimizer(object):
optimizer.loss_scale = new_loss_scale optimizer.loss_scale = new_loss_scale
For static loss scaling, manually adjusting the loss scale over time is a reasonable For static loss scaling, manually adjusting the loss scale over time is a reasonable
thing to do. During later epochs, gradients may become smaller, and a thing to do. During later epochs, gradients may become smaller, and a
higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss
scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting
the loss scale is not recommended. the loss scale is not recommended.
**Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in
Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer`
should still work as intended. should still work as intended.
""" """
def __init__(self, def __init__(self,
init_optimizer, init_optimizer,
static_loss_scale=1.0, static_loss_scale=1.0,
dynamic_loss_scale=False, dynamic_loss_scale=False,
dynamic_loss_args=None, dynamic_loss_args=None,
verbose=False): verbose=False):
...@@ -212,7 +220,7 @@ class FP16_Optimizer(object): ...@@ -212,7 +220,7 @@ class FP16_Optimizer(object):
# Reset existing state dict key to the new master param. # Reset existing state dict key to the new master param.
# We still need to recast per-param state tensors, if any, to FP32. # We still need to recast per-param state tensors, if any, to FP32.
if param in self.optimizer.state: if param in self.optimizer.state:
self.optimizer.state[master_param] = self.optimizer.state.pop(param) self.optimizer.state[master_param] = self.optimizer.state.pop(param)
elif param.type() == 'torch.cuda.FloatTensor': elif param.type() == 'torch.cuda.FloatTensor':
self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}" self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
.format(param.size())) .format(param.size()))
...@@ -220,9 +228,9 @@ class FP16_Optimizer(object): ...@@ -220,9 +228,9 @@ class FP16_Optimizer(object):
param_group['params'][i] = param param_group['params'][i] = param
else: else:
raise TypeError("Wrapped parameters must be either " raise TypeError("Wrapped parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. " "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type())) "Received {}".format(param.type()))
self.fp16_groups.append(fp16_params_this_group) self.fp16_groups.append(fp16_params_this_group)
self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
self.fp32_from_fp32_groups.append(fp32_params_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group)
...@@ -250,7 +258,7 @@ class FP16_Optimizer(object): ...@@ -250,7 +258,7 @@ class FP16_Optimizer(object):
def maybe_print(self, msg): def maybe_print(self, msg):
if self.verbose: if self.verbose:
print(msg) print(msg)
def __getstate__(self): def __getstate__(self):
raise RuntimeError("FP16_Optimizer should be serialized using state_dict().") raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")
...@@ -265,13 +273,13 @@ class FP16_Optimizer(object): ...@@ -265,13 +273,13 @@ class FP16_Optimizer(object):
# because gradients are copied into the FP32 master params. However, we zero # because gradients are copied into the FP32 master params. However, we zero
# all gradients owned by the optimizer, just to be safe: # all gradients owned by the optimizer, just to be safe:
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
if set_grads_to_None: if set_grads_to_None:
p.grad = None p.grad = None
else: else:
if p.grad is not None: if p.grad is not None:
p.grad.detach_() p.grad.detach_()
p.grad.zero_() p.grad.zero_()
# Zero fp16 gradients owned by the model: # Zero fp16 gradients owned by the model:
for fp16_group in self.fp16_groups: for fp16_group in self.fp16_groups:
...@@ -280,11 +288,11 @@ class FP16_Optimizer(object): ...@@ -280,11 +288,11 @@ class FP16_Optimizer(object):
param.grad = None param.grad = None
else: else:
if param.grad is not None: if param.grad is not None:
param.grad.detach_() # as in torch.optim.optimizer.zero_grad() param.grad.detach_() # as in torch.optim.optimizer.zero_grad()
param.grad.zero_() param.grad.zero_()
def _check_overflow(self): def _check_overflow(self):
params = [] params = []
for group in self.fp16_groups: for group in self.fp16_groups:
for param in group: for param in group:
params.append(param) params.append(param)
...@@ -304,8 +312,9 @@ class FP16_Optimizer(object): ...@@ -304,8 +312,9 @@ class FP16_Optimizer(object):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
master_params_to_model_params(fp32_from_fp16_group, fp16_group) master_params_to_model_params(fp32_from_fp16_group, fp16_group)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable # To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. # that does the overflow check, gradient copy + downscale, and fp32
# allreduce in a different stream.
def _model_grads_to_master_grads(self): def _model_grads_to_master_grads(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
...@@ -313,10 +322,13 @@ class FP16_Optimizer(object): ...@@ -313,10 +322,13 @@ class FP16_Optimizer(object):
def _downscale_master(self): def _downscale_master(self):
if self.loss_scale != 1.0: if self.loss_scale != 1.0:
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
for param in group['params']: grads = [p.grad for p in group['params'] if p.grad is not None]
if param.grad is not None: _overflow_buf = torch.cuda.IntTensor([0])
param.grad.data.mul_(1./self.loss_scale) multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grads, grads],
1./self.loss_scale)
def clip_master_grads(self, max_norm, norm_type=2): def clip_master_grads(self, max_norm, norm_type=2):
""" """
Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.
...@@ -364,9 +376,9 @@ class FP16_Optimizer(object): ...@@ -364,9 +376,9 @@ class FP16_Optimizer(object):
def load_state_dict(self, state_dict): def load_state_dict(self, state_dict):
""" """
Loads a state_dict created by an earlier call to state_dict(). Loads a state_dict created by an earlier call to state_dict().
If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before will call ``model.load_state_dict()`` before
``fp16_optimizer_instance.load_state_dict()`` is called. ``fp16_optimizer_instance.load_state_dict()`` is called.
...@@ -387,33 +399,34 @@ class FP16_Optimizer(object): ...@@ -387,33 +399,34 @@ class FP16_Optimizer(object):
self.first_closure_call_this_step = state_dict['first_closure_call_this_step'] self.first_closure_call_this_step = state_dict['first_closure_call_this_step']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
# At this point, the optimizer's references to the model's fp32 parameters are up to date. # At this point, the optimizer's references to the model's fp32 parameters are up to date.
# The optimizer's hyperparameters and internal buffers are also up to date. # The optimizer's hyperparameters and internal buffers are also up to date.
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
# out of date. There are two options. # out of date. There are two options.
# 1: Refresh the master params from the model's fp16 params. # 1: Refresh the master params from the model's fp16 params.
# This requires less storage but incurs precision loss. # This requires less storage but incurs precision loss.
# 2: Save and restore the fp32 master copies separately. # 2: Save and restore the fp32 master copies separately.
# We choose option 2. # We choose option 2.
# #
# Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
# of their associated parameters, because it's possible those buffers might not exist yet in # of their associated parameters, because it's possible those buffers might not exist yet in
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params # constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params. # are guaranteed to exist, so we can just copy_() from the saved master params.
for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']): for current_group, saved_group in zip(
self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
for current, saved in zip(current_group, saved_group): for current, saved in zip(current_group, saved_group):
current.data.copy_(saved.data) current.data.copy_(saved.data)
def step(self, closure=None): # could add clip option. def step(self, closure=None): # could add clip option.
""" """
If no closure is supplied, :attr:`step` should be called after If no closure is supplied, :attr:`step` should be called after
``fp16_optimizer_obj.backward(loss)``. ``fp16_optimizer_obj.backward(loss)``.
:attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to
:class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params
originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run
another forward pass using their model. another forward pass using their model.
If a closure is supplied, :attr:`step` may be called without a prior call to If a closure is supplied, :attr:`step` may be called without a prior call to
:attr:`backward(loss)`. :attr:`backward(loss)`.
This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. This control flow is identical to `ordinary Pytorch optimizer use`_ with closures.
However, the user should take care that any ``loss.backward()`` call within the closure However, the user should take care that any ``loss.backward()`` call within the closure
...@@ -424,7 +437,7 @@ class FP16_Optimizer(object): ...@@ -424,7 +437,7 @@ class FP16_Optimizer(object):
Example with closure:: Example with closure::
# optimizer is assumed to be an FP16_Optimizer object, previously constructed from an # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
# existing pytorch optimizer. # existing pytorch optimizer.
for input, target in dataset: for input, target in dataset:
def closure(): def closure():
...@@ -448,9 +461,9 @@ class FP16_Optimizer(object): ...@@ -448,9 +461,9 @@ class FP16_Optimizer(object):
if self.overflow: if self.overflow:
self.maybe_print("OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}" self.maybe_print("OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}"
.format(scale, self.loss_scale)) .format(scale, self.loss_scale))
return return
if closure is not None: if closure is not None:
retval = self._step_with_closure(closure) retval = self._step_with_closure(closure)
else: else:
...@@ -472,7 +485,7 @@ class FP16_Optimizer(object): ...@@ -472,7 +485,7 @@ class FP16_Optimizer(object):
self.first_closure_call_this_step = False self.first_closure_call_this_step = False
else: else:
# If self.optimizer.step() internally calls wrapped_closure more than once, # If self.optimizer.step() internally calls wrapped_closure more than once,
# it may update the fp32 params after each call. However, self.optimizer # it may update the fp32 params after each call. However, self.optimizer
# doesn't know about the fp16 params at all. If the fp32 params get updated, # doesn't know about the fp16 params at all. If the fp32 params get updated,
# we can't rely on self.optimizer to refresh the fp16 params. We need # we can't rely on self.optimizer to refresh the fp16 params. We need
# to handle that manually: # to handle that manually:
...@@ -480,16 +493,16 @@ class FP16_Optimizer(object): ...@@ -480,16 +493,16 @@ class FP16_Optimizer(object):
# Our API expects the user to give us ownership of the backward() call by # Our API expects the user to give us ownership of the backward() call by
# replacing all calls to loss.backward() with optimizer.backward(loss). # replacing all calls to loss.backward() with optimizer.backward(loss).
# This requirement holds whether or not the call to backward() is made within a closure. # This requirement holds whether or not the call to backward() is made within a closure.
# If the user is properly calling optimizer.backward(loss) within "closure," # If the user is properly calling optimizer.backward(loss) within "closure,"
# calling closure() here will give the fp32 master params fresh gradients # calling closure() here will give the fp32 master params fresh gradients
# for the optimizer to play with, so all wrapped_closure needs to do is call # for the optimizer to play with, so all wrapped_closure needs to do is call
# closure() and return the loss. # closure() and return the loss.
temp_loss = closure() temp_loss = closure()
while(self.overflow): while(self.overflow):
scale = self.loss_scaler.loss_scale scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow) self._update_scale(self.overflow)
self.maybe_print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, " self.maybe_print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, "
"reducing to {}".format(scale, self.loss_scale)) "reducing to {}".format(scale, self.loss_scale))
temp_loss = closure() temp_loss = closure()
return temp_loss return temp_loss
...@@ -500,7 +513,7 @@ class FP16_Optimizer(object): ...@@ -500,7 +513,7 @@ class FP16_Optimizer(object):
return retval return retval
def backward(self, loss, update_master_grads=True, retain_graph=False): def backward(self, loss, update_master_grads=True, retain_graph=False):
""" """
:attr:`backward` performs the following conceptual steps: :attr:`backward` performs the following conceptual steps:
1. fp32_loss = loss.float() (see first Note below) 1. fp32_loss = loss.float() (see first Note below)
...@@ -514,19 +527,19 @@ class FP16_Optimizer(object): ...@@ -514,19 +527,19 @@ class FP16_Optimizer(object):
.. note:: .. note::
:attr:`backward` internally converts the loss to fp32 before applying the loss scale. :attr:`backward` internally converts the loss to fp32 before applying the loss scale.
This provides some additional safety against overflow if the user has supplied an This provides some additional safety against overflow if the user has supplied an
fp16 loss value. fp16 loss value.
However, for maximum overflow safety, the user should However, for maximum overflow safety, the user should
compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to
:attr:`backward`. :attr:`backward`.
.. warning:: .. warning::
The gradients found in a model's leaves after the call to The gradients found in a model's leaves after the call to
:attr:`backward` should not be regarded as valid in general, :attr:`backward` should not be regarded as valid in general,
because it's possible because it's possible
they have been scaled (and in the case of dynamic loss scaling, they have been scaled (and in the case of dynamic loss scaling,
the scale factor may change over time). the scale factor may change over time).
If the user wants to inspect gradients after a call to :attr:`backward`, If the user wants to inspect gradients after a call to :attr:`backward`,
only the master gradients should be regarded as valid. These can be retrieved via only the master gradients should be regarded as valid. These can be retrieved via
:attr:`inspect_master_grad_data()`. :attr:`inspect_master_grad_data()`.
...@@ -541,54 +554,55 @@ class FP16_Optimizer(object): ...@@ -541,54 +554,55 @@ class FP16_Optimizer(object):
optimizer.backward(loss) optimizer.backward(loss)
# Naive operation with multiple losses (technically valid, but less efficient): # Naive operation with multiple losses (technically valid, but less efficient):
# fp32 grads will be correct after the second call, but # fp32 grads will be correct after the second call, but
# the first call incurs an unnecessary fp16->fp32 grad copy. # the first call incurs an unnecessary fp16->fp32 grad copy.
optimizer.backward(loss1) optimizer.backward(loss1)
optimizer.backward(loss2) optimizer.backward(loss2)
# More efficient way to handle multiple losses: # More efficient way to handle multiple losses:
# The fp16->fp32 grad copy is delayed until fp16 grads from all # The fp16->fp32 grad copy is delayed until fp16 grads from all
# losses have been accumulated. # losses have been accumulated.
optimizer.backward(loss1, update_master_grads=False) optimizer.backward(loss1, update_master_grads=False)
optimizer.backward(loss2, update_master_grads=False) optimizer.backward(loss2, update_master_grads=False)
optimizer.update_master_grads() optimizer.update_master_grads()
""" """
# To consider: try multiple backward passes using retain_grad=True to find # To consider: try multiple backward passes using retain_grad=True to find
# a loss scale that works. After you find a loss scale that works, do a final dummy # a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid # backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency. # discarding the iteration, but probably wouldn't improve overall efficiency.
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
if update_master_grads: if update_master_grads:
self.update_master_grads() self.update_master_grads()
def update_master_grads(self): def update_master_grads(self):
""" """
Copy the ``.grad`` attribute from stored references to fp16 parameters to Copy the ``.grad`` attribute from stored references to fp16 parameters to
the ``.grad`` attribute of the fp32 master parameters that are directly the ``.grad`` attribute of the fp32 master parameters that are directly
updated by the optimizer. :attr:`update_master_grads` only needs to be called if updated by the optimizer. :attr:`update_master_grads` only needs to be called if
``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.
""" """
if self.dynamic_loss_scale: if self.dynamic_loss_scale:
self._check_overflow() self._check_overflow()
if self.overflow: return if self.overflow:
return
self._model_grads_to_master_grads() self._model_grads_to_master_grads()
self._downscale_master() self._downscale_master()
def inspect_master_grad_data(self): def inspect_master_grad_data(self):
""" """
When running with :class:`FP16_Optimizer`, When running with :class:`FP16_Optimizer`,
``.grad`` attributes of a model's fp16 leaves should not be ``.grad`` attributes of a model's fp16 leaves should not be
regarded as truthful, because they might be scaled. regarded as truthful, because they might be scaled.
After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered,
the fp32 master params' ``.grad`` the fp32 master params' ``.grad``
attributes will contain valid gradients properly divided by the loss scale. However, attributes will contain valid gradients properly divided by the loss scale. However,
because :class:`FP16_Optimizer` flattens some parameters, accessing them may be because :class:`FP16_Optimizer` flattens some parameters, accessing them may be
nonintuitive. :attr:`inspect_master_grad_data` nonintuitive. :attr:`inspect_master_grad_data`
allows those gradients to be viewed with shapes corresponding to their associated model leaves. allows those gradients to be viewed with shapes corresponding to their associated model leaves.
Returns: Returns:
List of lists (one list for each parameter group). The list for each parameter group List of lists (one list for each parameter group). The list for each parameter group
is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.
""" """
if self.overflow: if self.overflow:
print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. " print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. "
...@@ -607,8 +621,8 @@ class FP16_Optimizer(object): ...@@ -607,8 +621,8 @@ class FP16_Optimizer(object):
master_grads_data.append(master_grads_this_group) master_grads_data.append(master_grads_this_group)
return master_grads_data return master_grads_data
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self): def _get_loss_scale(self):
return self.loss_scaler.loss_scale return self.loss_scaler.loss_scale
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,6 +18,9 @@ import torch.nn as nn ...@@ -18,6 +18,9 @@ import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu from megatron import mpu
...@@ -102,6 +105,7 @@ class FP16Model(nn.Module): ...@@ -102,6 +105,7 @@ class FP16Model(nn.Module):
def backwards_debug_hook(grad): def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!") raise RuntimeError("master_params recieved a gradient in the backward pass!")
def prep_param_lists(model, flat_master=False): def prep_param_lists(model, flat_master=False):
""" """
Creates a list of FP32 master parameters for a given model, as in Creates a list of FP32 master parameters for a given model, as in
...@@ -131,9 +135,9 @@ def prep_param_lists(model, flat_master=False): ...@@ -131,9 +135,9 @@ def prep_param_lists(model, flat_master=False):
# flatten_dense_tensors returns a contiguous flat array. # flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html # http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params = _flatten_dense_tensors([param.data for param in model_params]).float() master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
except: except BaseException:
print("Error in prep_param_lists: model may contain a mixture of parameters " print("Error in prep_param_lists: model may contain a mixture of parameters "
"of different types. Use flat_master=False, or use F16_Optimizer.") "of different types. Use flat_master=False, or use F16_Optimizer.")
raise raise
master_params = torch.nn.Parameter(master_params) master_params = torch.nn.Parameter(master_params)
master_params.requires_grad = True master_params.requires_grad = True
...@@ -150,7 +154,7 @@ def prep_param_lists(model, flat_master=False): ...@@ -150,7 +154,7 @@ def prep_param_lists(model, flat_master=False):
def model_grads_to_master_grads(model_params, master_params, flat_master=False): def model_grads_to_master_grads(model_params, master_params, flat_master=False):
""" """
Copy model gradients to master gradients. Copy model gradients to master gradients.
Args: Args:
model_params: List of model parameters created by :func:`prep_param_lists`. model_params: List of model parameters created by :func:`prep_param_lists`.
...@@ -165,9 +169,15 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False): ...@@ -165,9 +169,15 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False):
if model.grad is not None: if model.grad is not None:
if master.grad is None: if master.grad is None:
master.grad = Variable(master.data.new(*master.data.size())) master.grad = Variable(master.data.new(*master.data.size()))
master.grad.data.copy_(model.grad.data)
else: else:
master.grad = None master.grad = None
model_grads = [p.grad for p in model_params if p.grad is not None]
master_grads = [p.grad for p in master_params if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[model_grads, master_grads],
1.0)
def master_params_to_model_params(model_params, master_params, flat_master=False): def master_params_to_model_params(model_params, master_params, flat_master=False):
...@@ -179,7 +189,7 @@ def master_params_to_model_params(model_params, master_params, flat_master=False ...@@ -179,7 +189,7 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.
""" """
if flat_master: if flat_master:
for model, master in zip(model_params, for model, master in zip(model_params,
_unflatten_dense_tensors(master_params[0].data, model_params)): _unflatten_dense_tensors(master_params[0].data, model_params)):
model.data.copy_(master) model.data.copy_(master)
else: else:
...@@ -188,17 +198,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False ...@@ -188,17 +198,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
# Backward compatibility fixes # Backward compatibility fixes
def to_python_float(t): def to_python_float(t):
if hasattr(t, 'item'): if hasattr(t, 'item'):
return t.item() return t.item()
else: else:
return t[0] return t[0]
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
clip_grad_norm = mpu.clip_grad_norm clip_grad_norm = mpu.clip_grad_norm
#elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4: # elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# clip_grad_norm = torch.nn.utils.clip_grad_norm # clip_grad_norm = torch.nn.utils.clip_grad_norm
#else: # else:
# clip_grad_norm = torch.nn.utils.clip_grad_norm_ # clip_grad_norm = torch.nn.utils.clip_grad_norm_
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,21 +14,28 @@ ...@@ -14,21 +14,28 @@
# limitations under the License. # limitations under the License.
import torch import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu from megatron import mpu
# item() is a recent addition, so this helps with backward compatibility. # item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t): def to_python_float(t):
if hasattr(t, 'item'): if hasattr(t, 'item'):
return t.item() return t.item()
else: else:
return t[0] return t[0]
class LossScaler: class LossScaler:
""" """
Class that manages a static loss scale. This class is intended to interact with Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user. :class:`FP16_Optimizer`, and should not be directly manipulated by the user.
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
:class:`FP16_Optimizer`'s constructor. :class:`FP16_Optimizer`'s constructor.
Args: Args:
...@@ -54,16 +61,22 @@ class LossScaler: ...@@ -54,16 +61,22 @@ class LossScaler:
return self.cur_scale return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out): def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False): def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph) scaled_loss.backward(retain_graph=retain_graph)
class DynamicLossScaler: class DynamicLossScaler:
""" """
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
operates, because the default options can be changed using the operates, because the default options can be changed using the
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
...@@ -71,18 +84,18 @@ class DynamicLossScaler: ...@@ -71,18 +84,18 @@ class DynamicLossScaler:
Loss scaling is designed to combat the problem of underflowing gradients encountered at long Loss scaling is designed to combat the problem of underflowing gradients encountered at long
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
occurred. occurred.
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
If a certain number of iterations occur without overflowing gradients detected, If a certain number of iterations occur without overflowing gradients detected,
:class:`DynamicLossScaler` increases the loss scale once more. :class:`DynamicLossScaler` increases the loss scale once more.
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
always using the highest loss scale possible without incurring overflow. always using the highest loss scale possible without incurring overflow.
Args: Args:
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
""" """
...@@ -122,12 +135,12 @@ class DynamicLossScaler: ...@@ -122,12 +135,12 @@ class DynamicLossScaler:
overflow = overflow_gpu[0].item() overflow = overflow_gpu[0].item()
return bool(overflow) return bool(overflow)
# `x` is a torch.Tensor # `x` is a torch.Tensor
def _has_inf_or_nan(x): def _has_inf_or_nan(x):
try: try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if # if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x # Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch). # (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum()) cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar # More efficient version that can be used if .sum() returns a Python scalar
...@@ -158,7 +171,7 @@ class DynamicLossScaler: ...@@ -158,7 +171,7 @@ class DynamicLossScaler:
if overflow: if overflow:
# self.cur_scale /= self.scale_factor # self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1: if self.delayed_shift == 1 or self.cur_hysteresis == 1:
self.cur_scale = max(self.cur_scale/self.scale_factor, self.min_scale) self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
else: else:
self.cur_hysteresis -= 1 self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter self.last_overflow_iter = self.cur_iter
...@@ -176,13 +189,19 @@ class DynamicLossScaler: ...@@ -176,13 +189,19 @@ class DynamicLossScaler:
return self.cur_scale return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out): def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False): def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph) scaled_loss.backward(retain_graph=retain_graph)
##############################################################
##############################################################
# Example usage below here -- assuming it's in a separate file # Example usage below here -- assuming it's in a separate file
############################################################## ##############################################################
""" """
...@@ -218,10 +237,10 @@ if __name__ == "__main__": ...@@ -218,10 +237,10 @@ if __name__ == "__main__":
# Run backprop # Run backprop
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
# Check for overflow # Check for overflow
has_overflow = DynamicLossScaler.has_overflow(parameters) has_overflow = DynamicLossScaler.has_overflow(parameters)
# If no overflow, unscale grad and update as usual # If no overflow, unscale grad and update as usual
if not has_overflow: if not has_overflow:
for param in parameters: for param in parameters:
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -61,22 +61,26 @@ def get_timers(): ...@@ -61,22 +61,26 @@ def get_timers():
return _GLOBAL_TIMERS return _GLOBAL_TIMERS
def set_global_variables(extra_args_provider=None, args_defaults={}): def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider, args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults) defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_ = _build_tokenizer(args) _ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
_set_timers() _set_timers()
def _parse_args(extra_args_provider=None, defaults={}): def _parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse entire arguments.""" """Parse entire arguments."""
global _GLOBAL_ARGS global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args') _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider, _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults) defaults=defaults,
ignore_unknown_args=ignore_unknown_args)
return _GLOBAL_ARGS return _GLOBAL_ARGS
...@@ -124,7 +128,7 @@ def _set_adlr_autoresume(args): ...@@ -124,7 +128,7 @@ def _set_adlr_autoresume(args):
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.')) sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try: try:
from userlib.auto_resume import AutoResume from userlib.auto_resume import AutoResume
except: except BaseException:
print('ADLR autoresume is not available, exiting ...') print('ADLR autoresume is not available, exiting ...')
sys.exit() sys.exit()
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -28,7 +28,8 @@ from megatron import mpu ...@@ -28,7 +28,8 @@ from megatron import mpu
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None, args_defaults={}): def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set global variables, initialize distributed, and """Set global variables, initialize distributed, and
set autoresume and random seeds.""" set autoresume and random seeds."""
# Make sure cuda is available. # Make sure cuda is available.
...@@ -37,7 +38,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}): ...@@ -37,7 +38,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}):
# Parse args, build tokenizer, and set adlr-autoresume, # Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers. # tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider, set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults) args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
# Pytorch distributed. # Pytorch distributed.
_initialize_distributed() _initialize_distributed()
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -48,7 +48,6 @@ class AnnealingLR(object): ...@@ -48,7 +48,6 @@ class AnnealingLR(object):
print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_lr(self): def get_lr(self):
"""Learning rate decay functions from: """Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
...@@ -71,7 +70,6 @@ class AnnealingLR(object): ...@@ -71,7 +70,6 @@ class AnnealingLR(object):
lr = self.start_lr lr = self.start_lr
return max(lr, self.min_lr) return max(lr, self.min_lr)
def step(self, step_num=None): def step(self, step_num=None):
"""Set lr for all parameters groups.""" """Set lr for all parameters groups."""
if step_num is None: if step_num is None:
...@@ -81,7 +79,6 @@ class AnnealingLR(object): ...@@ -81,7 +79,6 @@ class AnnealingLR(object):
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr
def state_dict(self): def state_dict(self):
state_dict = { state_dict = {
'start_lr': self.start_lr, 'start_lr': self.start_lr,
...@@ -93,7 +90,6 @@ class AnnealingLR(object): ...@@ -93,7 +90,6 @@ class AnnealingLR(object):
} }
return state_dict return state_dict
def _check_and_set(self, cls_value, sd_value, name): def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and """Auxiliary function for checking the values in the checkpoint and
setting them.""" setting them."""
...@@ -108,7 +104,6 @@ class AnnealingLR(object): ...@@ -108,7 +104,6 @@ class AnnealingLR(object):
name)) name))
return sd_value return sd_value
def load_state_dict(self, sd): def load_state_dict(self, sd):
self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'], self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'],
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,16 +22,15 @@ import torch ...@@ -22,16 +22,15 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .transformer import LayerNorm
from .utils import gelu
from .utils import get_linear_layer
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def bert_attention_mask_func(attention_scores, attention_mask): def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
...@@ -70,7 +69,6 @@ def bert_position_ids(token_ids): ...@@ -70,7 +69,6 @@ def bert_position_ids(token_ids):
return position_ids return position_ids
class BertLMHead(MegatronModule): class BertLMHead(MegatronModule):
"""Masked LM head for Bert """Masked LM head for Bert
...@@ -81,11 +79,14 @@ class BertLMHead(MegatronModule): ...@@ -81,11 +79,14 @@ class BertLMHead(MegatronModule):
layernorm_epsilon: tolerance for layer norm divisions layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not. parallel_output: whether output logits being distributed or not.
""" """
def __init__(self, mpu_vocab_size, hidden_size, init_method, def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output): layernorm_epsilon, parallel_output):
super(BertLMHead, self).__init__() super(BertLMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True self.bias.model_parallel = True
self.bias.partition_dim = 0 self.bias.partition_dim = 0
...@@ -94,11 +95,13 @@ class BertLMHead(MegatronModule): ...@@ -94,11 +95,13 @@ class BertLMHead(MegatronModule):
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
def forward(self, hidden_states, word_embeddings_weight): def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states) hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states) hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states, output = parallel_lm_logits(hidden_states,
word_embeddings_weight, word_embeddings_weight,
...@@ -107,7 +110,6 @@ class BertLMHead(MegatronModule): ...@@ -107,7 +110,6 @@ class BertLMHead(MegatronModule):
return output return output
class BertModel(MegatronModule): class BertModel(MegatronModule):
"""Bert Language model.""" """Bert Language model."""
...@@ -184,7 +186,6 @@ class BertModel(MegatronModule): ...@@ -184,7 +186,6 @@ class BertModel(MegatronModule):
return lm_logits, None return lm_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
...@@ -206,7 +207,6 @@ class BertModel(MegatronModule): ...@@ -206,7 +207,6 @@ class BertModel(MegatronModule):
= self.ict_head.state_dict(destination, prefix, keep_vars) = self.ict_head.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
...@@ -224,8 +224,6 @@ class BertModel(MegatronModule): ...@@ -224,8 +224,6 @@ class BertModel(MegatronModule):
class REALMBertModel(MegatronModule): class REALMBertModel(MegatronModule):
# TODO: load BertModel checkpoint
def __init__(self, retriever): def __init__(self, retriever):
super(REALMBertModel, self).__init__() super(REALMBertModel, self).__init__()
bert_args = dict( bert_args = dict(
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -53,7 +53,6 @@ class Classification(MegatronModule): ...@@ -53,7 +53,6 @@ class Classification(MegatronModule):
init_method) init_method)
self._classification_head_key = 'classification_head' self._classification_head_key = 'classification_head'
def forward(self, input_ids, attention_mask, tokentype_ids): def forward(self, input_ids, attention_mask, tokentype_ids):
extended_attention_mask = bert_extended_attention_mask( extended_attention_mask = bert_extended_attention_mask(
...@@ -74,7 +73,6 @@ class Classification(MegatronModule): ...@@ -74,7 +73,6 @@ class Classification(MegatronModule):
return classification_logits return classification_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
...@@ -89,7 +87,6 @@ class Classification(MegatronModule): ...@@ -89,7 +87,6 @@ class Classification(MegatronModule):
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -31,10 +31,6 @@ class DistributedDataParallel(MegatronModule): ...@@ -31,10 +31,6 @@ class DistributedDataParallel(MegatronModule):
self.module = module self.module = module
self.data_parallel_group = mpu.get_data_parallel_group() self.data_parallel_group = mpu.get_data_parallel_group()
src_rank = mpu.get_model_parallel_rank()
for p in self.module.parameters():
if torch.is_tensor(p):
dist.broadcast(p, src_rank, group=self.data_parallel_group)
def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False): def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False):
if(self.needs_reduction): if(self.needs_reduction):
...@@ -71,8 +67,8 @@ class DistributedDataParallel(MegatronModule): ...@@ -71,8 +67,8 @@ class DistributedDataParallel(MegatronModule):
def allreduce_hook(*unused): def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params) Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook) # handle = param.register_hook(allreduce_hook)
#self.hooks.append(allreduce_hook) # self.hooks.append(allreduce_hook)
#self.hook_handles.append(handle) # self.hook_handles.append(handle)
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
...@@ -114,4 +110,3 @@ class DistributedDataParallel(MegatronModule): ...@@ -114,4 +110,3 @@ class DistributedDataParallel(MegatronModule):
super(DistributedDataParallel, self).train(mode) super(DistributedDataParallel, self).train(mode)
self.module.train(mode) self.module.train(mode)
''' '''
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal ...@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask): def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores = torch.mul(attention_scores, ltor_mask) - \ attention_scores.masked_fill_(ltor_mask, -10000.0)
10000.0 * (1.0 - ltor_mask)
return attention_scores return attention_scores
...@@ -49,7 +48,6 @@ class GPT2Model(MegatronModule): ...@@ -49,7 +48,6 @@ class GPT2Model(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
...@@ -79,7 +77,6 @@ class GPT2Model(MegatronModule): ...@@ -79,7 +77,6 @@ class GPT2Model(MegatronModule):
return output return output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -89,7 +86,6 @@ class GPT2Model(MegatronModule): ...@@ -89,7 +86,6 @@ class GPT2Model(MegatronModule):
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,9 +21,8 @@ import torch.nn.functional as F ...@@ -21,9 +21,8 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import gelu from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
...@@ -47,7 +46,13 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -47,7 +46,13 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method, scaled_init_method, max_pos_embeds=None): init_method, scaled_init_method, max_pos_embeds=None):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args()
# Use torch gelu unless otherwise forced.
gelu = F.gelu
if args.openai_gelu:
gelu = openai_gelu
# Language model. # Language model.
language_model = TransformerLanguageModel( language_model = TransformerLanguageModel(
attention_mask_func=attention_mask_func, attention_mask_func=attention_mask_func,
...@@ -63,7 +68,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -63,7 +68,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
return language_model, language_model_key return language_model, language_model_key
class Pooler(MegatronModule): class Pooler(MegatronModule):
"""Pooler layer. """Pooler layer.
...@@ -75,11 +79,11 @@ class Pooler(MegatronModule): ...@@ -75,11 +79,11 @@ class Pooler(MegatronModule):
init_method: weight initialization method for the linear layer. init_method: weight initialization method for the linear layer.
bias is set to zero. bias is set to zero.
""" """
def __init__(self, hidden_size, init_method): def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__() super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0): def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# sequence_index: index of the token to pool. # sequence_index: index of the token to pool.
...@@ -102,6 +106,7 @@ class Embedding(MegatronModule): ...@@ -102,6 +106,7 @@ class Embedding(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding will ignore this embedding
""" """
def __init__(self, def __init__(self,
hidden_size, hidden_size,
vocab_size, vocab_size,
...@@ -143,7 +148,6 @@ class Embedding(MegatronModule): ...@@ -143,7 +148,6 @@ class Embedding(MegatronModule):
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes): def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add """Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it. token-type embeddings in case the pretrained model does not have it.
...@@ -160,7 +164,6 @@ class Embedding(MegatronModule): ...@@ -160,7 +164,6 @@ class Embedding(MegatronModule):
# Initialize the token-type embeddings. # Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight) self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None): def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings. # Embeddings.
words_embeddings = self.word_embeddings(input_ids) words_embeddings = self.word_embeddings(input_ids)
...@@ -177,7 +180,6 @@ class Embedding(MegatronModule): ...@@ -177,7 +180,6 @@ class Embedding(MegatronModule):
return embeddings return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load.""" """For easy load."""
...@@ -195,7 +197,6 @@ class Embedding(MegatronModule): ...@@ -195,7 +197,6 @@ class Embedding(MegatronModule):
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
...@@ -224,7 +225,7 @@ class Embedding(MegatronModule): ...@@ -224,7 +225,7 @@ class Embedding(MegatronModule):
self.position_embeddings.load_state_dict(state_dict_, strict=strict) self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding. # Tokentype embedding.
if self.num_tokentypes > 0: if self.num_tokentypes > 0:
state_dict_ = {} state_dict_ = {}
if self._tokentype_embeddings_key in state_dict: if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key] state_dict_ = state_dict[self._tokentype_embeddings_key]
...@@ -242,7 +243,6 @@ class Embedding(MegatronModule): ...@@ -242,7 +243,6 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True) 'checkpoint but could not find it', flush=True)
class TransformerLanguageModel(MegatronModule): class TransformerLanguageModel(MegatronModule):
"""Transformer language model. """Transformer language model.
...@@ -261,6 +261,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -261,6 +261,7 @@ class TransformerLanguageModel(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding will ignore this embedding
""" """
def __init__(self, def __init__(self,
attention_mask_func, attention_mask_func,
mlp_activation_func, mlp_activation_func,
...@@ -298,7 +299,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -298,7 +299,6 @@ class TransformerLanguageModel(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler' self._pooler_key = 'pooler'
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0): pooling_sequence_index=0):
...@@ -320,7 +320,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -320,7 +320,6 @@ class TransformerLanguageModel(MegatronModule):
return transformer_output return transformer_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load.""" """For easy load."""
...@@ -339,7 +338,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -339,7 +338,6 @@ class TransformerLanguageModel(MegatronModule):
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule): ...@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule):
init_method) init_method)
self._multichoice_head_key = 'multichoice_head' self._multichoice_head_key = 'multichoice_head'
def forward(self, input_ids, attention_mask, tokentype_ids): def forward(self, input_ids, attention_mask, tokentype_ids):
# [batch, choices, sequence] --> [batch * choices, sequence] --> # [batch, choices, sequence] --> [batch * choices, sequence] -->
...@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule): ...@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule):
return multichoice_logits return multichoice_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
...@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule): ...@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule):
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -46,6 +46,7 @@ from megatron.module import MegatronModule ...@@ -46,6 +46,7 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask) unmaksed-attention-scores, attention-mask)
""" """
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
"""MLP. """MLP.
...@@ -63,7 +64,7 @@ class ParallelMLP(MegatronModule): ...@@ -63,7 +64,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear( self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
4*args.hidden_size, 4 * args.hidden_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method)
...@@ -71,14 +72,13 @@ class ParallelMLP(MegatronModule): ...@@ -71,14 +72,13 @@ class ParallelMLP(MegatronModule):
# Project back to h. # Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear( self.dense_4h_to_h = mpu.RowParallelLinear(
4*args.hidden_size, 4 * args.hidden_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method) init_method=output_layer_init_method)
self.dropout = torch.nn.Dropout(args.hidden_dropout) self.dropout = torch.nn.Dropout(args.hidden_dropout)
def forward(self, hidden_states): def forward(self, hidden_states):
# [b, s, 4hp] # [b, s, 4hp]
...@@ -91,13 +91,13 @@ class ParallelMLP(MegatronModule): ...@@ -91,13 +91,13 @@ class ParallelMLP(MegatronModule):
return output return output
class ParallelSelfAttention(MegatronModule): class ParallelSelfAttention(MegatronModule):
"""Parallel self-attention layer abstract class. """Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h] Self-attention layer takes input with size [b, s, h]
and returns output of the same size. and returns output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number): output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__() super(ParallelSelfAttention, self).__init__()
...@@ -123,7 +123,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -123,7 +123,7 @@ class ParallelSelfAttention(MegatronModule):
# Strided linear layer. # Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear( self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
3*args.hidden_size, 3 * args.hidden_size,
stride=3, stride=3,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method)
...@@ -141,18 +141,16 @@ class ParallelSelfAttention(MegatronModule): ...@@ -141,18 +141,16 @@ class ParallelSelfAttention(MegatronModule):
init_method=output_layer_init_method) init_method=output_layer_init_method)
self.output_dropout = torch.nn.Dropout(args.hidden_dropout) self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
def _transpose_for_scores(self, tensor): def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn]. size [b, np, s, hn].
""" """
new_tensor_shape = tensor.size()[:-1] + \ new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head) self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape) tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3) return tensor.permute(0, 2, 1, 3)
def _get_query_key_value(self, hidden_states): def _get_query_key_value(self, hidden_states):
"""Get query, key, and value and transpose to """Get query, key, and value and transpose to
get size [b, np, s, hn]. get size [b, np, s, hn].
...@@ -170,7 +168,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -170,7 +168,6 @@ class ParallelSelfAttention(MegatronModule):
return query_layer, key_layer, value_layer return query_layer, key_layer, value_layer
def _get_unmasked_attention_scores(self, query_layer, key_layer): def _get_unmasked_attention_scores(self, query_layer, key_layer):
"""Unmasked attention scores with size [b, np, s, s].""" """Unmasked attention scores with size [b, np, s, s]."""
coeff = 1 coeff = 1
...@@ -179,9 +176,8 @@ class ParallelSelfAttention(MegatronModule): ...@@ -179,9 +176,8 @@ class ParallelSelfAttention(MegatronModule):
norm_factor = math.sqrt(coeff * norm_factor = math.sqrt(coeff *
math.sqrt(self.hidden_size_per_attention_head)) math.sqrt(self.hidden_size_per_attention_head))
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
return torch.matmul(query_layer/norm_factor, return torch.matmul(query_layer / norm_factor,
key_layer.transpose(-1, -2)/norm_factor) key_layer.transpose(-1, -2) / norm_factor)
def _get_attention_probs(self, attention_scores): def _get_attention_probs(self, attention_scores):
"""Attention probabilies with dropout. The output has """Attention probabilies with dropout. The output has
...@@ -198,7 +194,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -198,7 +194,6 @@ class ParallelSelfAttention(MegatronModule):
return attention_probs return attention_probs
def _get_attended_context(self, attention_probs, value_layer): def _get_attended_context(self, attention_probs, value_layer):
"""Final attended tesnor and transposed back to [b, s, hp].""" """Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer. # Context layer.
...@@ -207,13 +202,12 @@ class ParallelSelfAttention(MegatronModule): ...@@ -207,13 +202,12 @@ class ParallelSelfAttention(MegatronModule):
# [b, s, np, hn] # [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + \ new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,) (self.hidden_size_per_partition,)
# [b, s, hp] # [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
return context_layer return context_layer
def _get_output(self, context_layer): def _get_output(self, context_layer):
"""Output layer with dropout.""" """Output layer with dropout."""
# Output. [b, s, h] # Output. [b, s, h]
...@@ -222,7 +216,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -222,7 +216,6 @@ class ParallelSelfAttention(MegatronModule):
return output return output
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
...@@ -254,7 +247,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -254,7 +247,7 @@ class ParallelSelfAttention(MegatronModule):
if layer_past is not None: if layer_past is not None:
attention_mask = attention_mask[ attention_mask = attention_mask[
..., ...,
attention_scores.size(3)-1, attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2) :attention_scores.size(3)].unsqueeze(2)
else: else:
attention_mask = attention_mask[ attention_mask = attention_mask[
...@@ -283,13 +276,13 @@ class ParallelSelfAttention(MegatronModule): ...@@ -283,13 +276,13 @@ class ParallelSelfAttention(MegatronModule):
return output return output
class ParallelTransformerLayer(MegatronModule): class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer. """A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an Transformore layer takes input with size [b, s, h] and returns an
output of the same size. output of the same size.
""" """
def __init__(self, attention_mask_func, mlp_activation_func, def __init__(self, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number): init_method, output_layer_init_method, layer_number):
args = get_args() args = get_args()
...@@ -319,7 +312,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -319,7 +312,6 @@ class ParallelTransformerLayer(MegatronModule):
self.mlp = ParallelMLP(mlp_activation_func, init_method, self.mlp = ParallelMLP(mlp_activation_func, init_method,
output_layer_init_method) output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
...@@ -375,14 +367,13 @@ class ParallelTransformer(MegatronModule): ...@@ -375,14 +367,13 @@ class ParallelTransformer(MegatronModule):
# Transformer layers. # Transformer layers.
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[get_layer(i+1) for i in range(args.num_layers)]) [get_layer(i + 1) for i in range(args.num_layers)])
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
def _checkpointed_forward(self, hidden_states, attention_mask): def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
def custom(start, end): def custom(start, end):
...@@ -398,13 +389,12 @@ class ParallelTransformer(MegatronModule): ...@@ -398,13 +389,12 @@ class ParallelTransformer(MegatronModule):
num_layers = len(self.layers) num_layers = len(self.layers)
while l < num_layers: while l < num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l+self.checkpoint_num_layers), custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask) hidden_states, attention_mask)
l += self.checkpoint_num_layers l += self.checkpoint_num_layers
return hidden_states return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False):
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -33,6 +33,7 @@ def init_method_normal(sigma): ...@@ -33,6 +33,7 @@ def init_method_normal(sigma):
def scaled_init_method_normal(sigma, num_layers): def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers).""" """Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers) std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor): def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std) return torch.nn.init.normal_(tensor, mean=0.0, std=std)
...@@ -53,8 +54,7 @@ def gelu_impl(x): ...@@ -53,8 +54,7 @@ def gelu_impl(x):
"""OpenAI's gelu implementation.""" """OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x))) (1.0 + 0.044715 * x * x)))
def openai_gelu(x):
def gelu(x):
return gelu_impl(x) return gelu_impl(x)
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,11 +21,9 @@ import torch ...@@ -21,11 +21,9 @@ import torch
class MegatronModule(torch.nn.Module): class MegatronModule(torch.nn.Module):
"""Megatron specific extentions of torch Module.""" """Megatron specific extentions of torch Module."""
def __init__(self): def __init__(self):
super(MegatronModule, self).__init__() super(MegatronModule, self).__init__()
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""Use this function to override the state dict for """Use this function to override the state dict for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment