Unverified Commit 57b4d7bc authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Remove import jax.extend.ffi (#2193)



* remove import jax.extend.ffi
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 5b3092a0
......@@ -5,11 +5,10 @@
from typing import Sequence, Union, Callable, Optional, Tuple
import operator
from functools import reduce, partial
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec
......@@ -37,10 +36,6 @@ from ..quantize import (
ScalingMode,
)
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
......
......@@ -8,11 +8,10 @@ import warnings
from dataclasses import dataclass, replace
from functools import partial, reduce
from typing import Optional, Tuple
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes, lax
from jax import dtypes, lax, ffi
from jax.sharding import PartitionSpec, NamedSharding
from jax.experimental.custom_partitioning import SdyShardingRule
......@@ -49,12 +48,6 @@ from ..sharding import (
)
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [
"FusedAttnHelper",
"fused_attn_fwd",
......
......@@ -7,22 +7,16 @@ import re
import warnings
from abc import ABCMeta, abstractmethod
from functools import partial
from packaging import version
from jax.extend import core
from jax.interpreters import xla, mlir
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
from jax._src import dispatch
from jax import ffi
import jax
import transformer_engine_jax
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
class BasePrimitive(metaclass=ABCMeta):
"""
......
......@@ -7,11 +7,10 @@ import warnings
import operator
from functools import partial, cache, reduce
from typing import Optional, Union
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec
......@@ -38,11 +37,6 @@ from ..quantize import (
ScalingMode,
)
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [
"layernorm_fwd",
......
......@@ -6,11 +6,10 @@ import operator
from functools import reduce
from typing import Tuple, Optional, Union
import math
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec
......@@ -41,11 +40,6 @@ from ..quantize import (
NoScaleTensor,
)
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
......
......@@ -6,22 +6,16 @@ from abc import abstractmethod
from functools import partial, reduce
import operator
import warnings
from packaging import version
import jax
import jax.numpy as jnp
from jax import dtypes
from jax import dtypes, ffi
from jax.sharding import PartitionSpec, NamedSharding
from .base import BasePrimitive, register_primitive
from .misc import get_padded_spec, check_valid_batch_dims
from ..softmax import SoftmaxType
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [
"scaled_softmax_fwd",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment