Unverified Commit 557f0cb5 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Use versioned flavor of get driver entrypoint function (#1835)



* Use versioned flavor of get driver entrypoint function
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Update the check to call the versioned API starting with CUDA 12.5 where
it was added
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Dynamically find entrypoint functions
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Error checking
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Lint fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent f64d1459
...@@ -13,10 +13,32 @@ namespace transformer_engine { ...@@ -13,10 +13,32 @@ namespace transformer_engine {
namespace cuda_driver { namespace cuda_driver {
void *get_symbol(const char *symbol) { typedef cudaError_t (*VersionedGetEntryPoint)(const char *, void **, unsigned int,
void *entry_point; unsigned long long, // NOLINT(*)
cudaDriverEntryPointQueryResult *);
typedef cudaError_t (*GetEntryPoint)(const char *, void **, unsigned long long, // NOLINT(*)
cudaDriverEntryPointQueryResult *);
void *get_symbol(const char *symbol, int cuda_version) {
constexpr char driver_entrypoint[] = "cudaGetDriverEntryPoint";
constexpr char driver_entrypoint_versioned[] = "cudaGetDriverEntryPointByVersion";
// We link to the libcudart.so already, so can search for it in the current context
static GetEntryPoint driver_entrypoint_fun =
reinterpret_cast<GetEntryPoint>(dlsym(RTLD_DEFAULT, driver_entrypoint));
static VersionedGetEntryPoint driver_entrypoint_versioned_fun =
reinterpret_cast<VersionedGetEntryPoint>(dlsym(RTLD_DEFAULT, driver_entrypoint_versioned));
cudaDriverEntryPointQueryResult driver_result; cudaDriverEntryPointQueryResult driver_result;
NVTE_CHECK_CUDA(cudaGetDriverEntryPoint(symbol, &entry_point, cudaEnableDefault, &driver_result)); void *entry_point = nullptr;
if (driver_entrypoint_versioned_fun != nullptr) {
// Found versioned entrypoint function
NVTE_CHECK_CUDA(driver_entrypoint_versioned_fun(symbol, &entry_point, cuda_version,
cudaEnableDefault, &driver_result));
} else {
NVTE_CHECK(driver_entrypoint_fun != nullptr, "Error finding the CUDA Runtime-Driver interop.");
// Versioned entrypoint function not found
NVTE_CHECK_CUDA(driver_entrypoint_fun(symbol, &entry_point, cudaEnableDefault, &driver_result));
}
NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess, NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess,
"Could not find CUDA driver entry point for ", symbol); "Could not find CUDA driver entry point for ", symbol);
return entry_point; return entry_point;
......
...@@ -19,7 +19,7 @@ namespace transformer_engine { ...@@ -19,7 +19,7 @@ namespace transformer_engine {
namespace cuda_driver { namespace cuda_driver {
/*! \brief Get pointer corresponding to symbol in CUDA driver library */ /*! \brief Get pointer corresponding to symbol in CUDA driver library */
void *get_symbol(const char *symbol); void *get_symbol(const char *symbol, int cuda_version = 12010);
/*! \brief Call function in CUDA driver library /*! \brief Call function in CUDA driver library
* *
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment