jax.rst 3.21 KB
Newer Older
1
..
2
    Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8

    See LICENSE for license information.

Jax
=======

9
10
11
12
13
14
15
16
17
18
19
20
21
Pre-defined Variable of Logical Axes
------------------------------------
Variables are available in `transformer_engine.jax.sharding`.

* BATCH_AXES: The logical axis of batch dimension. It is usually sharded along DP + FSDP on Mesh.
* SEQLEN_AXES: The logical axis of sequence length dimension. It is usually not sharded.
* SEQLEN_TP_AXES: The logical axis of sequence length dimension. It is usually sharded along TP on Mesh.
* HEAD_AXES: The logical axis of head dimension of MHA. It is usually sharded along TP on Mesh.
* HIDDEN_AXES: The logical axis of hidden dimension. It is usually not sharded.
* HIDDEN_TP_AXES: The logical axis of hidden dimension. It is usually sharded along TP on Mesh.
* JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded.


22
23
24
25
Checkpointing
------------------------------------
When using checkpointing with Transformer Engine JAX, please be aware of the checkpointing policy being applied to your model. Any JAX checkpointing policy using `dot`, such as `jax.checkpoint_policies.dots_with_no_batch_dims`, may not work with GEMMs provided by Transformer Engine as they do not always use the `jax.lax.dot_general` primitive. Instead, you can use `transformer_engine.jax.checkpoint_policies.dots_and_te_gemms_with_no_batch_dims` or similar policies that are designed to work with Transformer Engine's GEMMs and `jax.lax.dot_general` GEMMs. You may also use any JAX policies that do not filter by primitive, such as `jax.checkpoint_policies.save_only_these_names` or `jax.checkpoint_policies.everything_saveable`.

26
27
Modules
------------------------------------
28
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType
29
.. autoapiclass:: transformer_engine.jax.MeshResource()
30
31


32
.. autoapifunction:: transformer_engine.jax.fp8_autocast
33
.. autoapifunction:: transformer_engine.jax.autocast
34
.. autoapifunction:: transformer_engine.jax.update_collections
35
36


37
.. autoapiclass:: transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)
38
39
  :members: __call__

40
.. autoapiclass:: transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)
41
42
  :members: __call__

43
.. autoapiclass:: transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
44
45
  :members: __call__

46
.. autoapiclass:: transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)
47
48
  :members: __call__

49
.. autoapiclass:: transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)
50
51
  :members: __call__

52
53
54
.. autoapiclass:: transformer_engine.jax.flax.DotProductAttention(head_dim, num_heads, **kwargs)
  :members: __call__

55
.. autoapiclass:: transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)
56
57
  :members: __call__

58
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)
59
60
  :members: __call__

61
.. autoapifunction:: transformer_engine.jax.flax.extend_logical_axis_rules