Unverified Commit 0e137883 authored by Reese Wang's avatar Reese Wang Committed by GitHub
Browse files

[JAX] FFI API compatibility with both 0.4 and 0.5 (#1562)



Make ffi compatible with jax 0.4
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 31f32b37
......@@ -5,13 +5,13 @@
from typing import Tuple, Sequence, Union, Callable
import operator
from functools import reduce, partial
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax
from transformer_engine_jax import NVTE_Activation_Type
......@@ -28,6 +28,11 @@ from .misc import (
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["act_lu", "dact_lu", "act_lu_fp8"]
......
......@@ -2,12 +2,13 @@
#
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
from dataclasses import dataclass, replace
from functools import partial, reduce
import operator
import os
from typing import Optional, Tuple
import warnings
from dataclasses import dataclass, replace
from functools import partial, reduce
from typing import Optional, Tuple
from packaging import version
import jax
import jax.numpy as jnp
......@@ -15,8 +16,6 @@ from jax import dtypes, lax
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax
from transformer_engine_jax import NVTE_Fused_Attn_Backend
......@@ -50,6 +49,12 @@ from ..sharding import (
)
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [
"FusedAttnHelper",
"fused_attn_fwd",
......
......@@ -4,13 +4,19 @@
"""JAX/TE custom call"""
from dataclasses import dataclass
from enum import IntEnum
from packaging import version
import jax
from jax.interpreters import mlir
import transformer_engine_jax
import transformer_engine_jax
from .misc import is_ffi_enabled
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
try:
from jaxlib.hlo_helpers import custom_call
except ImportError:
......@@ -29,11 +35,11 @@ class CustomCallAPIVersion(IntEnum):
for _name, _value in transformer_engine_jax.registrations().items():
if _name.endswith("_ffi"):
if is_ffi_enabled():
jax.ffi.register_ffi_target(
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value
)
else:
jax.ffi.register_ffi_target(
ffi.register_ffi_target(
_name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value
)
......
......@@ -2,10 +2,11 @@
#
# See LICENSE for license information.
"""JAX/TE custom ops for normalization"""
from functools import partial, reduce, cache
import operator
import os
import warnings
from functools import partial, reduce, cache
from packaging import version
import jax
import jax.numpy as jnp
......@@ -13,7 +14,6 @@ from jax import dtypes
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax
......@@ -30,6 +30,11 @@ from .misc import (
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [
"layernorm_fwd",
......
......@@ -3,13 +3,13 @@
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
from typing import Tuple
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
......@@ -25,6 +25,11 @@ from .misc import (
)
from ..sharding import all_reduce_max_along_all_axes_except_PP
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["cast_fp8"]
......
......@@ -6,13 +6,13 @@ from abc import abstractmethod
from functools import partial, reduce
import operator
import warnings
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax
......@@ -21,6 +21,11 @@ from .custom_call import custom_caller, CustomCallArgsWrapper
from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled
from ..softmax import SoftmaxType
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [
"scaled_softmax_fwd",
......
......@@ -2,16 +2,16 @@
#
# See LICENSE for license information.
"""JAX/TE custom ops for transpose"""
import operator
from functools import partial, reduce
from typing import Tuple, Sequence, Union, Callable
import operator
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding
from jax import ffi
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
......@@ -33,6 +33,11 @@ from .activation import _jax_act_lu
from .quantization import _jax_cast_fp8
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [
"transpose",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment