/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * \file storage_access.h * \brief Common data structure for storage access analysis. */ #ifndef TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ #define TVM_TIR_TRANSFORMS_STORAGE_ACCESS_H_ #include #include #include #include #include #include #include "arith/ir_visitor_with_analyzer.h" #include "runtime/thread_storage_scope.h" namespace tvm { namespace tl { using namespace tir; using arith::IRVisitorWithAnalyzer; using runtime::StorageRank; using runtime::StorageScope; /*! * \brief Base class of storage access analysis */ class TileLangStorageAccessVisitor : public IRVisitorWithAnalyzer { public: /*! \brief Storage access type */ enum AccessType : uint8_t { kRead, kWrite, kSync, kAlloc, // acquired version of read, only need to handle WAR dep. kReadAcquire }; /*! \brief An access entry */ struct AccessEntry { /*! \brief The thread index that access this entry */ Array threads; /*! \brief The touched thread range */ Map thread_range; /*! \brief The buffer variable, if any */ Array buffer_indices; /*! \brief The buffer ranges for pointer access */ Array buffer_ranges; Var buffer = NullValue(); /*! \brief The access data type */ DataType dtype; /*! \brief The touched access range * * Has one IntSet for each index in the buffer being accessed. */ Array touched; /*! \brief The type of access */ AccessType type; /*! \brief The storage scope */ StorageScope scope; /*! \brief Whether the access is double buffer write */ bool double_buffer_write = false; /*! \brief Whether the access is pointer access */ bool is_pointer_access = false; }; /*! \brief Access pattern about a single statement */ struct StmtEntry { /*! \brief The statement */ const Object *stmt{}; /*! \brief access patterns in the statement */ std::vector access; }; // override visitor pattern void VisitExpr_(const BufferLoadNode *op) final; void VisitStmt_(const BufferStoreNode *op) final; void VisitStmt_(const EvaluateNode *op) final; void VisitStmt_(const LetStmtNode *op) final; void VisitStmt_(const AttrStmtNode *op) override; void VisitStmt_(const ForNode *op) final; void VisitStmt_(const IfThenElseNode *op) final; void VisitStmt_(const WhileNode *op) final; void VisitExpr_(const CallNode *op) final; void VisitStmt_(const BlockNode *op) final; void SetBufferDataToBuffer(const Var &buffer_var, const Buffer &buffer) { buffer_data_to_buffer_.Set(buffer_var, buffer); } protected: TileLangStorageAccessVisitor() { scope_.push_back(std::vector()); } /*! \return number of conditions in the current scope. */ int condition_counter() const { return condition_counter_; } /*! \return whether we are in device environment. */ bool in_device_env() const { return in_device_env_; } /*! \return environment threads */ const Array &env_threads() const { return env_threads_; } /*! * \brief Whether we need analyze the buffer in current scope. * \param buffer The buffer to be checked * \param scope The scope of the buffer. * \return Whether the analysis of buffer is enabled. */ virtual bool Enabled(const VarNode *buffer, const StorageScope &scope) const { return true; } /*! * \brief Summarize the sequence of operations into parent. * * Insert synchronization if necessary and remove un-necessary * memory access which are already synced. * * \param seq The sequence of the access operations. * \param loop Pass loop node if it is a loop, otherwise nullptr. * \return The summarized sequence that represent access that * the parent should taken care of to synchronize. */ virtual std::vector Summarize(std::vector seq, const ForNode *loop) = 0; /*! * \brief Compute the thread range for the given threads. * \param threads The threads to compute the range for. * \return The thread range. */ Map ComputeThreadRange(const Array &threads); /*! * \brief Get the scope of the buffer array. * \return The scope of the final buffer array. */ StorageScope GetScope(const Var &buffer_var) const; // access scope std::vector> scope_; private: // whether access appending is enabled. bool allow_append_{false}; // Whether we are in device environment bool in_device_env_{false}; // Whether we are inside condition. int condition_counter_{0}; // The current double buffer write scope. const VarNode *double_buffer_write_{nullptr}; // the current free stmt entry. StmtEntry curr_stmt_; // The involving threads Array env_threads_; // The buffer map Map buffer_data_to_buffer_; }; } // namespace tl } // namespace tvm #endif // TVM_TL_TRANSFORMS_STORAGE_ACCESS_H_