runtime_probe.cpp 3.64 KB
Newer Older
one's avatar
one committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#include <atomic>
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <dlfcn.h>
#include <unistd.h>

namespace {

using cuda_get_device_fn = int (*)(int *);
using cuda_set_device_fn = int (*)(int);
using hip_get_device_fn = int (*)(int *);
using hip_set_device_fn = int (*)(int);

struct Stats {
  std::atomic<unsigned long long> calls{0};
  std::atomic<unsigned long long> ns{0};
};

Stats cuda_get_device_stats;
Stats cuda_set_device_stats;
Stats hip_get_device_stats;
Stats hip_set_device_stats;

void *resolve_symbol(const char *name, const char *const *libs) {
  auto *symbol = dlsym(RTLD_NEXT, name);
  if (symbol) {
    return symbol;
  }

  for (const char *const *lib = libs; *lib; ++lib) {
    void *handle = dlopen(*lib, RTLD_LAZY | RTLD_LOCAL);
    if (!handle) {
      continue;
    }
    symbol = dlsym(handle, name);
    if (symbol) {
      return symbol;
    }
  }

  std::fprintf(stderr, "probe_missing_symbol,%s,%s\n", name, dlerror());
  std::abort();
}

template <typename Fn> Fn cuda_symbol(const char *name) {
  static const char *const libs[] = {
      "libcudart.so",
      nullptr,
  };
  return reinterpret_cast<Fn>(resolve_symbol(name, libs));
}

template <typename Fn> Fn hip_symbol(const char *name) {
  static const char *const libs[] = {
      "libamdhip64.so",
      nullptr,
  };
  return reinterpret_cast<Fn>(resolve_symbol(name, libs));
}

template <typename Fn, typename Call>
int measure(Stats &stats, Fn fn, Call call) {
  const auto start = std::chrono::steady_clock::now();
  const int result = call(fn);
  const auto stop = std::chrono::steady_clock::now();
  const auto ns =
      std::chrono::duration_cast<std::chrono::nanoseconds>(stop - start)
          .count();
  stats.calls.fetch_add(1, std::memory_order_relaxed);
  stats.ns.fetch_add(static_cast<unsigned long long>(ns),
                     std::memory_order_relaxed);
  return result;
}

void print_one(FILE *out, const char *name, const Stats &stats) {
  const auto calls = stats.calls.load(std::memory_order_relaxed);
  const auto ns = stats.ns.load(std::memory_order_relaxed);
  const double avg_ns = calls ? static_cast<double>(ns) / calls : 0.0;
  if (calls) {
    std::fprintf(out, "%d,%s,%llu,%llu,%.3f\n", static_cast<int>(getpid()),
                 name, calls, ns, avg_ns);
  }
}

void print_summary() {
  const char *path = std::getenv("FASTPT_MRE_PROBE_LOG");
  FILE *out = path ? std::fopen(path, "a") : stderr;
  if (!out) {
    out = stderr;
  }

  print_one(out, "cudaGetDevice", cuda_get_device_stats);
  print_one(out, "cudaSetDevice", cuda_set_device_stats);
  print_one(out, "hipGetDevice", hip_get_device_stats);
  print_one(out, "hipSetDevice", hip_set_device_stats);

  if (out != stderr) {
    std::fclose(out);
  }
}

struct AtExit {
  AtExit() { std::atexit(print_summary); }
} at_exit;

} // namespace

extern "C" int cudaGetDevice(int *device) {
  static auto real = cuda_symbol<cuda_get_device_fn>("cudaGetDevice");
  return measure(cuda_get_device_stats, real,
                 [device](auto fn) { return fn(device); });
}

extern "C" int cudaSetDevice(int device) {
  static auto real = cuda_symbol<cuda_set_device_fn>("cudaSetDevice");
  return measure(cuda_set_device_stats, real,
                 [device](auto fn) { return fn(device); });
}

extern "C" int hipGetDevice(int *device) {
  static auto real = hip_symbol<hip_get_device_fn>("hipGetDevice");
  return measure(hip_get_device_stats, real,
                 [device](auto fn) { return fn(device); });
}

extern "C" int hipSetDevice(int device) {
  static auto real = hip_symbol<hip_set_device_fn>("hipSetDevice");
  return measure(hip_set_device_stats, real,
                 [device](auto fn) { return fn(device); });
}