Unverified Commit 16208b3b authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Add self_attn_mask_type and replace attn_type (#273)



* Add self_attn_mask_type and replace attn_type
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the keyword style for the better readability
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Replace attn_type with attn_mask_type in praxis transformer
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix typos
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6cd5b128
......@@ -20,7 +20,6 @@ from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAtte
from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases
from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer
from transformer_engine.jax.flax.module import Softmax
from transformer_engine.jax.flax.transformer import AttentionType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax, LayerNorm
......@@ -666,32 +665,32 @@ class MultiHeadAttnAttr:
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ATTN_TYPE: AttentionType.PADDING
ATTN_TYPE: 'padding'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ATTN_TYPE: AttentionType.PADDING
ATTN_TYPE: 'padding'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ATTN_TYPE: AttentionType.PADDING
ATTN_TYPE: 'padding'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ATTN_TYPE: AttentionType.CAUSAL
ATTN_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ATTN_TYPE: AttentionType.CAUSAL
ATTN_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ATTN_TYPE: AttentionType.CAUSAL
ATTN_TYPE: 'causal'
}]
......
......@@ -197,17 +197,16 @@ def core_attention(query: Array,
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
class AttentionType(Enum):
"""TransformerLayerType."""
PADDING = AttnMaskType.PADDING_MASK
CAUSAL = AttnMaskType.CAUSAL_MASK
class MultiHeadAttention(nn.Module):
r"""
Multi-head Attention (MHA), including Query,
Key, Value and Output projection.
.. warning::
Argument :attr:`attn_type` is deprecated and superseded by :attr:`attn_mask_type`.
:attr:`attn_type` is ignored in version 0.10 and will be fully removed in version 0.11.
Parameters
----------
head_dim : int
......@@ -245,8 +244,11 @@ class MultiHeadAttention(nn.Module):
Indicate if apply residual connection with the output of layer normalization.
output_layernorm : bool, default = False
Indicate if apply a layer normalization at the end of MHA.
attn_type: AttentionType, defult = AttentionType.PADDING
Indicate the format of the attention mask in the core attention.
attn_type: Any, defult = None
*Deprecated*, will be ignored in v0.10 and be fully removed in v0.11.
Please use `attn_mask_type` to config the attention mask.
attn_mask_type: {'causal', 'padding'}, default = 'causal'
Type of attention mask passed into softmax operation.
Optimization parameters
-----------------------
......@@ -282,7 +284,9 @@ class MultiHeadAttention(nn.Module):
bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
attn_type: AttentionType = AttentionType.PADDING
# TODO(rewang): remove attn_type and the related doc after v0.11
attn_type: Any = None
attn_mask_type: str = 'causal'
dtype: DType = jnp.float32
fuse_qkv: bool = True
transpose_batch_sequence: bool = True
......@@ -293,6 +297,14 @@ class MultiHeadAttention(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
# TODO(rewang): remove attn_type after v0.11
if self.attn_type is not None:
warnings.warn(
"The 'attn_type' argument in the 'MultiHeadAttention' is"
" deprecated in version 0.10 and will be removed in version 0.11."
" Passing value in attn_type will be ignored, please use `attn_mask_type`"
" to config the attention mask type.",
category=DeprecationWarning)
super().__post_init__()
@nn.compact
......@@ -570,9 +582,23 @@ class MultiHeadAttention(nn.Module):
if use_fused_attn:
assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv)
assert not self.transpose_batch_sequence
# TODO(rewang): make it configurable for pre_scale_bias
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
def canonicalize_attn_mask_type(attn_mask_type):
"""
Convert the string to AttnMaskType
"""
if attn_mask_type == 'causal':
return AttnMaskType.CAUSAL_MASK
if attn_mask_type == 'padding':
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
if inputs_q is inputs_kv:
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = ('batch', 'length', 'qkv_dim', 'heads', 'kv')
......@@ -583,7 +609,7 @@ class MultiHeadAttention(nn.Module):
mask,
dropout_rng,
attn_bias_type=attn_bias_type,
attn_mask_type=self.attn_type.value,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
dropout_probability=self.dropout_rate,
is_training=not deterministic,
......@@ -602,18 +628,27 @@ class MultiHeadAttention(nn.Module):
mask,
dropout_rng,
attn_bias_type=attn_bias_type,
attn_mask_type=self.attn_type.value,
attn_mask_type=attn_mask_type,
scaling_factor=scale_factor,
dropout_probability=self.dropout_rate,
is_training=not deterministic,
sharding_type=first_sharding_type)
else:
softmax_type = SoftmaxType.SCALED
if self.attn_type is AttentionType.PADDING:
if mask is not None:
softmax_type = SoftmaxType.SCALED_MASKED
else:
softmax_type = SoftmaxType.SCALED_UPPER_TRIANG_MASKED
def convert_to_softmax_type(attn_mask_type, mask):
"""
Convert the string to SoftmaxType
"""
if attn_mask_type == 'causal':
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED
if attn_mask_type == 'padding':
if mask is not None:
return SoftmaxType.SCALED_MASKED
return SoftmaxType.SCALED
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
softmax_type = convert_to_softmax_type(self.attn_mask_type, mask)
x = core_attention(query,
key,
......@@ -765,6 +800,18 @@ class TransformerLayer(nn.Module):
an attention block and a feedforward network (MLP).
This standard layer is based on the paper “Attention Is All You Need”.
.. warning::
Argument :attr:`self_attn_mask_type` is introduced in version 0.10.
Starting from version 0.11, the default value will be `"causal"`.
However, to ensure compatibility with earlier versions, before 0.11,
the default value will be `"padding"` for the encoder and `"causal"` for the decoder.
.. note::
Argument :attr:`attention_mask` will be ignored when
:attr:`self_attn_mask_type` is set to `"causal"`.
Parameters
----------
hidden_size: int, default = 512
......@@ -825,6 +872,8 @@ class TransformerLayer(nn.Module):
If set to TransformerLayerType.DECODER, an additional cross-attention block
is added after self-attention.this can be used for structures like `T5`
Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: {'causal', 'padding'}, default = 'causal'
Type of attention mask passed into softmax operation.
enable_relative_embedding: bool, default = True
Whether to enable relative embedding as shifting of attention logits.
relative_embedding: flax.linen.Module, default = None
......@@ -878,6 +927,7 @@ class TransformerLayer(nn.Module):
output_layernorm: bool = False
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = None # TODO(rewang): default to 'causal' after 0.11
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
dtype: DType = jnp.float32
......@@ -893,6 +943,19 @@ class TransformerLayer(nn.Module):
if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
'truncated_normal')
# TODO(rewang): default to 'causal' in 0.11 (also updated the doc after 0.11)
if self.self_attn_mask_type is None:
warnings.warn(
"The 'self_attn_mask_type' argument in the 'TransformerLayer' is"
" introduced in version 0.10. Starting from version 0.11, the default"
" value will be 'causal'. However, to ensure compatibility with earlier"
" versions, before 0.11, the default value will be 'padding' for the"
" encoder and 'causal' for the decoder.",
category=FutureWarning)
if self.layer_type == TransformerLayerType.ENCODER:
self.self_attn_mask_type = 'padding'
else:
self.self_attn_mask_type = 'causal'
super().__post_init__()
@nn.compact
......@@ -975,16 +1038,12 @@ class TransformerLayer(nn.Module):
assert inputs.ndim == 3
self_attn_type = None
# Make name be the exactly same as T5X, since names would affect
# RNGKey during init and apply. Myabe no need in the feature.
if self.layer_type == TransformerLayerType.ENCODER:
mha_name = 'attention'
self_attn_type = AttentionType.PADDING
else:
mha_name = 'self_attention'
self_attn_type = AttentionType.CAUSAL
assert self_attn_type is not None
# [batch, length, emb_dim] -> [batch, length, emb_dim]
x, residual = MultiHeadAttention(
......@@ -1002,7 +1061,7 @@ class TransformerLayer(nn.Module):
zero_centered_gamma=self.zero_centered_gamma,
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
attn_type=self_attn_type,
attn_mask_type=self.self_attn_mask_type,
fuse_qkv=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init,
use_bias=self.use_bias,
......@@ -1049,7 +1108,7 @@ class TransformerLayer(nn.Module):
apply_residual_connection_post_layernorm=self.
apply_residual_connection_post_layernorm,
output_layernorm=False, # Must do LayerNorm before MHA.
attn_type=AttentionType.PADDING,
attn_mask_type='padding',
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
......
......@@ -5,14 +5,14 @@
Praxis Modules related Transformer
"""
from functools import partial
from typing import Optional, Sequence, Tuple
from typing import Any, Optional, Sequence, Tuple
from praxis import pax_fiddle
from praxis.base_layer import WeightInit
from praxis.pytypes import JTensor
from .module import TransformerEngineBaseLayer
from ..flax.transformer import AttentionType, TransformerLayerType
from ..flax.transformer import TransformerLayerType
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
......@@ -73,7 +73,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
bias_init: WeightInit = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
attn_type: AttentionType = AttentionType.PADDING
# TODO(rewang): remove attn_type and the related doc after v0.11
attn_type: Any = None
attn_mask_type: str = 'causal'
fuse_qkv: bool = True
transpose_batch_sequence: bool = True
scale_attn_logits: bool = False
......@@ -99,7 +101,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
attn_type=self.attn_type,
attn_mask_type=self.attn_mask_type,
fuse_qkv=self.fuse_qkv,
transpose_batch_sequence=self.transpose_batch_sequence,
scale_attn_logits=self.scale_attn_logits,
......@@ -145,6 +147,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
output_layernorm: bool = False
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = None # TODO(rewang): default to 'causal' after 0.11
enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0
......@@ -201,6 +204,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
output_layernorm=self.output_layernorm,
float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type,
self_attn_mask_type=self.self_attn_mask_type,
enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment