runtime.cc 7.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*!
 * \file tl/runtime/runtime.h
 * \brief Runtime functions.
 *
 */

#include "runtime.h"

#include "../target/cuda.h"
10
11
#include <tvm/ffi/function.h>
#include <tvm/node/node.h>
12
13
14
15

namespace tvm {
namespace tl {

16
#if (CUDA_MAJOR_VERSION >= 12)
17
template <typename T> static std::string ArrayToStr(const T *ptr, size_t n) {
18
19
20
  std::stringstream ss;
  ss << "[";
  for (size_t i = 0; i < n; i++) {
21
22
    if (i > 0)
      ss << ", ";
23
24
25
26
27
28
29
    ss << ptr[i];
  }
  ss << "]";
  return ss.str();
}

struct TensorMapArgs {
30
  CUtensorMap *map;
31
32
  CUtensorMapDataType type;
  cuuint32_t tensorRank;
33
  void *globalAddress;
34
35
36
37
38
39
40
  cuuint64_t globalDim[5], globalStride[5];
  cuuint32_t boxDim[5], elementStrides[5];
  CUtensorMapInterleave interleave;
  CUtensorMapSwizzle swizzle;
  CUtensorMapL2promotion l2Promotion;
  CUtensorMapFloatOOBfill oobFill;

41
  static TensorMapArgs Extract(PackedArgs args) {
42
43
    TensorMapArgs T;
    int idx = 0;
44
45
46
47
48
    ICHECK(args.size() >= 8);
    T.map = reinterpret_cast<CUtensorMap *>(args[idx++].cast<void *>());
    T.type = static_cast<CUtensorMapDataType>(args[idx++].cast<int64_t>());
    T.tensorRank = static_cast<cuuint32_t>(args[idx++].cast<int64_t>());
    T.globalAddress = args[idx++].cast<void *>();
49
    ICHECK(T.tensorRank >= 1 && T.tensorRank <= 5);
50
    ICHECK(args.size() == static_cast<int>(8 + T.tensorRank * 4));
51
    for (size_t i = 0; i < T.tensorRank; i++) {
52
      T.globalDim[i] = args[idx++].cast<cuuint64_t>();
53
54
    }
    for (size_t i = 0; i < T.tensorRank; i++) {
55
      T.globalStride[i] = args[idx++].cast<cuuint64_t>();
56
57
    }
    for (size_t i = 0; i < T.tensorRank; i++) {
58
      T.boxDim[i] = args[idx++].cast<cuuint64_t>();
59
60
    }
    for (size_t i = 0; i < T.tensorRank; i++) {
61
      T.elementStrides[i] = args[idx++].cast<cuuint64_t>();
62
    }
63
    T.interleave =
64
65
        static_cast<CUtensorMapInterleave>(args[idx++].cast<int64_t>());
    T.swizzle = static_cast<CUtensorMapSwizzle>(args[idx++].cast<int64_t>());
66
    T.l2Promotion =
67
        static_cast<CUtensorMapL2promotion>(args[idx++].cast<int64_t>());
68
    T.oobFill =
69
        static_cast<CUtensorMapFloatOOBfill>(args[idx++].cast<int64_t>());
70
71
72
73
74
75
76
77
78
79
80
81
    return T;
  }

  std::string ToDebugString() {
    std::stringstream ss;
    ss << "TMA Desc Addr:   " << map << std::endl
       << "format         " << type << std::endl
       << "dim            " << tensorRank << std::endl
       << "gmem_address   " << globalAddress << std::endl
       << "globalDim      " << ArrayToStr(globalDim, tensorRank) << std::endl
       << "globalStrides  " << ArrayToStr(globalStride, tensorRank) << std::endl
       << "boxDim         " << ArrayToStr(boxDim, tensorRank) << std::endl
82
83
       << "elementStrides " << ArrayToStr(elementStrides, tensorRank)
       << std::endl
84
85
86
87
88
89
90
91
92
       << "interleave     " << interleave << std::endl
       << "swizzle        " << swizzle << std::endl
       << "l2Promotion    " << l2Promotion << std::endl
       << "oobFill        " << oobFill << std::endl;
    return ss.str();
  }
};

// set device api
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def_packed(
      "tvm_tensormap_create_tiled", [](PackedArgs args, Any *ret) {
        TensorMapArgs T = TensorMapArgs::Extract(args);
        CUresult result = cuTensorMapEncodeTiled(
            T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
            T.globalStride + 1, T.boxDim, T.elementStrides, T.interleave,
            T.swizzle, T.l2Promotion, T.oobFill);
        if (result != CUDA_SUCCESS) {
          LOG_FATAL << "Failed to initialize the TMA descriptor " << result
                    << std::endl
                    << T.ToDebugString();
        }
        *ret = static_cast<int>(result);
      });
});
110
111

struct TensorMapIm2ColArgs {
112
  CUtensorMap *map;
113
114
  CUtensorMapDataType type;
  cuuint32_t tensorRank;
115
  void *globalAddress;
116
117
118
119
120
121
122
123
124
  cuuint64_t globalDim[5], globalStride[5];
  cuuint32_t elementStrides[5];
  int pixelBoxLowerCorner[3], pixelBoxUpperCorner[3];
  cuuint32_t smem_box_channel, smem_box_pixel;
  CUtensorMapInterleave interleave;
  CUtensorMapSwizzle swizzle;
  CUtensorMapL2promotion l2Promotion;
  CUtensorMapFloatOOBfill oobFill;

125
  static TensorMapIm2ColArgs Extract(PackedArgs args) {
126
127
    TensorMapIm2ColArgs T;
    int idx = 0;
128
129
130
131
132
    ICHECK(args.size() >= 8);
    T.map = reinterpret_cast<CUtensorMap *>(args[idx++].cast<void *>());
    T.type = static_cast<CUtensorMapDataType>(args[idx++].cast<int64_t>());
    T.tensorRank = static_cast<cuuint32_t>(args[idx++].cast<int64_t>());
    T.globalAddress = args[idx++].cast<void *>();
133
    ICHECK(T.tensorRank >= 3 && T.tensorRank <= 5);
134
    ICHECK(args.size() == static_cast<int>(6 + T.tensorRank * 5));
135
    for (size_t i = 0; i < T.tensorRank; i++) {
136
      T.globalDim[i] = args[idx++].cast<cuuint64_t>();
137
138
    }
    for (size_t i = 0; i < T.tensorRank; i++) {
139
      T.globalStride[i] = args[idx++].cast<cuuint64_t>();
140
141
    }
    for (size_t i = 0; i < T.tensorRank; i++) {
142
      T.elementStrides[i] = args[idx++].cast<cuuint64_t>();
143
144
    }
    for (size_t i = 0; i < T.tensorRank - 2; i++) {
145
      T.pixelBoxLowerCorner[i] = args[idx++].cast<int>();
146
147
    }
    for (size_t i = 0; i < T.tensorRank - 2; i++) {
148
      T.pixelBoxUpperCorner[i] = args[idx++].cast<int>();
149
    }
150
151
    T.smem_box_pixel = args[idx++].cast<cuuint64_t>();
    T.smem_box_channel = args[idx++].cast<cuuint64_t>();
152
    T.interleave =
153
154
        static_cast<CUtensorMapInterleave>(args[idx++].cast<int64_t>());
    T.swizzle = static_cast<CUtensorMapSwizzle>(args[idx++].cast<int64_t>());
155
    T.l2Promotion =
156
        static_cast<CUtensorMapL2promotion>(args[idx++].cast<int64_t>());
157
    T.oobFill =
158
        static_cast<CUtensorMapFloatOOBfill>(args[idx++].cast<int64_t>());
159
160
161
162
163
164
165
166
167
168
169
170
171
    return T;
  }

  std::string ToDebugString() {
    std::stringstream ss;
    ss << "TMA Desc Addr:   " << map << std::endl
       << "format         " << type << std::endl
       << "dim            " << tensorRank << std::endl
       << "gmem_address   " << globalAddress << std::endl
       << "globalDim      " << ArrayToStr(globalDim, tensorRank) << std::endl
       << "globalStrides  " << ArrayToStr(globalStride, tensorRank) << std::endl
       << "smem_box_pixel " << smem_box_pixel << std::endl
       << "smem_box_channel " << smem_box_channel << std::endl
172
173
174
175
176
177
       << "pixelBoxLowerCorner  "
       << ArrayToStr(pixelBoxLowerCorner, tensorRank - 2) << std::endl
       << "pixelBoxUpperCorner  "
       << ArrayToStr(pixelBoxUpperCorner, tensorRank - 2) << std::endl
       << "elementStrides " << ArrayToStr(elementStrides, tensorRank)
       << std::endl
178
179
180
181
182
183
184
185
       << "interleave     " << interleave << std::endl
       << "swizzle        " << swizzle << std::endl
       << "l2Promotion    " << l2Promotion << std::endl
       << "oobFill        " << oobFill << std::endl;
    return ss.str();
  }
};

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
TVM_FFI_STATIC_INIT_BLOCK({
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef().def_packed(
      "tvm_tensormap_create_im2col", [](PackedArgs args, Any *ret) {
        TensorMapIm2ColArgs T = TensorMapIm2ColArgs::Extract(args);
        CUresult result = cuTensorMapEncodeIm2col(
            T.map, T.type, T.tensorRank, T.globalAddress, T.globalDim,
            T.globalStride + 1, T.pixelBoxLowerCorner, T.pixelBoxUpperCorner,
            T.smem_box_channel, T.smem_box_pixel, T.elementStrides,
            T.interleave, T.swizzle, T.l2Promotion, T.oobFill);
        if (result != CUDA_SUCCESS) {
          LOG_FATAL << "Failed to initialize the TMA descriptor " << result
                    << std::endl
                    << T.ToDebugString();
        }
        *ret = static_cast<int>(result);
      });
});

205
#endif // (CUDA_MAJOR_VERSION >= 12)
206

207
208
} // namespace tl
} // namespace tvm