frontend_legalize.cc 2.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file frontend_legalize.cc
 * \brief Legalize the program from frontend
 */

25
#include <tvm/ffi/reflection/registry.h>
26
27
28
29
30
31
32
33
34
35
36
37
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "arith/ir_mutator_with_analyzer.h"

namespace tvm {
namespace tl {

using namespace tir;

class FrontendLegalizer : public arith::IRMutatorWithAnalyzer {
38
public:
39
40
41
  static PrimFunc Substitute(PrimFunc f) {
    arith::Analyzer analyzer;
    FrontendLegalizer substituter(&analyzer);
42
    PrimFuncNode *fptr = f.CopyOnWrite();
43
44
45
46
    fptr->body = substituter.VisitStmt(f->body);
    return f;
  }

47
private:
48
49
  using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;

50
  Stmt VisitStmt_(const ForNode *node) final {
51
52
53
54
55
56
57
58
59
60
    if (node->kind == ForKind::kParallel) {
      parallel_for_scope_++;
    }
    auto n = StmtExprMutator::VisitStmt_(node);
    if (node->kind == ForKind::kParallel) {
      parallel_for_scope_--;
    }
    return n;
  }

61
  PrimExpr VisitExpr_(const VarNode *node) final {
62
63
64
65
66
67
68
    if (let_bindings_.count(node)) {
      return arith::IRMutatorWithAnalyzer::VisitExpr(let_bindings_[node]);
    } else {
      return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
    }
  }

69
  Stmt VisitStmt_(const LetStmtNode *node) final {
70
71
72
73
    let_bindings_[node->var.get()] = node->value;
    return arith::IRMutatorWithAnalyzer::VisitStmt(node->body);
  }

74
  PrimExpr VisitExpr_(const LetNode *node) final {
75
76
77
78
79
    let_bindings_[node->var.get()] = node->value;
    return arith::IRMutatorWithAnalyzer::VisitExpr(node->body);
  }

  int parallel_for_scope_ = 0;
80
  std::unordered_map<const VarNode *, PrimExpr> let_bindings_;
81
82
83
84
85
86
87
88
89
90
91
};

using namespace tir::transform;

Pass FrontendLegalize() {
  auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
    return FrontendLegalizer::Substitute(std::move(f));
  };
  return CreatePrimFuncPass(pass_func, 0, "tl.FrontendLegalize", {});
}

92
93
94
95
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.FrontendLegalize", FrontendLegalize);
});
96

97
98
} // namespace tl
} // namespace tvm