Unverified Commit 68af2159 authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[BugFix] Refactor the op check in LowerTileOp pass using the member function...

[BugFix] Refactor the op check in LowerTileOp pass using the member function instead of string match (#771)

* [BugFix] Refactor the op check in LowerTileOp pass using the member function instead of string match

* [Lint]
parent 03f21987
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
#include "../layout/layout.h" #include "../layout/layout.h"
#include "../layout/utils.h" #include "../layout/utils.h"
#include "../op/builtin.h" #include "../op/builtin.h"
#include "../op/gemm.h"
#include "../op/gemm_sp.h"
#include "../op/operator.h" #include "../op/operator.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
...@@ -84,7 +86,7 @@ public: ...@@ -84,7 +86,7 @@ public:
private: private:
void VisitStmt_(const EvaluateNode *op) { void VisitStmt_(const EvaluateNode *op) {
auto call = Downcast<Call>(op->value); auto call = Downcast<Call>(op->value);
if (call->op.same_as(Op::Get("tl.gemm"))) { if (call->op.same_as(Gemm::Get())) {
auto srcA_buffer_access_ptr = Downcast<Call>(call->args[0]); auto srcA_buffer_access_ptr = Downcast<Call>(call->args[0]);
ICHECK(srcA_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr())); ICHECK(srcA_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcA_buffer_var = Downcast<Var>(srcA_buffer_access_ptr->args[1]); auto srcA_buffer_var = Downcast<Var>(srcA_buffer_access_ptr->args[1]);
...@@ -97,7 +99,7 @@ private: ...@@ -97,7 +99,7 @@ private:
buffer_var_gemm_.push_back(srcA_buffer_var); buffer_var_gemm_.push_back(srcA_buffer_var);
buffer_var_gemm_.push_back(srcB_buffer_var); buffer_var_gemm_.push_back(srcB_buffer_var);
buffer_var_gemm_.push_back(dst_buffer_var); buffer_var_gemm_.push_back(dst_buffer_var);
} else if (call->op.same_as(Op::Get("tl.gemm_sp"))) { } else if (call->op.same_as(GemmSP::Get())) {
auto srcA_buffer_access_ptr = Downcast<Call>(call->args[0]); auto srcA_buffer_access_ptr = Downcast<Call>(call->args[0]);
ICHECK(srcA_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr())); ICHECK(srcA_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcA_buffer_var = Downcast<Var>(srcA_buffer_access_ptr->args[1]); auto srcA_buffer_var = Downcast<Var>(srcA_buffer_access_ptr->args[1]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment