Unverified Commit a3e4e611 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Fully remove attn_type and set self_attn_mask_type default to 'causal' (#324)



* Fully remove attn_type and set self_attn_mask_type default to 'causal'
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix tests with new arguments
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Explicit self_attn_mask_type for examples
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update transformer_engine/jax/flax/transformer.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarzlsh80826 <rewang@nvidia.com>

* Update transformer_engine/jax/flax/transformer.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarzlsh80826 <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Signed-off-by: default avatarzlsh80826 <rewang@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 21162d06
......@@ -48,6 +48,7 @@ class Net(nn.Module):
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......
......@@ -45,6 +45,7 @@ class Net(nn.Module):
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......
......@@ -51,6 +51,7 @@ class Net(nn.Module):
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......
......@@ -40,6 +40,7 @@ class Net(nn.Module):
attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False,
dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......
......@@ -171,6 +171,7 @@ class TestEncoderLayer:
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
dtype=dtype,
**te_layer_attrs)
......@@ -215,6 +216,7 @@ class TestEncoderLayer:
layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
......
......@@ -659,38 +659,38 @@ class TestRelativePositionBias(TestLayer):
class MultiHeadAttnAttr:
USE_BIAS = 'use_bias'
LN_TYPE = 'layernorm_type'
ATTN_TYPE = 'attn_type'
ATTN_MASK_TYPE = 'attn_mask_type'
ZERO_CEN = 'zero_centered_gamma'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ATTN_TYPE: 'padding'
ATTN_MASK_TYPE: 'padding'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ATTN_TYPE: 'padding'
ATTN_MASK_TYPE: 'padding'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ATTN_TYPE: 'padding'
ATTN_MASK_TYPE: 'padding'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ATTN_TYPE: 'causal'
ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ATTN_TYPE: 'causal'
ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ATTN_TYPE: 'causal'
ATTN_MASK_TYPE: 'causal'
}]
......@@ -714,7 +714,7 @@ class TestMultiHeadAttn(TestLayer):
bias_init = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm = False
output_layernorm = False
attn_type = attrs[MultiHeadAttnAttr.ATTN_TYPE]
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
fuse_qkv: bool = True
transpose_batch_sequence = True
scale_attn_logits = False
......@@ -734,7 +734,7 @@ class TestMultiHeadAttn(TestLayer):
bias_init=bias_init,
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,
output_layernorm=output_layernorm,
attn_type=attn_type,
attn_mask_type=attn_mask_type,
fuse_qkv=fuse_qkv,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
......@@ -752,7 +752,7 @@ class TestMultiHeadAttn(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,
output_layernorm=output_layernorm,
attn_type=attn_type,
attn_mask_type=attn_mask_type,
fuse_qkv=fuse_qkv,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
......
......@@ -202,10 +202,10 @@ class MultiHeadAttention(nn.Module):
Multi-head Attention (MHA), including Query,
Key, Value and Output projection.
.. warning::
.. note::
Argument :attr:`attn_type` is deprecated and superseded by :attr:`attn_mask_type`.
:attr:`attn_type` is ignored in version 0.10 and will be fully removed in version 0.11.
Argument :attr:`mask` will be ignored when
:attr:`attn_mask_type` is set to `"causal"`.
Parameters
----------
......@@ -244,11 +244,9 @@ class MultiHeadAttention(nn.Module):
Indicate if apply residual connection with the output of layer normalization.
output_layernorm : bool, default = False
Indicate if apply a layer normalization at the end of MHA.
attn_type: Any, defult = None
*Deprecated*, will be ignored in v0.10 and be fully removed in v0.11.
Please use `attn_mask_type` to config the attention mask.
attn_mask_type: {'causal', 'padding'}, default = 'causal'
Type of attention mask passed into softmax operation.
Introduced in v0.10.0.
Optimization parameters
-----------------------
......@@ -284,8 +282,6 @@ class MultiHeadAttention(nn.Module):
bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
# TODO(rewang): remove attn_type and the related doc after v0.11
attn_type: Any = None
attn_mask_type: str = 'causal'
dtype: DType = jnp.float32
fuse_qkv: bool = True
......@@ -297,14 +293,6 @@ class MultiHeadAttention(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
# TODO(rewang): remove attn_type after v0.11
if self.attn_type is not None:
warnings.warn(
"The 'attn_type' argument in the 'MultiHeadAttention' is"
" deprecated in version 0.10 and will be removed in version 0.11."
" Passing value in attn_type will be ignored, please use `attn_mask_type`"
" to config the attention mask type.",
category=DeprecationWarning)
super().__post_init__()
@nn.compact
......@@ -803,13 +791,6 @@ class TransformerLayer(nn.Module):
an attention block and a feedforward network (MLP).
This standard layer is based on the paper “Attention Is All You Need”.
.. warning::
Argument :attr:`self_attn_mask_type` is introduced in version 0.10.
Starting from version 0.11, the default value will be `"causal"`.
However, to ensure compatibility with earlier versions, before 0.11,
the default value will be `"padding"` for the encoder and `"causal"` for the decoder.
.. note::
Argument :attr:`attention_mask` will be ignored when
......@@ -877,6 +858,7 @@ class TransformerLayer(nn.Module):
Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: {'causal', 'padding'}, default = 'causal'
Type of attention mask passed into softmax operation.
Introduced in v0.10.0.
enable_relative_embedding: bool, default = True
Whether to enable relative embedding as shifting of attention logits.
relative_embedding: flax.linen.Module, default = None
......@@ -930,7 +912,7 @@ class TransformerLayer(nn.Module):
output_layernorm: bool = False
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = None # TODO(rewang): default to 'causal' after 0.11
self_attn_mask_type: str = 'causal'
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
dtype: DType = jnp.float32
......@@ -946,19 +928,6 @@ class TransformerLayer(nn.Module):
if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
'truncated_normal')
# TODO(rewang): default to 'causal' in 0.11 (also updated the doc after 0.11)
if self.self_attn_mask_type is None:
warnings.warn(
"The 'self_attn_mask_type' argument in the 'TransformerLayer' is"
" introduced in version 0.10. Starting from version 0.11, the default"
" value will be 'causal'. However, to ensure compatibility with earlier"
" versions, before 0.11, the default value will be 'padding' for the"
" encoder and 'causal' for the decoder.",
category=FutureWarning)
if self.layer_type == TransformerLayerType.ENCODER:
self.self_attn_mask_type = 'padding'
else:
self.self_attn_mask_type = 'causal'
super().__post_init__()
@nn.compact
......
......@@ -5,7 +5,7 @@
Praxis Modules related Transformer
"""
from functools import partial
from typing import Any, Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple
from praxis import pax_fiddle
from praxis.base_layer import WeightInit
......@@ -73,8 +73,6 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
bias_init: WeightInit = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
# TODO(rewang): remove attn_type and the related doc after v0.11
attn_type: Any = None
attn_mask_type: str = 'causal'
fuse_qkv: bool = True
transpose_batch_sequence: bool = True
......@@ -147,7 +145,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
output_layernorm: bool = False
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = None # TODO(rewang): default to 'causal' after 0.11
self_attn_mask_type: str = 'causal'
enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment