Commit 651cc107 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merging megatron with ICT

parents e919dd8e c601d751
......@@ -3,17 +3,17 @@ import torch
import sys
from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.module import MegatronModule
from megatron.checkpointing import fix_query_key_value_ordering
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import get_checkpoint_name
from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def biencoder_model_provider(only_query_model=False,
only_context_model=False,
......@@ -165,16 +165,32 @@ class BiEncoderModel(MegatronModule):
print('global rank {} is loading BERT checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException:
raise ValueError("Could not load BERT checkpoint")
print_rank_0('could not load the BERT checkpoint')
sys.exit()
checkpoint_version = state_dict.get('checkpoint_version', 0)
# load the LM state dict into each model
model_dict = state_dict['model']['language_model']
if self.shared_query_context_model:
self.model.language_model.load_state_dict(model_dict)
fix_query_key_value_ordering(self.model, checkpoint_version)
else:
if self.use_query_model:
self.query_model.language_model.load_state_dict(model_dict)
......@@ -183,11 +199,14 @@ class BiEncoderModel(MegatronModule):
query_proj_state_dict = \
self.state_dict_for_save_checkpoint()\
[self._query_key]['projection_enc']
fix_query_key_value_ordering(self.query_model, checkpoint_version)
if self.use_context_model:
self.context_model.language_model.load_state_dict(model_dict)
if self.query_model is not None and self.projection_dim > 0:
self.context_model.projection_enc.load_state_dict\
(query_proj_state_dict)
fix_query_key_value_ordering(self.context_model, checkpoint_version)
class PretrainedBertModel(MegatronModule):
......@@ -209,9 +228,9 @@ class PretrainedBertModel(MegatronModule):
args.init_method_std, args.num_layers)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
......
......@@ -19,15 +19,16 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import PipelinedMegatronModule
from .module import MegatronModule
class ClassificationBase(PipelinedMegatronModule):
class ClassificationBase(MegatronModule):
def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationBase, self).__init__(share_word_embeddings=False)
......@@ -37,9 +38,9 @@ class ClassificationBase(PipelinedMegatronModule):
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
......
......@@ -20,7 +20,7 @@ from torch.nn.modules import Module
from torch.autograd import Variable
from megatron import mpu
from megatron.module import MegatronModule
from .module import MegatronModule
class DistributedDataParallel(MegatronModule):
......
......@@ -12,19 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .fp16util import (
BN_convert_float,
network_to_half,
prep_param_lists,
model_grads_to_master_grads,
master_params_to_model_params,
tofp16,
to_python_float,
clip_grad_norm,
convert_module,
convert_network,
FP16Model,
)
from .fp16 import *
from .loss_scaler import *
import enum
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
File mode changed from 100755 to 100644
......@@ -14,103 +14,127 @@
# limitations under the License.
import torch
from megatron.model.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = \
scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = \
scaled_upper_triang_masked_softmax_cuda.backward(output_grads,
softmax_results,
scale_t[0])
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function) :
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = \
scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
softmax_results = scaled_masked_softmax_cuda.forward(
inputs, mask, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = \
scaled_masked_softmax_cuda.backward(output_grads,
softmax_results,
scale_t[0])
input_grads = scaled_masked_softmax_cuda.backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking.
(used in gpt family networks)
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def __init__(self, input_in_fp16, upper_triang_mask_fusion,
general_mask_fusion, mask_func, softmax_in_fp32, scale):
def __init__(
self,
input_in_fp16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.upper_triang_mask_fusion = upper_triang_mask_fusion
self.general_mask_fusion = general_mask_fusion
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert self.scale is None or softmax_in_fp32, \
'softmax should be in fp32 when scaled'
assert (
self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, s, s]
# [b, np, sq, sk]
data_size = input.size()
query_seq_len = data_size[-2]
key_seq_len = data_size[-1]
assert input.dim() == 4
# invoke custom kernel
if self.input_in_fp16 and data_size[-1] <= 2048 and \
(self.upper_triang_mask_fusion or self.general_mask_fusion) and \
input.size()[2] == input.size()[3]:
if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \
query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0
if self.upper_triang_mask_fusion:
input = input.view(-1, data_size[2], data_size[3])
if self.attn_mask_type == AttnMaskType.causal:
assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size)
else:
assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else:
if self.input_in_fp16 and self.softmax_in_fp32:
......@@ -118,7 +142,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask)
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_fp16 and self.softmax_in_fp32:
......
......@@ -19,19 +19,15 @@ import torch
from megatron import get_args
from megatron import mpu
from megatron.module import PipelinedMegatronModule
from .module import MegatronModule
from .enums import AttnMaskType
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .utils import init_method_normal
from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output,
forward_method_parallel_output,
......@@ -61,37 +57,37 @@ def post_language_model_processing(lm_output, labels, logit_weights,
return loss
class GPT2ModelBase(PipelinedMegatronModule):
class GPTModelBase(MegatronModule):
"""GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2ModelBase, self).__init__()
super(GPTModelBase, self).__init__()
args = get_args()
self.parallel_output = parallel_output
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt2_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
self.initialize_word_embeddings(init_method_normal)
def forward(self, gpt2_model_input, attention_mask, labels=None,
def forward(self, gpt_model_input, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value}
if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = gpt2_model_input
(input_ids, position_ids) = gpt_model_input
args = [input_ids, position_ids, attention_mask]
kwargs['tokentype_ids'] = tokentype_ids
else:
args = [gpt2_model_input, attention_mask]
args = [gpt_model_input, attention_mask]
lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage():
......@@ -130,17 +126,17 @@ class GPT2ModelBase(PipelinedMegatronModule):
self.language_model.load_state_dict(state_dict, strict=strict)
class GPT2Model(GPT2ModelBase):
class GPTModel(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2Model, self).__init__(
super(GPTModel, self).__init__(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
return super(GPT2Model, self).forward(
return super(GPTModel, self).forward(
(input_ids, position_ids),
attention_mask,
labels=labels,
......@@ -150,15 +146,15 @@ class GPT2Model(GPT2ModelBase):
forward_method_parallel_output=forward_method_parallel_output)
class GPT2ModelFirstStage(GPT2ModelBase):
class GPTModelFirstStage(GPTModelBase):
def __init__(self, num_tokentypes=0):
super(GPT2ModelFirstStage, self).__init__(
super(GPTModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
return super(GPT2ModelFirstStage, self).forward(
return super(GPTModelFirstStage, self).forward(
(input_ids, position_ids),
attention_mask,
tokentype_ids=tokentype_ids,
......@@ -166,32 +162,32 @@ class GPT2ModelFirstStage(GPT2ModelBase):
get_key_value=get_key_value)
class GPT2ModelIntermediateStage(GPT2ModelBase):
class GPTModelIntermediateStage(GPTModelBase):
def __init__(self, num_tokentypes=0):
super(GPT2ModelIntermediateStage, self).__init__(
super(GPTModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask,
layer_past=None, get_key_value=False):
return super(GPT2ModelIntermediateStage, self).forward(
return super(GPTModelIntermediateStage, self).forward(
hidden_state,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
class GPT2ModelLastStage(GPT2ModelBase):
class GPTModelLastStage(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2ModelLastStage, self).__init__(
super(GPTModelLastStage, self).__init__(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask, labels=None,
layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
return super(GPT2ModelLastStage, self).forward(
return super(GPTModelLastStage, self).forward(
hidden_state,
attention_mask,
labels=labels,
......
......@@ -20,7 +20,8 @@ import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal
......@@ -42,8 +43,10 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method=None, scaled_init_method=None):
def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal):
"""Build language model and return along with the key to save."""
args = get_args()
......@@ -51,15 +54,18 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
# Language model.
args = [attention_mask_func, init_method, scaled_init_method]
args = [init_method, scaled_init_method, encoder_attn_mask_type]
kwargs = {}
cls = None
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModel
kwargs['num_tokentypes'] = num_tokentypes
kwargs['add_decoder'] = add_decoder
kwargs['decoder_attn_mask_type'] = decoder_attn_mask_type
kwargs['add_pooler'] = add_pooler
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelFirstStage
......@@ -262,12 +268,6 @@ class TransformerLanguageModelBase(MegatronModule):
Arguments:
transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
......@@ -277,10 +277,12 @@ class TransformerLanguageModelBase(MegatronModule):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False):
super(TransformerLanguageModelBase, self).__init__()
args = get_args()
......@@ -288,6 +290,9 @@ class TransformerLanguageModelBase(MegatronModule):
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
# Embeddings.
......@@ -301,41 +306,83 @@ class TransformerLanguageModelBase(MegatronModule):
self._embedding_key = 'embedding'
# Transformer.
self.transformer = ParallelTransformer(
attention_mask_func, self.init_method,
output_layer_init_method)
self._transformer_key = 'transformer'
self.encoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type)
self._encoder_key = 'encoder'
# Decoder
if self.add_decoder:
assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer(
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type)
self._decoder_key = 'decoder'
if mpu.is_pipeline_last_stage():
# Pooler.
if mpu.is_pipeline_last_stage() and self.add_pooler:
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, language_model_input, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
def forward(self, enc_language_model_input, enc_attn_mask,
dec_language_model_input=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
# Embeddings.
if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = language_model_input
(input_ids, position_ids) = enc_language_model_input
embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids)
transformer_input = embedding_output
encoder_input = embedding_output
else:
transformer_input = language_model_input
encoder_input = enc_language_model_input
# Transformer.
transformer_output = self.transformer(transformer_input,
attention_mask,
# encoder.
if enc_hidden_states is None:
encoder_output = self.encoder(encoder_input,
enc_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value)
else:
encoder_output = enc_hidden_states.to(encoder_input.dtype)
if mpu.is_pipeline_last_stage() and self.add_pooler:
pooled_output = self.pooler(transformer_output,
if mpu.is_pipeline_last_stage():
if self.add_pooler:
pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
return transformer_output, pooled_output
return transformer_output
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and mpu.is_pipeline_last_stage():
return encoder_output, pooled_output
else:
return encoder_output
# Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler and mpu.is_pipeline_last_stage():
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -346,13 +393,18 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._transformer_key] \
= self.transformer.state_dict_for_save_checkpoint(
state_dict_[self._encoder_key] \
= self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage() and self.add_pooler:
if mpu.is_pipeline_last_stage():
if self.add_pooler:
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
......@@ -371,36 +423,44 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict)
# Transformer.
if self._transformer_key in state_dict:
state_dict_ = state_dict[self._transformer_key]
# for compatiability with t5 architecture
# this is temporary unless t5_main is merged
elif 'encoder' in state_dict:
state_dict_ = state_dict['encoder']
# for forward compatibility for t5 architecture
state_dict_attention = {}
for key in state_dict_.keys():
if '.self_attention.' in key:
state_dict_attention[key.replace(".self_attention.",
".attention.")] = state_dict_[key]
else:
state_dict_attention[key] = state_dict_[key]
state_dict_ = state_dict_attention
# Encoder.
if self._encoder_key in state_dict:
state_dict_ = state_dict[self._encoder_key]
# for backward compatibility.
elif 'transformer' in state_dict:
state_dict_ = state_dict['transformer']
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler.
if mpu.is_pipeline_last_stage() and self.add_pooler:
# for backward compatibility.
state_dict_self_attention = {}
for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
if mpu.is_pipeline_last_stage():
# pooler
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
# decoder
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint'
self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase):
......@@ -409,28 +469,39 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
decoder_attn_mask_type=AttnMaskType.causal,
add_decoder=False,
add_pooler=False):
super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False,
pooling_sequence_index=0):
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
return super(TransformerLanguageModel, self).forward(
(input_ids, position_ids),
attention_mask,
(enc_input_ids, enc_position_ids),
enc_attn_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
output_enc_hidden=output_enc_hidden
)
......@@ -440,14 +511,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
......@@ -467,13 +538,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method):
output_layer_init_method,
encoder_attn_mask_type):
super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method)
output_layer_init_method,
encoder_attn_mask_type)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False):
......@@ -491,14 +562,14 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask,
......@@ -509,5 +580,5 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index
pooling_sequence_index=pooling_sequence_index,
)
......@@ -16,16 +16,31 @@
"""Megatron Module"""
import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron import get_args
from megatron import mpu
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module."""
"""Megatron specific extensions of torch Module with support
for pipelining."""
def __init__(self):
def __init__(self, share_word_embeddings=True):
super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......@@ -34,52 +49,127 @@ class MegatronModule(torch.nn.Module):
return self.state_dict(destination, prefix, keep_vars)
class PipelinedMegatronModule(MegatronModule):
"""Pipelining specific extensions of MegatronModule."""
def __init__(self, share_word_embeddings=True):
super(PipelinedMegatronModule, self).__init__()
args = get_args()
self.share_word_embeddings = share_word_embeddings
def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage():
return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage():
if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last stage, '
'but share_word_embeddings is false')
raise Exception('word_embeddings_weight() called for last '
'stage, but share_word_embeddings is false')
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
def initialize_word_embeddings(self, init_method_normal):
args = get_args()
if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false')
# Parameters are shared between the word embeddings layer, and the heads at
# the end of the model. In a pipelined setup with more than one stage, the
# initial embedding layer and the head are on different workers, so we do
# the following:
# 1. Create a second copy of word_embeddings on the last stage, with initial
# parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that the
# two copies of word_embeddings start off with the same parameter values.
# 3. In the training loop, before an all-reduce between the grads of the two
# word_embeddings layers to ensure that every applied weight update is the
# same on both stages.
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
if args.pipeline_model_parallel_size == 1:
return
# Parameters are shared between the word embeddings layer, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage():
assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings
# weights to 0 here, then copy first stage's weights using all_reduce
# below.
# set word_embeddings weights to 0 here, then copy first
# stage's weights using all_reduce below.
self.word_embeddings = mpu.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0)
# Ensure that first and last stages have the same initial parameter values.
self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter
# values.
if torch.distributed.is_initialized():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = val.half()
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _HALF_TYPES):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class FP16Module(MegatronModule):
def __init__(self, module):
super(FP16Module, self).__init__()
self.add_module('module', module.half())
def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage():
inputs = fp32_to_fp16(inputs)
outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage():
outputs = fp16_to_fp32(outputs)
return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
......@@ -19,15 +19,16 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import PipelinedMegatronModule
from .module import MegatronModule
class MultipleChoiceBase(PipelinedMegatronModule):
class MultipleChoiceBase(MegatronModule):
def __init__(self, num_tokentypes=2):
super(MultipleChoiceBase, self).__init__(share_word_embeddings=False)
......@@ -36,9 +37,9 @@ class MultipleChoiceBase(PipelinedMegatronModule):
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
......
......@@ -4,13 +4,14 @@ import torch
from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel
from megatron.module import MegatronModule
from .module import MegatronModule
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False):
......@@ -156,9 +157,9 @@ class IREncoderBertModel(MegatronModule):
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
......
......@@ -14,19 +14,18 @@
# limitations under the License.
"""Transformer."""
import math
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from .module import MegatronModule
from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
......@@ -47,12 +46,6 @@ torch._C._jit_override_can_fuse_on_gpu(True)
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
"""
class ParallelMLP(MegatronModule):
......@@ -71,7 +64,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size,
4 * args.hidden_size,
args.ffn_hidden_size,
gather_output=False,
init_method=init_method,
skip_bias_add=True)
......@@ -85,7 +78,7 @@ class ParallelMLP(MegatronModule):
# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size,
args.ffn_hidden_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
......@@ -109,39 +102,58 @@ class ParallelMLP(MegatronModule):
return output, output_bias
class ParallelSelfAttention(MegatronModule):
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
"""
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
super(ParallelSelfAttention, self).__init__()
def __init__(self, init_method,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size,
self.hidden_size_per_partition = mpu.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
args.hidden_size, args.num_attention_heads)
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
# Strided linear layer.
if attention_type == AttnType.self_attn:
self.query_key_value = mpu.ColumnParallelLinear(
args.hidden_size,
3 * args.hidden_size,
3 * projection_size,
gather_output=False,
init_method=init_method)
else:
assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
self.key_value = mpu.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method)
......@@ -153,9 +165,9 @@ class ParallelSelfAttention(MegatronModule):
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16,
args.scaled_upper_triang_masked_softmax_fusion,
args.scaled_masked_softmax_fusion,
self.attention_mask_func,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
......@@ -166,65 +178,24 @@ class ParallelSelfAttention(MegatronModule):
# Output.
self.dense = mpu.RowParallelLinear(
args.hidden_size,
projection_size,
args.hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method,
skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
input_shape = mixed_layer.size();
if num_splits_first:
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\
(num_splits, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
else:
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, num_splits)
mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
mixed_layer = mixed_layer.view(*input_shape)
return mixed_layer
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
get_key_value=False, encoder_output=None):
# hidden_states: [sq, b, h]
# =====================
# Query, Key, and Value
# =====================
if self.attention_type == AttnType.self_attn:
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version()
if get_args().override_checkpoint_version is not None:
checkpoint_version = get_args().override_checkpoint_version
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
......@@ -235,6 +206,27 @@ class ParallelSelfAttention(MegatronModule):
(query_layer,
key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
2 * self.hidden_size_per_attention_head)
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer,
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ==================================
# Adjust key and value for inference
......@@ -249,7 +241,6 @@ class ParallelSelfAttention(MegatronModule):
if get_key_value:
present = (key_layer, value_layer)
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
......@@ -263,6 +254,7 @@ class ParallelSelfAttention(MegatronModule):
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
......@@ -275,15 +267,15 @@ class ParallelSelfAttention(MegatronModule):
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(matmul_result,
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
......@@ -301,7 +293,6 @@ class ParallelSelfAttention(MegatronModule):
:attention_scores.size(3),
:attention_scores.size(3)]
# ===========================
# Attention probs and dropout
# ===========================
......@@ -315,7 +306,6 @@ class ParallelSelfAttention(MegatronModule):
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
......@@ -338,7 +328,7 @@ class ParallelSelfAttention(MegatronModule):
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
......@@ -351,7 +341,6 @@ class ParallelSelfAttention(MegatronModule):
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# =================
# Output. [sq, b, h]
# =================
......@@ -364,7 +353,7 @@ class ParallelSelfAttention(MegatronModule):
return output, bias
def bias_dropout_add(x, bias, residual, prob, training) :
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
......@@ -378,13 +367,13 @@ def get_bias_dropout_add(training):
@torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob) :
def bias_dropout_add_fused_train(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob) :
def bias_dropout_add_fused_inference(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, False)
......@@ -392,16 +381,18 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) :
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number):
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
args = get_args()
super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm
......@@ -413,30 +404,45 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon)
# Self attention.
self.attention = ParallelSelfAttention(attention_mask_func, init_method,
self.self_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number)
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data.
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
# MLP
self.mlp = ParallelMLP(init_method,
output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
layer_past=None, get_key_value=False):
# hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = \
self.attention(layernorm_output,
self.self_attention(layernorm_output,
attention_mask,
layer_past=layer_past,
get_key_value=get_key_value)
......@@ -462,7 +468,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
#re-enable torch grad to enable fused optimization.
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
......@@ -473,6 +479,28 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
......@@ -482,7 +510,7 @@ class ParallelTransformerLayer(MegatronModule):
else:
residual = layernorm_input
#re-enable torch grad to enable fused optimization.
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
......@@ -499,8 +527,9 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, attention_mask_func,
init_method, output_layer_init_method):
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
super(ParallelTransformer, self).__init__()
args = get_args()
......@@ -518,8 +547,11 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
attention_mask_func, init_method,
output_layer_init_method, layer_number)
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)])
......@@ -534,14 +566,18 @@ class ParallelTransformer(MegatronModule):
def _get_layer(self, layer_number):
return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask):
def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, inputs[1])
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_
return custom_forward
......@@ -551,13 +587,13 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers:
hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask)
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers
return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False):
get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
# Checks.
if layer_past is not None:
......@@ -578,9 +614,14 @@ class ParallelTransformer(MegatronModule):
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask)
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
if get_key_value:
presents = []
......@@ -591,6 +632,8 @@ class ParallelTransformer(MegatronModule):
past = layer_past[index]
hidden_states = layer(hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=past,
get_key_value=get_key_value)
if get_key_value:
......
......@@ -20,7 +20,6 @@ import math
import torch
from megatron import get_args
from megatron.model import import_layernorm
def init_method_normal(sigma):
"""Init method based on N(0, sigma)."""
......@@ -40,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers):
return init_
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
......@@ -60,28 +64,3 @@ def openai_gelu(x):
@torch.jit.script
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import math
import einops
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import (
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
)
from .module import MegatronModule
class VitMlpHead(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, num_classes):
super(VitMlpHead, self).__init__()
self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
self.dense_out = torch.nn.Linear(hidden_size, num_classes)
torch.nn.init.constant_(self.dense_out.bias, -10)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
x = hidden_states[:, sequence_index, :]
x = self.dense_in(x)
x = torch.tanh(x)
x = self.dense_out(x)
return x
def twod_interpolate_position_embeddings_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
args = get_args()
num_patches_per_dim = args.img_dim // args.patch_dim
num_patches = num_patches_per_dim ** 2
seq_length = num_patches + 1
hidden_size = args.hidden_size
key = prefix + "weight"
# import pdb
# pdb.set_trace()
assert key in state_dict
if key in state_dict:
input_param = state_dict[key]
assert input_param.shape[1] == hidden_size
if input_param.shape[0] != seq_length:
# update input_param and load it to state_dict[key]
num_tok_input = input_param.shape[0] - 1
num_tok_new = seq_length - 1
input_param_tok, input_param_grid = (
input_param[:1, :],
input_param[1:, :],
)
gs_input = int(math.sqrt(num_tok_input))
gs_new = int(math.sqrt(num_tok_new))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
input_param_grid = input_param_grid.reshape(
(1, -1, gs_input, gs_input)
)
input_param_grid = input_param_grid.float()
scale_factor = gs_new / gs_input
input_param_grid = F.interpolate(
input_param_grid, scale_factor=scale_factor, mode="bilinear"
)
input_param_grid = input_param_grid.half()
input_param_grid = input_param_grid.reshape((-1, gs_new * gs_new))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
assert input_param_grid.shape[1] == hidden_size
input_param = torch.cat((input_param_tok, input_param_grid), dim=0)
assert (
input_param.shape[0] == seq_length
and input_param.shape[1] == hidden_size
)
state_dict[key] = input_param
class VitModel(MegatronModule):
"""Vision Transformer Model."""
def __init__(self, num_classes, finetune=False):
super(VitModel, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
if args.init_method_xavier_uniform:
self.init_method = torch.nn.init.xavier_uniform_
self.scaled_init_method = torch.nn.init.xavier_uniform_
else:
self.init_method = init_method_normal(args.init_method_std)
self.scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers
)
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.patch_dim = args.patch_dim
self.img_dim = args.img_dim
self.finetune = finetune
assert self.img_dim % self.patch_dim == 0
self.num_patches_per_dim = self.img_dim // self.patch_dim
self.num_patches = self.num_patches_per_dim ** 2
self.seq_length = self.num_patches + 1
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
# cls_token
self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
torch.nn.init.zeros_(self.cls_token)
# Linear encoder
self.linear_encoder = torch.nn.Linear(
self.flatten_dim, self.hidden_size
)
# embedding
self.position_embeddings = torch.nn.Embedding(
self.seq_length, self.hidden_size
)
init_method_normal(args.init_method_std)(
self.position_embeddings.weight
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
self.position_embeddings._register_load_state_dict_pre_hook(
twod_interpolate_position_embeddings_hook
)
self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
# Transformer
self.transformer = ParallelTransformer(
self.init_method, self.scaled_init_method
)
# MLP head
if not self.finetune:
self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
else:
self.class_head = get_linear_layer(
self.hidden_size, num_classes, torch.nn.init.zeros_
)
def forward(self, x):
x = einops.rearrange(
x,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_dim,
p2=self.patch_dim,
)
assert x.dtype == torch.half
x = self.linear_encoder(x)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.position_embeddings(self.position_ids)
x = self.embedding_dropout(x)
x = self.transformer(x, None)
if not self.finetune:
x = self.mlp_head(x)
else:
x = self.class_head(x[:, 0, :])
return x
......@@ -19,8 +19,6 @@ from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data
from .grads import clip_grad_norm
from .initialize import is_unitialized
from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group
......@@ -46,6 +44,9 @@ from .initialize import model_parallel_is_initialized
from .layers import ColumnParallelLinear
from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding
from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region
......
......@@ -20,7 +20,7 @@ from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 4
_MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype):
......
......@@ -37,13 +37,53 @@ from .utils import split_tensor_along_last_dim
from .utils import VocabUtility
from megatron import get_args
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
# Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
# Set the attributes.
setattr(tensor, 'tensor_model_parallel', is_parallel)
setattr(tensor, 'partition_dim', dim)
setattr(tensor, 'partition_stride', stride)
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def maybe_set(attribute, value):
if not hasattr(tensor, attribute):
setattr(tensor, attribute, value)
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def maybe_copy(attribute):
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute,
getattr(source_tensor, attribute))
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_copy(attribute)
def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
weight.tensor_model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
with get_cuda_rng_tracker().fork():
init_method(weight)
......@@ -58,9 +98,10 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
Build the master weight on all processes and scatter
the relevant chunk."""
weight.tensor_model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
# Initialize master weight
master_weight = torch.empty(output_size, input_size,
......@@ -74,7 +115,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim)
rank = get_model_parallel_rank()
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
......@@ -225,9 +266,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=args.params_dtype))
self.bias.tensor_model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = stride
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
from megatron import get_args
from megatron.model import import_layernorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
def get_megatron_optimizer(model):
args = get_args()
# Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model)
if args.optimizer == 'adam':
optimizer = Adam(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps)
elif args.optimizer == 'sgd':
optimizer = SGD(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
momentum=args.sgd_momentum)
else:
raise Exception('{} optimizer is not supported.'.format(
args.optimizer))
if args.fp16:
# Constant loss scale.
if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale.
else:
grad_scaler = DynamicGradScaler(
initial_scale=args.initial_loss_scale,
min_scale=args.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis)
# Megatron optimizer.
return FP16OptimizerWithFP16Params(optimizer, grad_scaler,
args.clip_grad)
# FP32.
return FP32Optimizer(optimizer, args.clip_grad)
......@@ -13,67 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
"""Gradient clipping."""
import torch
from torch._six import inf
try:
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from .initialize import is_pipeline_first_stage
from .initialize import get_model_parallel_group
from .initialize import get_tensor_model_parallel_rank
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
def l2_grad_clipper(parameters, max_norm):
"""Efficient L2 norm gradient clipping."""
overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda')
# Make sure we have an iterable.
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters with gradients.
parameters_with_grads = list(filter(
lambda p: p.grad is not None, parameters))
# Filter parameters for norm calculations.
mp_rank_is_zero = (get_tensor_model_parallel_rank() == 0)
parameters_for_norm = list(filter(
lambda p: p.tensor_model_parallel or mp_rank_is_zero, parameters_with_grads))
# Calculate L2 norm.
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
overflow_buf,
[parameters_for_norm],
False # no per-parameter norm
)
# Sum across all model parallel GPUs.
norm_2 = norm * norm
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
total_norm = norm_2.item() ** 0.5
# Scale to get max_norm.
clip_coef = float(max_norm) / (total_norm + 1.0e-6)
grads = [p.grad for p in parameters_with_grads]
if clip_coef < 1.0:
multi_tensor_applier(
amp_C.multi_tensor_scale,
overflow_buf,
[grads, grads],
clip_coef)
return total_norm
from megatron import mpu
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
"""Clips gradient norm of an iterable of parameters.
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
......@@ -89,51 +44,77 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
if parameter_names is not None:
filtered_parameters = []
assert len(parameters) == len(parameter_names), \
'length of parameters and parameter_names should be the same'
for p, n in zip(parameters, parameter_names):
if p.grad is not None:
# TODO: Bit hacky; is there a cleaner way to do this?
# Count embedding layer only once (in first stage).
# Don't count the weights a second time in the last stage.
if "embedding" not in n or \
is_pipeline_first_stage():
filtered_parameters.append(p)
parameters = filtered_parameters
else:
parameters = list(filter(lambda p: p.grad is not None, parameters))
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
grads = []
grads_for_norm = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
grad = param.grad.detach()
if grad_not_none:
# Make sure the grads are in fp32
assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
total_norm = 0.0
# Calculate norm.
if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group())
group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item()
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
else:
total_norm = 0
for p in parameters:
if p.tensor_model_parallel or (get_tensor_model_parallel_rank() == 0):
param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type)
total_norm += param_norm.item() ** norm_type
if norm_type == 2.0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_for_norm],
False # no per-parameter norm
)
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
else:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
# Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
torch.distributed.all_reduce(total_norm_cuda,
torch.distributed.all_reduce(total_norm,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
total_norm = total_norm_cuda[0].item() ** (1. / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
group=mpu.get_model_parallel_group())
total_norm = total_norm.item() ** (1.0 / norm_type)
# Scale.
clip_coeff = max_norm / (total_norm + 1.0e-6)
if clip_coeff < 1.0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads, grads],
clip_coeff)
return total_norm
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron grad scaler."""
from abc import ABC
from abc import abstractmethod
import torch
class MegatronGradScaler(ABC):
def __init__(self, initial_scale):
"""Initialize scale value with the input initial scale."""
assert initial_scale > 0.0
self._scale = torch.cuda.FloatTensor([initial_scale])
@property
def scale(self):
return self._scale
@property
def inv_scale(self):
return self._scale.double().reciprocal().float()
@abstractmethod
def update(self, found_inf):
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
class ConstantGradScaler(MegatronGradScaler):
def update(self, found_inf):
pass
def state_dict(self):
return dict()
def load_state_dict(self, state_dict):
pass
class DynamicGradScaler(MegatronGradScaler):
def __init__(self, initial_scale, min_scale,
growth_factor, backoff_factor,
growth_interval, hysteresis):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
super(DynamicGradScaler, self).__init__(initial_scale)
# Lower bound on the scale.
assert min_scale > 0.0
assert min_scale <= initial_scale
self.min_scale = torch.cuda.FloatTensor([min_scale])
# Growth and backoff factors for the scale.
assert growth_factor > 1.0
self.growth_factor = torch.cuda.FloatTensor([growth_factor])
assert backoff_factor < 1.0
assert backoff_factor > 0.0
self.backoff_factor = torch.cuda.FloatTensor([backoff_factor])
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert growth_interval > 0
self.growth_interval = growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert hysteresis > 0
self.hysteresis = hysteresis
# Trackers.
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
def update(self, found_inf):
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if found_inf:
self._growth_tracker = 0
self._hysteresis_tracker -= 1
# Now if we are out of hysteresis count, scale down the loss.
if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor,
self.min_scale)
else:
# If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1
# If we have had enough consequitive intervals with no nan/inf:
if self._growth_tracker == self.growth_interval:
# Reset the tracker and hysteresis trackers,
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
# and scale up the loss scale.
self._scale = self._scale * self.growth_factor
def state_dict(self):
state_dict = {}
state_dict['scale'] = self._scale
state_dict['growth_tracker'] = self._growth_tracker
state_dict['hysteresis_tracker'] = self._hysteresis_tracker
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._growth_tracker = state_dict['growth_tracker']
self._hysteresis_tracker = state_dict['hysteresis_tracker']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment