"vscode:/vscode.git/clone" did not exist on "fffa2e1f4b7534d5f86e900838d9a24dfba307c9"
Unverified Commit a5dbf1e2 authored by Frédéric Bastien's avatar Frédéric Bastien Committed by GitHub
Browse files

[JAX] Use the new API when it is available. (#419)



Use the new API when it is available.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>
parent 805b9872
...@@ -190,14 +190,27 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs): ...@@ -190,14 +190,27 @@ def custom_caller(name, args, opaque, has_side_effect, **kwargs):
""" """
XLA custom call warpper XLA custom call warpper
""" """
out = custom_call(name, if hasattr(mlir, "custom_call"):
args.output_types, out = mlir.custom_call(name,
args.operands, result_types=args.output_types,
operand_layouts=args.operand_layouts, operands=args.operands,
result_layouts=args.output_layouts, operand_layouts=args.operand_layouts,
backend_config=opaque, result_layouts=args.output_layouts,
has_side_effect=has_side_effect, backend_config=opaque,
**kwargs) has_side_effect=has_side_effect,
**kwargs).results
else:
# Need to disable one pylint error as the second function
# parameter name recenctly in JAX. Otherwise we won't be
# compatible with multiple JAX version.
out = custom_call(name, # pylint: disable=too-many-function-args
args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs)
return out return out
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment