Unverified Commit 11456de2 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Feature] Add `tl.infinity` operator for infinity handling of bfloat16 (#1175)



* Update dependency version for apache-tvm-ffi in pyproject.toml to fix CI

* [Math] Add `tl.infinity` operation and update Python interface for infinity handling

- Implemented `infinity_op` in C++ to return infinity values for supported data types.
- Registered new operation `tl.infinity` with appropriate attributes.
- Updated Python interface to call the new `tl.infinity` operation instead of the previous method.

* Add unit tests for `tl.infinity` operation in TileLang

- Introduced a new test file `test_tilelang_language_infinity.py` to validate the behavior of the `tl.infinity` operation across multiple data types (float16, bfloat16, float32, float64).
- Implemented a kernel to fill a tensor with infinity values and asserted the correctness of the output against PyTorch's `torch.inf`.

* lint

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent c67d66a3
...@@ -27,7 +27,7 @@ classifiers = [ ...@@ -27,7 +27,7 @@ classifiers = [
] ]
dynamic = ["version"] dynamic = ["version"]
dependencies = [ dependencies = [
"apache-tvm-ffi~=0.1.0", "apache-tvm-ffi==0.1.0",
"cloudpickle", "cloudpickle",
"ml-dtypes", "ml-dtypes",
"numpy>=1.23.5", "numpy>=1.23.5",
......
...@@ -35,5 +35,31 @@ TVM_REGISTER_OP("tl.pow_of_int") ...@@ -35,5 +35,31 @@ TVM_REGISTER_OP("tl.pow_of_int")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int") .set_attr<TScriptPrinterName>("TScriptPrinterName", "pow_of_int")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op); .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", pow_of_int_op);
PrimExpr infinity_op(PrimExpr args) {
const CallNode *call = args.as<CallNode>();
CHECK(call != nullptr);
const DataType &dtype = call->dtype;
ICHECK_EQ(dtype.lanes(), 1);
// NOTE(wt): Codegen for PrintConst:Inf will handle this based on dtype
if (dtype.is_float()) {
if (dtype.bits() == 64 || dtype.bits() == 32 || dtype.bits() == 16) {
return FloatImm(dtype, std::numeric_limits<float>::infinity(),
call->span);
}
} else if (dtype.is_bfloat16()) {
return FloatImm(dtype, std::numeric_limits<float>::infinity(), call->span);
}
LOG(FATAL) << "Cannot decide infinity for type " << dtype;
throw; // Unreachable, keeps compiler happy
}
TVM_REGISTER_OP("tl.infinity")
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "infinity")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", infinity_op);
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
import torch
import tilelang
import tilelang.language as T
@tilelang.jit(out_idx=-1)
def get_inf_kernel(dtype: str):
@T.prim_func
def main(A: T.Tensor((32,), dtype)):
with T.Kernel(1, threads=32):
T.fill(A, T.infinity(dtype))
return main
def _test_infinity(dtype: str):
kernel = get_inf_kernel(dtype)
output = kernel()
assert torch.all(output == torch.inf), f'check failed for {dtype=}'
@tilelang.testing.requires_cuda
def test_infinity():
_test_infinity("float16")
_test_infinity("bfloat16")
_test_infinity("float32")
_test_infinity("float64")
if __name__ == "__main__":
tilelang.testing.main()
...@@ -2000,7 +2000,7 @@ def infinity(dtype: str, span: Span | None = None) -> Any: ...@@ -2000,7 +2000,7 @@ def infinity(dtype: str, span: Span | None = None) -> Any:
value : tvm.Expr value : tvm.Expr
The infinity value of dtype. The infinity value of dtype.
""" """
return _tvm_op.infinity(dtype, span) return call_intrin(dtype, _tvm_op.Op.get("tl.infinity"), dtype, span=span)
def reinterpret(dtype, value, span: Span | None = None) -> Any: def reinterpret(dtype, value, span: Span | None = None) -> Any:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment