Commit b7f1b050 authored by Neel Kant's avatar Neel Kant
Browse files

Lint whole repo

parent c99fa80c
......@@ -27,7 +27,8 @@ try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
......@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
......@@ -60,17 +62,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
......@@ -83,6 +87,7 @@ def get_pairs(word):
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
......@@ -138,23 +143,31 @@ class GPT2Tokenizer(object):
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
......@@ -172,8 +185,9 @@ class GPT2Tokenizer(object):
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
self.special_tokens = dict((tok, len(self.encoder) + i)
for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
......@@ -186,7 +200,7 @@ class GPT2Tokenizer(object):
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
......@@ -197,12 +211,12 @@ class GPT2Tokenizer(object):
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
except BaseException:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
......@@ -245,7 +259,8 @@ class GPT2Tokenizer(object):
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
" sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
......
......@@ -99,19 +99,19 @@ class BertTokenizer(object):
[(ids, tok) for tok, ids in self.vocab.items()])
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split)
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
never_split=never_split)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
self.max_len = max_len if max_len is not None else int(1e12)
def tokenize(self, text):
if self.do_basic_tokenize:
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
else:
split_tokens = self.wordpiece_tokenizer.tokenize(text)
split_tokens = self.wordpiece_tokenizer.tokenize(text)
return split_tokens
def convert_tokens_to_ids(self, tokens):
......@@ -123,7 +123,8 @@ class BertTokenizer(object):
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
" sequence through BERT will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
......
This diff is collapsed.
......@@ -102,6 +102,7 @@ class FP16Model(nn.Module):
def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!")
def prep_param_lists(model, flat_master=False):
"""
Creates a list of FP32 master parameters for a given model, as in
......@@ -131,9 +132,9 @@ def prep_param_lists(model, flat_master=False):
# flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
except:
except BaseException:
print("Error in prep_param_lists: model may contain a mixture of parameters "
"of different types. Use flat_master=False, or use F16_Optimizer.")
"of different types. Use flat_master=False, or use F16_Optimizer.")
raise
master_params = torch.nn.Parameter(master_params)
master_params.requires_grad = True
......@@ -150,7 +151,7 @@ def prep_param_lists(model, flat_master=False):
def model_grads_to_master_grads(model_params, master_params, flat_master=False):
"""
Copy model gradients to master gradients.
Copy model gradients to master gradients.
Args:
model_params: List of model parameters created by :func:`prep_param_lists`.
......@@ -179,7 +180,7 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`.
"""
if flat_master:
for model, master in zip(model_params,
for model, master in zip(model_params,
_unflatten_dense_tensors(master_params[0].data, model_params)):
model.data.copy_(master)
else:
......@@ -188,17 +189,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
# Backward compatibility fixes
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
clip_grad_norm = mpu.clip_grad_norm
#elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# clip_grad_norm = torch.nn.utils.clip_grad_norm
#else:
# else:
# clip_grad_norm = torch.nn.utils.clip_grad_norm_
......@@ -17,18 +17,21 @@ import torch
from megatron import mpu
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
class LossScaler:
"""
Class that manages a static loss scale. This class is intended to interact with
:class:`FP16_Optimizer`, and should not be directly manipulated by the user.
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
:class:`FP16_Optimizer`'s constructor.
Args:
......@@ -57,13 +60,14 @@ class LossScaler:
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
class DynamicLossScaler:
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
:class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler`
operates, because the default options can be changed using the
the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.
......@@ -71,18 +75,18 @@ class DynamicLossScaler:
Loss scaling is designed to combat the problem of underflowing gradients encountered at long
times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss
scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
occurred.
:class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
If a certain number of iterations occur without overflowing gradients detected,
:class:`DynamicLossScaler` increases the loss scale once more.
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
always using the highest loss scale possible without incurring overflow.
Args:
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
"""
......@@ -122,12 +126,12 @@ class DynamicLossScaler:
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
......@@ -158,7 +162,7 @@ class DynamicLossScaler:
if overflow:
# self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
self.cur_scale = max(self.cur_scale/self.scale_factor, self.min_scale)
self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
else:
self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter
......@@ -179,10 +183,11 @@ class DynamicLossScaler:
return tuple(self.loss_scale * g for g in grad_in)
def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
##############################################################
##############################################################
# Example usage below here -- assuming it's in a separate file
##############################################################
"""
......@@ -218,10 +223,10 @@ if __name__ == "__main__":
# Run backprop
optimizer.zero_grad()
loss.backward()
# Check for overflow
has_overflow = DynamicLossScaler.has_overflow(parameters)
# If no overflow, unscale grad and update as usual
if not has_overflow:
for param in parameters:
......
......@@ -124,7 +124,7 @@ def _set_adlr_autoresume(args):
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try:
from userlib.auto_resume import AutoResume
except:
except BaseException:
print('ADLR autoresume is not available, exiting ...')
sys.exit()
......
......@@ -48,7 +48,6 @@ class AnnealingLR(object):
print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_lr(self):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
......@@ -71,7 +70,6 @@ class AnnealingLR(object):
lr = self.start_lr
return max(lr, self.min_lr)
def step(self, step_num=None):
"""Set lr for all parameters groups."""
if step_num is None:
......@@ -81,7 +79,6 @@ class AnnealingLR(object):
for group in self.optimizer.param_groups:
group['lr'] = new_lr
def state_dict(self):
state_dict = {
'start_lr': self.start_lr,
......@@ -93,7 +90,6 @@ class AnnealingLR(object):
}
return state_dict
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
......@@ -108,7 +104,6 @@ class AnnealingLR(object):
name))
return sd_value
def load_state_dict(self, sd):
self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'],
......
......@@ -66,7 +66,6 @@ def bert_position_ids(token_ids):
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
......@@ -77,6 +76,7 @@ class BertLMHead(MegatronModule):
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: wether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output):
......@@ -91,7 +91,6 @@ class BertLMHead(MegatronModule):
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states)
......@@ -103,7 +102,6 @@ class BertLMHead(MegatronModule):
return output
class BertModel(MegatronModule):
"""Bert Language model."""
......@@ -136,7 +134,6 @@ class BertModel(MegatronModule):
init_method)
self._binary_head_key = 'binary_head'
def forward(self, input_ids, attention_mask, tokentype_ids=None):
extended_attention_mask = bert_extended_attention_mask(
......@@ -166,7 +163,6 @@ class BertModel(MegatronModule):
return lm_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -184,7 +180,6 @@ class BertModel(MegatronModule):
= self.binary_head.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
......@@ -53,7 +53,6 @@ class Classification(MegatronModule):
init_method)
self._classification_head_key = 'classification_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
extended_attention_mask = bert_extended_attention_mask(
......@@ -74,7 +73,6 @@ class Classification(MegatronModule):
return classification_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -89,7 +87,6 @@ class Classification(MegatronModule):
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
......@@ -71,8 +71,8 @@ class DistributedDataParallel(MegatronModule):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook)
#self.hooks.append(allreduce_hook)
#self.hook_handles.append(handle)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self.allreduce_params = allreduce_params
def forward(self, *inputs, **kwargs):
......@@ -114,4 +114,3 @@ class DistributedDataParallel(MegatronModule):
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
......@@ -28,7 +28,7 @@ from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
10000.0 * (1.0 - ltor_mask)
return attention_scores
......@@ -49,7 +49,6 @@ class GPT2Model(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
......@@ -79,7 +78,6 @@ class GPT2Model(MegatronModule):
return output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -89,7 +87,6 @@ class GPT2Model(MegatronModule):
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
......@@ -62,7 +62,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
......@@ -74,11 +73,11 @@ class Pooler(MegatronModule):
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
......@@ -101,6 +100,7 @@ class Embedding(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
......@@ -142,7 +142,6 @@ class Embedding(MegatronModule):
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
......@@ -159,7 +158,6 @@ class Embedding(MegatronModule):
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
......@@ -176,7 +174,6 @@ class Embedding(MegatronModule):
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
......@@ -194,7 +191,6 @@ class Embedding(MegatronModule):
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......@@ -223,7 +219,7 @@ class Embedding(MegatronModule):
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
......@@ -241,7 +237,6 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True)
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
......@@ -260,6 +255,7 @@ class TransformerLanguageModel(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
attention_mask_func,
mlp_activation_func,
......@@ -295,7 +291,6 @@ class TransformerLanguageModel(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
......@@ -317,7 +312,6 @@ class TransformerLanguageModel(MegatronModule):
return transformer_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
......@@ -336,7 +330,6 @@ class TransformerLanguageModel(MegatronModule):
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
......@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule):
init_method)
self._multichoice_head_key = 'multichoice_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
......@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule):
return multichoice_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule):
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
......@@ -46,6 +46,7 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask)
"""
class ParallelMLP(MegatronModule):
"""MLP.
......@@ -63,7 +64,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size,
4*args.hidden_size,
4 * args.hidden_size,
gather_output=False,
init_method=init_method)
......@@ -71,14 +72,13 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4*args.hidden_size,
4 * args.hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
self.dropout = torch.nn.Dropout(args.hidden_dropout)
def forward(self, hidden_states):
# [b, s, 4hp]
......@@ -91,13 +91,13 @@ class ParallelMLP(MegatronModule):
return output
class ParallelSelfAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__()
......@@ -123,7 +123,7 @@ class ParallelSelfAttention(MegatronModule):
# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size,
3*args.hidden_size,
3 * args.hidden_size,
stride=3,
gather_output=False,
init_method=init_method)
......@@ -141,18 +141,16 @@ class ParallelSelfAttention(MegatronModule):
init_method=output_layer_init_method)
self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
"""
new_tensor_shape = tensor.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def _get_query_key_value(self, hidden_states):
"""Get query, key, and value and transpose to
get size [b, np, s, hn].
......@@ -170,7 +168,6 @@ class ParallelSelfAttention(MegatronModule):
return query_layer, key_layer, value_layer
def _get_unmasked_attention_scores(self, query_layer, key_layer):
"""Unmasked attention scores with size [b, np, s, s]."""
coeff = 1
......@@ -179,9 +176,8 @@ class ParallelSelfAttention(MegatronModule):
norm_factor = math.sqrt(coeff *
math.sqrt(self.hidden_size_per_attention_head))
# Raw attention scores. [b, np, s, s]
return torch.matmul(query_layer/norm_factor,
key_layer.transpose(-1, -2)/norm_factor)
return torch.matmul(query_layer / norm_factor,
key_layer.transpose(-1, -2) / norm_factor)
def _get_attention_probs(self, attention_scores):
"""Attention probabilies with dropout. The output has
......@@ -198,7 +194,6 @@ class ParallelSelfAttention(MegatronModule):
return attention_probs
def _get_attended_context(self, attention_probs, value_layer):
"""Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer.
......@@ -207,13 +202,12 @@ class ParallelSelfAttention(MegatronModule):
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
(self.hidden_size_per_partition,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
def _get_output(self, context_layer):
"""Output layer with dropout."""
# Output. [b, s, h]
......@@ -222,7 +216,6 @@ class ParallelSelfAttention(MegatronModule):
return output
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
......@@ -254,7 +247,7 @@ class ParallelSelfAttention(MegatronModule):
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3)-1,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
......@@ -283,13 +276,13 @@ class ParallelSelfAttention(MegatronModule):
return output
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number):
args = get_args()
......@@ -319,7 +312,6 @@ class ParallelTransformerLayer(MegatronModule):
self.mlp = ParallelMLP(mlp_activation_func, init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
......@@ -375,14 +367,13 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer(i+1) for i in range(args.num_layers)])
[get_layer(i + 1) for i in range(args.num_layers)])
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
......@@ -398,13 +389,12 @@ class ParallelTransformer(MegatronModule):
num_layers = len(self.layers)
while l < num_layers:
hidden_states = mpu.checkpoint(
custom(l, l+self.checkpoint_num_layers),
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask)
l += self.checkpoint_num_layers
return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
......
......@@ -33,6 +33,7 @@ def init_method_normal(sigma):
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
......@@ -54,6 +55,7 @@ def gelu_impl(x):
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x)))
def gelu(x):
return gelu_impl(x)
......
......@@ -21,11 +21,9 @@ import torch
class MegatronModule(torch.nn.Module):
"""Megatron specific extentions of torch Module."""
def __init__(self):
super(MegatronModule, self).__init__()
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Use this function to override the state dict for
......
......@@ -72,7 +72,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
......
......@@ -89,6 +89,7 @@ class VocabParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(self, num_embeddings, embedding_dim,
init_method=init.xavier_normal_):
super(VocabParallelEmbedding, self).__init__()
......@@ -108,7 +109,7 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings, get_model_parallel_rank(),
get_model_parallel_world_size())
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
self.vocab_start_index
# Allocate weights.
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition,
......@@ -147,6 +148,7 @@ class ParallelEmbedding(torch.nn.Module):
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(self, num_embeddings, embedding_dim,
init_method=init.xavier_normal_,
keep_master_weight_for_test=False):
......@@ -205,6 +207,7 @@ class ColumnParallelLinear(torch.nn.Module):
set to False. It returns the master weights
used for initialization.
"""
def __init__(self, input_size, output_size, bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False):
......@@ -279,6 +282,7 @@ class RowParallelLinear(torch.nn.Module):
set to False. It returns the master weights
used for initialization.
"""
def __init__(self, input_size, output_size, bias=True,
input_is_parallel=False,
init_method=init.xavier_normal_, stride=1,
......@@ -327,4 +331,3 @@ class RowParallelLinear(torch.nn.Module):
else:
output = output_
return output
......@@ -131,11 +131,14 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
def copy_to_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
......@@ -73,6 +73,7 @@ class CudaRNGStatesTracker:
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment