Unverified Commit e4c99b03 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Use default factory for not sharing mutable default values (#1364)



* Bug Fix: Use default factory for not sharing mutable default values
---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 3102fdd1
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
""" """
Praxis Modules Praxis Modules
""" """
from dataclasses import field
from functools import partial from functools import partial
from typing import Callable, Iterable, Sequence, Tuple, Union from typing import Callable, Iterable, Sequence, Tuple, Union
...@@ -74,7 +75,9 @@ class LayerNorm(TransformerEngineBaseLayer): ...@@ -74,7 +75,9 @@ class LayerNorm(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
scale_init: WeightInit = None scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = () scale_axes: Tuple[str, ...] = ()
bias_init: WeightInit = WeightInit.Constant(0.0) bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = () bias_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False transpose_batch_sequence: bool = False
...@@ -129,7 +132,9 @@ class Linear(TransformerEngineBaseLayer): ...@@ -129,7 +132,9 @@ class Linear(TransformerEngineBaseLayer):
out_features: int = 512 out_features: int = 512
kernel_axes: Tuple[str, ...] = () kernel_axes: Tuple[str, ...] = ()
use_bias: bool = True use_bias: bool = True
bias_init: WeightInit = WeightInit.Constant(0.0) bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = () bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32 low_rank_adaptation_dim: int = 32
...@@ -174,11 +179,15 @@ class LayerNormLinear(TransformerEngineBaseLayer): ...@@ -174,11 +179,15 @@ class LayerNormLinear(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
scale_init: WeightInit = None scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = () scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = WeightInit.Constant(1.0) ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=1.0)
)
ln_bias_axes: Tuple[str, ...] = () ln_bias_axes: Tuple[str, ...] = ()
kernel_axes: Tuple[str, ...] = () kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0) bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = () bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32 low_rank_adaptation_dim: int = 32
...@@ -237,12 +246,16 @@ class LayerNormMLP(TransformerEngineBaseLayer): ...@@ -237,12 +246,16 @@ class LayerNormMLP(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
scale_init: WeightInit = None scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = () scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = WeightInit.Constant(1.0) ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=1.0)
)
ln_bias_axes: Tuple[str, ...] = () ln_bias_axes: Tuple[str, ...] = ()
kernel_axes_1: Tuple[str, ...] = () kernel_axes_1: Tuple[str, ...] = ()
kernel_axes_2: Tuple[str, ...] = () kernel_axes_2: Tuple[str, ...] = ()
use_bias: bool = False use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0) bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes_1: Tuple[str, ...] = () bias_axes_1: Tuple[str, ...] = ()
bias_axes_2: Tuple[str, ...] = () bias_axes_2: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False enable_low_rank_adaptation: bool = False
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
""" """
Praxis Modules related Transformer Praxis Modules related Transformer
""" """
from dataclasses import field
from functools import partial from functools import partial
from typing import Optional, Sequence, Tuple from typing import Optional, Sequence, Tuple
import warnings import warnings
...@@ -138,7 +139,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer): ...@@ -138,7 +139,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False zero_centered_gamma: bool = False
return_layernorm_output: bool = False return_layernorm_output: bool = False
use_bias: bool = False use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0) bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
attn_mask_type: str = "causal" attn_mask_type: str = "causal"
attn_bias_type: Optional[str] = None attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False enable_rotary_pos_emb: bool = False
...@@ -275,7 +278,9 @@ class TransformerLayer(TransformerEngineBaseLayer): ...@@ -275,7 +278,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
dropout_rng_name: str = "dropout" dropout_rng_name: str = "dropout"
mlp_activations: Sequence[str] = ("relu",) mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0) bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
apply_residual_connection_post_layernorm: bool = False apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False output_layernorm: bool = False
float32_attention_logits: bool = False float32_attention_logits: bool = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment