Unverified Commit e0be70d6 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Fixing custom op test failures due to changes in JAX lowering internals (#566)



applied Google-advised fix to register custom op primitives with the device dispatch list
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent e10997bf
......@@ -2,7 +2,6 @@
#
# See LICENSE for license information.
"""JAX te custom call"""
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Tuple
......@@ -11,6 +10,13 @@ import operator
import os
import warnings
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
import numpy as np
import jax.numpy as jnp
from jax.lib import xla_client
......@@ -20,6 +26,12 @@ from jax.experimental.custom_partitioning import custom_partitioning
from jax.interpreters.mlir import ir, dtype_to_ir_type
from jax.sharding import PartitionSpec, NamedSharding
from jax._src.interpreters import batching
from jax._src import dispatch
from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
from .sharding import get_all_mesh_axes, num_of_devices
from .sharding import get_padded_spec as te_get_padded_spec
try:
from jaxlib.hlo_helpers import custom_call
......@@ -28,18 +40,6 @@ except ImportError:
# version, so we still need this import.
pass
import transformer_engine_jax
from transformer_engine_jax import DType as TEDType
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from .sharding import all_reduce_max_along_all_axes_except_PP
from .sharding import all_reduce_sum_along_dp_fsdp
from .sharding import get_all_mesh_axes, num_of_devices
from .sharding import get_padded_spec as te_get_padded_spec
for _name, _value in transformer_engine_jax.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
......@@ -185,6 +185,7 @@ def register_primitive(cls):
return cls.name + "_wrapper"
inner_p = core.Primitive(cls.name)
dispatch.prim_requires_devices_during_lowering.add(inner_p)
inner_p.multiple_results = cls.multiple_results
inner_p.def_impl(partial(xla.apply_primitive, inner_p))
inner_p.def_abstract_eval(cls.abstract)
......@@ -192,6 +193,7 @@ def register_primitive(cls):
cls.inner_primitive = inner_p
outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results
outer_p.def_impl(cls.impl)
outer_p.def_abstract_eval(cls.abstract)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment