You need to sign in or sign up before continuing.
runtime.cc 7.53 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file tl/runtime/runtime.h
 * \brief Runtime functions.
 *
 */

#include "runtime.h"

#include "../target/cuda.h"
#include <tvm/runtime/registry.h>

namespace tvm {
namespace tl {

using namespace runtime;

20
template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
21
22
23
  std::stringstream ss;
  ss << "[";
  for (size_t i = 0; i < n; i++) {
24
25
    if (i > 0)
      ss << ", ";
26
27
28
29
30
31
32
    ss << ptr[i];
  }
  ss << "]";
  return ss.str();
}

struct TensorMapArgs {
33
  CUtensorMap *map;
34
35
  CUtensorMapDataType type;
  cuuint32_t tensorRank;
36
  void *globalAddress;
37
38
39
40
41
42
43
44
45
46
47
  cuuint64_t globalDim[5], globalStride[5];
  cuuint32_t boxDim[5], elementStrides[5];
  CUtensorMapInterleave interleave;
  CUtensorMapSwizzle swizzle;
  CUtensorMapL2promotion l2Promotion;
  CUtensorMapFloatOOBfill oobFill;

  static TensorMapArgs Extract(TVMArgs args) {
    TensorMapArgs T;
    int idx = 0;
    ICHECK(args.num_args >= 8);
48
49
50
    T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
    T.type =
        static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
    T.globalAddress = args[idx++];
    ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5);
    ICHECK(args.num_args == static_cast<int>(8 + T.tensorRank * 4));
    for (size_t i = 0; i < T.tensorRank; i++) {
      T.globalDim[i] = static_cast<cuuint64_t>(args[idx++]);
    }
    for (size_t i = 0; i < T.tensorRank; i++) {
      T.globalStride[i] = static_cast<cuuint64_t>(args[idx++]);
    }
    for (size_t i = 0; i < T.tensorRank; i++) {
      T.boxDim[i] = static_cast<cuuint64_t>(args[idx++]);
    }
    for (size_t i = 0; i < T.tensorRank; i++) {
      T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]);
    }
67
68
69
70
71
72
73
74
    T.interleave =
        static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
    T.swizzle =
        static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
    T.l2Promotion =
        static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
    T.oobFill =
        static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
75
76
77
78
79
80
81
82
83
84
85
86
    return T;
  }

  std::string ToDebugString() {
    std::stringstream ss;
    ss << "TMA Desc Addr:   " << map << std::endl
       << "format         " << type << std::endl
       << "dim            " << tensorRank << std::endl
       << "gmem_address   " << globalAddress << std::endl
       << "globalDim      " << ArrayToStr(globalDim, tensorRank) << std::endl
       << "globalStrides  " << ArrayToStr(globalStride, tensorRank) << std::endl
       << "boxDim         " << ArrayToStr(boxDim, tensorRank) << std::endl
87
88
       << "elementStrides " << ArrayToStr(elementStrides, tensorRank)
       << std::endl
89
90
91
92
93
94
95
96
97
       << "interleave     " << interleave << std::endl
       << "swizzle        " << swizzle << std::endl
       << "l2Promotion    " << l2Promotion << std::endl
       << "oobFill        " << oobFill << std::endl;
    return ss.str();
  }
};

// set device api
98
99
100
101
102
103
104
105
106
107
108
109
110
111
TVM_REGISTER_GLOBAL(tvm_tensormap_create_tiled)
    .set_body([](TVMArgs args, TVMRetValue *ret) {
      TensorMapArgs T = TensorMapArgs::Extract(args);
      CUresult result = cuTensorMapEncodeTiled(
          T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
          T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
          T.swizzle, T.l2Promotion, T.oobFill);
      if (result != CUDA_SUCCESS) {
        LOG_FATAL << "Failed to initialize the TMA descriptor " << result
                  << std::endl
                  << T.ToDebugString();
      }
      *ret = static_cast<int>(result);
    });
112
113

struct TensorMapIm2ColArgs {
114
  CUtensorMap *map;
115
116
  CUtensorMapDataType type;
  cuuint32_t tensorRank;
117
  void *globalAddress;
118
119
120
121
122
123
124
125
126
127
128
129
130
  cuuint64_t globalDim[5], globalStride[5];
  cuuint32_t elementStrides[5];
  int pixelBoxLowerCorner[3], pixelBoxUpperCorner[3];
  cuuint32_t smem_box_channel, smem_box_pixel;
  CUtensorMapInterleave interleave;
  CUtensorMapSwizzle swizzle;
  CUtensorMapL2promotion l2Promotion;
  CUtensorMapFloatOOBfill oobFill;

  static TensorMapIm2ColArgs Extract(TVMArgs args) {
    TensorMapIm2ColArgs T;
    int idx = 0;
    ICHECK(args.num_args >= 8);
131
132
133
    T.map = reinterpret_cast<CUtensorMap *>(static_cast<void *>(args[idx++]));
    T.type =
        static_cast<CUtensorMapDataType>(static_cast<int64_t>(args[idx++]));
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    T.tensorRank = static_cast<cuuint32_t>(static_cast<int64_t>(args[idx++]));
    T.globalAddress = args[idx++];
    ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5);
    ICHECK(args.num_args == static_cast<int>(6 + T.tensorRank * 5));
    for (size_t i = 0; i < T.tensorRank; i++) {
      T.globalDim[i] = static_cast<cuuint64_t>(args[idx++]);
    }
    for (size_t i = 0; i < T.tensorRank; i++) {
      T.globalStride[i] = static_cast<cuuint64_t>(args[idx++]);
    }
    for (size_t i = 0; i < T.tensorRank; i++) {
      T.elementStrides[i] = static_cast<cuuint64_t>(args[idx++]);
    }
    for (size_t i = 0; i < T.tensorRank - 2; i++) {
      T.pixelBoxLowerCorner[i] = static_cast<int>(args[idx++]);
    }
    for (size_t i = 0; i < T.tensorRank - 2; i++) {
      T.pixelBoxUpperCorner[i] = static_cast<int>(args[idx++]);
    }
    T.smem_box_pixel = static_cast<cuuint64_t>(args[idx++]);
    T.smem_box_channel = static_cast<cuuint64_t>(args[idx++]);
155
156
157
158
159
160
161
162
    T.interleave =
        static_cast<CUtensorMapInterleave>(static_cast<int64_t>(args[idx++]));
    T.swizzle =
        static_cast<CUtensorMapSwizzle>(static_cast<int64_t>(args[idx++]));
    T.l2Promotion =
        static_cast<CUtensorMapL2promotion>(static_cast<int64_t>(args[idx++]));
    T.oobFill =
        static_cast<CUtensorMapFloatOOBfill>(static_cast<int64_t>(args[idx++]));
163
164
165
166
167
168
169
170
171
172
173
174
175
    return T;
  }

  std::string ToDebugString() {
    std::stringstream ss;
    ss << "TMA Desc Addr:   " << map << std::endl
       << "format         " << type << std::endl
       << "dim            " << tensorRank << std::endl
       << "gmem_address   " << globalAddress << std::endl
       << "globalDim      " << ArrayToStr(globalDim, tensorRank) << std::endl
       << "globalStrides  " << ArrayToStr(globalStride, tensorRank) << std::endl
       << "smem_box_pixel " << smem_box_pixel << std::endl
       << "smem_box_channel " << smem_box_channel << std::endl
176
177
178
179
180
181
       << "pixelBoxLowerCorner  "
       << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl
       << "pixelBoxUpperCorner  "
       << ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl
       << "elementStrides " << ArrayToStr(elementStrides, tensorRank)
       << std::endl
182
183
184
185
186
187
188
189
       << "interleave     " << interleave << std::endl
       << "swizzle        " << swizzle << std::endl
       << "l2Promotion    " << l2Promotion << std::endl
       << "oobFill        " << oobFill << std::endl;
    return ss.str();
  }
};

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
TVM_REGISTER_GLOBAL(tvm_tensormap_create_im2col)
    .set_body([](TVMArgs args, TVMRetValue *ret) {
      TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
      CUresult result = cuTensorMapEncodeIm2col(
          T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
          T.globalStride + 1, T.pixelBoxLowerCorner, T.pixelBoxUpperCorner,
          T.smem_box_channel, T.smem_box_pixel, T.elementStrides, T.interleave,
          T.swizzle, T.l2Promotion, T.oobFill);
      if (result != CUDA_SUCCESS) {
        LOG_FATAL << "Failed to initialize the TMA descriptor " << result
                  << std::endl
                  << T.ToDebugString();
      }
      *ret = static_cast<int>(result);
    });
205

206
207
} // namespace tl
} // namespace tvm