Unverified Commit 1a6a6d7b authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Deprecate Praxis layers (#1694)



rm pax/praxis
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 39c0e709
...@@ -24,7 +24,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" ...@@ -24,7 +24,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py || test_fail "tests/jax/*not_distributed_*" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
# Test without custom calls # Test without custom calls
NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
......
...@@ -119,7 +119,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -119,7 +119,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"]) test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
if "jax" in frameworks: if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"]) install_reqs.extend(["jax", "flax>=0.7.1"])
# test_reqs.extend(["numpy", "praxis"])
test_reqs.extend(["numpy"]) test_reqs.extend(["numpy"])
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Praxis related Modules"""
from .module import FusedSoftmax, LayerNorm
from .module import LayerNormLinear, LayerNormMLP, Linear, TransformerEngineBaseLayer
from .transformer import DotProductAttention, MultiHeadAttention
from .transformer import RelativePositionBiases, TransformerLayer
from ..flax.transformer import TransformerLayerType
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules
"""
from dataclasses import field
from functools import partial
from typing import Callable, Iterable, Sequence, Tuple, Union
from praxis import pax_fiddle
from praxis.base_layer import init_var
from praxis.base_layer import BaseLayer, WeightInit, WeightHParams, WeightHParamsCollection
from praxis.layers import flax_adapter
from praxis.pytypes import JTensor
from ..fp8 import FP8Helper
from ..flax.module import DenseGeneral, LayerNormDenseGeneral
from ..flax.module import LayerNorm as flax_LayerNorm
from ..flax.module import LayerNormMLP as flax_LayerNormMLP
from ..flax.module import Softmax
from ..softmax import SoftmaxType
def _generate_ln_scale_init(scale_init):
if scale_init is not None:
return TransformerEngineBaseLayer.generate_params_init("scale", scale_init)
return scale_init
class TransformerEngineBaseLayer(BaseLayer):
"""TransformerEngineBaseLayer"""
logical_axes_rules: Tuple[Tuple, ...] = None
@staticmethod
def generate_params_init(name: str, initializer: WeightInit):
"""generate_params_init"""
def kernel_init(key, shape, dtype):
wp = WeightHParams(shape=shape, init=initializer, dtype=dtype)
return init_var(wp, key, name)
return kernel_init
def create_layer(self, name, flax_module_cls):
"""create_layer"""
fp8_collection_map = {
FP8Helper.FP8_COLLECTION_NAME: [
WeightHParamsCollection.SKIP_LP_REGULARIZATION,
WeightHParamsCollection.OVERWRITE_WITH_GRADIENT,
WeightHParamsCollection.DISALLOW_BFLOAT16_CONVERSION,
]
}
flax_module_p = pax_fiddle.Config(
flax_adapter.FlaxModuleAdapter,
module_factory_method=flax_module_cls,
logical_axes_rules=self.logical_axes_rules,
var_collection_map=fp8_collection_map,
ici_mesh_shape=self.ici_mesh_shape,
dcn_mesh_shape=self.dcn_mesh_shape,
mesh_axis_names=self.mesh_axis_names,
)
self.create_child(name, flax_module_p.clone())
class LayerNorm(TransformerEngineBaseLayer):
"""LayerNorm"""
epsilon: float = 1e-6
layernorm_type: str = "layernorm"
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
transpose_batch_sequence: bool = False
def setup(self) -> None:
"""setup"""
super().setup()
ln_cls = partial(
flax_LayerNorm,
epsilon=self.epsilon,
layernorm_type=self.layernorm_type,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", self.bias_init),
bias_axes=self.bias_axes,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("layer_norm", ln_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.layer_norm(x)
class FusedSoftmax(TransformerEngineBaseLayer):
"""FusedSoftmax"""
scale_factor: float = 1.0
softmax_type: SoftmaxType = SoftmaxType.SCALED
def setup(self) -> None:
"""setup"""
super().setup()
fused_softmax_cls = partial(
Softmax, scale_factor=self.scale_factor, softmax_type=self.softmax_type
)
self.create_layer("fused_softmax", fused_softmax_cls)
def __call__(self, x: JTensor, mask: JTensor = None, bias: JTensor = None) -> JTensor:
"""__call__"""
return self.fused_softmax(x, mask, bias)
class Linear(TransformerEngineBaseLayer):
"""Linear"""
out_features: int = 512
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = True
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
def setup(self) -> None:
"""setup"""
super().setup()
dense_general_cls = partial(
DenseGeneral,
features=self.out_features,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes=self.kernel_axes,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("linear", dense_general_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.linear(x)
class LayerNormLinear(TransformerEngineBaseLayer):
"""LayerNormLinear"""
out_features: int = 512
enable_layernorm: bool = True
layernorm_type: str = "layernorm"
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=1.0)
)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
return_layernorm_output: bool = True
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
depth_scaling: float = None
def setup(self) -> None:
"""setup"""
super().setup()
ln_dense_general_cls = partial(
LayerNormDenseGeneral,
features=self.out_features,
enable_layernorm=self.enable_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.epsilon,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.ln_bias_init
),
ln_bias_axes=self.ln_bias_axes,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes=self.kernel_axes,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes=self.bias_axes,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
return_layernorm_output=self.return_layernorm_output,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
depth_scaling=self.depth_scaling,
)
self.create_layer("ln_linear", ln_dense_general_cls)
def __call__(self, x: JTensor) -> JTensor:
"""__call__"""
return self.ln_linear(x)
class LayerNormMLP(TransformerEngineBaseLayer):
"""LayerNormMLP"""
intermediate_dim: int = 2048
enable_layernorm: bool = True
layernorm_type: str = "layernorm"
epsilon: float = 1e-6
zero_centered_gamma: bool = False
scale_init: WeightInit = None
scale_axes: Tuple[str, ...] = ()
ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=1.0)
)
ln_bias_axes: Tuple[str, ...] = ()
kernel_axes_1: Tuple[str, ...] = ()
kernel_axes_2: Tuple[str, ...] = ()
use_bias: bool = False
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
bias_axes_1: Tuple[str, ...] = ()
bias_axes_2: Tuple[str, ...] = ()
enable_low_rank_adaptation: bool = False
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
return_layernorm_output: bool = True
activations: Sequence[Union[str, Callable]] = ("relu",)
intermediate_dropout_rate: float = 0.1
intermediate_hidden_dropout_dims: Sequence[int] = ()
axis: Union[Iterable[int], int] = -1
transpose_batch_sequence: bool = False
def setup(self) -> None:
"""setup"""
super().setup()
ln_mlp_cls = partial(
flax_LayerNormMLP,
intermediate_dim=self.intermediate_dim,
enable_layernorm=self.enable_layernorm,
layernorm_type=self.layernorm_type,
epsilon=self.epsilon,
zero_centered_gamma=self.zero_centered_gamma,
scale_init=_generate_ln_scale_init(self.scale_init),
scale_axes=self.scale_axes,
ln_bias_init=TransformerEngineBaseLayer.generate_params_init(
"ln_bias", self.ln_bias_init
),
ln_bias_axes=self.ln_bias_axes,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
kernel_axes_1=self.kernel_axes_1,
kernel_axes_2=self.kernel_axes_2,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
bias_axes_1=self.bias_axes_1,
bias_axes_2=self.bias_axes_2,
enable_low_rank_adaptation=self.enable_low_rank_adaptation,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
return_layernorm_output=self.return_layernorm_output,
activations=self.activations,
intermediate_dropout_rate=self.intermediate_dropout_rate,
intermediate_hidden_dropout_dims=self.intermediate_hidden_dropout_dims,
axis=self.axis,
dtype=self.dtype,
transpose_batch_sequence=self.transpose_batch_sequence,
)
self.create_layer("ln_mlp", ln_mlp_cls)
def __call__(self, x: JTensor, deterministic: bool = False) -> JTensor:
"""__call__"""
return self.ln_mlp(x, deterministic)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Praxis Modules related Transformer
"""
from dataclasses import field
from functools import partial
from typing import Optional, Sequence, Tuple
import warnings
from praxis import pax_fiddle
from praxis.base_layer import WeightInit
from praxis.pytypes import JTensor
from .module import TransformerEngineBaseLayer
from ..flax.transformer import TransformerLayerType
from ..flax.transformer import DotProductAttention as flax_DotProductAttention
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
from ..attention import AttnBiasType, AttnMaskType
class RelativePositionBiases(TransformerEngineBaseLayer):
"""RelativePositionBiases"""
num_buckets: int = 32
max_distance: int = 128
num_attention_heads: int = 64
embedding_init: WeightInit = None
embedding_axes: Tuple[str, ...] = ()
@staticmethod
def generate_embedding_init(init, num_attention_heads, num_buckets):
"""generate_embedding_init"""
embedding_init = init
if embedding_init is None:
rb_stddev = (num_attention_heads * num_buckets) ** -0.5
embedding_init = WeightInit.Gaussian(rb_stddev)
return embedding_init
def setup(self) -> None:
"""setup"""
super().setup()
embedding_init = RelativePositionBiases.generate_embedding_init(
self.embedding_init, self.num_attention_heads, self.num_buckets
)
rpb_cls = partial(
flax_RelativePositionBiases,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
num_attention_heads=self.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init
),
embedding_axes=self.embedding_axes,
dtype=self.dtype,
)
self.create_layer("relative_position_bias", rpb_cls)
def __call__(self, q_seqlen: JTensor, k_seqlen: JTensor, bidirectional: bool = True) -> JTensor:
"""__call__"""
return self.relative_position_bias(q_seqlen, k_seqlen, bidirectional)
class DotProductAttention(TransformerEngineBaseLayer):
"""DotProductAttention"""
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.0
attn_mask_type: AttnMaskType = "causal"
attn_bias_type: AttnBiasType = None
dropout_rng_name: str = "dropout"
float32_logits: bool = False
qkv_layout: str = "bshd_bshd_bshd"
scale_factor: Optional[float] = None
transpose_batch_sequence: bool = True
window_size: Optional[Tuple[int, int]] = None
def setup(self) -> None:
"""setup"""
super().setup()
assert self.head_dim > 0, f"{self.head_dim=}"
assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"
dpa_cls = partial(
flax_DotProductAttention,
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
attention_dropout=self.attention_dropout,
dtype=self.dtype,
dropout_rng_name=self.dropout_rng_name,
float32_logits=self.float32_logits,
qkv_layout=self.qkv_layout,
scale_factor=self.scale_factor,
transpose_batch_sequence=self.transpose_batch_sequence,
window_size=self.window_size,
)
self.create_layer("dot_product_attention", dpa_cls)
def __call__(
self,
query: JTensor,
key: JTensor,
value: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
deterministic: bool = False,
) -> JTensor:
"""__call__"""
return self.dot_product_attention(
query, key, value, mask, bias, deterministic=deterministic
)
class MultiHeadAttention(TransformerEngineBaseLayer):
"""MultiHeadAttention"""
head_dim: int = 0
num_attention_heads: int = 0
num_gqa_groups: Optional[int] = None
attention_dropout: float = 0.0
dropout_rng_name: str = "dropout"
input_layernorm: bool = True
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
return_layernorm_output: bool = False
use_bias: bool = False
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
attn_mask_type: str = "causal"
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = "consecutive"
low_rank_adaptation_scope: str = "none"
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
float32_logits: bool = False
window_size: Optional[Tuple[int, int]] = None
# Deprecated parameters
num_heads: Optional[int] = None
dropout_rate: Optional[float] = None
output_layernorm: Optional[bool] = None
apply_residual_connection_post_layernorm: Optional[bool] = None
fuse_qkv: Optional[bool] = None
def __post_init__(self):
# Deal with the deprecated parameters
if self.num_heads is not None:
self.num_attention_heads = self.num_heads
warnings.warn(
f"{__class__}.num_heads is deprecated. It will be removed recently. "
f"Please uses {__class__}.num_attention_heads as the new API.",
DeprecationWarning,
)
if self.dropout_rate is not None:
self.attention_dropout = self.dropout_rate
warnings.warn(
f"{__class__}.dropout_rate is deprecated. It will be removed recently. "
f"Please use {__class__}.attention_dropout as the new API.",
DeprecationWarning,
)
if self.apply_residual_connection_post_layernorm is not None:
warnings.warn(
f"{__class__}.apply_residual_connection_post_layernorm is deprecated. "
f"It will be removed recently, please use {__class__}.return_layernorm_output.",
DeprecationWarning,
)
if self.fuse_qkv is not None:
warnings.warn(
f"{__class__}.fuse_qkv is deprecated. It will be removed recently. "
f"Please use {__class__}.fuse_qkv_params as the new API.",
DeprecationWarning,
)
assert self.output_layernorm is None, (
f"{__class__}.output_layernorm is deprecated. It will be removed recently. "
f"Please use {__class__}.input_layernorm for controlling whether to apply layernorm."
)
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_heads
super().__post_init__()
def setup(self) -> None:
"""setup"""
super().setup()
assert self.head_dim > 0, f"{self.head_dim=}"
assert self.num_attention_heads > 0, f"{self.num_attention_heads=}"
mha_cls = partial(
flax_MultiHeadAttention,
dtype=self.dtype,
head_dim=self.head_dim,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
attention_dropout=self.attention_dropout,
dropout_rng_name=self.dropout_rng_name,
input_layernorm=self.input_layernorm,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
return_layernorm_output=self.return_layernorm_output,
kernel_init=TransformerEngineBaseLayer.generate_params_init("kernel", self.params_init),
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
float32_logits=self.float32_logits,
window_size=self.window_size,
)
self.create_layer("multi_head_attn", mha_cls)
def __call__(
self,
inputs_q: JTensor,
inputs_kv: JTensor,
mask: Optional[JTensor] = None,
bias: Optional[JTensor] = None,
*,
decode: bool = False,
deterministic: bool = False,
) -> JTensor:
"""__call__"""
return self.multi_head_attn(
inputs_q, inputs_kv, mask, bias, decode=decode, deterministic=deterministic
)
class TransformerLayer(TransformerEngineBaseLayer):
"""TransformerLayer"""
hidden_size: int = 512
mlp_hidden_size: int = 2048
num_attention_heads: int = 8
num_gqa_groups: Optional[int] = None
layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
hidden_dropout: float = 0.1
hidden_dropout_dims: Sequence[int] = ()
attention_dropout: float = 0.1
intermediate_dropout: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
dropout_rng_name: str = "dropout"
mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
bias_init: WeightInit = field( # pylint: disable=invalid-field-call
default_factory=partial(WeightInit.Constant, scale=0.0)
)
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
float32_attention_logits: bool = False
layer_type: TransformerLayerType = TransformerLayerType.ENCODER
self_attn_mask_type: str = "causal"
self_attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = "consecutive"
low_rank_adaptation_scope: str = "none"
low_rank_adaptation_dim: int = 32
low_rank_adaptation_alpha: float = None
enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = False
enable_sequence_parallel: bool = False
scale_attn_logits: bool = False
scaled_query_init: bool = True
window_size: Optional[Tuple[int, int]] = None
def __post_init__(self):
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
def setup(self) -> None:
"""setup"""
super().setup()
relative_embedding_flax_module = None
if self.enable_relative_embedding and self.relative_embedding is not None:
assert self.relative_embedding.num_attention_heads == self.num_attention_heads, (
"TransformerLayer.relative_embedding.num_attention_heads shoule be"
"the same as TransformerLayer.num_attention_heads."
)
embedding_init = RelativePositionBiases.generate_embedding_init(
self.relative_embedding.embedding_init,
self.relative_embedding.num_attention_heads,
self.relative_embedding.num_buckets,
)
relative_embedding_flax_module = flax_RelativePositionBiases(
num_buckets=self.relative_embedding.num_buckets,
max_distance=self.relative_embedding.max_distance,
num_attention_heads=self.relative_embedding.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
"rel_embedding", embedding_init
),
embedding_axes=self.relative_embedding.embedding_axes,
dtype=self.relative_embedding.dtype,
)
transformerlayer_cls = partial(
flax_TransformerLayer,
dtype=self.dtype,
hidden_size=self.hidden_size,
mlp_hidden_size=self.mlp_hidden_size,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_gqa_groups,
layernorm_type=self.layernorm_type,
layernorm_epsilon=self.layernorm_epsilon,
zero_centered_gamma=self.zero_centered_gamma,
hidden_dropout=self.hidden_dropout,
hidden_dropout_dims=self.hidden_dropout_dims,
attention_dropout=self.attention_dropout,
intermediate_dropout=self.intermediate_dropout,
intermediate_dropout_dims=self.intermediate_dropout_dims,
dropout_rng_name=self.dropout_rng_name,
mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mha_kernel", self.params_init
),
mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
"mlp_kernel", self.params_init
),
mlp_activations=self.mlp_activations,
use_bias=self.use_bias,
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
apply_residual_connection_post_layernorm=self.apply_residual_connection_post_layernorm,
output_layernorm=self.output_layernorm,
float32_attention_logits=self.float32_attention_logits,
layer_type=self.layer_type,
self_attn_mask_type=self.self_attn_mask_type,
self_attn_bias_type=self.self_attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
low_rank_adaptation_scope=self.low_rank_adaptation_scope,
low_rank_adaptation_dim=self.low_rank_adaptation_dim,
low_rank_adaptation_alpha=self.low_rank_adaptation_alpha,
enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
window_size=self.window_size,
)
self.create_layer("transformerlayer", transformerlayer_cls)
def __call__(
self,
inputs: JTensor,
encoded: JTensor = None,
attention_mask: JTensor = None,
encoder_decoder_mask: JTensor = None,
deterministic: bool = False,
decode: bool = False,
max_decode_length: bool = None,
) -> JTensor:
"""__call__"""
return self.transformerlayer(
inputs,
encoded,
attention_mask,
encoder_decoder_mask,
deterministic,
decode,
max_decode_length,
)
...@@ -101,7 +101,7 @@ if __name__ == "__main__": ...@@ -101,7 +101,7 @@ if __name__ == "__main__":
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["jax", "flax>=0.7.1"], install_requires=["jax", "flax>=0.7.1"],
tests_require=["numpy", "praxis"], tests_require=["numpy"],
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir) shutil.rmtree(common_headers_dir)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment