"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "87bfc3485eb14061a3e3ef0f2ce44a3297af6049"
Unverified Commit a3e4e611 authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Fully remove attn_type and set self_attn_mask_type default to 'causal' (#324)



* Fully remove attn_type and set self_attn_mask_type default to 'causal'
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix tests with new arguments
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Explicit self_attn_mask_type for examples
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Update transformer_engine/jax/flax/transformer.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarzlsh80826 <rewang@nvidia.com>

* Update transformer_engine/jax/flax/transformer.py
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarzlsh80826 <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Signed-off-by: default avatarzlsh80826 <rewang@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 21162d06
...@@ -48,6 +48,7 @@ class Net(nn.Module): ...@@ -48,6 +48,7 @@ class Net(nn.Module):
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......
...@@ -45,6 +45,7 @@ class Net(nn.Module): ...@@ -45,6 +45,7 @@ class Net(nn.Module):
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......
...@@ -51,6 +51,7 @@ class Net(nn.Module): ...@@ -51,6 +51,7 @@ class Net(nn.Module):
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......
...@@ -40,6 +40,7 @@ class Net(nn.Module): ...@@ -40,6 +40,7 @@ class Net(nn.Module):
attention_dropout=0.1, attention_dropout=0.1,
dropout_rng_name=DROPOUT_KEY, dropout_rng_name=DROPOUT_KEY,
layer_type=te_flax.TransformerLayerType.ENCODER, layer_type=te_flax.TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
enable_relative_embedding=False, enable_relative_embedding=False,
dtype=jnp.bfloat16) dtype=jnp.bfloat16)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
......
...@@ -171,6 +171,7 @@ class TestEncoderLayer: ...@@ -171,6 +171,7 @@ class TestEncoderLayer:
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER, layer_type=TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
dtype=dtype, dtype=dtype,
**te_layer_attrs) **te_layer_attrs)
...@@ -215,6 +216,7 @@ class TestEncoderLayer: ...@@ -215,6 +216,7 @@ class TestEncoderLayer:
layer_cls = partial(TransformerLayer, layer_cls = partial(TransformerLayer,
hidden_dropout_dims=(sequence_dim,), hidden_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.ENCODER, layer_type=TransformerLayerType.ENCODER,
self_attn_mask_type='padding',
dtype=dtype, dtype=dtype,
**te_layer_attrs) **te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
......
...@@ -659,38 +659,38 @@ class TestRelativePositionBias(TestLayer): ...@@ -659,38 +659,38 @@ class TestRelativePositionBias(TestLayer):
class MultiHeadAttnAttr: class MultiHeadAttnAttr:
USE_BIAS = 'use_bias' USE_BIAS = 'use_bias'
LN_TYPE = 'layernorm_type' LN_TYPE = 'layernorm_type'
ATTN_TYPE = 'attn_type' ATTN_MASK_TYPE = 'attn_mask_type'
ZERO_CEN = 'zero_centered_gamma' ZERO_CEN = 'zero_centered_gamma'
ATTRS = [{ ATTRS = [{
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: False, ZERO_CEN: False,
ATTN_TYPE: 'padding' ATTN_MASK_TYPE: 'padding'
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: True, ZERO_CEN: True,
ATTN_TYPE: 'padding' ATTN_MASK_TYPE: 'padding'
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
ZERO_CEN: False, ZERO_CEN: False,
ATTN_TYPE: 'padding' ATTN_MASK_TYPE: 'padding'
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: False, ZERO_CEN: False,
ATTN_TYPE: 'causal' ATTN_MASK_TYPE: 'causal'
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'layernorm', LN_TYPE: 'layernorm',
ZERO_CEN: True, ZERO_CEN: True,
ATTN_TYPE: 'causal' ATTN_MASK_TYPE: 'causal'
}, { }, {
USE_BIAS: True, USE_BIAS: True,
LN_TYPE: 'rmsnorm', LN_TYPE: 'rmsnorm',
ZERO_CEN: False, ZERO_CEN: False,
ATTN_TYPE: 'causal' ATTN_MASK_TYPE: 'causal'
}] }]
...@@ -714,7 +714,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -714,7 +714,7 @@ class TestMultiHeadAttn(TestLayer):
bias_init = WeightInit.Constant(0.0) bias_init = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm = False apply_residual_connection_post_layernorm = False
output_layernorm = False output_layernorm = False
attn_type = attrs[MultiHeadAttnAttr.ATTN_TYPE] attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
fuse_qkv: bool = True fuse_qkv: bool = True
transpose_batch_sequence = True transpose_batch_sequence = True
scale_attn_logits = False scale_attn_logits = False
...@@ -734,7 +734,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -734,7 +734,7 @@ class TestMultiHeadAttn(TestLayer):
bias_init=bias_init, bias_init=bias_init,
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm, apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
attn_type=attn_type, attn_mask_type=attn_mask_type,
fuse_qkv=fuse_qkv, fuse_qkv=fuse_qkv,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits, scale_attn_logits=scale_attn_logits,
...@@ -752,7 +752,7 @@ class TestMultiHeadAttn(TestLayer): ...@@ -752,7 +752,7 @@ class TestMultiHeadAttn(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init), bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm, apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,
output_layernorm=output_layernorm, output_layernorm=output_layernorm,
attn_type=attn_type, attn_mask_type=attn_mask_type,
fuse_qkv=fuse_qkv, fuse_qkv=fuse_qkv,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits, scale_attn_logits=scale_attn_logits,
......
...@@ -202,10 +202,10 @@ class MultiHeadAttention(nn.Module): ...@@ -202,10 +202,10 @@ class MultiHeadAttention(nn.Module):
Multi-head Attention (MHA), including Query, Multi-head Attention (MHA), including Query,
Key, Value and Output projection. Key, Value and Output projection.
.. warning:: .. note::
Argument :attr:`attn_type` is deprecated and superseded by :attr:`attn_mask_type`. Argument :attr:`mask` will be ignored when
:attr:`attn_type` is ignored in version 0.10 and will be fully removed in version 0.11. :attr:`attn_mask_type` is set to `"causal"`.
Parameters Parameters
---------- ----------
...@@ -244,11 +244,9 @@ class MultiHeadAttention(nn.Module): ...@@ -244,11 +244,9 @@ class MultiHeadAttention(nn.Module):
Indicate if apply residual connection with the output of layer normalization. Indicate if apply residual connection with the output of layer normalization.
output_layernorm : bool, default = False output_layernorm : bool, default = False
Indicate if apply a layer normalization at the end of MHA. Indicate if apply a layer normalization at the end of MHA.
attn_type: Any, defult = None
*Deprecated*, will be ignored in v0.10 and be fully removed in v0.11.
Please use `attn_mask_type` to config the attention mask.
attn_mask_type: {'causal', 'padding'}, default = 'causal' attn_mask_type: {'causal', 'padding'}, default = 'causal'
Type of attention mask passed into softmax operation. Type of attention mask passed into softmax operation.
Introduced in v0.10.0.
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -284,8 +282,6 @@ class MultiHeadAttention(nn.Module): ...@@ -284,8 +282,6 @@ class MultiHeadAttention(nn.Module):
bias_init: Initializer = nn.initializers.zeros bias_init: Initializer = nn.initializers.zeros
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False output_layernorm: bool = False
# TODO(rewang): remove attn_type and the related doc after v0.11
attn_type: Any = None
attn_mask_type: str = 'causal' attn_mask_type: str = 'causal'
dtype: DType = jnp.float32 dtype: DType = jnp.float32
fuse_qkv: bool = True fuse_qkv: bool = True
...@@ -297,14 +293,6 @@ class MultiHeadAttention(nn.Module): ...@@ -297,14 +293,6 @@ class MultiHeadAttention(nn.Module):
def __post_init__(self): def __post_init__(self):
if self.kernel_init is None: if self.kernel_init is None:
self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal') self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
# TODO(rewang): remove attn_type after v0.11
if self.attn_type is not None:
warnings.warn(
"The 'attn_type' argument in the 'MultiHeadAttention' is"
" deprecated in version 0.10 and will be removed in version 0.11."
" Passing value in attn_type will be ignored, please use `attn_mask_type`"
" to config the attention mask type.",
category=DeprecationWarning)
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
...@@ -803,13 +791,6 @@ class TransformerLayer(nn.Module): ...@@ -803,13 +791,6 @@ class TransformerLayer(nn.Module):
an attention block and a feedforward network (MLP). an attention block and a feedforward network (MLP).
This standard layer is based on the paper “Attention Is All You Need”. This standard layer is based on the paper “Attention Is All You Need”.
.. warning::
Argument :attr:`self_attn_mask_type` is introduced in version 0.10.
Starting from version 0.11, the default value will be `"causal"`.
However, to ensure compatibility with earlier versions, before 0.11,
the default value will be `"padding"` for the encoder and `"causal"` for the decoder.
.. note:: .. note::
Argument :attr:`attention_mask` will be ignored when Argument :attr:`attention_mask` will be ignored when
...@@ -877,6 +858,7 @@ class TransformerLayer(nn.Module): ...@@ -877,6 +858,7 @@ class TransformerLayer(nn.Module):
Transformer in conjunction with the TransformerLayerType.ENCODER option. Transformer in conjunction with the TransformerLayerType.ENCODER option.
self_attn_mask_type: {'causal', 'padding'}, default = 'causal' self_attn_mask_type: {'causal', 'padding'}, default = 'causal'
Type of attention mask passed into softmax operation. Type of attention mask passed into softmax operation.
Introduced in v0.10.0.
enable_relative_embedding: bool, default = True enable_relative_embedding: bool, default = True
Whether to enable relative embedding as shifting of attention logits. Whether to enable relative embedding as shifting of attention logits.
relative_embedding: flax.linen.Module, default = None relative_embedding: flax.linen.Module, default = None
...@@ -930,7 +912,7 @@ class TransformerLayer(nn.Module): ...@@ -930,7 +912,7 @@ class TransformerLayer(nn.Module):
output_layernorm: bool = False output_layernorm: bool = False
float32_attention_logits: bool = False float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = None # TODO(rewang): default to 'causal' after 0.11 self_attn_mask_type: str = 'causal'
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: nn.Module = None relative_embedding: nn.Module = None
dtype: DType = jnp.float32 dtype: DType = jnp.float32
...@@ -946,19 +928,6 @@ class TransformerLayer(nn.Module): ...@@ -946,19 +928,6 @@ class TransformerLayer(nn.Module):
if self.mlp_kernel_init is None: if self.mlp_kernel_init is None:
self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', self.mlp_kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in',
'truncated_normal') 'truncated_normal')
# TODO(rewang): default to 'causal' in 0.11 (also updated the doc after 0.11)
if self.self_attn_mask_type is None:
warnings.warn(
"The 'self_attn_mask_type' argument in the 'TransformerLayer' is"
" introduced in version 0.10. Starting from version 0.11, the default"
" value will be 'causal'. However, to ensure compatibility with earlier"
" versions, before 0.11, the default value will be 'padding' for the"
" encoder and 'causal' for the decoder.",
category=FutureWarning)
if self.layer_type == TransformerLayerType.ENCODER:
self.self_attn_mask_type = 'padding'
else:
self.self_attn_mask_type = 'causal'
super().__post_init__() super().__post_init__()
@nn.compact @nn.compact
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
Praxis Modules related Transformer Praxis Modules related Transformer
""" """
from functools import partial from functools import partial
from typing import Any, Optional, Sequence, Tuple from typing import Optional, Sequence, Tuple
from praxis import pax_fiddle from praxis import pax_fiddle
from praxis.base_layer import WeightInit from praxis.base_layer import WeightInit
...@@ -73,8 +73,6 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -73,8 +73,6 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
bias_init: WeightInit = WeightInit.Constant(0.0) bias_init: WeightInit = WeightInit.Constant(0.0)
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False output_layernorm: bool = False
# TODO(rewang): remove attn_type and the related doc after v0.11
attn_type: Any = None
attn_mask_type: str = 'causal' attn_mask_type: str = 'causal'
fuse_qkv: bool = True fuse_qkv: bool = True
transpose_batch_sequence: bool = True transpose_batch_sequence: bool = True
...@@ -147,7 +145,7 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -147,7 +145,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
output_layernorm: bool = False output_layernorm: bool = False
float32_attention_logits: bool = False float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = None # TODO(rewang): default to 'causal' after 0.11 self_attn_mask_type: str = 'causal'
enable_relative_embedding: bool = True enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None) relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0 drop_path: float = 0.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment