utils.cc 2.24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*!
 * \file tl/target/utils.cc
 * \brief helper functions for target attributes.
 */

#include "utils.h"

namespace tvm {
namespace tl {

11
12
13
14
15
16
bool TargetIsCuda(Target target) {
  return target->GetTargetDeviceType() == kDLCUDA;
}
bool TargetIsRocm(Target target) {
  return target->GetTargetDeviceType() == kDLROCM;
}
17
18
19
20

int GetArchInt(Target target) {
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
21
  const char *arch_str = s.value().c_str();
22
23
24
25
26
27
28
  ICHECK_EQ(arch_str[0], 's');
  ICHECK_EQ(arch_str[1], 'm');
  ICHECK_EQ(arch_str[2], '_');
  return atoi(&arch_str[3]);
}

bool TargetIsVolta(Target target) {
29
30
  if (!TargetIsCuda(target))
    return false;
31
32
33
34
35
  int arch = GetArchInt(target);
  return arch >= 70 && arch < 75;
}

bool TargetIsTuring(Target target) {
36
37
  if (!TargetIsCuda(target))
    return false;
38
39
40
41
42
  int arch = GetArchInt(target);
  return arch >= 75 && arch < 80;
}

bool TargetIsAmpere(Target target) {
43
44
  if (!TargetIsCuda(target))
    return false;
45
46
47
48
49
  int arch = GetArchInt(target);
  return arch >= 80 && arch < 90;
}

bool TargetIsHopper(Target target) {
50
51
  if (!TargetIsCuda(target))
    return false;
52
53
54
55
56
  int arch = GetArchInt(target);
  return arch >= 90;
}

bool TargetIsCDNA(Target target) {
57
58
  if (!TargetIsRocm(target))
    return false;
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  if (target->attrs.count("mcpu")) {
    std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
    // if mcpu start with "gfx9", it is CDNA
    return mcpu.find("gfx9") == 0;
  }
  return false;
}

bool TargetHasAsyncCopy(Target target) {
  if (TargetIsCuda(target)) {
    int arch = GetArchInt(target);
    return arch >= 80;
  } else if (TargetIsCDNA(target)) {
    if (target->attrs.count("mcpu")) {
      std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
      if (mcpu.rfind("gfx9", 0) == 0) {
        int gfx_version = std::stoi(mcpu.substr(3, 2));
        return gfx_version >= 94;
      }
      return false;
    } else {
      return false;
    }
  }

  return false;
}
bool TargetHasLdmatrix(Target target) {
87
88
  if (!TargetIsCuda(target))
    return false;
89
90
91
92
93
  int arch = GetArchInt(target);
  return arch >= 75;
}

bool TargetHasStmatrix(Target target) {
94
95
  if (!TargetIsCuda(target))
    return false;
96
97
98
99
  int arch = GetArchInt(target);
  return arch >= 90;
}

100
101
} // namespace tl
} // namespace tvm