Commit efb2b1d5 authored by penguin_wwy's avatar penguin_wwy Committed by GitHub
Browse files

[Bugfix] Make quickstart work properly on cu118 (#193)

parent 79ea77e8
...@@ -17,6 +17,7 @@ namespace tl { ...@@ -17,6 +17,7 @@ namespace tl {
using namespace runtime; using namespace runtime;
#if (__CUDACC_VER_MAJOR__ >= 12)
template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) { template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
std::stringstream ss; std::stringstream ss;
ss << "["; ss << "[";
...@@ -202,6 +203,7 @@ TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col) ...@@ -202,6 +203,7 @@ TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col)
} }
*ret = static_cast<int>(result); *ret = static_cast<int>(result);
}); });
#endif // (__CUDACC_VER_MAJOR__ >= 12)
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
#if (__CUDACC_VER_MAJOR__ >= 12)
constexpr const char *tvm_tensormap_create_tiled = constexpr const char *tvm_tensormap_create_tiled =
"__tvm_tensormap_create_tiled"; "__tvm_tensormap_create_tiled";
constexpr const char *tvm_tensormap_create_im2col = constexpr const char *tvm_tensormap_create_im2col =
"__tvm_tensormap_create_im2col"; "__tvm_tensormap_create_im2col";
#endif // (__CUDACC_VER_MAJOR__ >= 12)
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
......
...@@ -36,6 +36,7 @@ namespace tl { ...@@ -36,6 +36,7 @@ namespace tl {
using namespace tir; using namespace tir;
#if (__CUDACC_VER_MAJOR__ >= 12)
class LowerHopperIntrin : public StmtExprMutator { class LowerHopperIntrin : public StmtExprMutator {
public: public:
static PrimFunc Substitute(PrimFunc &f) { static PrimFunc Substitute(PrimFunc &f) {
...@@ -168,6 +169,7 @@ tvm::transform::Pass LowerHopperIntrin() { ...@@ -168,6 +169,7 @@ tvm::transform::Pass LowerHopperIntrin() {
TVM_REGISTER_GLOBAL("tl.transform.LowerHopperIntrin") TVM_REGISTER_GLOBAL("tl.transform.LowerHopperIntrin")
.set_body_typed(LowerHopperIntrin); .set_body_typed(LowerHopperIntrin);
#endif // (__CUDACC_VER_MAJOR__ >= 12)
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -81,7 +81,8 @@ def LowerHopperIntrin(): ...@@ -81,7 +81,8 @@ def LowerHopperIntrin():
fpass : tvm.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.LowerHopperIntrin() # type: ignore return _ffi_api.LowerHopperIntrin() \
if hasattr(_ffi_api, "LowerHopperIntrin") else lambda f: f # type: ignore
def WarpSpecializedPipeline(): def WarpSpecializedPipeline():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment