Unverified Commit 0e137883 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] FFI API compatibility with both 0.4 and 0.5 (#1562)



Make ffi compatible with jax 0.4
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 31f32b37
...@@ -5,13 +5,13 @@ ...@@ -5,13 +5,13 @@
from typing import Tuple, Sequence, Union, Callable from typing import Tuple, Sequence, Union, Callable
import operator import operator
from functools import reduce, partial from functools import reduce, partial
from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type from transformer_engine_jax import NVTE_Activation_Type
...@@ -28,6 +28,11 @@ from .misc import ( ...@@ -28,6 +28,11 @@ from .misc import (
from .quantization import _jax_cast_fp8 from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP from ..sharding import all_reduce_max_along_all_axes_except_PP
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["act_lu", "dact_lu", "act_lu_fp8"] __all__ = ["act_lu", "dact_lu", "act_lu_fp8"]
......
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for attention""" """JAX/TE custom ops for attention"""
from dataclasses import dataclass, replace
from functools import partial, reduce
import operator import operator
import os import os
from typing import Optional, Tuple
import warnings import warnings
from dataclasses import dataclass, replace
from functools import partial, reduce
from typing import Optional, Tuple
from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -15,8 +16,6 @@ from jax import dtypes, lax ...@@ -15,8 +16,6 @@ from jax import dtypes, lax
from jax.interpreters import mlir from jax.interpreters import mlir
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine_jax import NVTE_Fused_Attn_Backend
...@@ -50,6 +49,12 @@ from ..sharding import ( ...@@ -50,6 +49,12 @@ from ..sharding import (
) )
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [ __all__ = [
"FusedAttnHelper", "FusedAttnHelper",
"fused_attn_fwd", "fused_attn_fwd",
......
...@@ -4,13 +4,19 @@ ...@@ -4,13 +4,19 @@
"""JAX/TE custom call""" """JAX/TE custom call"""
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from packaging import version
import jax import jax
from jax.interpreters import mlir from jax.interpreters import mlir
import transformer_engine_jax
import transformer_engine_jax
from .misc import is_ffi_enabled from .misc import is_ffi_enabled
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
try: try:
from jaxlib.hlo_helpers import custom_call from jaxlib.hlo_helpers import custom_call
except ImportError: except ImportError:
...@@ -29,11 +35,11 @@ class CustomCallAPIVersion(IntEnum): ...@@ -29,11 +35,11 @@ class CustomCallAPIVersion(IntEnum):
for _name, _value in transformer_engine_jax.registrations().items(): for _name, _value in transformer_engine_jax.registrations().items():
if _name.endswith("_ffi"): if _name.endswith("_ffi"):
if is_ffi_enabled(): if is_ffi_enabled():
jax.ffi.register_ffi_target( ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
) )
else: else:
jax.ffi.register_ffi_target( ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
) )
......
...@@ -2,10 +2,11 @@ ...@@ -2,10 +2,11 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for normalization""" """JAX/TE custom ops for normalization"""
from functools import partial, reduce, cache
import operator import operator
import os import os
import warnings import warnings
from functools import partial, reduce, cache
from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -13,7 +14,6 @@ from jax import dtypes ...@@ -13,7 +14,6 @@ from jax import dtypes
from jax.interpreters import mlir from jax.interpreters import mlir
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax import transformer_engine_jax
...@@ -30,6 +30,11 @@ from .misc import ( ...@@ -30,6 +30,11 @@ from .misc import (
from .quantization import _jax_cast_fp8 from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [ __all__ = [
"layernorm_fwd", "layernorm_fwd",
......
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for quantization""" """JAX/TE custom ops for quantization"""
from typing import Tuple from typing import Tuple
from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import DType as TEDType from transformer_engine_jax import DType as TEDType
...@@ -25,6 +25,11 @@ from .misc import ( ...@@ -25,6 +25,11 @@ from .misc import (
) )
from ..sharding import all_reduce_max_along_all_axes_except_PP from ..sharding import all_reduce_max_along_all_axes_except_PP
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["cast_fp8"] __all__ = ["cast_fp8"]
......
...@@ -6,13 +6,13 @@ from abc import abstractmethod ...@@ -6,13 +6,13 @@ from abc import abstractmethod
from functools import partial, reduce from functools import partial, reduce
import operator import operator
import warnings import warnings
from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax import transformer_engine_jax
...@@ -21,6 +21,11 @@ from .custom_call import custom_caller, CustomCallArgsWrapper ...@@ -21,6 +21,11 @@ from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [ __all__ = [
"scaled_softmax_fwd", "scaled_softmax_fwd",
......
...@@ -2,16 +2,16 @@ ...@@ -2,16 +2,16 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX/TE custom ops for transpose""" """JAX/TE custom ops for transpose"""
import operator
from functools import partial, reduce from functools import partial, reduce
from typing import Tuple, Sequence, Union, Callable from typing import Tuple, Sequence, Union, Callable
import operator from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax import transformer_engine_jax
from transformer_engine_jax import DType as TEDType from transformer_engine_jax import DType as TEDType
...@@ -33,6 +33,11 @@ from .activation import _jax_act_lu ...@@ -33,6 +33,11 @@ from .activation import _jax_act_lu
from .quantization import _jax_cast_fp8 from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [ __all__ = [
"transpose", "transpose",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment