Unverified Commit e4c99b03 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] Use default factory for not sharing mutable default values (#1364)



* Bug Fix: Use default factory for not sharing mutable default values
---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 3102fdd1
......@@ -4,6 +4,7 @@
"""
Praxis Modules
"""
from dataclasses import field
from functools import partial
from typing import Callable, Iterable, Sequence, Tuple, Union
......@@ -74,7 +75,9 @@ class LayerNorm(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
......@@ -129,7 +132,9 @@ class Linear(TransformerEngineBaseLayer):
out_features: int = 512
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = True
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
......@@ -174,11 +179,15 @@ class LayerNormLinear(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = WeightInit.Constant(1.0)
ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=1.0)
)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
......@@ -237,12 +246,16 @@ class LayerNormMLP(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = WeightInit.Constant(1.0)
ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=1.0)
)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes_1: Tuple[str, ...] = ()
kernel_axes_2: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes_1: Tuple[str, ...] = ()
bias_axes_2: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
......
......@@ -4,6 +4,7 @@
"""
Praxis Modules related Transformer
"""
from dataclasses import field
from functools import partial
from typing import Optional, Sequence, Tuple
import warnings
......@@ -138,7 +139,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
zero_centered_gamma: bool = False
return_layernorm_output: bool = False
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
attn_mask_type: str = "causal"
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
......@@ -275,7 +278,9 @@ class TransformerLayer(TransformerEngineBaseLayer):
dropout_rng_name: str = "dropout"
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
bias_init: WeightInit = WeightInit.Constant(0.0)
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
float32_attention_logits: bool = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment