utils.cc 2.31 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

/*!
 * \file tl/target/utils.cc
 * \brief helper functions for target attributes.
 */

#include "utils.h"

namespace tvm {
namespace tl {

14
15
16
17
18
19
bool TargetIsCuda(Target target) {
  return target->GetTargetDeviceType() == kDLCUDA;
}
bool TargetIsRocm(Target target) {
  return target->GetTargetDeviceType() == kDLROCM;
}
20
21
22
23

int GetArchInt(Target target) {
  auto s = target->GetAttr<String>("arch");
  ICHECK(s.defined());
24
  const char *arch_str = s.value().c_str();
25
26
27
28
29
30
31
  ICHECK_EQ(arch_str[0], 's');
  ICHECK_EQ(arch_str[1], 'm');
  ICHECK_EQ(arch_str[2], '_');
  return atoi(&arch_str[3]);
}

bool TargetIsVolta(Target target) {
32
33
  if (!TargetIsCuda(target))
    return false;
34
35
36
37
38
  int arch = GetArchInt(target);
  return arch >= 70 && arch < 75;
}

bool TargetIsTuring(Target target) {
39
40
  if (!TargetIsCuda(target))
    return false;
41
42
43
44
45
  int arch = GetArchInt(target);
  return arch >= 75 && arch < 80;
}

bool TargetIsAmpere(Target target) {
46
47
  if (!TargetIsCuda(target))
    return false;
48
49
50
51
52
  int arch = GetArchInt(target);
  return arch >= 80 && arch < 90;
}

bool TargetIsHopper(Target target) {
53
54
  if (!TargetIsCuda(target))
    return false;
55
56
57
58
59
  int arch = GetArchInt(target);
  return arch >= 90;
}

bool TargetIsCDNA(Target target) {
60
61
  if (!TargetIsRocm(target))
    return false;
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
  if (target->attrs.count("mcpu")) {
    std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
    // if mcpu start with "gfx9", it is CDNA
    return mcpu.find("gfx9") == 0;
  }
  return false;
}

bool TargetHasAsyncCopy(Target target) {
  if (TargetIsCuda(target)) {
    int arch = GetArchInt(target);
    return arch >= 80;
  } else if (TargetIsCDNA(target)) {
    if (target->attrs.count("mcpu")) {
      std::string mcpu = Downcast<String>(target->attrs.at("mcpu"));
      if (mcpu.rfind("gfx9", 0) == 0) {
        int gfx_version = std::stoi(mcpu.substr(3, 2));
        return gfx_version >= 94;
      }
      return false;
    } else {
      return false;
    }
  }

  return false;
}
bool TargetHasLdmatrix(Target target) {
90
91
  if (!TargetIsCuda(target))
    return false;
92
93
94
95
96
  int arch = GetArchInt(target);
  return arch >= 75;
}

bool TargetHasStmatrix(Target target) {
97
98
  if (!TargetIsCuda(target))
    return false;
99
100
101
102
  int arch = GetArchInt(target);
  return arch >= 90;
}

103
104
} // namespace tl
} // namespace tvm