codegen_cpp.h 4.66 KB
Newer Older
root's avatar
init  
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file codegen_c_host.h
 * \brief Generate C host code.
 */
#ifndef TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_
#define TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "target/source/codegen_c.h"
#include "tvm/target/codegen.h"
#include "tvm/tir/expr.h"

namespace tvm {
namespace codegen {

class CodeGenTileLangCPP : public CodeGenC {
public:
  CodeGenTileLangCPP();
  void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl,
            std::string target_str,
            const std::unordered_set<std::string> &devices);

  void InitGlobalContext();
  // Override this as a work around for non tvm runtime code generations
  void AddFunction(const PrimFunc &f);

  /*!
   * \brief Add functions from the (unordered) range to the current module in a
   * deterministic order. This helps with debugging.
   *
   * \param functions A vector of unordered range of current module.
   */
  void AddFunctionsOrdered(
      std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> functions);
  void DefineModuleName();

  using CodeGenC::PrintType;
  void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
  void PrintFuncPrefix(std::ostream &os) final;       // NOLINT(*)

  // overload visitor functions
  void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
  void VisitExpr_(const CallNode *op, std::ostream &os) override;   // NOLINT(*)
  // overload min and max to use the ternary operator, so we don't rely on the
  // standard library implementations
  void VisitExpr_(const MinNode *op, std::ostream &os) final; // NOLINT(*)
  void VisitExpr_(const MaxNode *op, std::ostream &os) final; // NOLINT(*)

  void VisitStmt_(const AssertStmtNode *op) final; // NOLINT(*)
  void VisitStmt_(const AllocateNode *op) final;   // NOLINT(*)

  void GenerateForwardFunctionDeclarations(String global_symbol,
                                           const Array<Type> &arg_types,
                                           const Type &ret_type) override;
  Array<String> GetFunctionNames() { return function_names_; }

private:
  /* \brief Internal structure to store information about function calls */
  struct FunctionInfo {
    /* \brief function name */
    std::string func_name;
    /* number of arguments required by the function */
    int64_t num_args;
    /* \brief name of resource_handle to pass */
    std::string resource_handle_name;
  };
  std::string module_name_;
  /* \brief mapping global packed func to the unique name */
  std::unordered_map<std::string, std::string> declared_globals_;
  /* \brief names of the functions declared in this module */
  Array<String> function_names_;
  /*! \brief whether to emit asserts in the resulting C code */
  bool emit_asserts_;
  /*! \brief whether to emit forward function declarations in the resulting C
   * code */
  bool emit_fwd_func_decl_;

  FunctionInfo GetFunctionInfo(const CallNode *op, bool has_resource_handle);
  std::string GetPackedName(const CallNode *op);
  void PrintGetFuncFromBackend(const std::string &func_name,
                               const std::string &packed_func_name);
  void PrintFuncCall(const std::string &packed_func_name, int num_args);
  void PrintFuncCallC(const std::string &packed_func_name, int num_args,
                      const std::string &resource_handle_name);

  /*!
   * \brief Print ternary conditional operator implementing binary `op`
   * Forces the operands to be in SSA form.
   * \param op binary operator being expressed
   * \param compare string representation of comparison operator
   * \param os stream reference to print into
   */
  template <typename T>
  inline void PrintTernaryCondExpr(const T *op, const char *compare,
                                   std::ostream &os); // NOLINT(*)
};

} // namespace codegen
} // namespace tvm

#endif // TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_