"src/webui/vscode:/vscode.git/clone" did not exist on "87ed70cd5fc0e88551899d07a2614def1964a84a"
Commit 651cc107 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merging megatron with ICT

parents e919dd8e c601d751
...@@ -3,17 +3,17 @@ import torch ...@@ -3,17 +3,17 @@ import torch
import sys import sys
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import fix_query_key_value_ordering
from megatron.module import MegatronModule from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import get_checkpoint_name
from megatron import mpu, get_tokenizer from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids from megatron.model.bert_model import bert_position_ids
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def biencoder_model_provider(only_query_model=False, def biencoder_model_provider(only_query_model=False,
only_context_model=False, only_context_model=False,
...@@ -165,16 +165,32 @@ class BiEncoderModel(MegatronModule): ...@@ -165,16 +165,32 @@ class BiEncoderModel(MegatronModule):
print('global rank {} is loading BERT checkpoint {}'.format( print('global rank {} is loading BERT checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
try: try:
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException: except BaseException:
raise ValueError("Could not load BERT checkpoint") print_rank_0('could not load the BERT checkpoint')
sys.exit()
checkpoint_version = state_dict.get('checkpoint_version', 0)
# load the LM state dict into each model # load the LM state dict into each model
model_dict = state_dict['model']['language_model'] model_dict = state_dict['model']['language_model']
if self.shared_query_context_model: if self.shared_query_context_model:
self.model.language_model.load_state_dict(model_dict) self.model.language_model.load_state_dict(model_dict)
fix_query_key_value_ordering(self.model, checkpoint_version)
else: else:
if self.use_query_model: if self.use_query_model:
self.query_model.language_model.load_state_dict(model_dict) self.query_model.language_model.load_state_dict(model_dict)
...@@ -183,11 +199,14 @@ class BiEncoderModel(MegatronModule): ...@@ -183,11 +199,14 @@ class BiEncoderModel(MegatronModule):
query_proj_state_dict = \ query_proj_state_dict = \
self.state_dict_for_save_checkpoint()\ self.state_dict_for_save_checkpoint()\
[self._query_key]['projection_enc'] [self._query_key]['projection_enc']
fix_query_key_value_ordering(self.query_model, checkpoint_version)
if self.use_context_model: if self.use_context_model:
self.context_model.language_model.load_state_dict(model_dict) self.context_model.language_model.load_state_dict(model_dict)
if self.query_model is not None and self.projection_dim > 0: if self.query_model is not None and self.projection_dim > 0:
self.context_model.projection_enc.load_state_dict\ self.context_model.projection_enc.load_state_dict\
(query_proj_state_dict) (query_proj_state_dict)
fix_query_key_value_ordering(self.context_model, checkpoint_version)
class PretrainedBertModel(MegatronModule): class PretrainedBertModel(MegatronModule):
...@@ -209,9 +228,9 @@ class PretrainedBertModel(MegatronModule): ...@@ -209,9 +228,9 @@ class PretrainedBertModel(MegatronModule):
args.init_method_std, args.num_layers) args.init_method_std, args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
......
...@@ -19,15 +19,16 @@ import torch ...@@ -19,15 +19,16 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.module import PipelinedMegatronModule from .module import MegatronModule
class ClassificationBase(PipelinedMegatronModule): class ClassificationBase(MegatronModule):
def __init__(self, num_classes, num_tokentypes=2): def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationBase, self).__init__(share_word_embeddings=False) super(ClassificationBase, self).__init__(share_word_embeddings=False)
...@@ -37,9 +38,9 @@ class ClassificationBase(PipelinedMegatronModule): ...@@ -37,9 +38,9 @@ class ClassificationBase(PipelinedMegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
......
...@@ -20,7 +20,7 @@ from torch.nn.modules import Module ...@@ -20,7 +20,7 @@ from torch.nn.modules import Module
from torch.autograd import Variable from torch.autograd import Variable
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from .module import MegatronModule
class DistributedDataParallel(MegatronModule): class DistributedDataParallel(MegatronModule):
......
...@@ -12,19 +12,17 @@ ...@@ -12,19 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .fp16util import (
BN_convert_float,
network_to_half,
prep_param_lists,
model_grads_to_master_grads,
master_params_to_model_params,
tofp16,
to_python_float,
clip_grad_norm,
convert_module,
convert_network,
FP16Model,
)
from .fp16 import * import enum
from .loss_scaler import *
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
File mode changed from 100755 to 100644
...@@ -14,103 +14,127 @@ ...@@ -14,103 +14,127 @@
# limitations under the License. # limitations under the License.
import torch import torch
from megatron.model.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
1. Scale the tensor. 1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models). 2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax. 3. Perform softmax.
""" """
@staticmethod @staticmethod
def forward(ctx, inputs, scale): def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = \ softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0]) inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = \ input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
scaled_upper_triang_masked_softmax_cuda.backward(output_grads, output_grads, softmax_results, scale_t[0]
softmax_results, )
scale_t[0])
return input_grads, None return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function) :
class ScaledMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
1. Scale the tensor. 1. Scale the tensor.
2. Apply the mask. 2. Apply the mask.
3. Perform softmax. 3. Perform softmax.
""" """
@staticmethod @staticmethod
def forward(ctx, inputs, mask, scale): def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = \ softmax_results = scaled_masked_softmax_cuda.forward(
scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) inputs, mask, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = \ input_grads = scaled_masked_softmax_cuda.backward(
scaled_masked_softmax_cuda.backward(output_grads, output_grads, softmax_results, scale_t[0]
softmax_results, )
scale_t[0])
return input_grads, None, None return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module): class FusedScaleMaskSoftmax(torch.nn.Module):
""" """
fused operation: scaling + mask + softmax fused operation: scaling + mask + softmax
Arguments: Arguments:
input_in_fp16: flag to indicate if input in fp16 data format. input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking. attn_mask_type: attention mask type (pad or causal)
(used in gpt family networks) mask_func: mask function to be applied.
mask_func: mask function to be applied. softmax_in_fp32: if true, softmax in performed at fp32 precision.
softmax_in_fp32: if true, softmax in performed at fp32 precision. scale: scaling factor used in input tensor scaling.
scale: scaling factor used in input tensor scaling.
""" """
def __init__(self, input_in_fp16, upper_triang_mask_fusion,
general_mask_fusion, mask_func, softmax_in_fp32, scale): def __init__(
self,
input_in_fp16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.upper_triang_mask_fusion = upper_triang_mask_fusion self.attn_mask_type = attn_mask_type
self.general_mask_fusion = general_mask_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32 self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale self.scale = scale
assert self.scale is None or softmax_in_fp32, \ assert (
'softmax should be in fp32 when scaled' self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, s, s] # [b, np, sq, sk]
data_size = input.size() data_size = input.size()
assert input.dim() == 4 query_seq_len = data_size[-2]
key_seq_len = data_size[-1]
assert input.dim() == 4
# invoke custom kernel # invoke custom kernel
if self.input_in_fp16 and data_size[-1] <= 2048 and \ if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \
(self.upper_triang_mask_fusion or self.general_mask_fusion) and \ query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:
input.size()[2] == input.size()[3]:
scale = self.scale if self.scale is not None else 1.0 scale = self.scale if self.scale is not None else 1.0
if self.upper_triang_mask_fusion:
input = input.view(-1, data_size[2], data_size[3]) if self.attn_mask_type == AttnMaskType.causal:
assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size) probs = probs.view(*data_size)
else: else:
assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale) probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else: else:
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_fp16 and self.softmax_in_fp32:
...@@ -118,7 +142,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -118,7 +142,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
if self.scale is not None: if self.scale is not None:
input = input * self.scale input = input * self.scale
mask_output = self.mask_func(input, mask) mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_fp16 and self.softmax_in_fp32:
......
...@@ -19,19 +19,15 @@ import torch ...@@ -19,19 +19,15 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import PipelinedMegatronModule from .module import MegatronModule
from .enums import AttnMaskType
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
from .language_model import get_language_model from .language_model import get_language_model
from .utils import init_method_normal from .utils import init_method_normal
from .utils import scaled_init_method_normal from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
def post_language_model_processing(lm_output, labels, logit_weights, def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output, get_key_value, parallel_output,
forward_method_parallel_output, forward_method_parallel_output,
...@@ -61,37 +57,37 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -61,37 +57,37 @@ def post_language_model_processing(lm_output, labels, logit_weights,
return loss return loss
class GPT2ModelBase(PipelinedMegatronModule): class GPTModelBase(MegatronModule):
"""GPT-2 Language model.""" """GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2ModelBase, self).__init__() super(GPTModelBase, self).__init__()
args = get_args() args = get_args()
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt2_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std), init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings(init_method_normal)
def forward(self, gpt2_model_input, attention_mask, labels=None, def forward(self, gpt_model_input, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value} kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value}
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = gpt2_model_input (input_ids, position_ids) = gpt_model_input
args = [input_ids, position_ids, attention_mask] args = [input_ids, position_ids, attention_mask]
kwargs['tokentype_ids'] = tokentype_ids kwargs['tokentype_ids'] = tokentype_ids
else: else:
args = [gpt2_model_input, attention_mask] args = [gpt_model_input, attention_mask]
lm_output = self.language_model(*args, **kwargs) lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
...@@ -130,17 +126,17 @@ class GPT2ModelBase(PipelinedMegatronModule): ...@@ -130,17 +126,17 @@ class GPT2ModelBase(PipelinedMegatronModule):
self.language_model.load_state_dict(state_dict, strict=strict) self.language_model.load_state_dict(state_dict, strict=strict)
class GPT2Model(GPT2ModelBase): class GPTModel(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2Model, self).__init__( super(GPTModel, self).__init__(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
parallel_output=parallel_output) parallel_output=parallel_output)
def forward(self, input_ids, position_ids, attention_mask, labels=None, def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
return super(GPT2Model, self).forward( return super(GPTModel, self).forward(
(input_ids, position_ids), (input_ids, position_ids),
attention_mask, attention_mask,
labels=labels, labels=labels,
...@@ -150,15 +146,15 @@ class GPT2Model(GPT2ModelBase): ...@@ -150,15 +146,15 @@ class GPT2Model(GPT2ModelBase):
forward_method_parallel_output=forward_method_parallel_output) forward_method_parallel_output=forward_method_parallel_output)
class GPT2ModelFirstStage(GPT2ModelBase): class GPTModelFirstStage(GPTModelBase):
def __init__(self, num_tokentypes=0): def __init__(self, num_tokentypes=0):
super(GPT2ModelFirstStage, self).__init__( super(GPTModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False): tokentype_ids=None, layer_past=None, get_key_value=False):
return super(GPT2ModelFirstStage, self).forward( return super(GPTModelFirstStage, self).forward(
(input_ids, position_ids), (input_ids, position_ids),
attention_mask, attention_mask,
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
...@@ -166,32 +162,32 @@ class GPT2ModelFirstStage(GPT2ModelBase): ...@@ -166,32 +162,32 @@ class GPT2ModelFirstStage(GPT2ModelBase):
get_key_value=get_key_value) get_key_value=get_key_value)
class GPT2ModelIntermediateStage(GPT2ModelBase): class GPTModelIntermediateStage(GPTModelBase):
def __init__(self, num_tokentypes=0): def __init__(self, num_tokentypes=0):
super(GPT2ModelIntermediateStage, self).__init__( super(GPTModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask, def forward(self, hidden_state, attention_mask,
layer_past=None, get_key_value=False): layer_past=None, get_key_value=False):
return super(GPT2ModelIntermediateStage, self).forward( return super(GPTModelIntermediateStage, self).forward(
hidden_state, hidden_state,
attention_mask, attention_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value) get_key_value=get_key_value)
class GPT2ModelLastStage(GPT2ModelBase): class GPTModelLastStage(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2ModelLastStage, self).__init__( super(GPTModelLastStage, self).__init__(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
parallel_output=parallel_output) parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask, labels=None, def forward(self, hidden_state, attention_mask, labels=None,
layer_past=None, get_key_value=False, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
return super(GPT2ModelLastStage, self).forward( return super(GPTModelLastStage, self).forward(
hidden_state, hidden_state,
attention_mask, attention_mask,
labels=labels, labels=labels,
......
...@@ -20,7 +20,8 @@ import torch.nn.functional as F ...@@ -20,7 +20,8 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal from megatron.model.utils import init_method_normal, scaled_init_method_normal
...@@ -42,8 +43,10 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -42,8 +43,10 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return mpu.gather_from_tensor_model_parallel_region(logits_parallel) return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
init_method=None, scaled_init_method=None): encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args() args = get_args()
...@@ -51,15 +54,18 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -51,15 +54,18 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None: if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
# Language model. # Language model.
args = [attention_mask_func, init_method, scaled_init_method] args = [init_method, scaled_init_method, encoder_attn_mask_type]
kwargs = {} kwargs = {}
cls = None cls = None
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModel cls = TransformerLanguageModel
kwargs['num_tokentypes'] = num_tokentypes kwargs['num_tokentypes'] = num_tokentypes
kwargs['add_decoder'] = add_decoder
kwargs['decoder_attn_mask_type'] = decoder_attn_mask_type
kwargs['add_pooler'] = add_pooler kwargs['add_pooler'] = add_pooler
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelFirstStage cls = TransformerLanguageModelFirstStage
...@@ -262,12 +268,6 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -262,12 +268,6 @@ class TransformerLanguageModelBase(MegatronModule):
Arguments: Arguments:
transformer_hparams: transformer hyperparameters transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This max_sequence_length: maximum size of sequence. This
is used for positional embedding is used for positional embedding
...@@ -277,10 +277,12 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -277,10 +277,12 @@ class TransformerLanguageModelBase(MegatronModule):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0, num_tokentypes=0,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelBase, self).__init__() super(TransformerLanguageModelBase, self).__init__()
args = get_args() args = get_args()
...@@ -288,6 +290,9 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -288,6 +290,9 @@ class TransformerLanguageModelBase(MegatronModule):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = init_method self.init_method = init_method
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler self.add_pooler = add_pooler
# Embeddings. # Embeddings.
...@@ -301,41 +306,83 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -301,41 +306,83 @@ class TransformerLanguageModelBase(MegatronModule):
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Transformer. # Transformer.
self.transformer = ParallelTransformer( self.encoder = ParallelTransformer(
attention_mask_func, self.init_method, self.init_method,
output_layer_init_method) output_layer_init_method,
self._transformer_key = 'transformer' self_attn_mask_type=self.encoder_attn_mask_type)
self._encoder_key = 'encoder'
# Pooler.
if mpu.is_pipeline_last_stage() and self.add_pooler: # Decoder
self.pooler = Pooler(self.hidden_size, self.init_method) if self.add_decoder:
self._pooler_key = 'pooler' assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
def forward(self, language_model_input, attention_mask, self.decoder = ParallelTransformer(
tokentype_ids=None, layer_past=None, get_key_value=False, self.init_method,
pooling_sequence_index=0): output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type)
self._decoder_key = 'decoder'
if mpu.is_pipeline_last_stage():
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, enc_language_model_input, enc_attn_mask,
dec_language_model_input=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
# Embeddings. # Embeddings.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = language_model_input (input_ids, position_ids) = enc_language_model_input
embedding_output = self.embedding(input_ids, position_ids, embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
transformer_input = embedding_output encoder_input = embedding_output
else: else:
transformer_input = language_model_input encoder_input = enc_language_model_input
# Transformer. # encoder.
transformer_output = self.transformer(transformer_input, if enc_hidden_states is None:
attention_mask, encoder_output = self.encoder(encoder_input,
layer_past=layer_past, enc_attn_mask,
get_key_value=get_key_value) layer_past=layer_past,
get_key_value=get_key_value)
if mpu.is_pipeline_last_stage() and self.add_pooler: else:
pooled_output = self.pooler(transformer_output, encoder_output = enc_hidden_states.to(encoder_input.dtype)
pooling_sequence_index)
return transformer_output, pooled_output if mpu.is_pipeline_last_stage():
if self.add_pooler:
return transformer_output pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and mpu.is_pipeline_last_stage():
return encoder_output, pooled_output
else:
return encoder_output
# Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler and mpu.is_pipeline_last_stage():
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -346,12 +393,17 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -346,12 +393,17 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._transformer_key] \ state_dict_[self._encoder_key] \
= self.transformer.state_dict_for_save_checkpoint( = self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage() and self.add_pooler: if mpu.is_pipeline_last_stage():
state_dict_[self._pooler_key] \ if self.add_pooler:
= self.pooler.state_dict_for_save_checkpoint( state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
...@@ -371,36 +423,44 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -371,36 +423,44 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[key] = state_dict[key] state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict) self.embedding.load_state_dict(state_dict_, strict=strict)
# Transformer. # Encoder.
if self._transformer_key in state_dict: if self._encoder_key in state_dict:
state_dict_ = state_dict[self._transformer_key] state_dict_ = state_dict[self._encoder_key]
# for compatiability with t5 architecture # for backward compatibility.
# this is temporary unless t5_main is merged elif 'transformer' in state_dict:
elif 'encoder' in state_dict: state_dict_ = state_dict['transformer']
state_dict_ = state_dict['encoder']
# for forward compatibility for t5 architecture
state_dict_attention = {}
for key in state_dict_.keys():
if '.self_attention.' in key:
state_dict_attention[key.replace(".self_attention.",
".attention.")] = state_dict_[key]
else:
state_dict_attention[key] = state_dict_[key]
state_dict_ = state_dict_attention
else: else:
# for backward compatibility. # for backward compatibility.
state_dict_ = {} state_dict_ = {}
for key in state_dict.keys(): for key in state_dict.keys():
if 'transformer.' in key: if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key] state_dict_[key.split('transformer.')[1]] = state_dict[key]
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler. # for backward compatibility.
if mpu.is_pipeline_last_stage() and self.add_pooler: state_dict_self_attention = {}
assert 'pooler' in state_dict, \ for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
if mpu.is_pipeline_last_stage():
# pooler
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
# decoder
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict) strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase): class TransformerLanguageModel(TransformerLanguageModelBase):
...@@ -409,28 +469,39 @@ class TransformerLanguageModel(TransformerLanguageModelBase): ...@@ -409,28 +469,39 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0, num_tokentypes=0,
decoder_attn_mask_type=AttnMaskType.causal,
add_decoder=False,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModel, self).__init__( super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler) add_pooler=add_pooler)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
tokentype_ids=None, layer_past=None, get_key_value=False, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
pooling_sequence_index=0): enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
return super(TransformerLanguageModel, self).forward( return super(TransformerLanguageModel, self).forward(
(input_ids, position_ids), (enc_input_ids, enc_position_ids),
attention_mask, enc_attn_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value, get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
output_enc_hidden=output_enc_hidden
) )
...@@ -440,14 +511,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase): ...@@ -440,14 +511,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0): num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__( super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
...@@ -467,13 +538,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase): ...@@ -467,13 +538,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method): output_layer_init_method,
encoder_attn_mask_type):
super(TransformerLanguageModelIntermediateStage, self).__init__( super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method) output_layer_init_method,
encoder_attn_mask_type)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False): layer_past=None, get_key_value=False):
...@@ -491,14 +562,14 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase): ...@@ -491,14 +562,14 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__( super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
add_pooler=add_pooler) add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
...@@ -509,5 +580,5 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase): ...@@ -509,5 +580,5 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask, attention_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value, get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index pooling_sequence_index=pooling_sequence_index,
) )
...@@ -16,16 +16,31 @@ ...@@ -16,16 +16,31 @@
"""Megatron Module""" """Megatron Module"""
import torch import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module): class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module.""" """Megatron specific extensions of torch Module with support
for pipelining."""
def __init__(self): def __init__(self, share_word_embeddings=True):
super(MegatronModule, self).__init__() super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -34,52 +49,127 @@ class MegatronModule(torch.nn.Module): ...@@ -34,52 +49,127 @@ class MegatronModule(torch.nn.Module):
return self.state_dict(destination, prefix, keep_vars) return self.state_dict(destination, prefix, keep_vars)
class PipelinedMegatronModule(MegatronModule):
"""Pipelining specific extensions of MegatronModule."""
def __init__(self, share_word_embeddings=True):
super(PipelinedMegatronModule, self).__init__()
args = get_args()
self.share_word_embeddings = share_word_embeddings
def word_embeddings_weight(self): def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last stage, ' raise Exception('word_embeddings_weight() called for last '
'but share_word_embeddings is false') 'stage, but share_word_embeddings is false')
return self.word_embeddings.weight return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be ' raise Exception('word_embeddings_weight() should be '
'called for first and last stage only') 'called for first and last stage only')
def initialize_word_embeddings(self, init_method_normal): def initialize_word_embeddings(self, init_method_normal):
args = get_args() args = get_args()
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but ' raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false') 'share_word_embeddings is false')
# Parameters are shared between the word embeddings layer, and the heads at
# the end of the model. In a pipelined setup with more than one stage, the # This function just initializes the word embeddings in the final stage
# initial embedding layer and the head are on different workers, so we do # when we are using pipeline parallelism. If we aren't using pipeline
# the following: # parallelism there is nothing to do.
# 1. Create a second copy of word_embeddings on the last stage, with initial if args.pipeline_model_parallel_size == 1:
# parameters of 0.0. return
# 2. Do an all-reduce between the first and last stage to ensure that the
# two copies of word_embeddings start off with the same parameter values. # Parameters are shared between the word embeddings layer, and the
# 3. In the training loop, before an all-reduce between the grads of the two # heads at the end of the model. In a pipelined setup with more than
# word_embeddings layers to ensure that every applied weight update is the # one stage, the initial embedding layer and the head are on different
# same on both stages. # workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage(): assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings # set word_embeddings weights to 0 here, then copy first
# weights to 0 here, then copy first stage's weights using all_reduce # stage's weights using all_reduce below.
# below. self.word_embeddings = mpu.VocabParallelEmbedding(
self.word_embeddings = mpu.VocabParallelEmbedding( args.padded_vocab_size, args.hidden_size,
args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std))
init_method=init_method_normal(args.init_method_std)) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): # Ensure that first and last stages have the same initial parameter
torch.distributed.all_reduce(self.word_embeddings_weight().data, # values.
group=mpu.get_embedding_group()) if torch.distributed.is_initialized():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = val.half()
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _HALF_TYPES):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class FP16Module(MegatronModule):
def __init__(self, module):
super(FP16Module, self).__init__()
self.add_module('module', module.half())
def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage():
inputs = fp32_to_fp16(inputs)
outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage():
outputs = fp16_to_fp32(outputs)
return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
...@@ -19,15 +19,16 @@ import torch ...@@ -19,15 +19,16 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.module import PipelinedMegatronModule from .module import MegatronModule
class MultipleChoiceBase(PipelinedMegatronModule): class MultipleChoiceBase(MegatronModule):
def __init__(self, num_tokentypes=2): def __init__(self, num_tokentypes=2):
super(MultipleChoiceBase, self).__init__(share_word_embeddings=False) super(MultipleChoiceBase, self).__init__(share_word_embeddings=False)
...@@ -36,9 +37,9 @@ class MultipleChoiceBase(PipelinedMegatronModule): ...@@ -36,9 +37,9 @@ class MultipleChoiceBase(PipelinedMegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
......
...@@ -4,13 +4,14 @@ import torch ...@@ -4,13 +4,14 @@ import torch
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel from megatron.model import BertModel
from megatron.module import MegatronModule from .module import MegatronModule
from megatron import mpu from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False): def general_ict_model_provider(only_query_model=False, only_block_model=False):
...@@ -156,9 +157,9 @@ class IREncoderBertModel(MegatronModule): ...@@ -156,9 +157,9 @@ class IREncoderBertModel(MegatronModule):
args.num_layers) args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
......
...@@ -14,19 +14,18 @@ ...@@ -14,19 +14,18 @@
# limitations under the License. # limitations under the License.
"""Transformer.""" """Transformer."""
import math import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from .module import MegatronModule
from megatron.checkpointing import get_checkpoint_version from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
...@@ -47,12 +46,6 @@ torch._C._jit_override_can_fuse_on_gpu(True) ...@@ -47,12 +46,6 @@ torch._C._jit_override_can_fuse_on_gpu(True)
Transformer takes input of size [s, b, h] and returns a Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments: tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
""" """
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
...@@ -71,7 +64,7 @@ class ParallelMLP(MegatronModule): ...@@ -71,7 +64,7 @@ class ParallelMLP(MegatronModule):
# Project to 4h. # Project to 4h.
self.dense_h_to_4h = mpu.ColumnParallelLinear( self.dense_h_to_4h = mpu.ColumnParallelLinear(
args.hidden_size, args.hidden_size,
4 * args.hidden_size, args.ffn_hidden_size,
gather_output=False, gather_output=False,
init_method=init_method, init_method=init_method,
skip_bias_add=True) skip_bias_add=True)
...@@ -85,12 +78,12 @@ class ParallelMLP(MegatronModule): ...@@ -85,12 +78,12 @@ class ParallelMLP(MegatronModule):
# Project back to h. # Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear( self.dense_4h_to_h = mpu.RowParallelLinear(
4 * args.hidden_size, args.ffn_hidden_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -109,41 +102,60 @@ class ParallelMLP(MegatronModule): ...@@ -109,41 +102,60 @@ class ParallelMLP(MegatronModule):
return output, output_bias return output, output_bias
class ParallelSelfAttention(MegatronModule): class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class. """Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h] Self-attention layer takes input with size [b, s, h]
and returns output of the same size. and returns output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, init_method,
output_layer_init_method, layer_number): output_layer_init_method, layer_number,
super(ParallelSelfAttention, self).__init__() attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__()
args = get_args() args = get_args()
self.fp16 = args.fp16 self.fp16 = args.fp16
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling: if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.attention_type = attention_type
self.attn_mask_type = attn_mask_type
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size() world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(args.hidden_size, self.hidden_size_per_partition = mpu.divide(projection_size,
world_size) world_size)
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = mpu.divide(
args.hidden_size, args.num_attention_heads) projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size) args.num_attention_heads, world_size)
# Strided linear layer. # Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear( if attention_type == AttnType.self_attn:
args.hidden_size, self.query_key_value = mpu.ColumnParallelLinear(
3 * args.hidden_size, args.hidden_size,
gather_output=False, 3 * projection_size,
init_method=init_method) gather_output=False,
init_method=init_method)
else:
assert attention_type == AttnType.cross_attn
self.query = mpu.ColumnParallelLinear(
args.hidden_size,
projection_size,
gather_output=False,
init_method=init_method)
self.key_value = mpu.ColumnParallelLinear(
args.hidden_size,
2 * projection_size,
gather_output=False,
init_method=init_method)
coeff = None coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
...@@ -153,9 +165,9 @@ class ParallelSelfAttention(MegatronModule): ...@@ -153,9 +165,9 @@ class ParallelSelfAttention(MegatronModule):
self.scale_mask_softmax = FusedScaleMaskSoftmax( self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.fp16,
args.scaled_upper_triang_masked_softmax_fusion, self.attn_mask_type,
args.scaled_masked_softmax_fusion, args.masked_softmax_fusion,
self.attention_mask_func, attention_mask_func,
self.attention_softmax_in_fp32, self.attention_softmax_in_fp32,
coeff) coeff)
...@@ -166,75 +178,55 @@ class ParallelSelfAttention(MegatronModule): ...@@ -166,75 +178,55 @@ class ParallelSelfAttention(MegatronModule):
# Output. # Output.
self.dense = mpu.RowParallelLinear( self.dense = mpu.RowParallelLinear(
args.hidden_size, projection_size,
args.hidden_size, args.hidden_size,
input_is_parallel=True, input_is_parallel=True,
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
input_shape = mixed_layer.size();
if num_splits_first:
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\
(num_splits, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
else:
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """
intermediate_shape = input_shape[:-1] +\
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, num_splits)
mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
mixed_layer = mixed_layer.view(*input_shape)
return mixed_layer
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False, encoder_output=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] if self.attention_type == AttnType.self_attn:
mixed_x_layer, _ = self.query_key_value(hidden_states) # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version() # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
if get_args().override_checkpoint_version is not None: new_tensor_shape = mixed_x_layer.size()[:-1] + \
checkpoint_version = get_args().override_checkpoint_version (self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
if checkpoint_version is not None: # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
if checkpoint_version == 0: (query_layer,
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] key_layer,
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True) value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
elif checkpoint_version == 1.0: else:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)] # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False) mixed_kv_layer, _ = self.key_value(encoder_output)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \ new_tensor_shape = mixed_kv_layer.size()[:-1] + \
(self.num_attention_heads_per_partition, (self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head) 2 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(query_layer, (key_layer,
key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape = query_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape)
# ================================== # ==================================
# Adjust key and value for inference # Adjust key and value for inference
...@@ -249,41 +241,41 @@ class ParallelSelfAttention(MegatronModule): ...@@ -249,41 +241,41 @@ class ParallelSelfAttention(MegatronModule):
if get_key_value: if get_key_value:
present = (key_layer, value_layer) present = (key_layer, value_layer)
# =================================== # ===================================
# Raw attention scores. [b, np, s, s] # Raw attention scores. [b, np, s, s]
# =================================== # ===================================
# [b, np, sq, sk] # [b, np, sq, sk]
output_size = (query_layer.size(1), output_size = (query_layer.size(1),
query_layer.size(2), query_layer.size(2),
query_layer.size(0), query_layer.size(0),
key_layer.size(0)) key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn] # [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2], query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3], key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk] # preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty( matmul_result = torch.empty(
output_size[0]*output_size[1], output_size[0]*output_size[1],
output_size[2], output_size[2],
output_size[3], output_size[3],
dtype=query_layer.dtype, dtype=query_layer.dtype,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(matmul_result, matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0,1).transpose(1, 2), #[b * np, hn, sk] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor)) beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk] # change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size) attention_scores = matmul_result.view(*output_size)
# ================================================== # ==================================================
# Update attention mask for inference. [b, np, sq, sk] # Update attention mask for inference. [b, np, sq, sk]
# ================================================== # ==================================================
...@@ -301,7 +293,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -301,7 +293,6 @@ class ParallelSelfAttention(MegatronModule):
:attention_scores.size(3), :attention_scores.size(3),
:attention_scores.size(3)] :attention_scores.size(3)]
# =========================== # ===========================
# Attention probs and dropout # Attention probs and dropout
# =========================== # ===========================
...@@ -315,7 +306,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -315,7 +306,6 @@ class ParallelSelfAttention(MegatronModule):
with mpu.get_cuda_rng_tracker().fork(): with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs) attention_probs = self.attention_dropout(attention_probs)
# ========================= # =========================
# Context layer. [sq, b, hp] # Context layer. [sq, b, hp]
# ========================= # =========================
...@@ -324,21 +314,21 @@ class ParallelSelfAttention(MegatronModule): ...@@ -324,21 +314,21 @@ class ParallelSelfAttention(MegatronModule):
# [sk, b, np, hn] --> [b, np, sq, hn] # [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn] # context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1), output_size = (value_layer.size(1),
value_layer.size(2), value_layer.size(2),
query_layer.size(0), query_layer.size(0),
value_layer.size(3)) value_layer.size(3))
# change view [sk, b * np, hn] # change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0), value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk] # change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1], attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1) output_size[2], -1)
# matmul: [b * np, sq, hn] # matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1)) context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn] # change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size) context_layer = context_layer.view(*output_size)
...@@ -351,7 +341,6 @@ class ParallelSelfAttention(MegatronModule): ...@@ -351,7 +341,6 @@ class ParallelSelfAttention(MegatronModule):
(self.hidden_size_per_partition,) (self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(*new_context_layer_shape)
# ================= # =================
# Output. [sq, b, h] # Output. [sq, b, h]
# ================= # =================
...@@ -364,7 +353,7 @@ class ParallelSelfAttention(MegatronModule): ...@@ -364,7 +353,7 @@ class ParallelSelfAttention(MegatronModule):
return output, bias return output, bias
def bias_dropout_add(x, bias, residual, prob, training) : def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training) out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out out = residual + out
...@@ -378,13 +367,13 @@ def get_bias_dropout_add(training): ...@@ -378,13 +367,13 @@ def get_bias_dropout_add(training):
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_train(x, bias, residual, prob) : def bias_dropout_add_fused_train(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor # type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, True) return bias_dropout_add(x, bias, residual, prob, True)
@torch.jit.script @torch.jit.script
def bias_dropout_add_fused_inference(x, bias, residual, prob) : def bias_dropout_add_fused_inference(x, bias, residual, prob):
# type: (Tensor, Tensor, Tensor, float) -> Tensor # type: (Tensor, Tensor, Tensor, float) -> Tensor
return bias_dropout_add(x, bias, residual, prob, False) return bias_dropout_add(x, bias, residual, prob, False)
...@@ -392,16 +381,18 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) : ...@@ -392,16 +381,18 @@ def bias_dropout_add_fused_inference(x, bias, residual, prob) :
class ParallelTransformerLayer(MegatronModule): class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer. """A single transformer layer.
Transformore layer takes input with size [b, s, h] and returns an Transformer layer takes input with size [b, s, h] and returns an
output of the same size. output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, init_method, output_layer_init_method,
output_layer_init_method, layer_number): layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
args = get_args() args = get_args()
super(ParallelTransformerLayer, self).__init__() super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number self.layer_number = layer_number
self.layer_type = layer_type
self.apply_residual_connection_post_layernorm \ self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
...@@ -413,45 +404,60 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -413,45 +404,60 @@ class ParallelTransformerLayer(MegatronModule):
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
# Self attention. # Self attention.
self.attention = ParallelSelfAttention(attention_mask_func, init_method, self.self_attention = ParallelAttention(
output_layer_init_method, init_method,
layer_number) output_layer_init_method,
layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion self.bias_dropout_fusion = args.bias_dropout_fusion
# Layernorm on the input data. # Layernorm on the attention output
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon) eps=args.layernorm_epsilon)
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
init_method,
output_layer_init_method,
layer_number,
attention_type=AttnType.cross_attn)
# Layernorm on the attention output.
self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size,
eps=args.layernorm_epsilon)
# MLP # MLP
self.mlp = ParallelMLP(init_method, self.mlp = ParallelMLP(init_method,
output_layer_init_method) output_layer_init_method)
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask,
get_key_value=False): encoder_output=None, enc_dec_attn_mask=None,
layer_past=None, get_key_value=False):
# hidden_states: [b, s, h] # hidden_states: [b, s, h]
# Layer norm at the begining of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
# Self attention. # Self attention.
attention_output, attention_bias = \ attention_output, attention_bias = \
self.attention(layernorm_output, self.self_attention(layernorm_output,
attention_mask, attention_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value) get_key_value=get_key_value)
if get_key_value: if get_key_value:
attention_output, presents = attention_output attention_output, presents = attention_output
# Residual connection. # Residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = layernorm_output residual = layernorm_output
else: else:
residual = hidden_states residual = hidden_states
# jit scripting for a nn.module (with dropout) is not # jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two # trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying # different nn.functional routines to account for varying
# dropout semantics during training and inference phases. # dropout semantics during training and inference phases.
if self.bias_dropout_fusion: if self.bias_dropout_fusion:
...@@ -462,7 +468,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -462,7 +468,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
bias_dropout_add_func = get_bias_dropout_add(self.training) bias_dropout_add_func = get_bias_dropout_add(self.training)
#re-enable torch grad to enable fused optimization. # re-enable torch grad to enable fused optimization.
with torch.enable_grad(): with torch.enable_grad():
layernorm_input = bias_dropout_add_func( layernorm_input = bias_dropout_add_func(
attention_output, attention_output,
...@@ -473,6 +479,28 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -473,6 +479,28 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention. # Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input) layernorm_output = self.post_attention_layernorm(layernorm_input)
if self.layer_type == LayerType.decoder:
attention_output, attention_bias = \
self.inter_attention(layernorm_output,
enc_dec_attn_mask,
encoder_output=encoder_output)
# residual connection
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the decoder attention
layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
# MLP. # MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output) mlp_output, mlp_bias = self.mlp(layernorm_output)
...@@ -482,7 +510,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -482,7 +510,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = layernorm_input residual = layernorm_input
#re-enable torch grad to enable fused optimization. # re-enable torch grad to enable fused optimization.
with torch.enable_grad(): with torch.enable_grad():
output = bias_dropout_add_func( output = bias_dropout_add_func(
mlp_output, mlp_output,
...@@ -499,8 +527,9 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -499,8 +527,9 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule): class ParallelTransformer(MegatronModule):
"""Transformer class.""" """Transformer class."""
def __init__(self, attention_mask_func, def __init__(self, init_method, output_layer_init_method,
init_method, output_layer_init_method): layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
...@@ -518,8 +547,11 @@ class ParallelTransformer(MegatronModule): ...@@ -518,8 +547,11 @@ class ParallelTransformer(MegatronModule):
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
attention_mask_func, init_method, init_method,
output_layer_init_method, layer_number) output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
...@@ -534,14 +566,18 @@ class ParallelTransformer(MegatronModule): ...@@ -534,14 +566,18 @@ class ParallelTransformer(MegatronModule):
def _get_layer(self, layer_number): def _get_layer(self, layer_number):
return self.layers[layer_number] return self.layers[layer_number]
def _checkpointed_forward(self, hidden_states, attention_mask): def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
"""Forward method with activation checkpointing.""" """Forward method with activation checkpointing."""
def custom(start, end): def custom(start, end):
def custom_forward(*inputs): def custom_forward(*inputs):
x_ = inputs[0] x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
for index in range(start, end): for index in range(start, end):
layer = self._get_layer(index) layer = self._get_layer(index)
x_ = layer(x_, inputs[1]) x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
return x_ return x_
return custom_forward return custom_forward
...@@ -551,13 +587,13 @@ class ParallelTransformer(MegatronModule): ...@@ -551,13 +587,13 @@ class ParallelTransformer(MegatronModule):
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.checkpoint_num_layers), custom(l, l + self.checkpoint_num_layers),
hidden_states, attention_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.checkpoint_num_layers l += self.checkpoint_num_layers
return hidden_states return hidden_states
def forward(self, hidden_states, attention_mask, layer_past=None, def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False): get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
# Checks. # Checks.
if layer_past is not None: if layer_past is not None:
...@@ -578,9 +614,14 @@ class ParallelTransformer(MegatronModule): ...@@ -578,9 +614,14 @@ class ParallelTransformer(MegatronModule):
else: else:
hidden_states = hidden_states.transpose(0, 1).contiguous() hidden_states = hidden_states.transpose(0, 1).contiguous()
if encoder_output is not None:
encoder_output = encoder_output.transpose(0, 1).contiguous()
if self.checkpoint_activations: if self.checkpoint_activations:
hidden_states = self._checkpointed_forward(hidden_states, hidden_states = self._checkpointed_forward(hidden_states,
attention_mask) attention_mask,
encoder_output,
enc_dec_attn_mask)
else: else:
if get_key_value: if get_key_value:
presents = [] presents = []
...@@ -591,12 +632,14 @@ class ParallelTransformer(MegatronModule): ...@@ -591,12 +632,14 @@ class ParallelTransformer(MegatronModule):
past = layer_past[index] past = layer_past[index]
hidden_states = layer(hidden_states, hidden_states = layer(hidden_states,
attention_mask, attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
layer_past=past, layer_past=past,
get_key_value=get_key_value) get_key_value=get_key_value)
if get_key_value: if get_key_value:
hidden_states, present = hidden_states hidden_states, present = hidden_states
presents.append(present) presents.append(present)
# Final layer norm. # Final layer norm.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# Reverting data format change [s b h] --> [b s h]. # Reverting data format change [s b h] --> [b s h].
......
...@@ -20,7 +20,6 @@ import math ...@@ -20,7 +20,6 @@ import math
import torch import torch
from megatron import get_args from megatron import get_args
from megatron.model import import_layernorm
def init_method_normal(sigma): def init_method_normal(sigma):
"""Init method based on N(0, sigma).""" """Init method based on N(0, sigma)."""
...@@ -40,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers): ...@@ -40,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers):
return init_ return init_
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method): def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization.""" """Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns) layer = torch.nn.Linear(rows, columns)
...@@ -60,28 +64,3 @@ def openai_gelu(x): ...@@ -60,28 +64,3 @@ def openai_gelu(x):
@torch.jit.script @torch.jit.script
def erf_gelu(x): def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype)) return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import math
import einops
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import (
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
)
from .module import MegatronModule
class VitMlpHead(MegatronModule):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size, num_classes):
super(VitMlpHead, self).__init__()
self.dense_in = torch.nn.Linear(hidden_size, hidden_size)
self.dense_out = torch.nn.Linear(hidden_size, num_classes)
torch.nn.init.constant_(self.dense_out.bias, -10)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
x = hidden_states[:, sequence_index, :]
x = self.dense_in(x)
x = torch.tanh(x)
x = self.dense_out(x)
return x
def twod_interpolate_position_embeddings_hook(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
args = get_args()
num_patches_per_dim = args.img_dim // args.patch_dim
num_patches = num_patches_per_dim ** 2
seq_length = num_patches + 1
hidden_size = args.hidden_size
key = prefix + "weight"
# import pdb
# pdb.set_trace()
assert key in state_dict
if key in state_dict:
input_param = state_dict[key]
assert input_param.shape[1] == hidden_size
if input_param.shape[0] != seq_length:
# update input_param and load it to state_dict[key]
num_tok_input = input_param.shape[0] - 1
num_tok_new = seq_length - 1
input_param_tok, input_param_grid = (
input_param[:1, :],
input_param[1:, :],
)
gs_input = int(math.sqrt(num_tok_input))
gs_new = int(math.sqrt(num_tok_new))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
input_param_grid = input_param_grid.reshape(
(1, -1, gs_input, gs_input)
)
input_param_grid = input_param_grid.float()
scale_factor = gs_new / gs_input
input_param_grid = F.interpolate(
input_param_grid, scale_factor=scale_factor, mode="bilinear"
)
input_param_grid = input_param_grid.half()
input_param_grid = input_param_grid.reshape((-1, gs_new * gs_new))
input_param_grid = input_param_grid.transpose(0, 1).contiguous()
assert input_param_grid.shape[1] == hidden_size
input_param = torch.cat((input_param_tok, input_param_grid), dim=0)
assert (
input_param.shape[0] == seq_length
and input_param.shape[1] == hidden_size
)
state_dict[key] = input_param
class VitModel(MegatronModule):
"""Vision Transformer Model."""
def __init__(self, num_classes, finetune=False):
super(VitModel, self).__init__()
args = get_args()
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
if args.init_method_xavier_uniform:
self.init_method = torch.nn.init.xavier_uniform_
self.scaled_init_method = torch.nn.init.xavier_uniform_
else:
self.init_method = init_method_normal(args.init_method_std)
self.scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers
)
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.patch_dim = args.patch_dim
self.img_dim = args.img_dim
self.finetune = finetune
assert self.img_dim % self.patch_dim == 0
self.num_patches_per_dim = self.img_dim // self.patch_dim
self.num_patches = self.num_patches_per_dim ** 2
self.seq_length = self.num_patches + 1
self.flatten_dim = self.patch_dim * self.patch_dim * args.num_channels
# cls_token
self.cls_token = torch.nn.Parameter(torch.randn(1, 1, self.hidden_size))
torch.nn.init.zeros_(self.cls_token)
# Linear encoder
self.linear_encoder = torch.nn.Linear(
self.flatten_dim, self.hidden_size
)
# embedding
self.position_embeddings = torch.nn.Embedding(
self.seq_length, self.hidden_size
)
init_method_normal(args.init_method_std)(
self.position_embeddings.weight
)
self.position_ids = torch.arange(self.seq_length).expand(1, -1).cuda()
self.position_embeddings._register_load_state_dict_pre_hook(
twod_interpolate_position_embeddings_hook
)
self.embedding_dropout = torch.nn.Dropout(args.hidden_dropout)
# Transformer
self.transformer = ParallelTransformer(
self.init_method, self.scaled_init_method
)
# MLP head
if not self.finetune:
self.mlp_head = VitMlpHead(self.hidden_size, self.num_classes)
else:
self.class_head = get_linear_layer(
self.hidden_size, num_classes, torch.nn.init.zeros_
)
def forward(self, x):
x = einops.rearrange(
x,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
p1=self.patch_dim,
p2=self.patch_dim,
)
assert x.dtype == torch.half
x = self.linear_encoder(x)
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.position_embeddings(self.position_ids)
x = self.embedding_dropout(x)
x = self.transformer(x, None)
if not self.finetune:
x = self.mlp_head(x)
else:
x = self.class_head(x[:, 0, :])
return x
...@@ -19,8 +19,6 @@ from .cross_entropy import vocab_parallel_cross_entropy ...@@ -19,8 +19,6 @@ from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data from .data import broadcast_data
from .grads import clip_grad_norm
from .initialize import is_unitialized from .initialize import is_unitialized
from .initialize import destroy_model_parallel from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group from .initialize import get_data_parallel_group
...@@ -46,7 +44,10 @@ from .initialize import model_parallel_is_initialized ...@@ -46,7 +44,10 @@ from .initialize import model_parallel_is_initialized
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
......
...@@ -20,7 +20,7 @@ from .initialize import get_tensor_model_parallel_rank ...@@ -20,7 +20,7 @@ from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 4 _MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype): def _check_data_types(keys, data, target_dtype):
......
...@@ -37,14 +37,54 @@ from .utils import split_tensor_along_last_dim ...@@ -37,14 +37,54 @@ from .utils import split_tensor_along_last_dim
from .utils import VocabUtility from .utils import VocabUtility
from megatron import get_args from megatron import get_args
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
# Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
# Set the attributes.
setattr(tensor, 'tensor_model_parallel', is_parallel)
setattr(tensor, 'partition_dim', dim)
setattr(tensor, 'partition_stride', stride)
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def maybe_set(attribute, value):
if not hasattr(tensor, attribute):
setattr(tensor, attribute, value)
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def maybe_copy(attribute):
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute,
getattr(source_tensor, attribute))
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_copy(attribute)
def _initialize_affine_weight_gpu(weight, init_method, def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1): partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU.""" """Initialize affine weight for model parallel on GPU."""
weight.tensor_model_parallel = True set_tensor_model_parallel_attributes(tensor=weight,
weight.partition_dim = partition_dim is_parallel=True,
weight.partition_stride = stride dim=partition_dim,
stride=stride)
with get_cuda_rng_tracker().fork(): with get_cuda_rng_tracker().fork():
init_method(weight) init_method(weight)
...@@ -58,9 +98,10 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -58,9 +98,10 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
Build the master weight on all processes and scatter Build the master weight on all processes and scatter
the relevant chunk.""" the relevant chunk."""
weight.tensor_model_parallel = True set_tensor_model_parallel_attributes(tensor=weight,
weight.partition_dim = partition_dim is_parallel=True,
weight.partition_stride = stride dim=partition_dim,
stride=stride)
# Initialize master weight # Initialize master weight
master_weight = torch.empty(output_size, input_size, master_weight = torch.empty(output_size, input_size,
...@@ -74,7 +115,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -74,7 +115,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_per_stride_size = divide(per_partition_size, stride) per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim) dim=partition_dim)
rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size] my_weight_list = weight_list[rank::world_size]
...@@ -225,9 +266,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -225,9 +266,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition, self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
self.bias.tensor_model_parallel = True set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
self.bias.partition_dim = 0
self.bias.stride = stride
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
from megatron import get_args
from megatron.model import import_layernorm
from .grad_scaler import ConstantGradScaler, DynamicGradScaler
from .optimizer import FP16OptimizerWithFP16Params, FP32Optimizer
def _get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
def get_megatron_optimizer(model):
args = get_args()
# Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model)
if args.optimizer == 'adam':
optimizer = Adam(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_eps)
elif args.optimizer == 'sgd':
optimizer = SGD(param_groups,
lr=args.lr,
weight_decay=args.weight_decay,
momentum=args.sgd_momentum)
else:
raise Exception('{} optimizer is not supported.'.format(
args.optimizer))
if args.fp16:
# Constant loss scale.
if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale.
else:
grad_scaler = DynamicGradScaler(
initial_scale=args.initial_loss_scale,
min_scale=args.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis)
# Megatron optimizer.
return FP16OptimizerWithFP16Params(optimizer, grad_scaler,
args.clip_grad)
# FP32.
return FP32Optimizer(optimizer, args.clip_grad)
...@@ -13,67 +13,22 @@ ...@@ -13,67 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Gradient clipping."""
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import torch import torch
from torch._six import inf from torch._six import inf
try: from apex.multi_tensor_apply import multi_tensor_applier
from apex.multi_tensor_apply import multi_tensor_applier import amp_C
import amp_C
except Exception as e:
print('WARNING: APEX is not installed, multi_tensor_applier will not be available.')
from .initialize import is_pipeline_first_stage from megatron import mpu
from .initialize import get_model_parallel_group from megatron.model.module import param_is_not_shared
from .initialize import get_tensor_model_parallel_rank from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def l2_grad_clipper(parameters, max_norm): def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Efficient L2 norm gradient clipping.""" """Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
overflow_buf = torch.zeros(1, dtype=torch.int, device='cuda')
# Make sure we have an iterable.
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
# Filter parameters with gradients.
parameters_with_grads = list(filter(
lambda p: p.grad is not None, parameters))
# Filter parameters for norm calculations.
mp_rank_is_zero = (get_tensor_model_parallel_rank() == 0)
parameters_for_norm = list(filter(
lambda p: p.tensor_model_parallel or mp_rank_is_zero, parameters_with_grads))
# Calculate L2 norm.
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
overflow_buf,
[parameters_for_norm],
False # no per-parameter norm
)
# Sum across all model parallel GPUs.
norm_2 = norm * norm
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group())
total_norm = norm_2.item() ** 0.5
# Scale to get max_norm.
clip_coef = float(max_norm) / (total_norm + 1.0e-6)
grads = [p.grad for p in parameters_with_grads]
if clip_coef < 1.0:
multi_tensor_applier(
amp_C.multi_tensor_scale,
overflow_buf,
[grads, grads],
clip_coef)
return total_norm
def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that added functionality to handle model parallel parameters. Note that
...@@ -89,51 +44,77 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None): ...@@ -89,51 +44,77 @@ def clip_grad_norm(parameters, max_norm, norm_type=2, parameter_names=None):
Returns: Returns:
Total norm of the parameters (viewed as a single vector). Total norm of the parameters (viewed as a single vector).
""" """
if isinstance(parameters, torch.Tensor): if isinstance(parameters, torch.Tensor):
parameters = [parameters] parameters = [parameters]
if parameter_names is not None:
filtered_parameters = [] # Filter parameters based on:
assert len(parameters) == len(parameter_names), \ # - grad should not be none
'length of parameters and parameter_names should be the same' # - parameter should not be shared
for p, n in zip(parameters, parameter_names): # - should not be a replica due to tensor model parallelism
if p.grad is not None: grads = []
# TODO: Bit hacky; is there a cleaner way to do this? grads_for_norm = []
# Count embedding layer only once (in first stage). for param in parameters:
# Don't count the weights a second time in the last stage. grad_not_none = param.grad is not None
if "embedding" not in n or \ is_not_shared = param_is_not_shared(param)
is_pipeline_first_stage(): is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
filtered_parameters.append(p) grad = param.grad.detach()
parameters = filtered_parameters if grad_not_none:
else: # Make sure the grads are in fp32
parameters = list(filter(lambda p: p.grad is not None, parameters)) assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm) max_norm = float(max_norm)
norm_type = float(norm_type) norm_type = float(norm_type)
total_norm = 0.0
# Calculate norm.
if norm_type == inf: if norm_type == inf:
total_norm = max(p.grad.data.abs().max() for p in parameters) total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs. # Take max across all model-parallel GPUs.
torch.distributed.all_reduce(total_norm_cuda, torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=get_model_parallel_group()) group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1:
for p in parameters:
p.grad.data.mul_(clip_coef)
else: else:
total_norm = 0 if norm_type == 2.0:
for p in parameters: dummy_overflow_buf = torch.cuda.IntTensor([0])
if p.tensor_model_parallel or (get_tensor_model_parallel_rank() == 0): # Use apex's multi-tensor applier for efficiency reasons.
param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type) # Multi-tensor applier takes a function and a list of list
total_norm += param_norm.item() ** norm_type # and performs the operation on that list all in one kernel.
grad_norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads_for_norm],
False # no per-parameter norm
)
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
else:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) torch.distributed.all_reduce(total_norm,
torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=get_model_parallel_group()) group=mpu.get_model_parallel_group())
total_norm = total_norm_cuda[0].item() ** (1. / norm_type) total_norm = total_norm.item() ** (1.0 / norm_type)
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1: # Scale.
for p in parameters: clip_coeff = max_norm / (total_norm + 1.0e-6)
p.grad.data.mul_(clip_coef) if clip_coeff < 1.0:
dummy_overflow_buf = torch.cuda.IntTensor([0])
multi_tensor_applier(amp_C.multi_tensor_scale,
dummy_overflow_buf,
[grads, grads],
clip_coeff)
return total_norm return total_norm
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron grad scaler."""
from abc import ABC
from abc import abstractmethod
import torch
class MegatronGradScaler(ABC):
def __init__(self, initial_scale):
"""Initialize scale value with the input initial scale."""
assert initial_scale > 0.0
self._scale = torch.cuda.FloatTensor([initial_scale])
@property
def scale(self):
return self._scale
@property
def inv_scale(self):
return self._scale.double().reciprocal().float()
@abstractmethod
def update(self, found_inf):
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
class ConstantGradScaler(MegatronGradScaler):
def update(self, found_inf):
pass
def state_dict(self):
return dict()
def load_state_dict(self, state_dict):
pass
class DynamicGradScaler(MegatronGradScaler):
def __init__(self, initial_scale, min_scale,
growth_factor, backoff_factor,
growth_interval, hysteresis):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
super(DynamicGradScaler, self).__init__(initial_scale)
# Lower bound on the scale.
assert min_scale > 0.0
assert min_scale <= initial_scale
self.min_scale = torch.cuda.FloatTensor([min_scale])
# Growth and backoff factors for the scale.
assert growth_factor > 1.0
self.growth_factor = torch.cuda.FloatTensor([growth_factor])
assert backoff_factor < 1.0
assert backoff_factor > 0.0
self.backoff_factor = torch.cuda.FloatTensor([backoff_factor])
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert growth_interval > 0
self.growth_interval = growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert hysteresis > 0
self.hysteresis = hysteresis
# Trackers.
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
def update(self, found_inf):
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if found_inf:
self._growth_tracker = 0
self._hysteresis_tracker -= 1
# Now if we are out of hysteresis count, scale down the loss.
if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor,
self.min_scale)
else:
# If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1
# If we have had enough consequitive intervals with no nan/inf:
if self._growth_tracker == self.growth_interval:
# Reset the tracker and hysteresis trackers,
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
# and scale up the loss scale.
self._scale = self._scale * self.growth_factor
def state_dict(self):
state_dict = {}
state_dict['scale'] = self._scale
state_dict['growth_tracker'] = self._growth_tracker
state_dict['hysteresis_tracker'] = self._hysteresis_tracker
return state_dict
def load_state_dict(self, state_dict):
self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._growth_tracker = state_dict['growth_tracker']
self._hysteresis_tracker = state_dict['hysteresis_tracker']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment