"vscode:/vscode.git/clone" did not exist on "db90effe0ccd1f49c1baab7592c70ee6c7857e45"
Commit ebf8b89e authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

attention_mask_func cleanup

parent 1aa2e08a
......@@ -29,10 +29,6 @@ from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def bert_extended_attention_mask(attention_mask):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
......@@ -145,7 +141,6 @@ class BertModelBase(MegatronModule):
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
......
......@@ -20,7 +20,7 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
......@@ -38,7 +38,6 @@ class ClassificationBase(MegatronModule):
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
......
......@@ -28,11 +28,6 @@ from .utils import init_method_normal
from .utils import scaled_init_method_normal
def gpt_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
def post_language_model_processing(lm_output, labels, logit_weights,
get_key_value, parallel_output,
forward_method_parallel_output,
......@@ -73,7 +68,6 @@ class GPTModelBase(MegatronModule):
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.causal,
......
......@@ -43,7 +43,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
def get_language_model(num_tokentypes, add_pooler,
encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal):
......@@ -58,8 +58,7 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
args.num_layers)
# Language model.
args = [attention_mask_func, init_method,
scaled_init_method, encoder_attn_mask_type]
args = [init_method, scaled_init_method, encoder_attn_mask_type]
kwargs = {}
cls = None
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
......@@ -269,12 +268,6 @@ class TransformerLanguageModelBase(MegatronModule):
Arguments:
transformer_hparams: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
......@@ -284,7 +277,6 @@ class TransformerLanguageModelBase(MegatronModule):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
......@@ -315,7 +307,6 @@ class TransformerLanguageModelBase(MegatronModule):
# Transformer.
self.encoder = ParallelTransformer(
attention_mask_func,
self.init_method,
output_layer_init_method,
self_attn_mask_type=self.encoder_attn_mask_type)
......@@ -326,7 +317,6 @@ class TransformerLanguageModelBase(MegatronModule):
assert args.pipeline_model_parallel_size == 1, \
'pipeline parallelism is not supported in the presence of decoder'
self.decoder = ParallelTransformer(
attention_mask_func,
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
......@@ -479,7 +469,6 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
......@@ -488,7 +477,6 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
add_decoder=False,
add_pooler=False):
super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
......@@ -523,13 +511,11 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
......@@ -552,12 +538,10 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type):
super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type)
......@@ -578,13 +562,11 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
"""
def __init__(self,
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
......
......@@ -20,7 +20,7 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
......@@ -37,7 +37,6 @@ class MultipleChoiceBase(MegatronModule):
init_method = init_method_normal(args.init_method_std)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
......
......@@ -11,7 +11,7 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
def general_ict_model_provider(only_query_model=False, only_block_model=False):
......@@ -157,7 +157,6 @@ class IREncoderBertModel(MegatronModule):
args.num_layers)
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
......
......@@ -26,7 +26,7 @@ from megatron.model.enums import AttnMaskType, LayerType, AttnType
from megatron.model import import_layernorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
......@@ -47,12 +47,6 @@ torch._C._jit_override_can_fuse_on_gpu(True)
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
attention_mask_func: a function that takes `unmaksed-attention-scores`
with size [b, np, s, s] and an `attention-mask` and will apply
the masking. The function should return a masked score of the
same size [b, np, s, s].
masked-attention-scores = attention_mask_func(
unmaksed-attention-scores, attention-mask)
"""
class ParallelMLP(MegatronModule):
......@@ -115,7 +109,7 @@ class ParallelAttention(MegatronModule):
and returns output of the same size.
"""
def __init__(self, attention_mask_func, init_method,
def __init__(self, init_method,
output_layer_init_method, layer_number,
attention_type=AttnType.self_attn,
attn_mask_type=AttnMaskType.padding):
......@@ -123,7 +117,6 @@ class ParallelAttention(MegatronModule):
args = get_args()
self.fp16 = args.fp16
self.attention_mask_func = attention_mask_func
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
......@@ -174,7 +167,7 @@ class ParallelAttention(MegatronModule):
self.fp16,
self.attn_mask_type,
args.masked_softmax_fusion,
self.attention_mask_func,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
......@@ -440,9 +433,8 @@ class ParallelTransformerLayer(MegatronModule):
output of the same size.
"""
def __init__(self, attention_mask_func, init_method,
output_layer_init_method, layer_number,
layer_type=LayerType.encoder,
def __init__(self, init_method, output_layer_init_method,
layer_number, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
args = get_args()
......@@ -461,7 +453,6 @@ class ParallelTransformerLayer(MegatronModule):
# Self attention.
self.self_attention = ParallelAttention(
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
......@@ -477,7 +468,6 @@ class ParallelTransformerLayer(MegatronModule):
if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention(
attention_mask_func,
init_method,
output_layer_init_method,
layer_number,
......@@ -585,8 +575,7 @@ class ParallelTransformerLayer(MegatronModule):
class ParallelTransformer(MegatronModule):
"""Transformer class."""
def __init__(self, attention_mask_func,
init_method, output_layer_init_method,
def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding):
super(ParallelTransformer, self).__init__()
......@@ -606,8 +595,9 @@ class ParallelTransformer(MegatronModule):
# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
attention_mask_func, init_method,
output_layer_init_method, layer_number,
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type)
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
......
......@@ -39,6 +39,11 @@ def scaled_init_method_normal(sigma, num_layers):
return init_
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
def get_linear_layer(rows, columns, init_method):
"""Simple linear layer with weight initialization."""
layer = torch.nn.Linear(rows, columns)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment