Commit 1f5eb492 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Feature] Implement fast integer power operation and related API (#466)

* [Refactor] Enhance TMA barrier validation and support for additional architectures (#463)

* Updated the TMA barrier validation in `inject_tma_barrier.cc` to check for non-empty `barrier_id_to_range_` before raising an error for missing `create_list_of_mbarrier`.
* Refactored architecture checks in `phase.py` to utilize a new constant `SUPPORTED_TMA_ARCHS`, allowing for easier updates and improved readability in the target architecture validation logic.

* [Feature] Implement fast integer power operation and related API

* Added a new math operation `tl.power_of_int` in `math.cc` for efficient integer exponentiation.
* Introduced a corresponding Python API `pow_of_int` in `tir/op.py` to facilitate usage in TileLang.
* Enhanced `common.h` with a template function for integer power calculations.
* Updated documentation to reflect the new functionality and usage examples.
parent 2ffbd369
/*!
* \file tl/op/math.cc
* \brief Math operations.
*
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
namespace tvm {
namespace tl {
using namespace tir;
PrimExpr power_of_int_op(PrimExpr args) {
const CallNode *call = args.as<CallNode>();
CHECK(call != nullptr);
const Array<PrimExpr> &arg = call->args;
ICHECK_EQ(arg.size(), 2);
PrimExpr base = arg[0];
PrimExpr exp = arg[1];
String power_of_int_name =
"tl::power_of_int<" + std::to_string(exp.as<IntImmNode>()->value) + ">";
return tir::Call(base.dtype(), tir::builtin::call_extern(),
{StringImm(power_of_int_name), base});
}
TVM_REGISTER_OP("tl.power_of_int")
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure))
.set_attr<TScriptPrinterName>("TScriptPrinterName", "power_of_int")
.set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", power_of_int_op);
} // namespace tl
} // namespace tvm
......@@ -208,4 +208,14 @@ template <typename T> TL_DEVICE bool All(T *a, int size) {
}
return true;
}
// Pow of int
template <int y = 1, typename T> TL_DEVICE T pow_of_int(T x) {
T result = x;
for (int i = 1; i < y; i++) {
result *= x;
}
return result;
}
} // namespace tl
......@@ -2600,6 +2600,21 @@ def isinf(x, span=None):
return _tvm_op.isinf(x, span)
def pow_of_int(x: PrimExpr, y: int) -> PrimExpr:
"""Fast power operation than pow(float, float).
Args:
x (PrimExpr): Base value
y (int): Exponent value
"""
return call_intrin(
x.dtype,
tvm.tir.op.Op.get("tl.power_of_int"),
x,
y,
)
def power(x, y, span=None):
"""x power y
......@@ -2619,6 +2634,8 @@ def power(x, y, span=None):
z : PrimExpr
The result.
"""
if isinstance(y, (int, IntImm)):
return pow_of_int(x, y)
return _tvm_op.power(x, y, span)
......@@ -2641,6 +2658,8 @@ def pow(x, y, span=None):
z : PrimExpr
The result.
"""
if isinstance(y, (int, IntImm)):
return pow_of_int(x, y)
return _tvm_op.pow(x, y, span)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment