Commit b7f1b050 authored by Neel Kant's avatar Neel Kant
Browse files

Lint whole repo

parent c99fa80c
...@@ -27,7 +27,8 @@ try: ...@@ -27,7 +27,8 @@ try:
from functools import lru_cache from functools import lru_cache
except ImportError: except ImportError:
# Just a dummy decorator to get the checks to run on python2 # Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. # because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache(): def lru_cache():
return lambda func: func return lambda func: func
...@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json' ...@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt' MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt' SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache() @lru_cache()
def bytes_to_unicode(): def bytes_to_unicode():
""" """
...@@ -60,17 +62,19 @@ def bytes_to_unicode(): ...@@ -60,17 +62,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on. And avoids mapping to whitespace/control characters the bpe code barfs on.
""" """
_chr = unichr if sys.version_info[0] == 2 else chr _chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:] cs = bs[:]
n = 0 n = 0
for b in range(2**8): for b in range(2**8):
if b not in bs: if b not in bs:
bs.append(b) bs.append(b)
cs.append(2**8+n) cs.append(2**8 + n)
n += 1 n += 1
cs = [_chr(n) for n in cs] cs = [_chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
def get_pairs(word): def get_pairs(word):
"""Return set of symbol pairs in a word. """Return set of symbol pairs in a word.
...@@ -83,6 +87,7 @@ def get_pairs(word): ...@@ -83,6 +87,7 @@ def get_pairs(word):
prev_char = char prev_char = char
return pairs return pairs
class GPT2Tokenizer(object): class GPT2Tokenizer(object):
""" """
GPT-2 BPE tokenizer. Peculiarities: GPT-2 BPE tokenizer. Peculiarities:
...@@ -138,23 +143,31 @@ class GPT2Tokenizer(object): ...@@ -138,23 +143,31 @@ class GPT2Tokenizer(object):
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1] special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else: else:
special_tokens = kwargs.pop('special_tokens', []) special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs) tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None): def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode() self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data] bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {} self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions # Should haved added re.IGNORECASE so BPE merges can happen for
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") # capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {} self.special_tokens = {}
self.special_tokens_decoder = {} self.special_tokens_decoder = {}
...@@ -172,8 +185,9 @@ class GPT2Tokenizer(object): ...@@ -172,8 +185,9 @@ class GPT2Tokenizer(object):
self.special_tokens = {} self.special_tokens = {}
self.special_tokens_decoder = {} self.special_tokens_decoder = {}
return return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) self.special_tokens = dict((tok, len(self.encoder) + i)
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens)) logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token): def bpe(self, token):
...@@ -186,7 +200,7 @@ class GPT2Tokenizer(object): ...@@ -186,7 +200,7 @@ class GPT2Tokenizer(object):
return token return token
while True: while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks: if bigram not in self.bpe_ranks:
break break
first, second = bigram first, second = bigram
...@@ -197,12 +211,12 @@ class GPT2Tokenizer(object): ...@@ -197,12 +211,12 @@ class GPT2Tokenizer(object):
j = word.index(first, i) j = word.index(first, i)
new_word.extend(word[i:j]) new_word.extend(word[i:j])
i = j i = j
except: except BaseException:
new_word.extend(word[i:]) new_word.extend(word[i:])
break break
if word[i] == first and i < len(word)-1 and word[i+1] == second: if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first+second) new_word.append(first + second)
i += 2 i += 2
else: else:
new_word.append(word[i]) new_word.append(word[i])
...@@ -245,7 +259,8 @@ class GPT2Tokenizer(object): ...@@ -245,7 +259,8 @@ class GPT2Tokenizer(object):
logger.warning( logger.warning(
"Token indices sequence length is longer than the specified maximum " "Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this" " sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len) " sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
) )
return ids return ids
......
...@@ -123,7 +123,8 @@ class BertTokenizer(object): ...@@ -123,7 +123,8 @@ class BertTokenizer(object):
logger.warning( logger.warning(
"Token indices sequence length is longer than the specified maximum " "Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this" " sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len) " sequence through BERT will result in indexing errors".format(
len(ids), self.max_len)
) )
return ids return ids
......
...@@ -28,6 +28,7 @@ from megatron.module import MegatronModule ...@@ -28,6 +28,7 @@ from megatron.module import MegatronModule
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def conversion_helper(val, conversion): def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)): if not isinstance(val, (tuple, list)):
...@@ -37,6 +38,7 @@ def conversion_helper(val, conversion): ...@@ -37,6 +38,7 @@ def conversion_helper(val, conversion):
rtn = tuple(rtn) rtn = tuple(rtn)
return rtn return rtn
def fp32_to_fp16(val): def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16""" """Convert fp32 `val` to fp16"""
def half_conversion(val): def half_conversion(val):
...@@ -48,6 +50,7 @@ def fp32_to_fp16(val): ...@@ -48,6 +50,7 @@ def fp32_to_fp16(val):
return val return val
return conversion_helper(val, half_conversion) return conversion_helper(val, half_conversion)
def fp16_to_fp32(val): def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32""" """Convert fp16 `val` to fp32"""
def float_conversion(val): def float_conversion(val):
...@@ -59,6 +62,7 @@ def fp16_to_fp32(val): ...@@ -59,6 +62,7 @@ def fp16_to_fp32(val):
return val return val
return conversion_helper(val, float_conversion) return conversion_helper(val, float_conversion)
class FP16_Module(MegatronModule): class FP16_Module(MegatronModule):
def __init__(self, module): def __init__(self, module):
super(FP16_Module, self).__init__() super(FP16_Module, self).__init__()
...@@ -79,6 +83,8 @@ class FP16_Module(MegatronModule): ...@@ -79,6 +83,8 @@ class FP16_Module(MegatronModule):
self.module.load_state_dict(state_dict, strict=strict) self.module.load_state_dict(state_dict, strict=strict)
# TODO: Update overflow check + downscale to use Carl's fused kernel. # TODO: Update overflow check + downscale to use Carl's fused kernel.
class FP16_Optimizer(object): class FP16_Optimizer(object):
""" """
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
...@@ -305,7 +311,8 @@ class FP16_Optimizer(object): ...@@ -305,7 +311,8 @@ class FP16_Optimizer(object):
master_params_to_model_params(fp32_from_fp16_group, fp16_group) master_params_to_model_params(fp32_from_fp16_group, fp16_group)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable # To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. # that does the overflow check, gradient copy + downscale, and fp32
# allreduce in a different stream.
def _model_grads_to_master_grads(self): def _model_grads_to_master_grads(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
...@@ -315,7 +322,7 @@ class FP16_Optimizer(object): ...@@ -315,7 +322,7 @@ class FP16_Optimizer(object):
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
for param in group['params']: for param in group['params']:
if param.grad is not None: if param.grad is not None:
param.grad.data.mul_(1./self.loss_scale) param.grad.data.mul_(1. / self.loss_scale)
def clip_master_grads(self, max_norm, norm_type=2): def clip_master_grads(self, max_norm, norm_type=2):
""" """
...@@ -400,7 +407,8 @@ class FP16_Optimizer(object): ...@@ -400,7 +407,8 @@ class FP16_Optimizer(object):
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params # constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params. # are guaranteed to exist, so we can just copy_() from the saved master params.
for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']): for current_group, saved_group in zip(
self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
for current, saved in zip(current_group, saved_group): for current, saved in zip(current_group, saved_group):
current.data.copy_(saved.data) current.data.copy_(saved.data)
...@@ -570,7 +578,8 @@ class FP16_Optimizer(object): ...@@ -570,7 +578,8 @@ class FP16_Optimizer(object):
""" """
if self.dynamic_loss_scale: if self.dynamic_loss_scale:
self._check_overflow() self._check_overflow()
if self.overflow: return if self.overflow:
return
self._model_grads_to_master_grads() self._model_grads_to_master_grads()
self._downscale_master() self._downscale_master()
...@@ -607,8 +616,8 @@ class FP16_Optimizer(object): ...@@ -607,8 +616,8 @@ class FP16_Optimizer(object):
master_grads_data.append(master_grads_this_group) master_grads_data.append(master_grads_this_group)
return master_grads_data return master_grads_data
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self): def _get_loss_scale(self):
return self.loss_scaler.loss_scale return self.loss_scaler.loss_scale
......
...@@ -102,6 +102,7 @@ class FP16Model(nn.Module): ...@@ -102,6 +102,7 @@ class FP16Model(nn.Module):
def backwards_debug_hook(grad): def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!") raise RuntimeError("master_params recieved a gradient in the backward pass!")
def prep_param_lists(model, flat_master=False): def prep_param_lists(model, flat_master=False):
""" """
Creates a list of FP32 master parameters for a given model, as in Creates a list of FP32 master parameters for a given model, as in
...@@ -131,7 +132,7 @@ def prep_param_lists(model, flat_master=False): ...@@ -131,7 +132,7 @@ def prep_param_lists(model, flat_master=False):
# flatten_dense_tensors returns a contiguous flat array. # flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html # http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params = _flatten_dense_tensors([param.data for param in model_params]).float() master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
except: except BaseException:
print("Error in prep_param_lists: model may contain a mixture of parameters " print("Error in prep_param_lists: model may contain a mixture of parameters "
"of different types. Use flat_master=False, or use F16_Optimizer.") "of different types. Use flat_master=False, or use F16_Optimizer.")
raise raise
...@@ -188,17 +189,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False ...@@ -188,17 +189,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
# Backward compatibility fixes # Backward compatibility fixes
def to_python_float(t): def to_python_float(t):
if hasattr(t, 'item'): if hasattr(t, 'item'):
return t.item() return t.item()
else: else:
return t[0] return t[0]
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
clip_grad_norm = mpu.clip_grad_norm clip_grad_norm = mpu.clip_grad_norm
#elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4: # elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# clip_grad_norm = torch.nn.utils.clip_grad_norm # clip_grad_norm = torch.nn.utils.clip_grad_norm
#else: # else:
# clip_grad_norm = torch.nn.utils.clip_grad_norm_ # clip_grad_norm = torch.nn.utils.clip_grad_norm_
...@@ -17,12 +17,15 @@ import torch ...@@ -17,12 +17,15 @@ import torch
from megatron import mpu from megatron import mpu
# item() is a recent addition, so this helps with backward compatibility. # item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t): def to_python_float(t):
if hasattr(t, 'item'): if hasattr(t, 'item'):
return t.item() return t.item()
else: else:
return t[0] return t[0]
class LossScaler: class LossScaler:
""" """
Class that manages a static loss scale. This class is intended to interact with Class that manages a static loss scale. This class is intended to interact with
...@@ -57,9 +60,10 @@ class LossScaler: ...@@ -57,9 +60,10 @@ class LossScaler:
return tuple(self.loss_scale * g for g in grad_in) return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss, retain_graph=False): def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph) scaled_loss.backward(retain_graph=retain_graph)
class DynamicLossScaler: class DynamicLossScaler:
""" """
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
...@@ -122,8 +126,8 @@ class DynamicLossScaler: ...@@ -122,8 +126,8 @@ class DynamicLossScaler:
overflow = overflow_gpu[0].item() overflow = overflow_gpu[0].item()
return bool(overflow) return bool(overflow)
# `x` is a torch.Tensor # `x` is a torch.Tensor
def _has_inf_or_nan(x): def _has_inf_or_nan(x):
try: try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if # if x is half, the .float() incurs an additional deep copy, but it's necessary if
...@@ -158,7 +162,7 @@ class DynamicLossScaler: ...@@ -158,7 +162,7 @@ class DynamicLossScaler:
if overflow: if overflow:
# self.cur_scale /= self.scale_factor # self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1: if self.delayed_shift == 1 or self.cur_hysteresis == 1:
self.cur_scale = max(self.cur_scale/self.scale_factor, self.min_scale) self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
else: else:
self.cur_hysteresis -= 1 self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter self.last_overflow_iter = self.cur_iter
...@@ -179,9 +183,10 @@ class DynamicLossScaler: ...@@ -179,9 +183,10 @@ class DynamicLossScaler:
return tuple(self.loss_scale * g for g in grad_in) return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss, retain_graph=False): def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph) scaled_loss.backward(retain_graph=retain_graph)
############################################################## ##############################################################
# Example usage below here -- assuming it's in a separate file # Example usage below here -- assuming it's in a separate file
############################################################## ##############################################################
......
...@@ -124,7 +124,7 @@ def _set_adlr_autoresume(args): ...@@ -124,7 +124,7 @@ def _set_adlr_autoresume(args):
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.')) sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try: try:
from userlib.auto_resume import AutoResume from userlib.auto_resume import AutoResume
except: except BaseException:
print('ADLR autoresume is not available, exiting ...') print('ADLR autoresume is not available, exiting ...')
sys.exit() sys.exit()
......
...@@ -48,7 +48,6 @@ class AnnealingLR(object): ...@@ -48,7 +48,6 @@ class AnnealingLR(object):
print_rank_0('> learning rate decay style: {}'.format(self.decay_style)) print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_lr(self): def get_lr(self):
"""Learning rate decay functions from: """Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
...@@ -71,7 +70,6 @@ class AnnealingLR(object): ...@@ -71,7 +70,6 @@ class AnnealingLR(object):
lr = self.start_lr lr = self.start_lr
return max(lr, self.min_lr) return max(lr, self.min_lr)
def step(self, step_num=None): def step(self, step_num=None):
"""Set lr for all parameters groups.""" """Set lr for all parameters groups."""
if step_num is None: if step_num is None:
...@@ -81,7 +79,6 @@ class AnnealingLR(object): ...@@ -81,7 +79,6 @@ class AnnealingLR(object):
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr
def state_dict(self): def state_dict(self):
state_dict = { state_dict = {
'start_lr': self.start_lr, 'start_lr': self.start_lr,
...@@ -93,7 +90,6 @@ class AnnealingLR(object): ...@@ -93,7 +90,6 @@ class AnnealingLR(object):
} }
return state_dict return state_dict
def _check_and_set(self, cls_value, sd_value, name): def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and """Auxiliary function for checking the values in the checkpoint and
setting them.""" setting them."""
...@@ -108,7 +104,6 @@ class AnnealingLR(object): ...@@ -108,7 +104,6 @@ class AnnealingLR(object):
name)) name))
return sd_value return sd_value
def load_state_dict(self, sd): def load_state_dict(self, sd):
self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'], self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'],
......
...@@ -66,7 +66,6 @@ def bert_position_ids(token_ids): ...@@ -66,7 +66,6 @@ def bert_position_ids(token_ids):
return position_ids return position_ids
class BertLMHead(MegatronModule): class BertLMHead(MegatronModule):
"""Masked LM head for Bert """Masked LM head for Bert
...@@ -77,6 +76,7 @@ class BertLMHead(MegatronModule): ...@@ -77,6 +76,7 @@ class BertLMHead(MegatronModule):
layernorm_epsilon: tolerance for layer norm divisions layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not. parallel_output: wether output logits being distributed or not.
""" """
def __init__(self, mpu_vocab_size, hidden_size, init_method, def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output): layernorm_epsilon, parallel_output):
...@@ -91,7 +91,6 @@ class BertLMHead(MegatronModule): ...@@ -91,7 +91,6 @@ class BertLMHead(MegatronModule):
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def forward(self, hidden_states, word_embeddings_weight): def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states) hidden_states = gelu(hidden_states)
...@@ -103,7 +102,6 @@ class BertLMHead(MegatronModule): ...@@ -103,7 +102,6 @@ class BertLMHead(MegatronModule):
return output return output
class BertModel(MegatronModule): class BertModel(MegatronModule):
"""Bert Language model.""" """Bert Language model."""
...@@ -136,7 +134,6 @@ class BertModel(MegatronModule): ...@@ -136,7 +134,6 @@ class BertModel(MegatronModule):
init_method) init_method)
self._binary_head_key = 'binary_head' self._binary_head_key = 'binary_head'
def forward(self, input_ids, attention_mask, tokentype_ids=None): def forward(self, input_ids, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask( extended_attention_mask = bert_extended_attention_mask(
...@@ -166,7 +163,6 @@ class BertModel(MegatronModule): ...@@ -166,7 +163,6 @@ class BertModel(MegatronModule):
return lm_logits, None return lm_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
...@@ -184,7 +180,6 @@ class BertModel(MegatronModule): ...@@ -184,7 +180,6 @@ class BertModel(MegatronModule):
= self.binary_head.state_dict(destination, prefix, keep_vars) = self.binary_head.state_dict(destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
...@@ -53,7 +53,6 @@ class Classification(MegatronModule): ...@@ -53,7 +53,6 @@ class Classification(MegatronModule):
init_method) init_method)
self._classification_head_key = 'classification_head' self._classification_head_key = 'classification_head'
def forward(self, input_ids, attention_mask, tokentype_ids): def forward(self, input_ids, attention_mask, tokentype_ids):
extended_attention_mask = bert_extended_attention_mask( extended_attention_mask = bert_extended_attention_mask(
...@@ -74,7 +73,6 @@ class Classification(MegatronModule): ...@@ -74,7 +73,6 @@ class Classification(MegatronModule):
return classification_logits return classification_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
...@@ -89,7 +87,6 @@ class Classification(MegatronModule): ...@@ -89,7 +87,6 @@ class Classification(MegatronModule):
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
...@@ -71,8 +71,8 @@ class DistributedDataParallel(MegatronModule): ...@@ -71,8 +71,8 @@ class DistributedDataParallel(MegatronModule):
def allreduce_hook(*unused): def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params) Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook) # handle = param.register_hook(allreduce_hook)
#self.hooks.append(allreduce_hook) # self.hooks.append(allreduce_hook)
#self.hook_handles.append(handle) # self.hook_handles.append(handle)
self.allreduce_params = allreduce_params self.allreduce_params = allreduce_params
def forward(self, *inputs, **kwargs): def forward(self, *inputs, **kwargs):
...@@ -114,4 +114,3 @@ class DistributedDataParallel(MegatronModule): ...@@ -114,4 +114,3 @@ class DistributedDataParallel(MegatronModule):
super(DistributedDataParallel, self).train(mode) super(DistributedDataParallel, self).train(mode)
self.module.train(mode) self.module.train(mode)
''' '''
...@@ -49,7 +49,6 @@ class GPT2Model(MegatronModule): ...@@ -49,7 +49,6 @@ class GPT2Model(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
...@@ -79,7 +78,6 @@ class GPT2Model(MegatronModule): ...@@ -79,7 +78,6 @@ class GPT2Model(MegatronModule):
return output return output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -89,7 +87,6 @@ class GPT2Model(MegatronModule): ...@@ -89,7 +87,6 @@ class GPT2Model(MegatronModule):
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
...@@ -62,7 +62,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -62,7 +62,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
return language_model, language_model_key return language_model, language_model_key
class Pooler(MegatronModule): class Pooler(MegatronModule):
"""Pooler layer. """Pooler layer.
...@@ -74,11 +73,11 @@ class Pooler(MegatronModule): ...@@ -74,11 +73,11 @@ class Pooler(MegatronModule):
init_method: weight initialization method for the linear layer. init_method: weight initialization method for the linear layer.
bias is set to zero. bias is set to zero.
""" """
def __init__(self, hidden_size, init_method): def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__() super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0): def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# sequence_index: index of the token to pool. # sequence_index: index of the token to pool.
...@@ -101,6 +100,7 @@ class Embedding(MegatronModule): ...@@ -101,6 +100,7 @@ class Embedding(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding will ignore this embedding
""" """
def __init__(self, def __init__(self,
hidden_size, hidden_size,
vocab_size, vocab_size,
...@@ -142,7 +142,6 @@ class Embedding(MegatronModule): ...@@ -142,7 +142,6 @@ class Embedding(MegatronModule):
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes): def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add """Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it. token-type embeddings in case the pretrained model does not have it.
...@@ -159,7 +158,6 @@ class Embedding(MegatronModule): ...@@ -159,7 +158,6 @@ class Embedding(MegatronModule):
# Initialize the token-type embeddings. # Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight) self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None): def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings. # Embeddings.
words_embeddings = self.word_embeddings(input_ids) words_embeddings = self.word_embeddings(input_ids)
...@@ -176,7 +174,6 @@ class Embedding(MegatronModule): ...@@ -176,7 +174,6 @@ class Embedding(MegatronModule):
return embeddings return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load.""" """For easy load."""
...@@ -194,7 +191,6 @@ class Embedding(MegatronModule): ...@@ -194,7 +191,6 @@ class Embedding(MegatronModule):
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
...@@ -241,7 +237,6 @@ class Embedding(MegatronModule): ...@@ -241,7 +237,6 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True) 'checkpoint but could not find it', flush=True)
class TransformerLanguageModel(MegatronModule): class TransformerLanguageModel(MegatronModule):
"""Transformer language model. """Transformer language model.
...@@ -260,6 +255,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -260,6 +255,7 @@ class TransformerLanguageModel(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding will ignore this embedding
""" """
def __init__(self, def __init__(self,
attention_mask_func, attention_mask_func,
mlp_activation_func, mlp_activation_func,
...@@ -295,7 +291,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -295,7 +291,6 @@ class TransformerLanguageModel(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler' self._pooler_key = 'pooler'
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0): pooling_sequence_index=0):
...@@ -317,7 +312,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -317,7 +312,6 @@ class TransformerLanguageModel(MegatronModule):
return transformer_output return transformer_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load.""" """For easy load."""
...@@ -336,7 +330,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -336,7 +330,6 @@ class TransformerLanguageModel(MegatronModule):
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
...@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule): ...@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule):
init_method) init_method)
self._multichoice_head_key = 'multichoice_head' self._multichoice_head_key = 'multichoice_head'
def forward(self, input_ids, attention_mask, tokentype_ids): def forward(self, input_ids, attention_mask, tokentype_ids):
# [batch, choices, sequence] --> [batch * choices, sequence] --> # [batch, choices, sequence] --> [batch * choices, sequence] -->
...@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule): ...@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule):
return multichoice_logits return multichoice_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""For easy load when model is combined with other heads, """For easy load when model is combined with other heads,
...@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule): ...@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule):
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Customized load.""" """Customized load."""
......
...@@ -46,6 +46,7 @@ from megatron.module import MegatronModule ...@@ -46,6 +46,7 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask) unmaksed-attention-scores, attention-mask)
""" """
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
"""MLP. """MLP.
...@@ -63,7 +64,7 @@ class ParallelMLP(MegatronModule): ...@@ -63,7 +64,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear( self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
4*args.hidden_size, 4 * args.hidden_size,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method)
...@@ -71,14 +72,13 @@ class ParallelMLP(MegatronModule): ...@@ -71,14 +72,13 @@ class ParallelMLP(MegatronModule):
# Project back to h. # Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear( self.dense_4h_to_h = mpu.RowParallelLinear(
4*args.hidden_size, 4 * args.hidden_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method) init_method=output_layer_init_method)
self.dropout = torch.nn.Dropout(args.hidden_dropout) self.dropout = torch.nn.Dropout(args.hidden_dropout)
def forward(self, hidden_states): def forward(self, hidden_states):
# [b, s, 4hp] # [b, s, 4hp]
...@@ -91,13 +91,13 @@ class ParallelMLP(MegatronModule): ...@@ -91,13 +91,13 @@ class ParallelMLP(MegatronModule):
return output return output
class ParallelSelfAttention(MegatronModule): class ParallelSelfAttention(MegatronModule):
"""Parallel self-attention layer abstract class. """Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h] Self-attention layer takes input with size [b, s, h]
and returns output of the same size. and returns output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number): output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__() super(ParallelSelfAttention, self).__init__()
...@@ -123,7 +123,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -123,7 +123,7 @@ class ParallelSelfAttention(MegatronModule):
# Strided linear layer. # Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear( self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
3*args.hidden_size, 3 * args.hidden_size,
stride=3, stride=3,
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method)
...@@ -141,7 +141,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -141,7 +141,6 @@ class ParallelSelfAttention(MegatronModule):
init_method=output_layer_init_method) init_method=output_layer_init_method)
self.output_dropout = torch.nn.Dropout(args.hidden_dropout) self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
def _transpose_for_scores(self, tensor): def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with """Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn]. size [b, np, s, hn].
...@@ -152,7 +151,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -152,7 +151,6 @@ class ParallelSelfAttention(MegatronModule):
tensor = tensor.view(*new_tensor_shape) tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3) return tensor.permute(0, 2, 1, 3)
def _get_query_key_value(self, hidden_states): def _get_query_key_value(self, hidden_states):
"""Get query, key, and value and transpose to """Get query, key, and value and transpose to
get size [b, np, s, hn]. get size [b, np, s, hn].
...@@ -170,7 +168,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -170,7 +168,6 @@ class ParallelSelfAttention(MegatronModule):
return query_layer, key_layer, value_layer return query_layer, key_layer, value_layer
def _get_unmasked_attention_scores(self, query_layer, key_layer): def _get_unmasked_attention_scores(self, query_layer, key_layer):
"""Unmasked attention scores with size [b, np, s, s].""" """Unmasked attention scores with size [b, np, s, s]."""
coeff = 1 coeff = 1
...@@ -179,9 +176,8 @@ class ParallelSelfAttention(MegatronModule): ...@@ -179,9 +176,8 @@ class ParallelSelfAttention(MegatronModule):
norm_factor = math.sqrt(coeff * norm_factor = math.sqrt(coeff *
math.sqrt(self.hidden_size_per_attention_head)) math.sqrt(self.hidden_size_per_attention_head))
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
return torch.matmul(query_layer/norm_factor, return torch.matmul(query_layer / norm_factor,
key_layer.transpose(-1, -2)/norm_factor) key_layer.transpose(-1, -2) / norm_factor)
def _get_attention_probs(self, attention_scores): def _get_attention_probs(self, attention_scores):
"""Attention probabilies with dropout. The output has """Attention probabilies with dropout. The output has
...@@ -198,7 +194,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -198,7 +194,6 @@ class ParallelSelfAttention(MegatronModule):
return attention_probs return attention_probs
def _get_attended_context(self, attention_probs, value_layer): def _get_attended_context(self, attention_probs, value_layer):
"""Final attended tesnor and transposed back to [b, s, hp].""" """Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer. # Context layer.
...@@ -213,7 +208,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -213,7 +208,6 @@ class ParallelSelfAttention(MegatronModule):
return context_layer return context_layer
def _get_output(self, context_layer): def _get_output(self, context_layer):
"""Output layer with dropout.""" """Output layer with dropout."""
# Output. [b, s, h] # Output. [b, s, h]
...@@ -222,7 +216,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -222,7 +216,6 @@ class ParallelSelfAttention(MegatronModule):
return output return output
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
...@@ -254,7 +247,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -254,7 +247,7 @@ class ParallelSelfAttention(MegatronModule):
if layer_past is not None: if layer_past is not None:
attention_mask = attention_mask[ attention_mask = attention_mask[
..., ...,
attention_scores.size(3)-1, attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2) :attention_scores.size(3)].unsqueeze(2)
else: else:
attention_mask = attention_mask[ attention_mask = attention_mask[
...@@ -283,13 +276,13 @@ class ParallelSelfAttention(MegatronModule): ...@@ -283,13 +276,13 @@ class ParallelSelfAttention(MegatronModule):
return output return output
class ParallelTransformerLayer(MegatronModule): class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer. """A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an Transformore layer takes input with size [b, s, h] and returns an
output of the same size. output of the same size.
""" """
def __init__(self, attention_mask_func, mlp_activation_func, def __init__(self, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number): init_method, output_layer_init_method, layer_number):
args = get_args() args = get_args()
...@@ -319,7 +312,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -319,7 +312,6 @@ class ParallelTransformerLayer(MegatronModule):
self.mlp = ParallelMLP(mlp_activation_func, init_method, self.mlp = ParallelMLP(mlp_activation_func, init_method,
output_layer_init_method) output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
...@@ -375,14 +367,13 @@ class ParallelTransformer(MegatronModule): ...@@ -375,14 +367,13 @@ class ParallelTransformer(MegatronModule):
# Transformer layers. # Transformer layers.
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[get_layer(i+1) for i in range(args.num_layers)]) [get_layer(i + 1) for i in range(args.num_layers)])
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
def _checkpointed_forward(self, hidden_states, attention_mask): def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
def custom(start, end): def custom(start, end):
...@@ -398,13 +389,12 @@ class ParallelTransformer(MegatronModule): ...@@ -398,13 +389,12 @@ class ParallelTransformer(MegatronModule):
num_layers = len(self.layers) num_layers = len(self.layers)
while l < num_layers: while l < num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l+self.checkpoint_num_layers), custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask) hidden_states, attention_mask)
l += self.checkpoint_num_layers l += self.checkpoint_num_layers
return hidden_states return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False):
......
...@@ -33,6 +33,7 @@ def init_method_normal(sigma): ...@@ -33,6 +33,7 @@ def init_method_normal(sigma):
def scaled_init_method_normal(sigma, num_layers): def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers).""" """Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers) std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor): def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std) return torch.nn.init.normal_(tensor, mean=0.0, std=std)
...@@ -54,6 +55,7 @@ def gelu_impl(x): ...@@ -54,6 +55,7 @@ def gelu_impl(x):
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x))) (1.0 + 0.044715 * x * x)))
def gelu(x): def gelu(x):
return gelu_impl(x) return gelu_impl(x)
......
...@@ -21,11 +21,9 @@ import torch ...@@ -21,11 +21,9 @@ import torch
class MegatronModule(torch.nn.Module): class MegatronModule(torch.nn.Module):
"""Megatron specific extentions of torch Module.""" """Megatron specific extentions of torch Module."""
def __init__(self): def __init__(self):
super(MegatronModule, self).__init__() super(MegatronModule, self).__init__()
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
"""Use this function to override the state dict for """Use this function to override the state dict for
......
...@@ -72,7 +72,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function): ...@@ -72,7 +72,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group()) group=get_model_parallel_group())
# Loss = log(sum(exp(logits))) - predicted-logit. # Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits loss = torch.log(sum_exp_logits) - predicted_logits
......
...@@ -89,6 +89,7 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -89,6 +89,7 @@ class VocabParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
init_method: method to initialize weights. init_method: method to initialize weights.
""" """
def __init__(self, num_embeddings, embedding_dim, def __init__(self, num_embeddings, embedding_dim,
init_method=init.xavier_normal_): init_method=init.xavier_normal_):
super(VocabParallelEmbedding, self).__init__() super(VocabParallelEmbedding, self).__init__()
...@@ -147,6 +148,7 @@ class ParallelEmbedding(torch.nn.Module): ...@@ -147,6 +148,7 @@ class ParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state. embedding_dim: size of hidden state.
init_method: method to initialize weights. init_method: method to initialize weights.
""" """
def __init__(self, num_embeddings, embedding_dim, def __init__(self, num_embeddings, embedding_dim,
init_method=init.xavier_normal_, init_method=init.xavier_normal_,
keep_master_weight_for_test=False): keep_master_weight_for_test=False):
...@@ -205,6 +207,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -205,6 +207,7 @@ class ColumnParallelLinear(torch.nn.Module):
set to False. It returns the master weights set to False. It returns the master weights
used for initialization. used for initialization.
""" """
def __init__(self, input_size, output_size, bias=True, gather_output=True, def __init__(self, input_size, output_size, bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False): keep_master_weight_for_test=False):
...@@ -279,6 +282,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -279,6 +282,7 @@ class RowParallelLinear(torch.nn.Module):
set to False. It returns the master weights set to False. It returns the master weights
used for initialization. used for initialization.
""" """
def __init__(self, input_size, output_size, bias=True, def __init__(self, input_size, output_size, bias=True,
input_is_parallel=False, input_is_parallel=False,
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
...@@ -327,4 +331,3 @@ class RowParallelLinear(torch.nn.Module): ...@@ -327,4 +331,3 @@ class RowParallelLinear(torch.nn.Module):
else: else:
output = output_ output = output_
return output return output
...@@ -131,11 +131,14 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): ...@@ -131,11 +131,14 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
def copy_to_model_parallel_region(input_): def copy_to_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_) return _CopyToModelParallelRegion.apply(input_)
def reduce_from_model_parallel_region(input_): def reduce_from_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_) return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_model_parallel_region(input_): def scatter_to_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_) return _ScatterToModelParallelRegion.apply(input_)
def gather_from_model_parallel_region(input_): def gather_from_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_) return _GatherFromModelParallelRegion.apply(input_)
...@@ -73,6 +73,7 @@ class CudaRNGStatesTracker: ...@@ -73,6 +73,7 @@ class CudaRNGStatesTracker:
rng state, we can perform operations and return to our starting rng state, we can perform operations and return to our starting
cuda state. cuda state.
""" """
def __init__(self): def __init__(self):
# Map from a string name to the cuda rng state. # Map from a string name to the cuda rng state.
self.states_ = {} self.states_ = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment