Commit 5e56e563 authored by Neel Kant's avatar Neel Kant
Browse files

Merge master into realm-mlm

parents 6c0a5bd8 569b3dab
...@@ -27,7 +27,8 @@ try: ...@@ -27,7 +27,8 @@ try:
from functools import lru_cache from functools import lru_cache
except ImportError: except ImportError:
# Just a dummy decorator to get the checks to run on python2 # Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. # because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache(): def lru_cache():
return lambda func: func return lambda func: func
...@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json' ...@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt' MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt' SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache() @lru_cache()
def bytes_to_unicode(): def bytes_to_unicode():
""" """
...@@ -60,17 +62,19 @@ def bytes_to_unicode(): ...@@ -60,17 +62,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on. And avoids mapping to whitespace/control characters the bpe code barfs on.
""" """
_chr = unichr if sys.version_info[0] == 2 else chr _chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:] cs = bs[:]
n = 0 n = 0
for b in range(2**8): for b in range(2**8):
if b not in bs: if b not in bs:
bs.append(b) bs.append(b)
cs.append(2**8+n) cs.append(2**8 + n)
n += 1 n += 1
cs = [_chr(n) for n in cs] cs = [_chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
def get_pairs(word): def get_pairs(word):
"""Return set of symbol pairs in a word. """Return set of symbol pairs in a word.
...@@ -83,6 +87,7 @@ def get_pairs(word): ...@@ -83,6 +87,7 @@ def get_pairs(word):
prev_char = char prev_char = char
return pairs return pairs
class GPT2Tokenizer(object): class GPT2Tokenizer(object):
""" """
GPT-2 BPE tokenizer. Peculiarities: GPT-2 BPE tokenizer. Peculiarities:
...@@ -138,23 +143,31 @@ class GPT2Tokenizer(object): ...@@ -138,23 +143,31 @@ class GPT2Tokenizer(object):
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else: else:
special_tokens = kwargs.pop('special_tokens', []) special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode() self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data] bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {} self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions # Should haved added re.IGNORECASE so BPE merges can happen for
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") # capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {} self.special_tokens = {}
self.special_tokens_decoder = {} self.special_tokens_decoder = {}
...@@ -172,8 +185,9 @@ class GPT2Tokenizer(object): ...@@ -172,8 +185,9 @@ class GPT2Tokenizer(object):
self.special_tokens = {} self.special_tokens = {}
self.special_tokens_decoder = {} self.special_tokens_decoder = {}
return return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) self.special_tokens = dict((tok, len(self.encoder) + i)
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens)) logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
...@@ -186,7 +200,7 @@ class GPT2Tokenizer(object): ...@@ -186,7 +200,7 @@ class GPT2Tokenizer(object):
return token return token
while True: while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
...@@ -197,12 +211,12 @@ class GPT2Tokenizer(object): ...@@ -197,12 +211,12 @@ class GPT2Tokenizer(object):
j = word.index(first, i) j = word.index(first, i)
new_word.extend(word[i:j]) new_word.extend(word[i:j])
i = j i = j
except: except BaseException:
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first+second) new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
...@@ -245,7 +259,8 @@ class GPT2Tokenizer(object): ...@@ -245,7 +259,8 @@ class GPT2Tokenizer(object):
logger.warning( logger.warning(
"Token indices sequence length is longer than the specified maximum " "Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this" " sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len) " sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
) )
return ids return ids
......
...@@ -99,19 +99,19 @@ class BertTokenizer(object): ...@@ -99,19 +99,19 @@ class BertTokenizer(object):
[(ids, tok) for tok, ids in self.vocab.items()]) [(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize: if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split) never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
def tokenize(self, text): def tokenize(self, text):
if self.do_basic_tokenize: if self.do_basic_tokenize:
split_tokens = [] split_tokens = []
for token in self.basic_tokenizer.tokenize(text): for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token): for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token) split_tokens.append(sub_token)
else: else:
split_tokens = self.wordpiece_tokenizer.tokenize(text) split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
...@@ -123,7 +123,8 @@ class BertTokenizer(object): ...@@ -123,7 +123,8 @@ class BertTokenizer(object):
logger.warning( logger.warning(
"Token indices sequence length is longer than the specified maximum " "Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this" " sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len) " sequence through BERT will result in indexing errors".format(
len(ids), self.max_len)
) )
return ids return ids
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
This diff is collapsed.
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,6 +18,9 @@ import torch.nn as nn ...@@ -18,6 +18,9 @@ import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu from megatron import mpu
...@@ -102,6 +105,7 @@ class FP16Model(nn.Module): ...@@ -102,6 +105,7 @@ class FP16Model(nn.Module):
def backwards_debug_hook(grad): def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!") raise RuntimeError("master_params recieved a gradient in the backward pass!")
def prep_param_lists(model, flat_master=False): def prep_param_lists(model, flat_master=False):
""" """
Creates a list of FP32 master parameters for a given model, as in Creates a list of FP32 master parameters for a given model, as in
...@@ -131,9 +135,9 @@ def prep_param_lists(model, flat_master=False): ...@@ -131,9 +135,9 @@ def prep_param_lists(model, flat_master=False):
# flatten_dense_tensors returns a contiguous flat array. # flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html # http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params = _flatten_dense_tensors([param.data for param in model_params]).float() master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
except: except BaseException:
print("Error in prep_param_lists: model may contain a mixture of parameters " print("Error in prep_param_lists: model may contain a mixture of parameters "
"of different types. Use flat_master=False, or use F16_Optimizer.") "of different types. Use flat_master=False, or use F16_Optimizer.")
raise raise
master_params = torch.nn.Parameter(master_params) master_params = torch.nn.Parameter(master_params)
master_params.requires_grad = True master_params.requires_grad = True
...@@ -150,7 +154,7 @@ def prep_param_lists(model, flat_master=False): ...@@ -150,7 +154,7 @@ def prep_param_lists(model, flat_master=False):
def model_grads_to_master_grads(model_params, master_params, flat_master=False): def model_grads_to_master_grads(model_params, master_params, flat_master=False):
""" """
Copy model gradients to master gradients. Copy model gradients to master gradients.
Args: Args:
model_params: List of model parameters created by :func:`prep_param_lists`. model_params: List of model parameters created by :func:`prep_param_lists`.
...@@ -165,9 +169,15 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False): ...@@ -165,9 +169,15 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False):
if model.grad is not None: if model.grad is not None:
if master.grad is None: if master.grad is None:
master.grad = Variable(master.data.new(*master.data.size())) master.grad = Variable(master.data.new(*master.data.size()))
master.grad.data.copy_(model.grad.data)
else: else:
master.grad = None master.grad = None
model_grads = [p.grad for p in model_params if p.grad is not None]
master_grads = [p.grad for p in master_params if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[model_grads, master_grads],
1.0)
def master_params_to_model_params(model_params, master_params, flat_master=False): def master_params_to_model_params(model_params, master_params, flat_master=False):
...@@ -179,7 +189,7 @@ def master_params_to_model_params(model_params, master_params, flat_master=False ...@@ -179,7 +189,7 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.
""" """
if flat_master: if flat_master:
for model, master in zip(model_params, for model, master in zip(model_params,
_unflatten_dense_tensors(master_params[0].data, model_params)): _unflatten_dense_tensors(master_params[0].data, model_params)):
model.data.copy_(master) model.data.copy_(master)
else: else:
...@@ -188,17 +198,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False ...@@ -188,17 +198,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
# Backward compatibility fixes # Backward compatibility fixes
def to_python_float(t): def to_python_float(t):
if hasattr(t, 'item'): if hasattr(t, 'item'):
return t.item() return t.item()
else: else:
return t[0] return t[0]
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
clip_grad_norm = mpu.clip_grad_norm clip_grad_norm = mpu.clip_grad_norm
#elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4: # elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# clip_grad_norm = torch.nn.utils.clip_grad_norm # clip_grad_norm = torch.nn.utils.clip_grad_norm
#else: # else:
# clip_grad_norm = torch.nn.utils.clip_grad_norm_ # clip_grad_norm = torch.nn.utils.clip_grad_norm_
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,21 +14,28 @@ ...@@ -14,21 +14,28 @@
# limitations under the License. # limitations under the License.
import torch import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu from megatron import mpu
# item() is a recent addition, so this helps with backward compatibility. # item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t): def to_python_float(t):
if hasattr(t, 'item'): if hasattr(t, 'item'):
return t.item() return t.item()
else: else:
return t[0] return t[0]
class LossScaler: class LossScaler:
""" """
Class that manages a static loss scale. This class is intended to interact with Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user. :class:`FP16_Optimizer`, and should not be directly manipulated by the user.
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
:class:`FP16_Optimizer`'s constructor. :class:`FP16_Optimizer`'s constructor.
Args: Args:
...@@ -54,16 +61,22 @@ class LossScaler: ...@@ -54,16 +61,22 @@ class LossScaler:
return self.cur_scale return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out): def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False): def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph) scaled_loss.backward(retain_graph=retain_graph)
class DynamicLossScaler: class DynamicLossScaler:
""" """
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
operates, because the default options can be changed using the operates, because the default options can be changed using the
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
...@@ -71,18 +84,18 @@ class DynamicLossScaler: ...@@ -71,18 +84,18 @@ class DynamicLossScaler:
Loss scaling is designed to combat the problem of underflowing gradients encountered at long Loss scaling is designed to combat the problem of underflowing gradients encountered at long
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
occurred. occurred.
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
If a certain number of iterations occur without overflowing gradients detected, If a certain number of iterations occur without overflowing gradients detected,
:class:`DynamicLossScaler` increases the loss scale once more. :class:`DynamicLossScaler` increases the loss scale once more.
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
always using the highest loss scale possible without incurring overflow. always using the highest loss scale possible without incurring overflow.
Args: Args:
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
""" """
...@@ -122,12 +135,12 @@ class DynamicLossScaler: ...@@ -122,12 +135,12 @@ class DynamicLossScaler:
overflow = overflow_gpu[0].item() overflow = overflow_gpu[0].item()
return bool(overflow) return bool(overflow)
# `x` is a torch.Tensor # `x` is a torch.Tensor
def _has_inf_or_nan(x): def _has_inf_or_nan(x):
try: try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if # if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x # Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch). # (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum()) cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar # More efficient version that can be used if .sum() returns a Python scalar
...@@ -158,7 +171,7 @@ class DynamicLossScaler: ...@@ -158,7 +171,7 @@ class DynamicLossScaler:
if overflow: if overflow:
# self.cur_scale /= self.scale_factor # self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1: if self.delayed_shift == 1 or self.cur_hysteresis == 1:
self.cur_scale = max(self.cur_scale/self.scale_factor, self.min_scale) self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
else: else:
self.cur_hysteresis -= 1 self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter self.last_overflow_iter = self.cur_iter
...@@ -176,13 +189,19 @@ class DynamicLossScaler: ...@@ -176,13 +189,19 @@ class DynamicLossScaler:
return self.cur_scale return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out): def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in) _overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False): def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph) scaled_loss.backward(retain_graph=retain_graph)
##############################################################
##############################################################
# Example usage below here -- assuming it's in a separate file # Example usage below here -- assuming it's in a separate file
############################################################## ##############################################################
""" """
...@@ -218,10 +237,10 @@ if __name__ == "__main__": ...@@ -218,10 +237,10 @@ if __name__ == "__main__":
# Run backprop # Run backprop
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
# Check for overflow # Check for overflow
has_overflow = DynamicLossScaler.has_overflow(parameters) has_overflow = DynamicLossScaler.has_overflow(parameters)
# If no overflow, unscale grad and update as usual # If no overflow, unscale grad and update as usual
if not has_overflow: if not has_overflow:
for param in parameters: for param in parameters:
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -61,22 +61,26 @@ def get_timers(): ...@@ -61,22 +61,26 @@ def get_timers():
return _GLOBAL_TIMERS return _GLOBAL_TIMERS
def set_global_variables(extra_args_provider=None, args_defaults={}): def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider, args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults) defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_ = _build_tokenizer(args) _ = _build_tokenizer(args)
_set_tensorboard_writer(args) _set_tensorboard_writer(args)
_set_adlr_autoresume(args) _set_adlr_autoresume(args)
_set_timers() _set_timers()
def _parse_args(extra_args_provider=None, defaults={}): def _parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse entire arguments.""" """Parse entire arguments."""
global _GLOBAL_ARGS global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args') _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider, _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults) defaults=defaults,
ignore_unknown_args=ignore_unknown_args)
return _GLOBAL_ARGS return _GLOBAL_ARGS
...@@ -124,7 +128,7 @@ def _set_adlr_autoresume(args): ...@@ -124,7 +128,7 @@ def _set_adlr_autoresume(args):
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.')) sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try: try:
from userlib.auto_resume import AutoResume from userlib.auto_resume import AutoResume
except: except BaseException:
print('ADLR autoresume is not available, exiting ...') print('ADLR autoresume is not available, exiting ...')
sys.exit() sys.exit()
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -28,7 +28,8 @@ from megatron import mpu ...@@ -28,7 +28,8 @@ from megatron import mpu
from megatron.global_vars import set_global_variables from megatron.global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None, args_defaults={}): def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set global variables, initialize distributed, and """Set global variables, initialize distributed, and
set autoresume and random seeds.""" set autoresume and random seeds."""
# Make sure cuda is available. # Make sure cuda is available.
...@@ -37,7 +38,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}): ...@@ -37,7 +38,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}):
# Parse args, build tokenizer, and set adlr-autoresume, # Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers. # tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider, set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults) args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
# Pytorch distributed. # Pytorch distributed.
_initialize_distributed() _initialize_distributed()
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -48,7 +48,6 @@ class AnnealingLR(object): ...@@ -48,7 +48,6 @@ class AnnealingLR(object):
print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_lr(self): def get_lr(self):
"""Learning rate decay functions from: """Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
...@@ -71,7 +70,6 @@ class AnnealingLR(object): ...@@ -71,7 +70,6 @@ class AnnealingLR(object):
lr = self.start_lr lr = self.start_lr
return max(lr, self.min_lr) return max(lr, self.min_lr)
def step(self, step_num=None): def step(self, step_num=None):
"""Set lr for all parameters groups.""" """Set lr for all parameters groups."""
if step_num is None: if step_num is None:
...@@ -81,7 +79,6 @@ class AnnealingLR(object): ...@@ -81,7 +79,6 @@ class AnnealingLR(object):
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr
def state_dict(self): def state_dict(self):
state_dict = { state_dict = {
'start_lr': self.start_lr, 'start_lr': self.start_lr,
...@@ -93,7 +90,6 @@ class AnnealingLR(object): ...@@ -93,7 +90,6 @@ class AnnealingLR(object):
} }
return state_dict return state_dict
def _check_and_set(self, cls_value, sd_value, name): def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and """Auxiliary function for checking the values in the checkpoint and
setting them.""" setting them."""
...@@ -108,7 +104,6 @@ class AnnealingLR(object): ...@@ -108,7 +104,6 @@ class AnnealingLR(object):
name)) name))
return sd_value return sd_value
def load_state_dict(self, sd): def load_state_dict(self, sd):
self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'], self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'],
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -22,16 +22,15 @@ import torch ...@@ -22,16 +22,15 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule from megatron.module import MegatronModule
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .transformer import LayerNorm
from .utils import gelu
from .utils import get_linear_layer
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def bert_attention_mask_func(attention_scores, attention_mask): def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask attention_scores = attention_scores + attention_mask
...@@ -70,7 +69,6 @@ def bert_position_ids(token_ids): ...@@ -70,7 +69,6 @@ def bert_position_ids(token_ids):
return position_ids return position_ids
class BertLMHead(MegatronModule): class BertLMHead(MegatronModule):
"""Masked LM head for Bert """Masked LM head for Bert
...@@ -81,11 +79,14 @@ class BertLMHead(MegatronModule): ...@@ -81,11 +79,14 @@ class BertLMHead(MegatronModule):
layernorm_epsilon: tolerance for layer norm divisions layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not. parallel_output: whether output logits being distributed or not.
""" """
def __init__(self, mpu_vocab_size, hidden_size, init_method, def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output): layernorm_epsilon, parallel_output):
super(BertLMHead, self).__init__() super(BertLMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size)) self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True self.bias.model_parallel = True
self.bias.partition_dim = 0 self.bias.partition_dim = 0
...@@ -94,11 +95,13 @@ class BertLMHead(MegatronModule): ...@@ -94,11 +95,13 @@ class BertLMHead(MegatronModule):
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
def forward(self, hidden_states, word_embeddings_weight): def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states) hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states) hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states, output = parallel_lm_logits(hidden_states,
word_embeddings_weight, word_embeddings_weight,
...@@ -107,7 +110,6 @@ class BertLMHead(MegatronModule): ...@@ -107,7 +110,6 @@ class BertLMHead(MegatronModule):
return output return output
class BertModel(MegatronModule): class BertModel(MegatronModule):
"""Bert Language model.""" """Bert Language model."""
...@@ -184,7 +186,6 @@ class BertModel(MegatronModule): ...@@ -184,7 +186,6 @@ class BertModel(MegatronModule):
return lm_logits, None return lm_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
...@@ -206,7 +207,6 @@ class BertModel(MegatronModule): ...@@ -206,7 +207,6 @@ class BertModel(MegatronModule):
= self.ict_head.state_dict(destination, prefix, keep_vars) = self.ict_head.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
...@@ -224,8 +224,6 @@ class BertModel(MegatronModule): ...@@ -224,8 +224,6 @@ class BertModel(MegatronModule):
class REALMBertModel(MegatronModule): class REALMBertModel(MegatronModule):
# TODO: load BertModel checkpoint
def __init__(self, retriever): def __init__(self, retriever):
super(REALMBertModel, self).__init__() super(REALMBertModel, self).__init__()
bert_args = dict( bert_args = dict(
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -53,7 +53,6 @@ class Classification(MegatronModule): ...@@ -53,7 +53,6 @@ class Classification(MegatronModule):
init_method) init_method)
self._classification_head_key = 'classification_head' self._classification_head_key = 'classification_head'
def forward(self, input_ids, attention_mask, tokentype_ids): def forward(self, input_ids, attention_mask, tokentype_ids):
extended_attention_mask = bert_extended_attention_mask( extended_attention_mask = bert_extended_attention_mask(
...@@ -74,7 +73,6 @@ class Classification(MegatronModule): ...@@ -74,7 +73,6 @@ class Classification(MegatronModule):
return classification_logits return classification_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
...@@ -89,7 +87,6 @@ class Classification(MegatronModule): ...@@ -89,7 +87,6 @@ class Classification(MegatronModule):
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -31,10 +31,6 @@ class DistributedDataParallel(MegatronModule): ...@@ -31,10 +31,6 @@ class DistributedDataParallel(MegatronModule):
self.module = module self.module = module
self.data_parallel_group = mpu.get_data_parallel_group() self.data_parallel_group = mpu.get_data_parallel_group()
src_rank = mpu.get_model_parallel_rank()
for p in self.module.parameters():
if torch.is_tensor(p):
dist.broadcast(p, src_rank, group=self.data_parallel_group)
def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False): def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False):
if(self.needs_reduction): if(self.needs_reduction):
...@@ -71,8 +67,8 @@ class DistributedDataParallel(MegatronModule): ...@@ -71,8 +67,8 @@ class DistributedDataParallel(MegatronModule):
def allreduce_hook(*unused): def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params) Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook) # handle = param.register_hook(allreduce_hook)
#self.hooks.append(allreduce_hook) # self.hooks.append(allreduce_hook)
#self.hook_handles.append(handle) # self.hook_handles.append(handle)
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
...@@ -114,4 +110,3 @@ class DistributedDataParallel(MegatronModule): ...@@ -114,4 +110,3 @@ class DistributedDataParallel(MegatronModule):
super(DistributedDataParallel, self).train(mode) super(DistributedDataParallel, self).train(mode)
self.module.train(mode) self.module.train(mode)
''' '''
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal ...@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask): def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores = torch.mul(attention_scores, ltor_mask) - \ attention_scores.masked_fill_(ltor_mask, -10000.0)
10000.0 * (1.0 - ltor_mask)
return attention_scores return attention_scores
...@@ -49,7 +48,6 @@ class GPT2Model(MegatronModule): ...@@ -49,7 +48,6 @@ class GPT2Model(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
...@@ -79,7 +77,6 @@ class GPT2Model(MegatronModule): ...@@ -79,7 +77,6 @@ class GPT2Model(MegatronModule):
return output return output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -89,7 +86,6 @@ class GPT2Model(MegatronModule): ...@@ -89,7 +86,6 @@ class GPT2Model(MegatronModule):
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,9 +21,8 @@ import torch.nn.functional as F ...@@ -21,9 +21,8 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import gelu from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
...@@ -47,7 +46,13 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -47,7 +46,13 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method, scaled_init_method, max_pos_embeds=None): init_method, scaled_init_method, max_pos_embeds=None):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args()
# Use torch gelu unless otherwise forced.
gelu = F.gelu
if args.openai_gelu:
gelu = openai_gelu
# Language model. # Language model.
language_model = TransformerLanguageModel( language_model = TransformerLanguageModel(
attention_mask_func=attention_mask_func, attention_mask_func=attention_mask_func,
...@@ -63,7 +68,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -63,7 +68,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
return language_model, language_model_key return language_model, language_model_key
class Pooler(MegatronModule): class Pooler(MegatronModule):
"""Pooler layer. """Pooler layer.
...@@ -75,11 +79,11 @@ class Pooler(MegatronModule): ...@@ -75,11 +79,11 @@ class Pooler(MegatronModule):
init_method: weight initialization method for the linear layer. init_method: weight initialization method for the linear layer.
bias is set to zero. bias is set to zero.
""" """
def __init__(self, hidden_size, init_method): def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__() super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0): def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# sequence_index: index of the token to pool. # sequence_index: index of the token to pool.
...@@ -102,6 +106,7 @@ class Embedding(MegatronModule): ...@@ -102,6 +106,7 @@ class Embedding(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding will ignore this embedding
""" """
def __init__(self, def __init__(self,
hidden_size, hidden_size,
vocab_size, vocab_size,
...@@ -143,7 +148,6 @@ class Embedding(MegatronModule): ...@@ -143,7 +148,6 @@ class Embedding(MegatronModule):
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes): def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add """Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it. token-type embeddings in case the pretrained model does not have it.
...@@ -160,7 +164,6 @@ class Embedding(MegatronModule): ...@@ -160,7 +164,6 @@ class Embedding(MegatronModule):
# Initialize the token-type embeddings. # Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight) self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None): def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings. # Embeddings.
words_embeddings = self.word_embeddings(input_ids) words_embeddings = self.word_embeddings(input_ids)
...@@ -177,7 +180,6 @@ class Embedding(MegatronModule): ...@@ -177,7 +180,6 @@ class Embedding(MegatronModule):
return embeddings return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load.""" """For easy load."""
...@@ -195,7 +197,6 @@ class Embedding(MegatronModule): ...@@ -195,7 +197,6 @@ class Embedding(MegatronModule):
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
...@@ -224,7 +225,7 @@ class Embedding(MegatronModule): ...@@ -224,7 +225,7 @@ class Embedding(MegatronModule):
self.position_embeddings.load_state_dict(state_dict_, strict=strict) self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding. # Tokentype embedding.
if self.num_tokentypes > 0: if self.num_tokentypes > 0:
state_dict_ = {} state_dict_ = {}
if self._tokentype_embeddings_key in state_dict: if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key] state_dict_ = state_dict[self._tokentype_embeddings_key]
...@@ -242,7 +243,6 @@ class Embedding(MegatronModule): ...@@ -242,7 +243,6 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True) 'checkpoint but could not find it', flush=True)
class TransformerLanguageModel(MegatronModule): class TransformerLanguageModel(MegatronModule):
"""Transformer language model. """Transformer language model.
...@@ -261,6 +261,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -261,6 +261,7 @@ class TransformerLanguageModel(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding will ignore this embedding
""" """
def __init__(self, def __init__(self,
attention_mask_func, attention_mask_func,
mlp_activation_func, mlp_activation_func,
...@@ -298,7 +299,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -298,7 +299,6 @@ class TransformerLanguageModel(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler' self._pooler_key = 'pooler'
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0): pooling_sequence_index=0):
...@@ -320,7 +320,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -320,7 +320,6 @@ class TransformerLanguageModel(MegatronModule):
return transformer_output return transformer_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load.""" """For easy load."""
...@@ -339,7 +338,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -339,7 +338,6 @@ class TransformerLanguageModel(MegatronModule):
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule): ...@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule):
init_method) init_method)
self._multichoice_head_key = 'multichoice_head' self._multichoice_head_key = 'multichoice_head'
def forward(self, input_ids, attention_mask, tokentype_ids): def forward(self, input_ids, attention_mask, tokentype_ids):
# [batch, choices, sequence] --> [batch * choices, sequence] --> # [batch, choices, sequence] --> [batch * choices, sequence] -->
...@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule): ...@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule):
return multichoice_logits return multichoice_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
...@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule): ...@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule):
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
This diff is collapsed.
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -33,6 +33,7 @@ def init_method_normal(sigma): ...@@ -33,6 +33,7 @@ def init_method_normal(sigma):
def scaled_init_method_normal(sigma, num_layers): def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers).""" """Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers) std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor): def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std) return torch.nn.init.normal_(tensor, mean=0.0, std=std)
...@@ -53,8 +54,7 @@ def gelu_impl(x): ...@@ -53,8 +54,7 @@ def gelu_impl(x):
"""OpenAI's gelu implementation.""" """OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x))) (1.0 + 0.044715 * x * x)))
def openai_gelu(x):
def gelu(x):
return gelu_impl(x) return gelu_impl(x)
......
# coding=utf-8 # coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,11 +21,9 @@ import torch ...@@ -21,11 +21,9 @@ import torch
class MegatronModule(torch.nn.Module): class MegatronModule(torch.nn.Module):
"""Megatron specific extentions of torch Module.""" """Megatron specific extentions of torch Module."""
def __init__(self): def __init__(self):
super(MegatronModule, self).__init__() super(MegatronModule, self).__init__()
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""Use this function to override the state dict for """Use this function to override the state dict for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment