make_packed_api.cc 17.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
 * \file make_packed_api.cc Lower PrimFunc to use the packed function API.
 */
23
#include <tvm/ffi/extra/module.h>
24
25
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
26
#include <tvm/runtime/device_api.h>
27
#include <tvm/runtime/module.h>
28
29
30
31
32
33
34
35
#include <tvm/target/target.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

36
#include <unordered_set>
37
38
39
#include <utility>
#include <vector>

40
#include "../op/builtin.h"
41
#include "arg_binder.h"
42
43
44
45
46
#include "tir/transforms/ir_utils.h"

namespace tvm {
namespace tl {
using namespace tir;
47
using namespace ffi;
48
49
50
51

namespace {
class ReturnRewriter : public StmtMutator {
public:
52
  explicit ReturnRewriter(Var ret_var) : ret_var_(ret_var) {}
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

  Stmt VisitStmt_(const ForNode *node) override {
    if (node->kind == ForKind::kParallel)
      in_parallel_ += 1;
    Stmt ret = StmtMutator::VisitStmt_(node);
    if (node->kind == ForKind::kParallel)
      in_parallel_ -= 1;
    return ret;
  }

  Stmt VisitStmt_(const EvaluateNode *node) override {
    Stmt ret = StmtMutator::VisitStmt_(node);
    const EvaluateNode *eval = ret.as<EvaluateNode>();
    ICHECK(eval);
    if (const CallNode *call = eval->value.as<CallNode>()) {
      if (call->op.same_as(builtin::ret())) {
        ICHECK_EQ(in_parallel_, 0)
            << "tir.ret cannot be used in parallel scope.";
        ICHECK_EQ(call->args.size(), 1) << "tir.ret expect a single argument.";
        ret = WriteToOut(call->args[0]);
      }
    }
    return ret;
  }

private:
  struct ConvertedInfo {
80
    int type_index{-1};
81
82
83
    PrimExpr expr;
  };

84
  ConvertedInfo ConvertForFFI(const PrimExpr &val) {
85
86
87
88
    ConvertedInfo info;

    // convert val's data type to FFI data type, return type code
    DataType dtype = val.dtype();
89
90
91
92
93
    if (dtype.is_bool()) {
      info.type_index = ffi::TypeIndex::kTVMFFIBool;
      info.expr = Cast(DataType::Int(64), val);

    } else if (dtype.is_int() || dtype.is_uint()) {
94
      info.type_index = ffi::TypeIndex::kTVMFFIInt;
95
96
      info.expr = Cast(DataType::Int(64), val);
    } else if (dtype.is_float()) {
97
      info.type_index = ffi::TypeIndex::kTVMFFIFloat;
98
99
      info.expr = Cast(DataType::Float(64), val);
    } else if (dtype.is_void()) {
100
      info.type_index = ffi::TypeIndex::kTVMFFINone;
101
102
103
104
105
106
107
108
      info.expr = val;
    } else {
      LOG(FATAL) << "data type " << dtype << " not supported yet";
    }

    return info;
  }

109
  Stmt WriteToOut(PrimExpr val) {
110
    auto info = ConvertForFFI(val);
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
    Stmt store_tindex = tir::Evaluate(
        tir::Call(DataType::Int(32), tir::builtin::tvm_struct_set(),
                  {ret_var_, IntImm(DataType::Int(32), 0),
                   IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyTypeIndex),
                   IntImm(DataType::Int(32), info.type_index)}));
    Stmt store_zero_padding = tir::Evaluate(tir::Call(
        DataType::Int(32), tir::builtin::tvm_struct_set(),
        {ret_var_, IntImm(DataType::Int(32), 0),
         IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyZeroPadding),
         IntImm(DataType::Int(32), 0)}));
    Stmt store_val = tir::Evaluate(tir::Call(
        DataType::Int(32), tir::builtin::tvm_struct_set(),
        {ret_var_, IntImm(DataType::Int(32), 0),
         IntImm(DataType::Int(32), tir::builtin::kTVMFFIAnyUnionValue),
         info.expr}));
126
    Stmt ret_zero = Evaluate(tvm::ret(0));
127
    return SeqStmt({store_tindex, store_zero_padding, store_val, ret_zero});
128
129
130
131
132
133
134
135
  }

  Var ret_var_;
  int in_parallel_{0};
};

class SubroutineCallRewriter : public StmtExprMutator {
public:
136
137
138
  static ffi::Optional<Stmt>
  Apply(const ffi::Map<GlobalVar, ffi::String> &packed_func_methods,
        Stmt stmt) {
139
    SubroutineCallRewriter rewriter(packed_func_methods);
140
    stmt = rewriter.VisitStmt(stmt);
141
142
143
    if (rewriter.made_change_) {
      return stmt;
    } else {
144
      return std::nullopt;
145
146
147
148
149
    }
  }

private:
  explicit SubroutineCallRewriter(
150
      const ffi::Map<GlobalVar, ffi::String> &packed_func_methods)
151
152
153
154
155
156
      : packed_func_methods(packed_func_methods) {}

  PrimExpr VisitExpr_(const CallNode *op) override {
    auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));

    if (auto *gvar_ptr = node->op.as<GlobalVarNode>()) {
157
      auto gvar = ffi::GetRef<GlobalVar>(gvar_ptr);
158
      if (auto symbol = packed_func_methods.Get(gvar)) {
159
        ffi::Array<PrimExpr> cpacked_args;
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        cpacked_args.push_back(tir::StringImm(symbol.value()));
        for (auto arg : node->args) {
          cpacked_args.push_back(arg);
        }

        // push an empty handle to be compatible with current cpacked convention
        cpacked_args.push_back(tir::make_zero(DataType::Handle()));
        made_change_ = true;
        return tir::Call(node->dtype, tir::builtin::tvm_call_cpacked(),
                         cpacked_args);
      }
    }

    return node;
  }
175
  const ffi::Map<GlobalVar, ffi::String> &packed_func_methods;
176
177
178
179
180
  bool made_change_{false};
};

} // namespace

181
182
inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) {
  return AssertStmt(lhs == rhs, tvm::tir::StringImm(msg), Evaluate(0));
183
184
}

185
186
inline Stmt MakeAssertNotNull(PrimExpr ptr, std::string msg) {
  Call isnull(DataType::Bool(), builtin::isnullptr(), {ptr});
187
188
189
190
191
192
193
194
  return AssertStmt(!isnull, tvm::tir::StringImm(msg), Evaluate(0));
}

/* \brief Return the global_symbol of the function, if it should be updated
 *
 * \param func The function to be inspected
 *
 * \returns The global_symbol to be used for the function at call
195
 * sites, or std::nullopt if the function is to remain unchanged.
196
197
198
199
200
201
 */
Optional<String> RequiresPackedAPI(const PrimFunc &func) {
  // A function with an explicit calling convention has already been
  // lowered, and should not be modified.
  if (auto opt = func->GetAttr<Integer>(tvm::attr::kCallingConv)) {
    if (CallingConv(opt.value()->value) != CallingConv::kDefault) {
202
      return std::nullopt;
203
204
205
206
207
    }
  }

  // Internal function calls do not need the PackedFunc API
  auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
208
  if (!global_symbol) {
209
    return std::nullopt;
210
211
212
213
214
215
216
  }

  return global_symbol;
}

PrimFunc MakePackedAPI(PrimFunc func) {
  auto global_symbol = RequiresPackedAPI(func);
217
  if (!global_symbol) {
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    return func;
  }
  std::string name_hint = global_symbol.value();

  Target target = [&]() {
    auto opt = func->GetAttr<Target>(tvm::attr::kTarget);
    ICHECK(opt) << "MakePackedAPI required the function to be annotated with "
                   "tvm::attr::kTarget ("
                << tvm::attr::kTarget
                << "), but the function only has attributes " << func->attrs;
    return opt.value();
  }();
  int target_device_type = target->GetTargetDeviceType();

  // A function without a host target has already been lowered.
  Target target_host;
  if (auto opt = target->GetHost()) {
    target_host = opt.value();
  } else {
    return func;
  }

  auto *func_ptr = func.CopyOnWrite();
241
  // set the global symbol to the packed function name
242
243
244
245
246
  const Stmt nop = Evaluate(0);
  int num_args = static_cast<int>(func_ptr->params.size());

  // Data field definitions
  // The packed fields
247
  Var v_self_handle("self_handle", DataType::Handle());
248
249
  Var v_packed_args("args", DataType::Handle());
  Var v_num_packed_args("num_args", DataType::Int(32));
250
  Var v_result("result", PointerType(PrimType(DataType::Void())));
251
252
253
254
255
256
257
258
259
260
261
262
263

  // The device context
  Var device_id("dev_id");
  Integer device_type(target_device_type);
  // seq_init gives sequence of initialization
  // seq_check gives sequence of later checks after init
  std::vector<Stmt> seq_init, seq_check, arg_buffer_declarations;
  std::unordered_map<const VarNode *, PrimExpr> vmap;
  ArgBinder binder(&vmap);

  // ---------------------------
  // local function definitions
  // load i-th argument as type t
264
265
  auto f_load_arg_value = [&](DataType arg_type, int i) {
    ffi::Array<PrimExpr> call_args{
266
        v_packed_args, IntImm(DataType::Int(32), i),
267
        IntImm(DataType::Int(32), builtin::kTVMFFIAnyUnionValue)};
268
    // load 64 bit version
269
    DataType api_type = APIType(arg_type);
270
271
    PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args);
    // cast to the target version.
272
273
    if (api_type != arg_type) {
      res = Cast(arg_type, res);
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    }
    return res;
  };

  // Assert correct type codes for each argument.  This must be done
  // *before* any initialization steps produced by
  // `binder.BindDLTensor()`.  The validity of those initialization
  // steps depends on the correct types being present, and must not
  // occur before the type codes are actually checked.
  seq_init.push_back(
      MakeAssertEQ(v_num_packed_args, num_args, [&]() -> std::string {
        std::ostringstream error_message;
        error_message << name_hint << ": num_args should be " << num_args;
        return error_message.str();
      }()));

290
291
292
293
  if (num_args > 0) {
    seq_init.push_back(
        MakeAssertNotNull(v_packed_args, name_hint + ": args pointer is NULL"));
  }
294
295
296
297
298
299
300
301

  // Need to delay binding of the buffers, in case some arguments also
  // appear in the buffer.
  std::vector<std::pair<PrimExpr, Var>> var_def;
  std::vector<std::pair<Var, Buffer>> buffer_def;

  for (int i = 0; i < static_cast<int>(func_ptr->params.size()); ++i) {
    Var param = func_ptr->params[i];
302
303
304
305
    PrimExpr arg_value;
    // type index checks
    Var type_index(param->name_hint + ".type_index", DataType::Int(32));
    seq_init.push_back(LetStmt(
306
        type_index,
307
308
309
        tir::Call(DataType::Int(32), builtin::tvm_struct_get(),
                  {v_packed_args, IntImm(DataType::Int(32), i),
                   IntImm(DataType::Int(32), builtin::kTVMFFIAnyTypeIndex)}),
310
        nop));
311
312
    DataType dtype = param.dtype();
    if (dtype.is_handle()) {
313
314
315
      std::ostringstream msg;
      msg << name_hint << ": Expect arg[" << i << "] to be pointer";
      seq_init.emplace_back(
316
317
318
319
          AssertStmt(type_index == ffi::TypeIndex::kTVMFFINone ||
                         type_index == ffi::TypeIndex::kTVMFFIOpaquePtr ||
                         type_index == ffi::TypeIndex::kTVMFFIDLTensorPtr ||
                         type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin,
320
                     tvm::tir::StringImm(msg.str()), nop));
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
      // if type_index is Tensor, we need to add the offset of the DLTensor
      // header which always equals 16 bytes, this ensures that T.handle always
      // shows up as a DLTensor*
      const int64_t object_cell_offset = sizeof(TVMFFIObject);
      static_assert(object_cell_offset == 24);
      arg_value = f_load_arg_value(param.dtype(), i);
      PrimExpr handle_from_tensor =
          Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(),
               {arg_value, IntImm(DataType::Int(32), object_cell_offset)});
      arg_value = Select(type_index == ffi::TypeIndex::kTVMFFITensor,
                         handle_from_tensor, arg_value);
    } else if (dtype.is_bool()) {
      std::ostringstream msg;
      msg << name_hint << ": Expect arg[" << i << "] to be boolean";
      seq_init.emplace_back(
          AssertStmt(type_index == ffi::TypeIndex::kTVMFFIBool ||
                         type_index == ffi::TypeIndex::kTVMFFIInt,
                     tvm::tir::StringImm(msg.str()), nop));
      arg_value =
          Cast(DataType::Bool(), f_load_arg_value(DataType::Int(64), i));

    } else if (dtype.is_int() || dtype.is_uint()) {
343
344
      std::ostringstream msg;
      msg << name_hint << ": Expect arg[" << i << "] to be int";
345
346
347
348
349
      seq_init.emplace_back(
          AssertStmt(type_index == ffi::TypeIndex::kTVMFFIInt ||
                         type_index == ffi::TypeIndex::kTVMFFIBool,
                     tvm::tir::StringImm(msg.str()), nop));
      arg_value = f_load_arg_value(param.dtype(), i);
350
    } else {
351
      ICHECK(dtype.is_float());
352
353
      std::ostringstream msg;
      msg << name_hint << ": Expect arg[" << i << "] to be float";
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
      seq_init.emplace_back(
          AssertStmt(type_index == ffi::TypeIndex::kTVMFFIFloat ||
                         type_index == ffi::TypeIndex::kTVMFFIInt ||
                         type_index == ffi::TypeIndex::kTVMFFIBool,
                     tvm::tir::StringImm(msg.str()), nop));
      // use select so we can also handle int conversion to bool
      arg_value = tir::Select(
          type_index == ffi::TypeIndex::kTVMFFIFloat,
          /* true_value = */ f_load_arg_value(param.dtype(), i),
          /* false_value = */
          Cast(param.dtype(), f_load_arg_value(DataType::Int(64), i)));
    }
    var_def.emplace_back(arg_value, param);
    if (func_ptr->buffer_map.count(param)) {
      // buffer binding now depends on type index
      // if the index is Tensor handle, we need to offset to get the DLTensor*
      buffer_def.emplace_back(param, func_ptr->buffer_map[param]);
371
372
373
    }
  }

374
375
376
377
  // signature: (void* handle, TVMFFIAny* packed_args, int num_args, TVMFFIAny*
  // v_result)
  ffi::Array<Var> args{v_self_handle, v_packed_args, v_num_packed_args,
                       v_result};
378
379
380
381
382
383
384
385
386
387
388
389

  // Arg definitions are defined before buffer binding to avoid the use before
  // def errors.
  //
  // For example, for auto broadcasting, checks are required to guarantee that
  // either 0 or the original stride will be correctly used. Checks here have
  // to use the args that may have no let binding yet. Therefore, hoisting let
  // binding for args before buffer declaration is needed.
  for (const auto &[expr, param] : var_def) {
    binder.Bind(param, expr, name_hint + "." + param->name_hint, true);
  }

390
391
392
393
  for (const auto &[var, buffer] : buffer_def) {
    binder.BindDLTensor(buffer, device_type, device_id, var,
                        name_hint + "." + var->name_hint);
    arg_buffer_declarations.push_back(DeclBuffer(buffer, nop));
394
  }
395
396
397
398
399
400
401
402
403
  // reset global symbol to attach prefix
  func = WithAttrs(
      std::move(func),
      {{tvm::attr::kCallingConv, static_cast<int>(CallingConv::kCPackedFunc)},
       {tvm::attr::kTarget, target_host},
       {tvm::attr::kGlobalSymbol,
        ffi::symbol::tvm_ffi_symbol_prefix + global_symbol.value()}});

  Stmt body = ReturnRewriter(v_result)(func_ptr->body);
404
405
406
407
  body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::compute_scope,
                  StringImm(name_hint + "_compute_"), body);
  // Set device context
  if (vmap.count(device_id.get())) {
408
    ffi::Any node = ffi::String("default");
409
410
411
412
    seq_check.push_back(AttrStmt(node, tir::attr::device_id, device_id, nop));
    seq_check.push_back(
        AttrStmt(node, tir::attr::device_type, device_type, nop));

413
    if (runtime::DeviceAPI::NeedSetDevice(target_device_type)) {
414
      Stmt set_device =
415
          Evaluate(Call(DataType::Int(32), tir::builtin::tvm_call_packed(),
416
417
418
419
420
421
422
423
424
                        {StringImm(runtime::symbol::tvm_set_device),
                         device_type, device_id}));
      body = SeqStmt({set_device, body});
    }
  }

  // Return error code of zero on success
  body = SeqStmt({body, Evaluate(ret(Integer(0)))});

425
426
427
  body = MergeNest({seq_init, binder.init_nest(), seq_check, binder.asserts(),
                    arg_buffer_declarations},
                   body);
428
429
430
  func_ptr->body = body;
  func_ptr->params = args;

431
432
  ffi::Array<Var> undefined = UndefinedVars(body, func_ptr->params);

433
434
435
436
  ICHECK_EQ(undefined.size(), 0)
      << "In PrimFunc " << name_hint << " variables " << undefined
      << " are used, but are not passed in as API arguments";

437
438
439
440
  func_ptr->buffer_map = ffi::Map<Var, Buffer>();
  func_ptr->ret_type = PrimType(DataType::Int(32));

  // return the function.
441
442
443
444
445
  return func;
}

tvm::transform::Pass MakePackedAPI() {
  using tvm::transform::Pass;
446
  auto pass_func = [](IRModule mod, const tvm::transform::PassContext &ctx) {
447
448
449
    Map<GlobalVar, String> packed_func_methods;
    for (const auto &[gvar, base_func] : mod->functions) {
      if (auto opt = base_func.as<PrimFunc>()) {
450
        const auto &prim_func = opt.value();
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        if (auto global_symbol = RequiresPackedAPI(prim_func)) {
          packed_func_methods.Set(gvar, global_symbol.value());
        }
      }
    }

    IRModuleNode *mptr = mod.CopyOnWrite();
    IRModule updates;

    for (const auto &[gvar, base_func] : mptr->functions) {
      if (auto opt = base_func.as<PrimFunc>()) {
        auto func = opt.value();
        auto orig_func = func;

        if (auto body = SubroutineCallRewriter::Apply(packed_func_methods,
                                                      func->body)) {
          func.CopyOnWrite()->body = body.value();
        }
        func = MakePackedAPI(std::move(func));

        if (!func.same_as(orig_func)) {
          updates->Add(gvar, func);
        }
      }
    }

477
    if (!updates->functions.empty()) {
478
479
480
481
482
483
484
485
      mod.CopyOnWrite()->Update(updates);
    }
    return mod;
  };

  return tvm::transform::CreateModulePass(pass_func, 0, "tl.MakePackedAPI", {});
}

486
TVM_FFI_STATIC_INIT_BLOCK() {
487
488
489
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def("tl.transform.MakePackedAPI",
                        []() { return MakePackedAPI(); });
490
}
491
492
493

} // namespace tl
} // namespace tvm