/*! * \file tl/op/op.cc * * Define operators usd in tile library. */ #include "op.h" #include #include #include namespace tvm { namespace tl { using namespace tir; TIR_REGISTER_TL_OP(RegionOp, region) .set_num_inputs(-1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); std::unique_ptr ParseOperator(Call call, BufferMap vmap) { auto op_map = Op::GetAttrMap("TLOpBuilder"); Op op = call->op.as().value(); if (op_map.count(op)) { Operator *ptr = static_cast(op_map[op](call->args, vmap)); ICHECK(ptr != nullptr); return std::unique_ptr(ptr); } return nullptr; } std::unique_ptr ParseOperator(Stmt stmt, BufferMap vmap) { if (stmt.as() && stmt.as()->value.as()) { auto call = stmt.as()->value.as(); return ParseOperator(GetRef(call), vmap); } return nullptr; } Var GetVarFromAccessPtr(const PrimExpr &expr) { auto call = expr.as(); ICHECK(call); ICHECK(call->op.same_as(builtin::tvm_access_ptr())); auto var = call->args[1].as(); ICHECK(var); return GetRef(var); } RegionOp::RegionOp(Array args, BufferMap vmap) { size_t n = args.size(); size_t ndim = n - 2; auto load = args[0].as(); ICHECK(load); ICHECK(load->indices.size() == ndim) << "load->indices.size() = " << load->indices << " ndim = " << ndim; buffer_ = load->buffer; access_mask_ = static_cast(*as_const_int(args[1])); for (size_t i = 0; i < ndim; i++) { PrimExpr min = load->indices[i]; PrimExpr extent = args[2 + i]; ranges_.push_back(Range::FromMinExtent(min, extent)); } } bool RegionOp::IsFullRegion() const { for (size_t i = 0; i < ranges_.size(); i++) { if (!is_zero(ranges_[i]->min)) return false; if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i])) return false; } return true; } Stmt Operator::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ICHECK(0) << "Not Implemented Lower method."; return Evaluate(0); } Stmt Operator::Canonialize(const CanonializeArgs &T, arith::Analyzer *analyzer) const { return {}; } LayoutMap Operator::InferLayout(const LayoutInferArgs &T, InferLevel level) { return {}; } } // namespace tl } // namespace tvm