Unverified Commit a5dbf1e2 authored by Frédéric Bastien's avatar Frédéric Bastien Committed by GitHub
Browse files

[JAX] Use the new API when it is available. (#419)



Use the new API when it is available.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>
parent 805b9872
......@@ -190,14 +190,27 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs):
"""
XLA custom call warpper
"""
out = custom_call(name,
args.output_types,
args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs)
if hasattr(mlir, "custom_call"):
out = mlir.custom_call(name,
result_types=args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs).results
else:
# Need to disable one pylint error as the second function
# parameter name recenctly in JAX. Otherwise we won't be
# compatible with multiple JAX version.
out = custom_call(name, # pylint: disable=too-many-function-args
args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs)
return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment