Commit ebf8b89e authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

attention_mask_func cleanup

parent 1aa2e08a
...@@ -29,10 +29,6 @@ from megatron.model.utils import init_method_normal ...@@ -29,10 +29,6 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule from .module import MegatronModule
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def bert_extended_attention_mask(attention_mask): def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask. # We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s] # [b, 1, s]
...@@ -145,7 +141,6 @@ class BertModelBase(MegatronModule): ...@@ -145,7 +141,6 @@ class BertModelBase(MegatronModule):
args.num_layers) args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head, add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
...@@ -38,7 +38,6 @@ class ClassificationBase(MegatronModule): ...@@ -38,7 +38,6 @@ class ClassificationBase(MegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
......
...@@ -28,11 +28,6 @@ from .utils import init_method_normal ...@@ -28,11 +28,6 @@ from .utils import init_method_normal
from .utils import scaled_init_method_normal from .utils import scaled_init_method_normal
def gpt_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
def post_language_model_processing(lm_output, labels, logit_weights, def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output, get_key_value, parallel_output,
forward_method_parallel_output, forward_method_parallel_output,
...@@ -73,7 +68,6 @@ class GPTModelBase(MegatronModule): ...@@ -73,7 +68,6 @@ class GPTModelBase(MegatronModule):
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal, encoder_attn_mask_type=AttnMaskType.causal,
......
...@@ -43,7 +43,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -43,7 +43,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return mpu.gather_from_tensor_model_parallel_region(logits_parallel) return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None, encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False, scaled_init_method=None, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal): decoder_attn_mask_type=AttnMaskType.causal):
...@@ -58,8 +58,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -58,8 +58,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
args.num_layers) args.num_layers)
# Language model. # Language model.
args = [attention_mask_func, init_method, args = [init_method, scaled_init_method, encoder_attn_mask_type]
scaled_init_method, encoder_attn_mask_type]
kwargs = {} kwargs = {}
cls = None cls = None
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
...@@ -269,12 +268,6 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -269,12 +268,6 @@ class TransformerLanguageModelBase(MegatronModule):
Arguments: Arguments:
transformer_hparams: transformer hyperparameters transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This max_sequence_length: maximum size of sequence. This
is used for positional embedding is used for positional embedding
...@@ -284,7 +277,6 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -284,7 +277,6 @@ class TransformerLanguageModelBase(MegatronModule):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
...@@ -315,7 +307,6 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -315,7 +307,6 @@ class TransformerLanguageModelBase(MegatronModule):
# Transformer. # Transformer.
self.encoder = ParallelTransformer( self.encoder = ParallelTransformer(
attention_mask_func,
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type) self_attn_mask_type=self.encoder_attn_mask_type)
...@@ -326,7 +317,6 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -326,7 +317,6 @@ class TransformerLanguageModelBase(MegatronModule):
assert args.pipeline_model_parallel_size == 1, \ assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder' 'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer( self.decoder = ParallelTransformer(
attention_mask_func,
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
layer_type=LayerType.decoder, layer_type=LayerType.decoder,
...@@ -479,7 +469,6 @@ class TransformerLanguageModel(TransformerLanguageModelBase): ...@@ -479,7 +469,6 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
...@@ -488,7 +477,6 @@ class TransformerLanguageModel(TransformerLanguageModelBase): ...@@ -488,7 +477,6 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
add_decoder=False, add_decoder=False,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModel, self).__init__( super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
...@@ -523,13 +511,11 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase): ...@@ -523,13 +511,11 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
num_tokentypes=0): num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__( super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
...@@ -552,12 +538,10 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase): ...@@ -552,12 +538,10 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type): encoder_attn_mask_type):
super(TransformerLanguageModelIntermediateStage, self).__init__( super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type) encoder_attn_mask_type)
...@@ -578,13 +562,11 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase): ...@@ -578,13 +562,11 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
""" """
def __init__(self, def __init__(self,
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__( super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type, encoder_attn_mask_type,
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.enums import AttnMaskType from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
...@@ -37,7 +37,6 @@ class MultipleChoiceBase(MegatronModule): ...@@ -37,7 +37,6 @@ class MultipleChoiceBase(MegatronModule):
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
......
...@@ -11,7 +11,7 @@ from megatron.model.utils import get_linear_layer ...@@ -11,7 +11,7 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False): def general_ict_model_provider(only_query_model=False, only_block_model=False):
...@@ -157,7 +157,6 @@ class IREncoderBertModel(MegatronModule): ...@@ -157,7 +157,6 @@ class IREncoderBertModel(MegatronModule):
args.num_layers) args.num_layers)
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
......
...@@ -26,7 +26,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType ...@@ -26,7 +26,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
...@@ -47,12 +47,6 @@ torch._C._jit_override_can_fuse_on_gpu(True) ...@@ -47,12 +47,6 @@ torch._C._jit_override_can_fuse_on_gpu(True)
Transformer takes input of size [s, b, h] and returns a Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments: tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
""" """
class ParallelMLP(MegatronModule): class ParallelMLP(MegatronModule):
...@@ -115,7 +109,7 @@ class ParallelAttention(MegatronModule): ...@@ -115,7 +109,7 @@ class ParallelAttention(MegatronModule):
and returns output of the same size. and returns output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, init_method,
output_layer_init_method, layer_number, output_layer_init_method, layer_number,
attention_type=AttnType.self_attn, attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding): attn_mask_type=AttnMaskType.padding):
...@@ -123,7 +117,6 @@ class ParallelAttention(MegatronModule): ...@@ -123,7 +117,6 @@ class ParallelAttention(MegatronModule):
args = get_args() args = get_args()
self.fp16 = args.fp16 self.fp16 = args.fp16
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling: if self.apply_query_key_layer_scaling:
...@@ -174,7 +167,7 @@ class ParallelAttention(MegatronModule): ...@@ -174,7 +167,7 @@ class ParallelAttention(MegatronModule):
self.fp16, self.fp16,
self.attn_mask_type, self.attn_mask_type,
args.masked_softmax_fusion, args.masked_softmax_fusion,
self.attention_mask_func, attention_mask_func,
self.attention_softmax_in_fp32, self.attention_softmax_in_fp32,
coeff) coeff)
...@@ -440,9 +433,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -440,9 +433,8 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size. output of the same size.
""" """
def __init__(self, attention_mask_func, init_method, def __init__(self, init_method, output_layer_init_method,
output_layer_init_method, layer_number, layer_number, layer_type=LayerType.encoder,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding): self_attn_mask_type=AttnMaskType.padding):
args = get_args() args = get_args()
...@@ -461,7 +453,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -461,7 +453,6 @@ class ParallelTransformerLayer(MegatronModule):
# Self attention. # Self attention.
self.self_attention = ParallelAttention( self.self_attention = ParallelAttention(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
layer_number, layer_number,
...@@ -477,7 +468,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -477,7 +468,6 @@ class ParallelTransformerLayer(MegatronModule):
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention( self.inter_attention = ParallelAttention(
attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
layer_number, layer_number,
...@@ -585,8 +575,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -585,8 +575,7 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule): class ParallelTransformer(MegatronModule):
"""Transformer class.""" """Transformer class."""
def __init__(self, attention_mask_func, def __init__(self, init_method, output_layer_init_method,
init_method, output_layer_init_method,
layer_type=LayerType.encoder, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding): self_attn_mask_type=AttnMaskType.padding):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
...@@ -606,8 +595,9 @@ class ParallelTransformer(MegatronModule): ...@@ -606,8 +595,9 @@ class ParallelTransformer(MegatronModule):
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
return ParallelTransformerLayer( return ParallelTransformerLayer(
attention_mask_func, init_method, init_method,
output_layer_init_method, layer_number, output_layer_init_method,
layer_number,
layer_type=layer_type, layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type) self_attn_mask_type=self_attn_mask_type)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
......
...@@ -39,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers): ...@@ -39,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers):
return init_ return init_
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method): def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization.""" """Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns) layer = torch.nn.Linear(rows, columns)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment