Unverified Commit 198f22b3 authored by Yuqi Dong's avatar Yuqi Dong Committed by GitHub
Browse files

[Refactor]:Move device_assert from extern_call to intrin_call (#1134)

* update

* Update codegen_cuda.cc
parent e1b12bd0
...@@ -301,5 +301,15 @@ TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op) ...@@ -301,5 +301,15 @@ TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque)); Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(device_assert)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(device_assert_with_msg)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -505,6 +505,7 @@ TVM_DLL const Op &initialize_descriptor(); ...@@ -505,6 +505,7 @@ TVM_DLL const Op &initialize_descriptor();
*/ */
TVM_DLL const Op &increase_descriptor_offset(); TVM_DLL const Op &increase_descriptor_offset();
/*! /*!
* \brief tilelang intrinsic for element-wise atomic addition. * \brief tilelang intrinsic for element-wise atomic addition.
* *
...@@ -513,6 +514,20 @@ TVM_DLL const Op &increase_descriptor_offset(); ...@@ -513,6 +514,20 @@ TVM_DLL const Op &increase_descriptor_offset();
*/ */
TVM_DLL const Op &atomicadd_elem_op(); TVM_DLL const Op &atomicadd_elem_op();
/*!
* \brief tilelang intrinsic for assert on device.
*
* This op is used to represent an assert on device
*/
TVM_DLL const Op &device_assert();
/*!
* \brief tilelang intrinsic for assert on device with additional message.
*
* This op is used to represent an assert on device with additional message.
*/
TVM_DLL const Op &device_assert_with_msg();
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
...@@ -2345,6 +2345,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) { ...@@ -2345,6 +2345,16 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
stream << " " << vid_global_barrier_expect_ << " = 0;\n"; stream << " " << vid_global_barrier_expect_ << " = 0;\n";
PrintIndent(); PrintIndent();
stream << "}\n"; stream << "}\n";
}
if (call && (call->op.same_as(tvm::tl::device_assert()))) {
std::string cond = PrintExpr(call->args[0]);
this->PrintIndent();
stream << "device_assert(" << cond << ");\n";
} else if (call && call->op.same_as(tvm::tl::device_assert_with_msg())) {
std::string cond = PrintExpr(call->args[0]);
std::string msg_expr = PrintExpr(call->args[1]);
this->PrintIndent();
stream << "device_assert_with_msg(" << cond << ", " << msg_expr << ");\n";
} else { } else {
CodeGenC::VisitStmt_(op); CodeGenC::VisitStmt_(op);
} }
......
...@@ -5,6 +5,7 @@ It includes functionality to print variables, print values in buffers, condition ...@@ -5,6 +5,7 @@ It includes functionality to print variables, print values in buffers, condition
from tvm import tir from tvm import tir
from typing import Any from typing import Any
import tilelang.language as T
from tilelang.language.kernel import get_thread_bindings from tilelang.language.kernel import get_thread_bindings
from tilelang.language import copy, macro, serial, alloc_shared from tilelang.language import copy, macro, serial, alloc_shared
from tilelang.language.utils import index_to_coordinates from tilelang.language.utils import index_to_coordinates
...@@ -148,10 +149,10 @@ def device_assert(condition: tir.PrimExpr, msg: str = ""): ...@@ -148,10 +149,10 @@ def device_assert(condition: tir.PrimExpr, msg: str = ""):
""" """
if _IS_CUDA_AVAILABLE: if _IS_CUDA_AVAILABLE:
if msg == "": if msg == "":
tir.call_extern("void", "device_assert", condition) T.call_intrin("void", tir.op.Op.get("tl.device_assert"), condition)
else: else:
warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2) warnings.warn("Non-empty msg may slightly slow down the kernel", stacklevel=2)
tir.call_extern("void", "device_assert_with_msg", condition, msg) T.call_intrin("void", tir.op.Op.get("tl.device_assert_with_msg"), condition, msg)
def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr: def print(obj: Any, msg: str = "", warp_group_id: int = 0, warp_id: int = 0) -> tir.PrimExpr:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment