Commit 651cc107 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Merging megatron with ICT

parents e919dd8e c601d751
...@@ -3,17 +3,17 @@ import torch ...@@ -3,17 +3,17 @@ import torch
import sys import sys
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import fix_query_key_value_ordering
from megatron.module import MegatronModule from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import get_checkpoint_name
from megatron import mpu, get_tokenizer from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids from megatron.model.bert_model import bert_position_ids
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def biencoder_model_provider(only_query_model=False, def biencoder_model_provider(only_query_model=False,
only_context_model=False, only_context_model=False,
...@@ -165,16 +165,32 @@ class BiEncoderModel(MegatronModule): ...@@ -165,16 +165,32 @@ class BiEncoderModel(MegatronModule):
print('global rank {} is loading BERT checkpoint {}'.format( print('global rank {} is loading BERT checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
try: try:
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException: except BaseException:
raise ValueError("Could not load BERT checkpoint") print_rank_0('could not load the BERT checkpoint')
sys.exit()
checkpoint_version = state_dict.get('checkpoint_version', 0)
# load the LM state dict into each model # load the LM state dict into each model
model_dict = state_dict['model']['language_model'] model_dict = state_dict['model']['language_model']
if self.shared_query_context_model: if self.shared_query_context_model:
self.model.language_model.load_state_dict(model_dict) self.model.language_model.load_state_dict(model_dict)
fix_query_key_value_ordering(self.model, checkpoint_version)
else: else:
if self.use_query_model: if self.use_query_model:
self.query_model.language_model.load_state_dict(model_dict) self.query_model.language_model.load_state_dict(model_dict)
...@@ -183,11 +199,14 @@ class BiEncoderModel(MegatronModule): ...@@ -183,11 +199,14 @@ class BiEncoderModel(MegatronModule):
query_proj_state_dict = \ query_proj_state_dict = \
self.state_dict_for_save_checkpoint()\ self.state_dict_for_save_checkpoint()\
[self._query_key]['projection_enc'] [self._query_key]['projection_enc']
fix_query_key_value_ordering(self.query_model, checkpoint_version)
if self.use_context_model: if self.use_context_model:
self.context_model.language_model.load_state_dict(model_dict) self.context_model.language_model.load_state_dict(model_dict)
if self.query_model is not None and self.projection_dim > 0: if self.query_model is not None and self.projection_dim > 0:
self.context_model.projection_enc.load_state_dict\ self.context_model.projection_enc.load_state_dict\
(query_proj_state_dict) (query_proj_state_dict)
fix_query_key_value_ordering(self.context_model, checkpoint_version)
class PretrainedBertModel(MegatronModule): class PretrainedBertModel(MegatronModule):
...@@ -209,9 +228,9 @@ class PretrainedBertModel(MegatronModule): ...@@ -209,9 +228,9 @@ class PretrainedBertModel(MegatronModule):
args.init_method_std, args.num_layers) args.init_method_std, args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
......
...@@ -19,15 +19,16 @@ import torch ...@@ -19,15 +19,16 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.module import PipelinedMegatronModule from .module import MegatronModule
class ClassificationBase(PipelinedMegatronModule): class ClassificationBase(MegatronModule):
def __init__(self, num_classes, num_tokentypes=2): def __init__(self, num_classes, num_tokentypes=2):
super(ClassificationBase, self).__init__(share_word_embeddings=False) super(ClassificationBase, self).__init__(share_word_embeddings=False)
...@@ -37,9 +38,9 @@ class ClassificationBase(PipelinedMegatronModule): ...@@ -37,9 +38,9 @@ class ClassificationBase(PipelinedMegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
......
...@@ -20,7 +20,7 @@ from torch.nn.modules import Module ...@@ -20,7 +20,7 @@ from torch.nn.modules import Module
from torch.autograd import Variable from torch.autograd import Variable
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from .module import MegatronModule
class DistributedDataParallel(MegatronModule): class DistributedDataParallel(MegatronModule):
......
...@@ -12,19 +12,17 @@ ...@@ -12,19 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .fp16util import (
BN_convert_float,
network_to_half,
prep_param_lists,
model_grads_to_master_grads,
master_params_to_model_params,
tofp16,
to_python_float,
clip_grad_norm,
convert_module,
convert_network,
FP16Model,
)
from .fp16 import * import enum
from .loss_scaler import *
class LayerType(enum.Enum):
encoder = 1
decoder = 2
class AttnType(enum.Enum):
self_attn = 1
cross_attn = 2
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
File mode changed from 100755 to 100644
...@@ -14,103 +14,127 @@ ...@@ -14,103 +14,127 @@
# limitations under the License. # limitations under the License.
import torch import torch
from megatron.model.enums import AttnMaskType
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function) :
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
1. Scale the tensor. 1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models). 2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax. 3. Perform softmax.
""" """
@staticmethod @staticmethod
def forward(ctx, inputs, scale): def forward(ctx, inputs, scale):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = \ softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0]) inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = \ input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
scaled_upper_triang_masked_softmax_cuda.backward(output_grads, output_grads, softmax_results, scale_t[0]
softmax_results, )
scale_t[0])
return input_grads, None return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function) :
class ScaledMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
1. Scale the tensor. 1. Scale the tensor.
2. Apply the mask. 2. Apply the mask.
3. Perform softmax. 3. Perform softmax.
""" """
@staticmethod @staticmethod
def forward(ctx, inputs, mask, scale): def forward(ctx, inputs, mask, scale):
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale]) scale_t = torch.tensor([scale])
softmax_results = \ softmax_results = scaled_masked_softmax_cuda.forward(
scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) inputs, mask, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t) ctx.save_for_backward(softmax_results, scale_t)
return softmax_results return softmax_results
@staticmethod @staticmethod
def backward(ctx, output_grads): def backward(ctx, output_grads):
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors softmax_results, scale_t = ctx.saved_tensors
input_grads = \ input_grads = scaled_masked_softmax_cuda.backward(
scaled_masked_softmax_cuda.backward(output_grads, output_grads, softmax_results, scale_t[0]
softmax_results, )
scale_t[0])
return input_grads, None, None return input_grads, None, None
class FusedScaleMaskSoftmax(torch.nn.Module): class FusedScaleMaskSoftmax(torch.nn.Module):
""" """
fused operation: scaling + mask + softmax fused operation: scaling + mask + softmax
Arguments: Arguments:
input_in_fp16: flag to indicate if input in fp16 data format. input_in_fp16: flag to indicate if input in fp16 data format.
upper_triang_mask: if true, apply upper triangular masking. attn_mask_type: attention mask type (pad or causal)
(used in gpt family networks) mask_func: mask function to be applied.
mask_func: mask function to be applied. softmax_in_fp32: if true, softmax in performed at fp32 precision.
softmax_in_fp32: if true, softmax in performed at fp32 precision. scale: scaling factor used in input tensor scaling.
scale: scaling factor used in input tensor scaling.
""" """
def __init__(self, input_in_fp16, upper_triang_mask_fusion,
general_mask_fusion, mask_func, softmax_in_fp32, scale): def __init__(
self,
input_in_fp16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__() super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16 self.input_in_fp16 = input_in_fp16
self.upper_triang_mask_fusion = upper_triang_mask_fusion self.attn_mask_type = attn_mask_type
self.general_mask_fusion = general_mask_fusion self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32 self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale self.scale = scale
assert self.scale is None or softmax_in_fp32, \ assert (
'softmax should be in fp32 when scaled' self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, s, s] # [b, np, sq, sk]
data_size = input.size() data_size = input.size()
assert input.dim() == 4 query_seq_len = data_size[-2]
key_seq_len = data_size[-1]
assert input.dim() == 4
# invoke custom kernel # invoke custom kernel
if self.input_in_fp16 and data_size[-1] <= 2048 and \ if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \
(self.upper_triang_mask_fusion or self.general_mask_fusion) and \ query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion:
input.size()[2] == input.size()[3]:
scale = self.scale if self.scale is not None else 1.0 scale = self.scale if self.scale is not None else 1.0
if self.upper_triang_mask_fusion:
input = input.view(-1, data_size[2], data_size[3]) if self.attn_mask_type == AttnMaskType.causal:
assert query_seq_len == key_seq_len, \
"causal mask is only for self attention"
input = input.view(-1, query_seq_len, key_seq_len)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
probs = probs.view(*data_size) probs = probs.view(*data_size)
else: else:
assert self.attn_mask_type == AttnMaskType.padding
probs = ScaledMaskedSoftmax.apply(input, mask, scale) probs = ScaledMaskedSoftmax.apply(input, mask, scale)
else: else:
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_fp16 and self.softmax_in_fp32:
...@@ -118,7 +142,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -118,7 +142,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
if self.scale is not None: if self.scale is not None:
input = input * self.scale input = input * self.scale
mask_output = self.mask_func(input, mask) mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output) probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_fp16 and self.softmax_in_fp32: if self.input_in_fp16 and self.softmax_in_fp32:
......
...@@ -19,19 +19,15 @@ import torch ...@@ -19,19 +19,15 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import PipelinedMegatronModule from .module import MegatronModule
from .enums import AttnMaskType
from .language_model import parallel_lm_logits from .language_model import parallel_lm_logits
from .language_model import get_language_model from .language_model import get_language_model
from .utils import init_method_normal from .utils import init_method_normal
from .utils import scaled_init_method_normal from .utils import scaled_init_method_normal
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
def post_language_model_processing(lm_output, labels, logit_weights, def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output, get_key_value, parallel_output,
forward_method_parallel_output, forward_method_parallel_output,
...@@ -61,37 +57,37 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -61,37 +57,37 @@ def post_language_model_processing(lm_output, labels, logit_weights,
return loss return loss
class GPT2ModelBase(PipelinedMegatronModule): class GPTModelBase(MegatronModule):
"""GPT-2 Language model.""" """GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2ModelBase, self).__init__() super(GPTModelBase, self).__init__()
args = get_args() args = get_args()
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt2_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std), init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings(init_method_normal)
def forward(self, gpt2_model_input, attention_mask, labels=None, def forward(self, gpt_model_input, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value} kwargs = {'layer_past': layer_past, 'get_key_value': get_key_value}
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = gpt2_model_input (input_ids, position_ids) = gpt_model_input
args = [input_ids, position_ids, attention_mask] args = [input_ids, position_ids, attention_mask]
kwargs['tokentype_ids'] = tokentype_ids kwargs['tokentype_ids'] = tokentype_ids
else: else:
args = [gpt2_model_input, attention_mask] args = [gpt_model_input, attention_mask]
lm_output = self.language_model(*args, **kwargs) lm_output = self.language_model(*args, **kwargs)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
...@@ -130,17 +126,17 @@ class GPT2ModelBase(PipelinedMegatronModule): ...@@ -130,17 +126,17 @@ class GPT2ModelBase(PipelinedMegatronModule):
self.language_model.load_state_dict(state_dict, strict=strict) self.language_model.load_state_dict(state_dict, strict=strict)
class GPT2Model(GPT2ModelBase): class GPTModel(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2Model, self).__init__( super(GPTModel, self).__init__(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
parallel_output=parallel_output) parallel_output=parallel_output)
def forward(self, input_ids, position_ids, attention_mask, labels=None, def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False, tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
return super(GPT2Model, self).forward( return super(GPTModel, self).forward(
(input_ids, position_ids), (input_ids, position_ids),
attention_mask, attention_mask,
labels=labels, labels=labels,
...@@ -150,15 +146,15 @@ class GPT2Model(GPT2ModelBase): ...@@ -150,15 +146,15 @@ class GPT2Model(GPT2ModelBase):
forward_method_parallel_output=forward_method_parallel_output) forward_method_parallel_output=forward_method_parallel_output)
class GPT2ModelFirstStage(GPT2ModelBase): class GPTModelFirstStage(GPTModelBase):
def __init__(self, num_tokentypes=0): def __init__(self, num_tokentypes=0):
super(GPT2ModelFirstStage, self).__init__( super(GPTModelFirstStage, self).__init__(
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False): tokentype_ids=None, layer_past=None, get_key_value=False):
return super(GPT2ModelFirstStage, self).forward( return super(GPTModelFirstStage, self).forward(
(input_ids, position_ids), (input_ids, position_ids),
attention_mask, attention_mask,
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
...@@ -166,32 +162,32 @@ class GPT2ModelFirstStage(GPT2ModelBase): ...@@ -166,32 +162,32 @@ class GPT2ModelFirstStage(GPT2ModelBase):
get_key_value=get_key_value) get_key_value=get_key_value)
class GPT2ModelIntermediateStage(GPT2ModelBase): class GPTModelIntermediateStage(GPTModelBase):
def __init__(self, num_tokentypes=0): def __init__(self, num_tokentypes=0):
super(GPT2ModelIntermediateStage, self).__init__( super(GPTModelIntermediateStage, self).__init__(
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
def forward(self, hidden_state, attention_mask, def forward(self, hidden_state, attention_mask,
layer_past=None, get_key_value=False): layer_past=None, get_key_value=False):
return super(GPT2ModelIntermediateStage, self).forward( return super(GPTModelIntermediateStage, self).forward(
hidden_state, hidden_state,
attention_mask, attention_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value) get_key_value=get_key_value)
class GPT2ModelLastStage(GPT2ModelBase): class GPTModelLastStage(GPTModelBase):
def __init__(self, num_tokentypes=0, parallel_output=True): def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2ModelLastStage, self).__init__( super(GPTModelLastStage, self).__init__(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
parallel_output=parallel_output) parallel_output=parallel_output)
def forward(self, hidden_state, attention_mask, labels=None, def forward(self, hidden_state, attention_mask, labels=None,
layer_past=None, get_key_value=False, layer_past=None, get_key_value=False,
forward_method_parallel_output=None): forward_method_parallel_output=None):
return super(GPT2ModelLastStage, self).forward( return super(GPTModelLastStage, self).forward(
hidden_state, hidden_state,
attention_mask, attention_mask,
labels=labels, labels=labels,
......
...@@ -20,7 +20,8 @@ import torch.nn.functional as F ...@@ -20,7 +20,8 @@ import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.module import MegatronModule from .module import MegatronModule
from megatron.model.enums import LayerType, AttnMaskType
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal from megatron.model.utils import init_method_normal, scaled_init_method_normal
...@@ -42,8 +43,10 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -42,8 +43,10 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return mpu.gather_from_tensor_model_parallel_region(logits_parallel) return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
init_method=None, scaled_init_method=None): encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args() args = get_args()
...@@ -51,15 +54,18 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -51,15 +54,18 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
if scaled_init_method is None: if scaled_init_method is None:
scaled_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers) scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers)
# Language model. # Language model.
args = [attention_mask_func, init_method, scaled_init_method] args = [init_method, scaled_init_method, encoder_attn_mask_type]
kwargs = {} kwargs = {}
cls = None cls = None
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModel cls = TransformerLanguageModel
kwargs['num_tokentypes'] = num_tokentypes kwargs['num_tokentypes'] = num_tokentypes
kwargs['add_decoder'] = add_decoder
kwargs['decoder_attn_mask_type'] = decoder_attn_mask_type
kwargs['add_pooler'] = add_pooler kwargs['add_pooler'] = add_pooler
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelFirstStage cls = TransformerLanguageModelFirstStage
...@@ -262,12 +268,6 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -262,12 +268,6 @@ class TransformerLanguageModelBase(MegatronModule):
Arguments: Arguments:
transformer_hparams: transformer hyperparameters transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This max_sequence_length: maximum size of sequence. This
is used for positional embedding is used for positional embedding
...@@ -277,10 +277,12 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -277,10 +277,12 @@ class TransformerLanguageModelBase(MegatronModule):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0, num_tokentypes=0,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelBase, self).__init__() super(TransformerLanguageModelBase, self).__init__()
args = get_args() args = get_args()
...@@ -288,6 +290,9 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -288,6 +290,9 @@ class TransformerLanguageModelBase(MegatronModule):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = init_method self.init_method = init_method
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler self.add_pooler = add_pooler
# Embeddings. # Embeddings.
...@@ -301,41 +306,83 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -301,41 +306,83 @@ class TransformerLanguageModelBase(MegatronModule):
self._embedding_key = 'embedding' self._embedding_key = 'embedding'
# Transformer. # Transformer.
self.transformer = ParallelTransformer( self.encoder = ParallelTransformer(
attention_mask_func, self.init_method, self.init_method,
output_layer_init_method) output_layer_init_method,
self._transformer_key = 'transformer' self_attn_mask_type=self.encoder_attn_mask_type)
self._encoder_key = 'encoder'
# Pooler.
if mpu.is_pipeline_last_stage() and self.add_pooler: # Decoder
self.pooler = Pooler(self.hidden_size, self.init_method) if self.add_decoder:
self._pooler_key = 'pooler' assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
def forward(self, language_model_input, attention_mask, self.decoder = ParallelTransformer(
tokentype_ids=None, layer_past=None, get_key_value=False, self.init_method,
pooling_sequence_index=0): output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=self.decoder_attn_mask_type)
self._decoder_key = 'decoder'
if mpu.is_pipeline_last_stage():
# Pooler.
if self.add_pooler:
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, enc_language_model_input, enc_attn_mask,
dec_language_model_input=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
# Embeddings. # Embeddings.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
(input_ids, position_ids) = language_model_input (input_ids, position_ids) = enc_language_model_input
embedding_output = self.embedding(input_ids, position_ids, embedding_output = self.embedding(input_ids, position_ids,
tokentype_ids=tokentype_ids) tokentype_ids=tokentype_ids)
transformer_input = embedding_output encoder_input = embedding_output
else: else:
transformer_input = language_model_input encoder_input = enc_language_model_input
# Transformer. # encoder.
transformer_output = self.transformer(transformer_input, if enc_hidden_states is None:
attention_mask, encoder_output = self.encoder(encoder_input,
layer_past=layer_past, enc_attn_mask,
get_key_value=get_key_value) layer_past=layer_past,
get_key_value=get_key_value)
if mpu.is_pipeline_last_stage() and self.add_pooler: else:
pooled_output = self.pooler(transformer_output, encoder_output = enc_hidden_states.to(encoder_input.dtype)
pooling_sequence_index)
return transformer_output, pooled_output if mpu.is_pipeline_last_stage():
if self.add_pooler:
return transformer_output pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if not self.add_decoder or output_enc_hidden:
if self.add_pooler and mpu.is_pipeline_last_stage():
return encoder_output, pooled_output
else:
return encoder_output
# Decoder Embedding
(dec_input_ids, dec_position_ids) = dec_language_model_input
dec_embedding_output = self.embedding(dec_input_ids,
dec_position_ids)
# decoder
decoder_output = self.decoder(dec_embedding_output,
dec_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask)
if self.add_pooler and mpu.is_pipeline_last_stage():
return decoder_output, encoder_output, pooled_output
else:
return decoder_output, encoder_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -346,12 +393,17 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -346,12 +393,17 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[self._embedding_key] \ state_dict_[self._embedding_key] \
= self.embedding.state_dict_for_save_checkpoint( = self.embedding.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
state_dict_[self._transformer_key] \ state_dict_[self._encoder_key] \
= self.transformer.state_dict_for_save_checkpoint( = self.encoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage() and self.add_pooler: if mpu.is_pipeline_last_stage():
state_dict_[self._pooler_key] \ if self.add_pooler:
= self.pooler.state_dict_for_save_checkpoint( state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) destination, prefix, keep_vars)
return state_dict_ return state_dict_
...@@ -371,36 +423,44 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -371,36 +423,44 @@ class TransformerLanguageModelBase(MegatronModule):
state_dict_[key] = state_dict[key] state_dict_[key] = state_dict[key]
self.embedding.load_state_dict(state_dict_, strict=strict) self.embedding.load_state_dict(state_dict_, strict=strict)
# Transformer. # Encoder.
if self._transformer_key in state_dict: if self._encoder_key in state_dict:
state_dict_ = state_dict[self._transformer_key] state_dict_ = state_dict[self._encoder_key]
# for compatiability with t5 architecture # for backward compatibility.
# this is temporary unless t5_main is merged elif 'transformer' in state_dict:
elif 'encoder' in state_dict: state_dict_ = state_dict['transformer']
state_dict_ = state_dict['encoder']
# for forward compatibility for t5 architecture
state_dict_attention = {}
for key in state_dict_.keys():
if '.self_attention.' in key:
state_dict_attention[key.replace(".self_attention.",
".attention.")] = state_dict_[key]
else:
state_dict_attention[key] = state_dict_[key]
state_dict_ = state_dict_attention
else: else:
# for backward compatibility. # for backward compatibility.
state_dict_ = {} state_dict_ = {}
for key in state_dict.keys(): for key in state_dict.keys():
if 'transformer.' in key: if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key] state_dict_[key.split('transformer.')[1]] = state_dict[key]
self.transformer.load_state_dict(state_dict_, strict=strict)
# Pooler. # for backward compatibility.
if mpu.is_pipeline_last_stage() and self.add_pooler: state_dict_self_attention = {}
assert 'pooler' in state_dict, \ for key in state_dict_.keys():
if '.attention.' in key:
state_dict_self_attention[key.replace(".attention.",
".self_attention.")] = state_dict_[key]
else:
state_dict_self_attention[key] = state_dict_[key]
state_dict_ = state_dict_self_attention
self.encoder.load_state_dict(state_dict_, strict=strict)
if mpu.is_pipeline_last_stage():
# pooler
if self.add_pooler:
assert 'pooler' in state_dict, \
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
# decoder
if self.add_decoder:
assert 'decoder' in state_dict, \
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.decoder.load_state_dict(state_dict[self._decoder_key],
strict=strict) strict=strict)
class TransformerLanguageModel(TransformerLanguageModelBase): class TransformerLanguageModel(TransformerLanguageModelBase):
...@@ -409,28 +469,39 @@ class TransformerLanguageModel(TransformerLanguageModelBase): ...@@ -409,28 +469,39 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0, num_tokentypes=0,
decoder_attn_mask_type=AttnMaskType.causal,
add_decoder=False,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModel, self).__init__( super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler) add_pooler=add_pooler)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
tokentype_ids=None, layer_past=None, get_key_value=False, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
pooling_sequence_index=0): enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
return super(TransformerLanguageModel, self).forward( return super(TransformerLanguageModel, self).forward(
(input_ids, position_ids), (enc_input_ids, enc_position_ids),
attention_mask, enc_attn_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
tokentype_ids=tokentype_ids, tokentype_ids=tokentype_ids,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value, get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index pooling_sequence_index=pooling_sequence_index,
enc_hidden_states=enc_hidden_states,
output_enc_hidden=output_enc_hidden
) )
...@@ -440,14 +511,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase): ...@@ -440,14 +511,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0): num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__( super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes) num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
...@@ -467,13 +538,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase): ...@@ -467,13 +538,13 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method): output_layer_init_method,
encoder_attn_mask_type):
super(TransformerLanguageModelIntermediateStage, self).__init__( super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method) output_layer_init_method,
encoder_attn_mask_type)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False): layer_past=None, get_key_value=False):
...@@ -491,14 +562,14 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase): ...@@ -491,14 +562,14 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__( super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
add_pooler=add_pooler) add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
...@@ -509,5 +580,5 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase): ...@@ -509,5 +580,5 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask, attention_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value, get_key_value=get_key_value,
pooling_sequence_index=pooling_sequence_index pooling_sequence_index=pooling_sequence_index,
) )
...@@ -16,16 +16,31 @@ ...@@ -16,16 +16,31 @@
"""Megatron Module""" """Megatron Module"""
import torch import torch
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
_FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
_HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def param_is_not_shared(param):
return not hasattr(param, 'shared') or not param.shared
class MegatronModule(torch.nn.Module): class MegatronModule(torch.nn.Module):
"""Megatron specific extensions of torch Module.""" """Megatron specific extensions of torch Module with support
for pipelining."""
def __init__(self): def __init__(self, share_word_embeddings=True):
super(MegatronModule, self).__init__() super(MegatronModule, self).__init__()
self.share_word_embeddings = share_word_embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False): keep_vars=False):
...@@ -34,52 +49,127 @@ class MegatronModule(torch.nn.Module): ...@@ -34,52 +49,127 @@ class MegatronModule(torch.nn.Module):
return self.state_dict(destination, prefix, keep_vars) return self.state_dict(destination, prefix, keep_vars)
class PipelinedMegatronModule(MegatronModule):
"""Pipelining specific extensions of MegatronModule."""
def __init__(self, share_word_embeddings=True):
super(PipelinedMegatronModule, self).__init__()
args = get_args()
self.share_word_embeddings = share_word_embeddings
def word_embeddings_weight(self): def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
return self.language_model.embedding.word_embeddings.weight return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('word_embeddings_weight() called for last stage, ' raise Exception('word_embeddings_weight() called for last '
'but share_word_embeddings is false') 'stage, but share_word_embeddings is false')
return self.word_embeddings.weight return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be ' raise Exception('word_embeddings_weight() should be '
'called for first and last stage only') 'called for first and last stage only')
def initialize_word_embeddings(self, init_method_normal): def initialize_word_embeddings(self, init_method_normal):
args = get_args() args = get_args()
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but ' raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false') 'share_word_embeddings is false')
# Parameters are shared between the word embeddings layer, and the heads at
# the end of the model. In a pipelined setup with more than one stage, the # This function just initializes the word embeddings in the final stage
# initial embedding layer and the head are on different workers, so we do # when we are using pipeline parallelism. If we aren't using pipeline
# the following: # parallelism there is nothing to do.
# 1. Create a second copy of word_embeddings on the last stage, with initial if args.pipeline_model_parallel_size == 1:
# parameters of 0.0. return
# 2. Do an all-reduce between the first and last stage to ensure that the
# two copies of word_embeddings start off with the same parameter values. # Parameters are shared between the word embeddings layer, and the
# 3. In the training loop, before an all-reduce between the grads of the two # heads at the end of the model. In a pipelined setup with more than
# word_embeddings layers to ensure that every applied weight update is the # one stage, the initial embedding layer and the head are on different
# same on both stages. # workers, so we do the following:
# 1. Create a second copy of word_embeddings on the last stage, with
# initial parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that
# the two copies of word_embeddings start off with the same
# parameter values.
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage(): assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings # set word_embeddings weights to 0 here, then copy first
# weights to 0 here, then copy first stage's weights using all_reduce # stage's weights using all_reduce below.
# below. self.word_embeddings = mpu.VocabParallelEmbedding(
self.word_embeddings = mpu.VocabParallelEmbedding( args.padded_vocab_size, args.hidden_size,
args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std))
init_method=init_method_normal(args.init_method_std)) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): # Ensure that first and last stages have the same initial parameter
torch.distributed.all_reduce(self.word_embeddings_weight().data, # values.
group=mpu.get_embedding_group()) if torch.distributed.is_initialized():
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
else:
print("WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong.")
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val`
#is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _FLOAT_TYPES):
val = val.half()
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, _HALF_TYPES):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class FP16Module(MegatronModule):
def __init__(self, module):
super(FP16Module, self).__init__()
self.add_module('module', module.half())
def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage():
inputs = fp32_to_fp16(inputs)
outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage():
outputs = fp16_to_fp32(outputs)
return outputs
def state_dict(self, destination=None, prefix='', keep_vars=False):
return self.module.state_dict(destination, prefix, keep_vars)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
return self.module.state_dict_for_save_checkpoint(destination, prefix,
keep_vars)
def load_state_dict(self, state_dict, strict=True):
self.module.load_state_dict(state_dict, strict=strict)
...@@ -19,15 +19,16 @@ import torch ...@@ -19,15 +19,16 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.module import PipelinedMegatronModule from .module import MegatronModule
class MultipleChoiceBase(PipelinedMegatronModule): class MultipleChoiceBase(MegatronModule):
def __init__(self, num_tokentypes=2): def __init__(self, num_tokentypes=2):
super(MultipleChoiceBase, self).__init__(share_word_embeddings=False) super(MultipleChoiceBase, self).__init__(share_word_embeddings=False)
...@@ -36,9 +37,9 @@ class MultipleChoiceBase(PipelinedMegatronModule): ...@@ -36,9 +37,9 @@ class MultipleChoiceBase(PipelinedMegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
......
...@@ -4,13 +4,14 @@ import torch ...@@ -4,13 +4,14 @@ import torch
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.model import BertModel from megatron.model import BertModel
from megatron.module import MegatronModule from .module import MegatronModule
from megatron import mpu from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False): def general_ict_model_provider(only_query_model=False, only_block_model=False):
...@@ -156,9 +157,9 @@ class IREncoderBertModel(MegatronModule): ...@@ -156,9 +157,9 @@ class IREncoderBertModel(MegatronModule):
args.num_layers) args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
......
This diff is collapsed.
...@@ -20,7 +20,6 @@ import math ...@@ -20,7 +20,6 @@ import math
import torch import torch
from megatron import get_args from megatron import get_args
from megatron.model import import_layernorm
def init_method_normal(sigma): def init_method_normal(sigma):
"""Init method based on N(0, sigma).""" """Init method based on N(0, sigma)."""
...@@ -40,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers): ...@@ -40,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers):
return init_ return init_
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method): def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization.""" """Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns) layer = torch.nn.Linear(rows, columns)
...@@ -60,28 +64,3 @@ def openai_gelu(x): ...@@ -60,28 +64,3 @@ def openai_gelu(x):
@torch.jit.script @torch.jit.script
def erf_gelu(x): def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype)) return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
def get_params_for_weight_decay_optimization(module):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args = get_args()
LayerNorm = import_layernorm(args.fp32_residual_connection)
weight_decay_params = {'params': []}
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module_ in module.modules():
if isinstance(module_, LayerNorm):
no_weight_decay_params['params'].extend(
[p for p in list(module_._parameters.values())
if p is not None])
else:
weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n != 'bias'])
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params
This diff is collapsed.
...@@ -19,8 +19,6 @@ from .cross_entropy import vocab_parallel_cross_entropy ...@@ -19,8 +19,6 @@ from .cross_entropy import vocab_parallel_cross_entropy
from .data import broadcast_data from .data import broadcast_data
from .grads import clip_grad_norm
from .initialize import is_unitialized from .initialize import is_unitialized
from .initialize import destroy_model_parallel from .initialize import destroy_model_parallel
from .initialize import get_data_parallel_group from .initialize import get_data_parallel_group
...@@ -46,7 +44,10 @@ from .initialize import model_parallel_is_initialized ...@@ -46,7 +44,10 @@ from .initialize import model_parallel_is_initialized
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
......
...@@ -20,7 +20,7 @@ from .initialize import get_tensor_model_parallel_rank ...@@ -20,7 +20,7 @@ from .initialize import get_tensor_model_parallel_rank
from .initialize import get_tensor_model_parallel_src_rank from .initialize import get_tensor_model_parallel_src_rank
_MAX_DATA_DIM = 4 _MAX_DATA_DIM = 5
def _check_data_types(keys, data, target_dtype): def _check_data_types(keys, data, target_dtype):
......
...@@ -37,14 +37,54 @@ from .utils import split_tensor_along_last_dim ...@@ -37,14 +37,54 @@ from .utils import split_tensor_along_last_dim
from .utils import VocabUtility from .utils import VocabUtility
from megatron import get_args from megatron import get_args
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
# Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
# Set the attributes.
setattr(tensor, 'tensor_model_parallel', is_parallel)
setattr(tensor, 'partition_dim', dim)
setattr(tensor, 'partition_stride', stride)
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def maybe_set(attribute, value):
if not hasattr(tensor, attribute):
setattr(tensor, attribute, value)
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def maybe_copy(attribute):
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute,
getattr(source_tensor, attribute))
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_copy(attribute)
def _initialize_affine_weight_gpu(weight, init_method, def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1): partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU.""" """Initialize affine weight for model parallel on GPU."""
weight.tensor_model_parallel = True set_tensor_model_parallel_attributes(tensor=weight,
weight.partition_dim = partition_dim is_parallel=True,
weight.partition_stride = stride dim=partition_dim,
stride=stride)
with get_cuda_rng_tracker().fork(): with get_cuda_rng_tracker().fork():
init_method(weight) init_method(weight)
...@@ -58,9 +98,10 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -58,9 +98,10 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
Build the master weight on all processes and scatter Build the master weight on all processes and scatter
the relevant chunk.""" the relevant chunk."""
weight.tensor_model_parallel = True set_tensor_model_parallel_attributes(tensor=weight,
weight.partition_dim = partition_dim is_parallel=True,
weight.partition_stride = stride dim=partition_dim,
stride=stride)
# Initialize master weight # Initialize master weight
master_weight = torch.empty(output_size, input_size, master_weight = torch.empty(output_size, input_size,
...@@ -74,7 +115,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size, ...@@ -74,7 +115,7 @@ def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_per_stride_size = divide(per_partition_size, stride) per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size, weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim) dim=partition_dim)
rank = get_model_parallel_rank() rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size] my_weight_list = weight_list[rank::world_size]
...@@ -225,9 +266,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -225,9 +266,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size_per_partition, self.output_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
self.bias.tensor_model_parallel = True set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
self.bias.partition_dim = 0
self.bias.stride = stride
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment