"sgl-router/git@developer.sourcefind.cn:change/sglang.git" did not exist on "4a4772ae03c8b29834efbfa1175ba6abeafa77c9"
codegen_c_host.h 4.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file codegen_c_host.h
 * \brief Generate C host code (TileLang copy).
 */
#ifndef TL_TARGET_SOURCE_CODEGEN_C_HOST_H_
#define TL_TARGET_SOURCE_CODEGEN_C_HOST_H_

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

#include "target/source/codegen_c.h"
#include "tvm/target/codegen.h"
#include "tvm/tir/expr.h"

namespace tvm {
namespace tl {

// TileLang copy of TVM's CodeGenCHost, under the tl namespace.
// Inherits from tvm::codegen::CodeGenC.
class CodeGenCHost : public tvm::codegen::CodeGenC {
public:
  CodeGenCHost();
  void Init(bool output_ssa, bool emit_asserts, bool emit_fwd_func_decl,
            std::string target_str,
            const std::unordered_set<std::string> &devices);

  void InitGlobalContext();

  void AddFunction(const tvm::GlobalVar &gvar,
                   const tvm::tir::PrimFunc &f) override;
  void AddFunction(const tvm::GlobalVar &gvar, const tvm::tir::PrimFunc &f,
                   bool emit_fwd_func_decl);
  /*!
   * \brief Add functions from the (unordered) range to the current module in a
   * deterministic order. This helps with debugging.
   *
   * \param functions A vector of unordered range of current module.
   */
  void AddFunctionsOrdered(
      std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> functions);
  void DefineModuleName();

  using tvm::codegen::CodeGenC::PrintType;
  void PrintType(tvm::DataType t, std::ostream &os) final; // NOLINT(*)
  void PrintFuncPrefix(std::ostream &os) final;            // NOLINT(*)

  // overload visitor functions
  void VisitExpr_(const tvm::tir::BroadcastNode *op,
                  std::ostream &os) final; // NOLINT(*)
  void VisitExpr_(const tvm::tir::CallNode *op,
                  std::ostream &os) override; // NOLINT(*)
  // overload min and max to use the ternary operator, so we don't rely on the
  // standard library implementations
  void VisitExpr_(const tvm::tir::MinNode *op,
                  std::ostream &os) final; // NOLINT(*)
  void VisitExpr_(const tvm::tir::MaxNode *op,
                  std::ostream &os) final; // NOLINT(*)

  void VisitStmt_(const tvm::tir::AssertStmtNode *op) final; // NOLINT(*)

  void GenerateForwardFunctionDeclarations(
      tvm::ffi::String global_symbol,
      const tvm::ffi::Array<tvm::Type> &arg_types,
      const tvm::Type &ret_type) override;
  tvm::ffi::Array<tvm::ffi::String> GetFunctionNames() {
    return function_names_;
  }

private:
  std::string module_name_;
  /* \brief mapping global packed func to the unique name */
  std::unordered_map<std::string, std::string> declared_globals_;
  /* \brief names of the functions declared in this module */
  tvm::ffi::Array<tvm::ffi::String> function_names_;
  /*! \brief whether to emit asserts in the resulting C code */
  bool emit_asserts_;
  /*! \brief whether to emit forwared function declarations in the resulting C
   * code */
  bool emit_fwd_func_decl_;
  /*! \brief whether to generate the entry function if encountered */
  bool has_main_func_ = false;

  std::string GetPackedName(const tvm::tir::CallNode *op);
  void PrintGetFuncFromBackend(const std::string &func_name,
                               const std::string &packed_func_name);
  void PrintCallPacked(const tvm::tir::CallNode *op);
  /*!
   * \brief Print ternary conditional operator implementing binary `op`
   * Forces the operands to be in SSA form.
   * \param op binary operator being expressed
   * \param compare string representation of comparison operator
   * \param os stream reference to print into
   */
  template <typename T>
  inline void PrintTernaryCondExpr(const T *op, const char *compare,
                                   std::ostream &os); // NOLINT(*)
};

} // namespace tl
} // namespace tvm

#endif // TL_TARGET_SOURCE_CODEGEN_C_HOST_H_