Unverified Commit 198f22b3 authored by Yuqi Dong's avatar Yuqi Dong Committed by GitHub
Browse files

[Refactor]:Move device_assert from extern_call to intrin_call (#1134)

* update

* Update codegen_cuda.cc
parent e1b12bd0
......@@ -301,5 +301,15 @@ TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(device_assert)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(device_assert_with_msg)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
......@@ -505,6 +505,7 @@ TVM_DLL const Op &initialize_descriptor();
*/
TVM_DLL const Op &increase_descriptor_offset();
/*!
* \brief tilelang intrinsic for element-wise atomic addition.
*
......@@ -513,6 +514,20 @@ TVM_DLL const Op &increase_descriptor_offset();
*/
TVM_DLL const Op &atomicadd_elem_op();
/*!
* \brief tilelang intrinsic for assert on device.
*
* This op is used to represent an assert on device
*/
TVM_DLL const Op &device_assert();
/*!
* \brief tilelang intrinsic for assert on device with additional message.
*
* This op is used to represent an assert on device with additional message.
*/
TVM_DLL const Op &device_assert_with_msg();
} // namespace tl
} // namespace tvm
......
......@@ -2345,6 +2345,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
stream << " " << vid_global_barrier_expect_ << " = 0;\n";
PrintIndent();
stream << "}\n";
}
if (call && (call->op.same_as(tvm::tl::device_assert()))) {
std::string cond = PrintExpr(call->args[0]);
this->PrintIndent();
stream << "device_assert(" << cond << ");\n";
} else if (call && call->op.same_as(tvm::tl::device_assert_with_msg())) {
std::string cond = PrintExpr(call->args[0]);
std::string msg_expr = PrintExpr(call->args[1]);
this->PrintIndent();
stream << "device_assert_with_msg(" << cond << ", " << msg_expr << ");\n";
} else {
CodeGenC::VisitStmt_(op);
}
......
......@@ -5,6 +5,7 @@ It includes functionality to print variables, print values in buffers, condition
from tvm import tir
from typing import Any
import tilelang.language as T
from tilelang.language.kernel import get_thread_bindings
from tilelang.language import copy, macro, serial, alloc_shared
from tilelang.language.utils import index_to_coordinates
......@@ -148,10 +149,10 @@ def device_assert(condition: tir.PrimExpr, msg: str = ""):
"""
if _IS_CUDA_AVAILABLE:
if msg == "":
tir.call_extern("void", "device_assert", condition)
T.call_intrin("void", tir.op.Op.get("tl.device_assert"), condition)
else:
warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2)
tir.call_extern("void", "device_assert_with_msg", condition, msg)
T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, msg)
def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment