utils.cc 4.61 KB
Newer Older
1
2
3
4
5
6
7
/*!
 * \file tl/target/utils.cc
 * \brief helper functions for target attributes.
 */

#include "utils.h"

8
9
10
#include "../support/ffi_aliases.h"
#include <tvm/node/node.h>

11
12
13
namespace tvm {
namespace tl {

14
15
16
17
18
19
bool TargetIsCuda(Target target) {
  return target->GetTargetDeviceType() == kDLCUDA;
}
bool TargetIsRocm(Target target) {
  return target->GetTargetDeviceType() == kDLROCM;
}
20
21

int GetArchInt(Target target) {
22
23
  auto s = target->GetAttr<tvm::ffi::String>("arch");
  ICHECK(s.has_value());
24
25
26
27
28
  const std::string arch_str = s.value();
  ICHECK(arch_str.size() >= 3);
  ICHECK_EQ(arch_str.compare(0, 3, "sm_"), 0)
      << "arch string must start with sm_";
  return std::stoi(arch_str.substr(3));
29
30
31
}

bool TargetIsVolta(Target target) {
32
33
  if (!TargetIsCuda(target))
    return false;
34
35
36
37
38
  int arch = GetArchInt(target);
  return arch >= 70 && arch < 75;
}

bool TargetIsTuring(Target target) {
39
40
  if (!TargetIsCuda(target))
    return false;
41
42
43
44
45
  int arch = GetArchInt(target);
  return arch >= 75 && arch < 80;
}

bool TargetIsAmpere(Target target) {
46
47
  if (!TargetIsCuda(target))
    return false;
48
49
50
51
52
  int arch = GetArchInt(target);
  return arch >= 80 && arch < 90;
}

bool TargetIsHopper(Target target) {
53
54
  if (!TargetIsCuda(target))
    return false;
55
  int arch = GetArchInt(target);
56
57
58
  return arch >= 90 && arch < 100;
}

59
60
61
62
bool TargetIsSm100(Target target) {
  if (!TargetIsCuda(target))
    return false;
  int arch = GetArchInt(target);
63
  return arch >= 100 & arch <= 110;
64
65
}

66
67
68
69
70
bool TargetIsSM120(Target target) {
  if (!TargetIsCuda(target))
    return false;
  int arch = GetArchInt(target);
  return arch >= 120 && arch < 130;
71
72
73
}

bool TargetIsCDNA(Target target) {
74
75
  if (!TargetIsRocm(target))
    return false;
76
  if (target->attrs.count("mcpu")) {
77
    std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
78
79
80
81
82
83
    // if mcpu start with "gfx9", it is CDNA
    return mcpu.find("gfx9") == 0;
  }
  return false;
}

Lukinon's avatar
Lukinon committed
84
85
86
87
88
89
90
91
92
93
bool TargetIsDCU(Target target) {
  if (!TargetIsRocm(target))
    return false;
  if (target->attrs.count("mcpu")) {
    // if mcpu start with "gfx936", it is DCU
    return mcpu.find("gfx936") == 0;
  }
  return false;
}

94
95
96
97
98
99
bool TargetHasAsyncCopy(Target target) {
  if (TargetIsCuda(target)) {
    int arch = GetArchInt(target);
    return arch >= 80;
  } else if (TargetIsCDNA(target)) {
    if (target->attrs.count("mcpu")) {
100
      std::string mcpu = Downcast<tvm::ffi::String>(target->attrs.at("mcpu"));
101
102
103
104
105
106
107
108
109
110
111
112
113
      if (mcpu.rfind("gfx9", 0) == 0) {
        int gfx_version = std::stoi(mcpu.substr(3, 2));
        return gfx_version >= 94;
      }
      return false;
    } else {
      return false;
    }
  }

  return false;
}
bool TargetHasLdmatrix(Target target) {
114
115
  if (!TargetIsCuda(target))
    return false;
116
117
118
119
120
  int arch = GetArchInt(target);
  return arch >= 75;
}

bool TargetHasStmatrix(Target target) {
121
122
  if (!TargetIsCuda(target))
    return false;
123
124
125
126
  int arch = GetArchInt(target);
  return arch >= 90;
}

127
128
129
130
131
132
bool TargetHasTmem(Target target) {
  if (!TargetIsCuda(target))
    return false;
  return TargetIsSm100(target);
}

133
134
135
136
137
138
139
bool TargetHasBulkCopy(Target target) {
  if (!TargetIsCuda(target))
    return false;
  int arch = GetArchInt(target);
  return arch >= 90;
}

140
141
142
143
144
145
146
int TargetGetWarpSize(Target target) {
  int res = 32;
  if (TargetIsCDNA(target))
    res = 64;
  return res;
}

147
TVM_FFI_STATIC_INIT_BLOCK() {
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
  namespace refl = tvm::ffi::reflection;
  refl::GlobalDef()
      .def("tl.TargetIsCuda",
           [](Target target) { return TargetIsCuda(target); })
      .def("tl.TargetIsRocm",
           [](Target target) { return TargetIsRocm(target); })
      .def("tl.TargetIsVolta",
           [](Target target) { return TargetIsVolta(target); })
      .def("tl.TargetIsTuring",
           [](Target target) { return TargetIsTuring(target); })
      .def("tl.TargetIsAmpere",
           [](Target target) { return TargetIsAmpere(target); })
      .def("tl.TargetIsHopper",
           [](Target target) { return TargetIsHopper(target); })
      .def("tl.TargetIsSM120",
           [](Target target) { return TargetIsSM120(target); })
      .def("tl.TargetIsCDNA",
           [](Target target) { return TargetIsCDNA(target); })
      .def("tl.TargetHasAsyncCopy",
           [](Target target) { return TargetHasAsyncCopy(target); })
      .def("tl.TargetHasLdmatrix",
           [](Target target) { return TargetHasLdmatrix(target); })
      .def("tl.TargetHasStmatrix",
           [](Target target) { return TargetHasStmatrix(target); })
      .def("tl.TargetHasBulkCopy",
           [](Target target) { return TargetHasBulkCopy(target); })
      .def("tl.TargetGetWarpSize",
           [](Target target) { return TargetGetWarpSize(target); });
176
}
177

178
179
} // namespace tl
} // namespace tvm