#include #include #include #include #include #include namespace { using cuda_get_device_fn = int (*)(int *); using cuda_set_device_fn = int (*)(int); using hip_get_device_fn = int (*)(int *); using hip_set_device_fn = int (*)(int); struct Stats { std::atomic calls{0}; std::atomic ns{0}; }; Stats cuda_get_device_stats; Stats cuda_set_device_stats; Stats hip_get_device_stats; Stats hip_set_device_stats; void *resolve_symbol(const char *name, const char *const *libs) { auto *symbol = dlsym(RTLD_NEXT, name); if (symbol) { return symbol; } for (const char *const *lib = libs; *lib; ++lib) { void *handle = dlopen(*lib, RTLD_LAZY | RTLD_LOCAL); if (!handle) { continue; } symbol = dlsym(handle, name); if (symbol) { return symbol; } } std::fprintf(stderr, "probe_missing_symbol,%s,%s\n", name, dlerror()); std::abort(); } template Fn cuda_symbol(const char *name) { static const char *const libs[] = { "libcudart.so", nullptr, }; return reinterpret_cast(resolve_symbol(name, libs)); } template Fn hip_symbol(const char *name) { static const char *const libs[] = { "libamdhip64.so", nullptr, }; return reinterpret_cast(resolve_symbol(name, libs)); } template int measure(Stats &stats, Fn fn, Call call) { const auto start = std::chrono::steady_clock::now(); const int result = call(fn); const auto stop = std::chrono::steady_clock::now(); const auto ns = std::chrono::duration_cast(stop - start) .count(); stats.calls.fetch_add(1, std::memory_order_relaxed); stats.ns.fetch_add(static_cast(ns), std::memory_order_relaxed); return result; } void print_one(FILE *out, const char *name, const Stats &stats) { const auto calls = stats.calls.load(std::memory_order_relaxed); const auto ns = stats.ns.load(std::memory_order_relaxed); const double avg_ns = calls ? static_cast(ns) / calls : 0.0; if (calls) { std::fprintf(out, "%d,%s,%llu,%llu,%.3f\n", static_cast(getpid()), name, calls, ns, avg_ns); } } void print_summary() { const char *path = std::getenv("FASTPT_MRE_PROBE_LOG"); FILE *out = path ? std::fopen(path, "a") : stderr; if (!out) { out = stderr; } print_one(out, "cudaGetDevice", cuda_get_device_stats); print_one(out, "cudaSetDevice", cuda_set_device_stats); print_one(out, "hipGetDevice", hip_get_device_stats); print_one(out, "hipSetDevice", hip_set_device_stats); if (out != stderr) { std::fclose(out); } } struct AtExit { AtExit() { std::atexit(print_summary); } } at_exit; } // namespace extern "C" int cudaGetDevice(int *device) { static auto real = cuda_symbol("cudaGetDevice"); return measure(cuda_get_device_stats, real, [device](auto fn) { return fn(device); }); } extern "C" int cudaSetDevice(int device) { static auto real = cuda_symbol("cudaSetDevice"); return measure(cuda_set_device_stats, real, [device](auto fn) { return fn(device); }); } extern "C" int hipGetDevice(int *device) { static auto real = hip_symbol("hipGetDevice"); return measure(hip_get_device_stats, real, [device](auto fn) { return fn(device); }); } extern "C" int hipSetDevice(int device) { static auto real = hip_symbol("hipSetDevice"); return measure(hip_set_device_stats, real, [device](auto fn) { return fn(device); }); }