Unverified Commit 16208b3b authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Add self_attn_mask_type and replace attn_type (#273)



* Add self_attn_mask_type and replace attn_type
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Refine the keyword style for the better readability
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Replace attn_type with attn_mask_type in praxis transformer
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix typos
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 6cd5b128
...@@ -20,7 +20,6 @@ from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAtte ...@@ -20,7 +20,6 @@ from transformer_engine.jax.flax import MultiHeadAttention as flax_MultiHeadAtte
from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases from transformer_engine.jax.flax import RelativePositionBiases as flax_RelativePositionBiases
from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer from transformer_engine.jax.flax import TransformerLayer as flax_TransformerLayer
from transformer_engine.jax.flax.module import Softmax from transformer_engine.jax.flax.module import Softmax
from transformer_engine.jax.flax.transformer import AttentionType
from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
from transformer_engine.jax.praxis import LayerNorm from transformer_engine.jax.praxis import LayerNorm
from transformer_engine.jax.praxis import FusedSoftmax, LayerNorm from transformer_engine.jax.praxis import FusedSoftmax, LayerNorm
...@@ -666,32 +665,32 @@ class MultiHeadAttnAttr: ...@@ -666,32 +665,32 @@ class MultiHeadAttnAttr:
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: False, ZERO_CEN: False,
ATTN_TYPE: AttentionType.PADDING ATTN_TYPE: 'padding'
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: True, ZERO_CEN: True,
ATTN_TYPE: AttentionType.PADDING ATTN_TYPE: 'padding'
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
ZERO_CEN: False, ZERO_CEN: False,
ATTN_TYPE: AttentionType.PADDING ATTN_TYPE: 'padding'
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: False, ZERO_CEN: False,
ATTN_TYPE: AttentionType.CAUSAL ATTN_TYPE: 'causal'
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: True, ZERO_CEN: True,
ATTN_TYPE: AttentionType.CAUSAL ATTN_TYPE: 'causal'
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
ZERO_CEN: False, ZERO_CEN: False,
ATTN_TYPE: AttentionType.CAUSAL ATTN_TYPE: 'causal'
}] }]
......
...@@ -197,17 +197,16 @@ def core_attention(query: Array, ...@@ -197,17 +197,16 @@ def core_attention(query: Array,
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
class AttentionType(Enum):
"""TransformerLayerType."""
PADDING = AttnMaskType.PADDING_MASK
CAUSAL = AttnMaskType.CAUSAL_MASK
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
r""" r"""
Multi-head Attention (MHA), including Query, Multi-head Attention (MHA), including Query,
Key, Value and Output projection. Key, Value and Output projection.
.. warning::
Argument :attr:`attn_type` is deprecated and superseded by :attr:`attn_mask_type`.
:attr:`attn_type` is ignored in version 0.10 and will be fully removed in version 0.11.
Parameters Parameters
---------- ----------
head_dim : int head_dim : int
...@@ -245,8 +244,11 @@ class MultiHeadAttention(nn.Module): ...@@ -245,8 +244,11 @@ class MultiHeadAttention(nn.Module):
Indicate if apply residual connection with the output of layer normalization. Indicate if apply residual connection with the output of layer normalization.
output_layernorm : bool, default = False output_layernorm : bool, default = False
Indicate if apply a layer normalization at the end of MHA. Indicate if apply a layer normalization at the end of MHA.
attn_type: AttentionType, defult = AttentionType.PADDING attn_type: Any, defult = None
Indicate the format of the attention mask in the core attention. *Deprecated*, will be ignored in v0.10 and be fully removed in v0.11.
Please use `attn_mask_type` to config the attention mask.
attn_mask_type: {'causal', 'padding'}, default = 'causal'
Type of attention mask passed into softmax operation.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -282,7 +284,9 @@ class MultiHeadAttention(nn.Module): ...@@ -282,7 +284,9 @@ class MultiHeadAttention(nn.Module):
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False output_layernorm: bool = False
attn_type: AttentionType = AttentionType.PADDING # TODO(rewang): remove attn_type and the related doc after v0.11
attn_type: Any = None
attn_mask_type: str = 'causal'
dtype: DType = jnp.float32 dtype: DType = jnp.float32
fuse_qkv: bool = True fuse_qkv: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
...@@ -293,6 +297,14 @@ class MultiHeadAttention(nn.Module): ...@@ -293,6 +297,14 @@ class MultiHeadAttention(nn.Module):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
# TODO(rewang): remove attn_type after v0.11
if self.attn_type is not None:
warnings.warn(
"The 'attn_type' argument in the 'MultiHeadAttention' is"
" deprecated in version 0.10 and will be removed in version 0.11."
" Passing value in attn_type will be ignored, please use `attn_mask_type`"
" to config the attention mask type.",
category=DeprecationWarning)
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -570,9 +582,23 @@ class MultiHeadAttention(nn.Module): ...@@ -570,9 +582,23 @@ class MultiHeadAttention(nn.Module):
if use_fused_attn: if use_fused_attn:
assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv) assert mask is not None and mask.ndim == 4 # (b, 1, s_q, s_kv)
assert not self.transpose_batch_sequence assert not self.transpose_batch_sequence
# TODO(rewang): make it configurable for pre_scale_bias # TODO(rewang): make it configurable for pre_scale_bias
attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS attn_bias_type = AttnBiasType.NO_BIAS if bias is None else AttnBiasType.POST_SCALE_BIAS
def canonicalize_attn_mask_type(attn_mask_type):
"""
Convert the string to AttnMaskType
"""
if attn_mask_type == 'causal':
return AttnMaskType.CAUSAL_MASK
if attn_mask_type == 'padding':
return AttnMaskType.PADDING_MASK
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
attn_mask_type = canonicalize_attn_mask_type(self.attn_mask_type)
if inputs_q is inputs_kv: if inputs_q is inputs_kv:
qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim)) qkv_proj = qkv_proj.reshape((*qkv_proj.shape[:-1], self.num_heads, self.head_dim))
qkv_sharding_constraint = ('batch', 'length', 'qkv_dim', 'heads', 'kv') qkv_sharding_constraint = ('batch', 'length', 'qkv_dim', 'heads', 'kv')
...@@ -583,7 +609,7 @@ class MultiHeadAttention(nn.Module): ...@@ -583,7 +609,7 @@ class MultiHeadAttention(nn.Module):
mask, mask,
dropout_rng, dropout_rng,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=self.attn_type.value, attn_mask_type=attn_mask_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.dropout_rate, dropout_probability=self.dropout_rate,
is_training=not deterministic, is_training=not deterministic,
...@@ -602,18 +628,27 @@ class MultiHeadAttention(nn.Module): ...@@ -602,18 +628,27 @@ class MultiHeadAttention(nn.Module):
mask, mask,
dropout_rng, dropout_rng,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_mask_type=self.attn_type.value, attn_mask_type=attn_mask_type,
scaling_factor=scale_factor, scaling_factor=scale_factor,
dropout_probability=self.dropout_rate, dropout_probability=self.dropout_rate,
is_training=not deterministic, is_training=not deterministic,
sharding_type=first_sharding_type) sharding_type=first_sharding_type)
else: else:
softmax_type = SoftmaxType.SCALED
if self.attn_type is AttentionType.PADDING: def convert_to_softmax_type(attn_mask_type, mask):
if mask is not None: """
softmax_type = SoftmaxType.SCALED_MASKED Convert the string to SoftmaxType
else: """
softmax_type = SoftmaxType.SCALED_UPPER_TRIANG_MASKED if attn_mask_type == 'causal':
return SoftmaxType.SCALED_UPPER_TRIANG_MASKED
if attn_mask_type == 'padding':
if mask is not None:
return SoftmaxType.SCALED_MASKED
return SoftmaxType.SCALED
raise ValueError(f"Unsupported {attn_mask_type=}, "
"supported attn_mask_type = {'causal', 'padding'}")
softmax_type = convert_to_softmax_type(self.attn_mask_type, mask)
x = core_attention(query, x = core_attention(query,
key, key,
...@@ -765,6 +800,18 @@ class TransformerLayer(nn.Module): ...@@ -765,6 +800,18 @@ class TransformerLayer(nn.Module):
an attention block and a feedforward network (MLP). an attention block and a feedforward network (MLP).
This standard layer is based on the paper “Attention Is All You Need”. This standard layer is based on the paper “Attention Is All You Need”.
.. warning::
Argument :attr:`self_attn_mask_type` is introduced in version 0.10.
Starting from version 0.11, the default value will be `"causal"`.
However, to ensure compatibility with earlier versions, before 0.11,
the default value will be `"padding"` for the encoder and `"causal"` for the decoder.
.. note::
Argument :attr:`attention_mask` will be ignored when
:attr:`self_attn_mask_type` is set to `"causal"`.
Parameters Parameters
---------- ----------
hidden_size: int, default = 512 hidden_size: int, default = 512
...@@ -825,6 +872,8 @@ class TransformerLayer(nn.Module): ...@@ -825,6 +872,8 @@ class TransformerLayer(nn.Module):
If set to TransformerLayerType.DECODER, an additional cross-attention block If set to TransformerLayerType.DECODER, an additional cross-attention block
is added after self-attention.this can be used for structures like `T5` is added after self-attention.this can be used for structures like `T5`
Transformer in conjunction with the TransformerLayerType.ENCODER option. Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: {'causal', 'padding'}, default = 'causal'
Type of attention mask passed into softmax operation.
enable_relative_embedding: bool, default = True enable_relative_embedding: bool, default = True
Whether to enable relative embedding as shifting of attention logits. Whether to enable relative embedding as shifting of attention logits.
relative_embedding: flax.linen.Module, default = None relative_embedding: flax.linen.Module, default = None
...@@ -878,6 +927,7 @@ class TransformerLayer(nn.Module): ...@@ -878,6 +927,7 @@ class TransformerLayer(nn.Module):
output_layernorm: bool = False output_layernorm: bool = False
float32_attention_logits: bool = False float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = None # TODO(rewang): default to 'causal' after 0.11
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: nn.Module = None relative_embedding: nn.Module = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
...@@ -893,6 +943,19 @@ class TransformerLayer(nn.Module): ...@@ -893,6 +943,19 @@ class TransformerLayer(nn.Module):
if self.mlp_kernel_init is None: if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
'truncated_normal') 'truncated_normal')
# TODO(rewang): default to 'causal' in 0.11 (also updated the doc after 0.11)
if self.self_attn_mask_type is None:
warnings.warn(
"The 'self_attn_mask_type' argument in the 'TransformerLayer' is"
" introduced in version 0.10. Starting from version 0.11, the default"
" value will be 'causal'. However, to ensure compatibility with earlier"
" versions, before 0.11, the default value will be 'padding' for the"
" encoder and 'causal' for the decoder.",
category=FutureWarning)
if self.layer_type == TransformerLayerType.ENCODER:
self.self_attn_mask_type = 'padding'
else:
self.self_attn_mask_type = 'causal'
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -975,16 +1038,12 @@ class TransformerLayer(nn.Module): ...@@ -975,16 +1038,12 @@ class TransformerLayer(nn.Module):
assert inputs.ndim == 3 assert inputs.ndim == 3
self_attn_type = None
# Make name be the exactly same as T5X, since names would affect # Make name be the exactly same as T5X, since names would affect
# RNGKey during init and apply. Myabe no need in the feature. # RNGKey during init and apply. Myabe no need in the feature.
if self.layer_type == TransformerLayerType.ENCODER: if self.layer_type == TransformerLayerType.ENCODER:
mha_name = 'attention' mha_name = 'attention'
self_attn_type = AttentionType.PADDING
else: else:
mha_name = 'self_attention' mha_name = 'self_attention'
self_attn_type = AttentionType.CAUSAL
assert self_attn_type is not None
# [batch, length, emb_dim] -> [batch, length, emb_dim] # [batch, length, emb_dim] -> [batch, length, emb_dim]
x, residual = MultiHeadAttention( x, residual = MultiHeadAttention(
...@@ -1002,7 +1061,7 @@ class TransformerLayer(nn.Module): ...@@ -1002,7 +1061,7 @@ class TransformerLayer(nn.Module):
zero_centered_gamma=self.zero_centered_gamma, zero_centered_gamma=self.zero_centered_gamma,
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm, output_layernorm=self.output_layernorm,
attn_type=self_attn_type, attn_mask_type=self.self_attn_mask_type,
fuse_qkv=self.fuse_qkv_params, fuse_qkv=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init, kernel_init=self.mha_kernel_init,
use_bias=self.use_bias, use_bias=self.use_bias,
...@@ -1049,7 +1108,7 @@ class TransformerLayer(nn.Module): ...@@ -1049,7 +1108,7 @@ class TransformerLayer(nn.Module):
apply_residual_connection_post_layernorm=self. apply_residual_connection_post_layernorm=self.
apply_residual_connection_post_layernorm, apply_residual_connection_post_layernorm,
output_layernorm=False, # Must do LayerNorm before MHA. output_layernorm=False, # Must do LayerNorm before MHA.
attn_type=AttentionType.PADDING, attn_mask_type='padding',
float32_logits=self.float32_attention_logits, float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init, scaled_query_init=self.scaled_query_init,
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
Praxis Modules related Transformer Praxis Modules related Transformer
""" """
from functools import partial from functools import partial
from typing import Optional, Sequence, Tuple from typing import Any, Optional, Sequence, Tuple
from praxis import pax_fiddle from praxis import pax_fiddle
from praxis.base_layer import WeightInit from praxis.base_layer import WeightInit
from praxis.pytypes import JTensor from praxis.pytypes import JTensor
from .module import TransformerEngineBaseLayer from .module import TransformerEngineBaseLayer
from ..flax.transformer import AttentionType, TransformerLayerType from ..flax.transformer import TransformerLayerType
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer from ..flax.transformer import TransformerLayer as flax_TransformerLayer
...@@ -73,7 +73,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -73,7 +73,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
bias_init: WeightInit = WeightInit.Constant(0.0) bias_init: WeightInit = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False output_layernorm: bool = False
attn_type: AttentionType = AttentionType.PADDING # TODO(rewang): remove attn_type and the related doc after v0.11
attn_type: Any = None
attn_mask_type: str = 'causal'
fuse_qkv: bool = True fuse_qkv: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
scale_attn_logits: bool = False scale_attn_logits: bool = False
...@@ -99,7 +101,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -99,7 +101,7 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init), bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm, apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm, output_layernorm=self.output_layernorm,
attn_type=self.attn_type, attn_mask_type=self.attn_mask_type,
fuse_qkv=self.fuse_qkv, fuse_qkv=self.fuse_qkv,
transpose_batch_sequence=self.transpose_batch_sequence, transpose_batch_sequence=self.transpose_batch_sequence,
scale_attn_logits=self.scale_attn_logits, scale_attn_logits=self.scale_attn_logits,
...@@ -145,6 +147,7 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -145,6 +147,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
output_layernorm: bool = False output_layernorm: bool = False
float32_attention_logits: bool = False float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = None # TODO(rewang): default to 'causal' after 0.11
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None) relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0 drop_path: float = 0.0
...@@ -201,6 +204,7 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -201,6 +204,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
output_layernorm=self.output_layernorm, output_layernorm=self.output_layernorm,
float32_attention_logits=self.float32_attention_logits, float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type, layer_type=self.layer_type,
self_attn_mask_type=self.self_attn_mask_type,
enable_relative_embedding=self.enable_relative_embedding, enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module, relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path, drop_path=self.drop_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment