Unverified Commit e0be70d6 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Fixing custom op test failures due to changes in JAX lowering internals (#566)



applied Google-advised fix to register custom op primitives with the device dispatch list
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent e10997bf
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""JAX te custom call""" """JAX te custom call"""
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple from typing import Tuple
...@@ -11,6 +10,13 @@ import operator ...@@ -11,6 +10,13 @@ import operator
import os import os
import warnings import warnings
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
from jax.lib import xla_client from jax.lib import xla_client
...@@ -20,6 +26,12 @@ from jax.experimental.custom_partitioning import custom_partitioning ...@@ -20,6 +26,12 @@ from jax.experimental.custom_partitioning import custom_partitioning
from jax.interpreters.mlir import ir, dtype_to_ir_type from jax.interpreters.mlir import ir, dtype_to_ir_type
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src import dispatch
from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
from .sharding import get_all_mesh_axes, num_of_devices
from .sharding import get_padded_spec as te_get_padded_spec
try: try:
from jaxlib.hlo_helpers import custom_call from jaxlib.hlo_helpers import custom_call
...@@ -28,18 +40,6 @@ except ImportError: ...@@ -28,18 +40,6 @@ except ImportError:
# version, so we still need this import. # version, so we still need this import.
pass pass
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
from .sharding import get_all_mesh_axes, num_of_devices
from .sharding import get_padded_spec as te_get_padded_spec
for _name, _value in transformer_engine_jax.registrations().items(): for _name, _value in transformer_engine_jax.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA") xla_client.register_custom_call_target(_name, _value, platform="CUDA")
...@@ -185,6 +185,7 @@ def register_primitive(cls): ...@@ -185,6 +185,7 @@ def register_primitive(cls):
return cls.name + "_wrapper" return cls.name + "_wrapper"
inner_p = core.Primitive(cls.name) inner_p = core.Primitive(cls.name)
dispatch.prim_requires_devices_during_lowering.add(inner_p)
inner_p.multiple_results = cls.multiple_results inner_p.multiple_results = cls.multiple_results
inner_p.def_impl(partial(xla.apply_primitive, inner_p)) inner_p.def_impl(partial(xla.apply_primitive, inner_p))
inner_p.def_abstract_eval(cls.abstract) inner_p.def_abstract_eval(cls.abstract)
...@@ -192,6 +193,7 @@ def register_primitive(cls): ...@@ -192,6 +193,7 @@ def register_primitive(cls):
cls.inner_primitive = inner_p cls.inner_primitive = inner_p
outer_p = core.Primitive(name_of_wrapper_p()) outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results outer_p.multiple_results = cls.multiple_results
outer_p.def_impl(cls.impl) outer_p.def_impl(cls.impl)
outer_p.def_abstract_eval(cls.abstract) outer_p.def_abstract_eval(cls.abstract)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment