Commit 5e56e563 authored by Neel Kant's avatar Neel Kant
Browse files

Merge master into realm-mlm

parents 6c0a5bd8 569b3dab
......@@ -27,7 +27,8 @@ try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
# because honestly I don't want to support a byte-level unicode BPE
# tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
......@@ -48,6 +49,7 @@ VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
"""
......@@ -60,17 +62,19 @@ def bytes_to_unicode():
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
_chr = unichr if sys.version_info[0] == 2 else chr
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + \
list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
cs.append(2**8 + n)
n += 1
cs = [_chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
......@@ -83,6 +87,7 @@ def get_pairs(word):
prev_char = char
return pairs
class GPT2Tokenizer(object):
"""
GPT-2 BPE tokenizer. Peculiarities:
......@@ -138,23 +143,31 @@ class GPT2Tokenizer(object):
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
tokenizer = cls(
resolved_vocab_file,
resolved_merges_file,
special_tokens=special_tokens,
*inputs,
**kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', special_tokens=None, max_len=None):
def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
self.special_tokens = {}
self.special_tokens_decoder = {}
......@@ -172,8 +185,9 @@ class GPT2Tokenizer(object):
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
self.special_tokens = dict((tok, len(self.encoder) + i)
for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
......@@ -186,7 +200,7 @@ class GPT2Tokenizer(object):
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
......@@ -197,12 +211,12 @@ class GPT2Tokenizer(object):
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
except BaseException:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
......@@ -245,7 +259,8 @@ class GPT2Tokenizer(object):
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT model ({} > {}). Running this"
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
" sequence through the model will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
......
......@@ -123,7 +123,8 @@ class BertTokenizer(object):
logger.warning(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
" sequence through BERT will result in indexing errors".format(
len(ids), self.max_len)
)
return ids
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -22,12 +22,15 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron.module import MegatronModule
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
......@@ -37,6 +40,7 @@ def conversion_helper(val, conversion):
rtn = tuple(rtn)
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def half_conversion(val):
......@@ -48,6 +52,7 @@ def fp32_to_fp16(val):
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float_conversion(val):
......@@ -59,6 +64,7 @@ def fp16_to_fp32(val):
return val
return conversion_helper(val, float_conversion)
class FP16_Module(MegatronModule):
def __init__(self, module):
super(FP16_Module, self).__init__()
......@@ -79,6 +85,8 @@ class FP16_Module(MegatronModule):
self.module.load_state_dict(state_dict, strict=strict)
# TODO: Update overflow check + downscale to use Carl's fused kernel.
class FP16_Optimizer(object):
"""
:class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
......@@ -305,7 +313,8 @@ class FP16_Optimizer(object):
master_params_to_model_params(fp32_from_fp16_group, fp16_group)
# To consider: Integrate distributed with this wrapper by registering a hook on each variable
# that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream.
# that does the overflow check, gradient copy + downscale, and fp32
# allreduce in a different stream.
def _model_grads_to_master_grads(self):
for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)
......@@ -313,9 +322,12 @@ class FP16_Optimizer(object):
def _downscale_master(self):
if self.loss_scale != 1.0:
for group in self.optimizer.param_groups:
for param in group['params']:
if param.grad is not None:
param.grad.data.mul_(1./self.loss_scale)
grads = [p.grad for p in group['params'] if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grads, grads],
1./self.loss_scale)
def clip_master_grads(self, max_norm, norm_type=2):
"""
......@@ -400,7 +412,8 @@ class FP16_Optimizer(object):
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
# constructed in the same way as the one whose state_dict we are loading, the same master params
# are guaranteed to exist, so we can just copy_() from the saved master params.
for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
for current_group, saved_group in zip(
self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
for current, saved in zip(current_group, saved_group):
current.data.copy_(saved.data)
......@@ -570,7 +583,8 @@ class FP16_Optimizer(object):
"""
if self.dynamic_loss_scale:
self._check_overflow()
if self.overflow: return
if self.overflow:
return
self._model_grads_to_master_grads()
self._downscale_master()
......@@ -607,8 +621,8 @@ class FP16_Optimizer(object):
master_grads_data.append(master_grads_this_group)
return master_grads_data
# Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
def _get_loss_scale(self):
return self.loss_scaler.loss_scale
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,6 +18,9 @@ import torch.nn as nn
from torch.autograd import Variable
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu
......@@ -102,6 +105,7 @@ class FP16Model(nn.Module):
def backwards_debug_hook(grad):
raise RuntimeError("master_params recieved a gradient in the backward pass!")
def prep_param_lists(model, flat_master=False):
"""
Creates a list of FP32 master parameters for a given model, as in
......@@ -131,7 +135,7 @@ def prep_param_lists(model, flat_master=False):
# flatten_dense_tensors returns a contiguous flat array.
# http://pytorch.org/docs/master/_modules/torch/_utils.html
master_params = _flatten_dense_tensors([param.data for param in model_params]).float()
except:
except BaseException:
print("Error in prep_param_lists: model may contain a mixture of parameters "
"of different types. Use flat_master=False, or use F16_Optimizer.")
raise
......@@ -165,9 +169,15 @@ def model_grads_to_master_grads(model_params, master_params, flat_master=False):
if model.grad is not None:
if master.grad is None:
master.grad = Variable(master.data.new(*master.data.size()))
master.grad.data.copy_(model.grad.data)
else:
master.grad = None
model_grads = [p.grad for p in model_params if p.grad is not None]
master_grads = [p.grad for p in master_params if p.grad is not None]
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[model_grads, master_grads],
1.0)
def master_params_to_model_params(model_params, master_params, flat_master=False):
......@@ -188,17 +198,19 @@ def master_params_to_model_params(model_params, master_params, flat_master=False
# Backward compatibility fixes
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
clip_grad_norm = mpu.clip_grad_norm
#elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# elif TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
# clip_grad_norm = torch.nn.utils.clip_grad_norm
#else:
# else:
# clip_grad_norm = torch.nn.utils.clip_grad_norm_
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,15 +14,22 @@
# limitations under the License.
import torch
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import mpu
# item() is a recent addition, so this helps with backward compatibility.
def to_python_float(t):
if hasattr(t, 'item'):
return t.item()
else:
return t[0]
class LossScaler:
"""
Class that manages a static loss scale. This class is intended to interact with
......@@ -54,12 +61,18 @@ class LossScaler:
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
class DynamicLossScaler:
"""
Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler`
......@@ -122,8 +135,8 @@ class DynamicLossScaler:
overflow = overflow_gpu[0].item()
return bool(overflow)
# `x` is a torch.Tensor
def _has_inf_or_nan(x):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
......@@ -158,7 +171,7 @@ class DynamicLossScaler:
if overflow:
# self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
self.cur_scale = max(self.cur_scale/self.scale_factor, self.min_scale)
self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
else:
self.cur_hysteresis -= 1
self.last_overflow_iter = self.cur_iter
......@@ -176,12 +189,18 @@ class DynamicLossScaler:
return self.cur_scale
def scale_gradient(self, module, grad_in, grad_out):
return tuple(self.loss_scale * g for g in grad_in)
_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
_overflow_buf,
[grad_in, grad_in],
self.loss_scale)
return grad_in
def backward(self, loss, retain_graph=False):
scaled_loss = loss*self.loss_scale
scaled_loss = loss * self.loss_scale
scaled_loss.backward(retain_graph=retain_graph)
##############################################################
# Example usage below here -- assuming it's in a separate file
##############################################################
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -61,22 +61,26 @@ def get_timers():
return _GLOBAL_TIMERS
def set_global_variables(extra_args_provider=None, args_defaults={}):
def set_global_variables(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
args = _parse_args(extra_args_provider=extra_args_provider,
defaults=args_defaults)
defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
_ = _build_tokenizer(args)
_set_tensorboard_writer(args)
_set_adlr_autoresume(args)
_set_timers()
def _parse_args(extra_args_provider=None, defaults={}):
def _parse_args(extra_args_provider=None, defaults={},
ignore_unknown_args=False):
"""Parse entire arguments."""
global _GLOBAL_ARGS
_ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
_GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
defaults=defaults)
defaults=defaults,
ignore_unknown_args=ignore_unknown_args)
return _GLOBAL_ARGS
......@@ -124,7 +128,7 @@ def _set_adlr_autoresume(args):
sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
try:
from userlib.auto_resume import AutoResume
except:
except BaseException:
print('ADLR autoresume is not available, exiting ...')
sys.exit()
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -28,7 +28,8 @@ from megatron import mpu
from megatron.global_vars import set_global_variables
def initialize_megatron(extra_args_provider=None, args_defaults={}):
def initialize_megatron(extra_args_provider=None, args_defaults={},
ignore_unknown_args=False):
"""Set global variables, initialize distributed, and
set autoresume and random seeds."""
# Make sure cuda is available.
......@@ -37,7 +38,8 @@ def initialize_megatron(extra_args_provider=None, args_defaults={}):
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
args_defaults=args_defaults,
ignore_unknown_args=ignore_unknown_args)
# Pytorch distributed.
_initialize_distributed()
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -48,7 +48,6 @@ class AnnealingLR(object):
print_rank_0('> learning rate decay style: {}'.format(self.decay_style))
def get_lr(self):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
......@@ -71,7 +70,6 @@ class AnnealingLR(object):
lr = self.start_lr
return max(lr, self.min_lr)
def step(self, step_num=None):
"""Set lr for all parameters groups."""
if step_num is None:
......@@ -81,7 +79,6 @@ class AnnealingLR(object):
for group in self.optimizer.param_groups:
group['lr'] = new_lr
def state_dict(self):
state_dict = {
'start_lr': self.start_lr,
......@@ -93,7 +90,6 @@ class AnnealingLR(object):
}
return state_dict
def _check_and_set(self, cls_value, sd_value, name):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
......@@ -108,7 +104,6 @@ class AnnealingLR(object):
name))
return sd_value
def load_state_dict(self, sd):
self.start_lr = self._check_and_set(self.start_lr, sd['start_lr'],
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -22,16 +22,15 @@ import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .transformer import LayerNorm
from .utils import gelu
from .utils import get_linear_layer
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask
......@@ -70,7 +69,6 @@ def bert_position_ids(token_ids):
return position_ids
class BertLMHead(MegatronModule):
"""Masked LM head for Bert
......@@ -81,11 +79,14 @@ class BertLMHead(MegatronModule):
layernorm_epsilon: tolerance for layer norm divisions
parallel_output: whether output logits being distributed or not.
"""
def __init__(self, mpu_vocab_size, hidden_size, init_method,
layernorm_epsilon, parallel_output):
super(BertLMHead, self).__init__()
args = get_args()
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
self.bias.model_parallel = True
self.bias.partition_dim = 0
......@@ -94,11 +95,13 @@ class BertLMHead(MegatronModule):
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
def forward(self, hidden_states, word_embeddings_weight):
hidden_states = self.dense(hidden_states)
hidden_states = gelu(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
output = parallel_lm_logits(hidden_states,
word_embeddings_weight,
......@@ -107,7 +110,6 @@ class BertLMHead(MegatronModule):
return output
class BertModel(MegatronModule):
"""Bert Language model."""
......@@ -184,7 +186,6 @@ class BertModel(MegatronModule):
return lm_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -206,7 +207,6 @@ class BertModel(MegatronModule):
= self.ict_head.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......@@ -224,8 +224,6 @@ class BertModel(MegatronModule):
class REALMBertModel(MegatronModule):
# TODO: load BertModel checkpoint
def __init__(self, retriever):
super(REALMBertModel, self).__init__()
bert_args = dict(
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -53,7 +53,6 @@ class Classification(MegatronModule):
init_method)
self._classification_head_key = 'classification_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
extended_attention_mask = bert_extended_attention_mask(
......@@ -74,7 +73,6 @@ class Classification(MegatronModule):
return classification_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -89,7 +87,6 @@ class Classification(MegatronModule):
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -31,10 +31,6 @@ class DistributedDataParallel(MegatronModule):
self.module = module
self.data_parallel_group = mpu.get_data_parallel_group()
src_rank = mpu.get_model_parallel_rank()
for p in self.module.parameters():
if torch.is_tensor(p):
dist.broadcast(p, src_rank, group=self.data_parallel_group)
def allreduce_params(reduce_after=True, no_scale=False, fp32_allreduce=False):
if(self.needs_reduction):
......@@ -71,8 +67,8 @@ class DistributedDataParallel(MegatronModule):
def allreduce_hook(*unused):
Variable._execution_engine.queue_callback(allreduce_params)
# handle = param.register_hook(allreduce_hook)
#self.hooks.append(allreduce_hook)
#self.hook_handles.append(handle)
# self.hooks.append(allreduce_hook)
# self.hook_handles.append(handle)
self.allreduce_params = allreduce_params
def forward(self, *inputs, **kwargs):
......@@ -114,4 +110,3 @@ class DistributedDataParallel(MegatronModule):
super(DistributedDataParallel, self).train(mode)
self.module.train(mode)
'''
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores = torch.mul(attention_scores, ltor_mask) - \
10000.0 * (1.0 - ltor_mask)
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
......@@ -49,7 +48,6 @@ class GPT2Model(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
......@@ -79,7 +77,6 @@ class GPT2Model(MegatronModule):
return output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -89,7 +86,6 @@ class GPT2Model(MegatronModule):
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,9 +21,8 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import gelu
from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer
......@@ -47,6 +46,12 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method, scaled_init_method, max_pos_embeds=None):
"""Build language model and return along with the key to save."""
args = get_args()
# Use torch gelu unless otherwise forced.
gelu = F.gelu
if args.openai_gelu:
gelu = openai_gelu
# Language model.
language_model = TransformerLanguageModel(
......@@ -63,7 +68,6 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
return language_model, language_model_key
class Pooler(MegatronModule):
"""Pooler layer.
......@@ -75,11 +79,11 @@ class Pooler(MegatronModule):
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
......@@ -102,6 +106,7 @@ class Embedding(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
......@@ -143,7 +148,6 @@ class Embedding(MegatronModule):
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
......@@ -160,7 +164,6 @@ class Embedding(MegatronModule):
# Initialize the token-type embeddings.
self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
......@@ -177,7 +180,6 @@ class Embedding(MegatronModule):
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
......@@ -195,7 +197,6 @@ class Embedding(MegatronModule):
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......@@ -242,7 +243,6 @@ class Embedding(MegatronModule):
'checkpoint but could not find it', flush=True)
class TransformerLanguageModel(MegatronModule):
"""Transformer language model.
......@@ -261,6 +261,7 @@ class TransformerLanguageModel(MegatronModule):
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
attention_mask_func,
mlp_activation_func,
......@@ -298,7 +299,6 @@ class TransformerLanguageModel(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
......@@ -320,7 +320,6 @@ class TransformerLanguageModel(MegatronModule):
return transformer_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load."""
......@@ -339,7 +338,6 @@ class TransformerLanguageModel(MegatronModule):
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -51,7 +51,6 @@ class MultipleChoice(MegatronModule):
init_method)
self._multichoice_head_key = 'multichoice_head'
def forward(self, input_ids, attention_mask, tokentype_ids):
# [batch, choices, sequence] --> [batch * choices, sequence] -->
......@@ -86,7 +85,6 @@ class MultipleChoice(MegatronModule):
return multichoice_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
......@@ -101,7 +99,6 @@ class MultipleChoice(MegatronModule):
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -46,6 +46,7 @@ from megatron.module import MegatronModule
unmaksed-attention-scores, attention-mask)
"""
class ParallelMLP(MegatronModule):
"""MLP.
......@@ -63,7 +64,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size,
4*args.hidden_size,
4 * args.hidden_size,
gather_output=False,
init_method=init_method)
......@@ -71,14 +72,13 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4*args.hidden_size,
4 * args.hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
self.dropout = torch.nn.Dropout(args.hidden_dropout)
def forward(self, hidden_states):
# [b, s, 4hp]
......@@ -91,13 +91,13 @@ class ParallelMLP(MegatronModule):
return output
class ParallelSelfAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__()
......@@ -123,7 +123,7 @@ class ParallelSelfAttention(MegatronModule):
# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size,
3*args.hidden_size,
3 * args.hidden_size,
stride=3,
gather_output=False,
init_method=init_method)
......@@ -141,7 +141,6 @@ class ParallelSelfAttention(MegatronModule):
init_method=output_layer_init_method)
self.output_dropout = torch.nn.Dropout(args.hidden_dropout)
def _transpose_for_scores(self, tensor):
"""Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with
size [b, np, s, hn].
......@@ -152,7 +151,6 @@ class ParallelSelfAttention(MegatronModule):
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def _get_query_key_value(self, hidden_states):
"""Get query, key, and value and transpose to
get size [b, np, s, hn].
......@@ -170,7 +168,6 @@ class ParallelSelfAttention(MegatronModule):
return query_layer, key_layer, value_layer
def _get_unmasked_attention_scores(self, query_layer, key_layer):
"""Unmasked attention scores with size [b, np, s, s]."""
coeff = 1
......@@ -179,9 +176,8 @@ class ParallelSelfAttention(MegatronModule):
norm_factor = math.sqrt(coeff *
math.sqrt(self.hidden_size_per_attention_head))
# Raw attention scores. [b, np, s, s]
return torch.matmul(query_layer/norm_factor,
key_layer.transpose(-1, -2)/norm_factor)
return torch.matmul(query_layer / norm_factor,
key_layer.transpose(-1, -2) / norm_factor)
def _get_attention_probs(self, attention_scores):
"""Attention probabilies with dropout. The output has
......@@ -198,7 +194,6 @@ class ParallelSelfAttention(MegatronModule):
return attention_probs
def _get_attended_context(self, attention_probs, value_layer):
"""Final attended tesnor and transposed back to [b, s, hp]."""
# Context layer.
......@@ -213,7 +208,6 @@ class ParallelSelfAttention(MegatronModule):
return context_layer
def _get_output(self, context_layer):
"""Output layer with dropout."""
# Output. [b, s, h]
......@@ -222,7 +216,6 @@ class ParallelSelfAttention(MegatronModule):
return output
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
......@@ -254,7 +247,7 @@ class ParallelSelfAttention(MegatronModule):
if layer_past is not None:
attention_mask = attention_mask[
...,
attention_scores.size(3)-1,
attention_scores.size(3) - 1,
:attention_scores.size(3)].unsqueeze(2)
else:
attention_mask = attention_mask[
......@@ -283,13 +276,13 @@ class ParallelSelfAttention(MegatronModule):
return output
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, attention_mask_func, mlp_activation_func,
init_method, output_layer_init_method, layer_number):
args = get_args()
......@@ -319,7 +312,6 @@ class ParallelTransformerLayer(MegatronModule):
self.mlp = ParallelMLP(mlp_activation_func, init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
# hidden_states: [b, s, h]
......@@ -375,14 +367,13 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
self.layers = torch.nn.ModuleList(
[get_layer(i+1) for i in range(args.num_layers)])
[get_layer(i + 1) for i in range(args.num_layers)])
# Final layer norm before output.
self.final_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
def _checkpointed_forward(self, hidden_states, attention_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
......@@ -398,13 +389,12 @@ class ParallelTransformer(MegatronModule):
num_layers = len(self.layers)
while l < num_layers:
hidden_states = mpu.checkpoint(
custom(l, l+self.checkpoint_num_layers),
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask)
l += self.checkpoint_num_layers
return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -33,6 +33,7 @@ def init_method_normal(sigma):
def scaled_init_method_normal(sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
......@@ -53,8 +54,7 @@ def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
(1.0 + 0.044715 * x * x)))
def gelu(x):
def openai_gelu(x):
return gelu_impl(x)
......
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -21,11 +21,9 @@ import torch
class MegatronModule(torch.nn.Module):
"""Megatron specific extentions of torch Module."""
def __init__(self):
super(MegatronModule, self).__init__()
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Use this function to override the state dict for
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment