Unverified Commit d097883e authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Better way of checking cuDNN version (#485)



* Ability to check cuDNN version from Python
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Modify the fused attention test to not use the CUDNN_VERSION env
variable which is specific to NGC containers
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent 136acacb
......@@ -44,7 +44,15 @@ from test_numerics import get_dummy_cuda_rng_tracker, reset_rng_states
fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available()
_flash_attn_version = packaging.version.Version(version("flash-attn"))
_flash_attn_2_available = _flash_attn_version >= packaging.version.Version("2")
_cudnn_version = [int(i) for i in os.environ['CUDNN_VERSION'].split('.')]
def _get_cudnn_version():
cudnn_version_encoded = ext.get_cudnn_version()
cudnn_major = cudnn_version_encoded // 1000
cudnn_minor = (cudnn_version_encoded - cudnn_major * 1000) // 100
cudnn_patch = cudnn_version_encoded - 1000 * cudnn_major - 100 * cudnn_minor
return [cudnn_major, cudnn_minor, cudnn_patch]
_cudnn_version = _get_cudnn_version()
class ModelConfig:
......
......@@ -31,6 +31,7 @@
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cublasLt.h>
#include <cudnn.h>
#include <stdexcept>
#include <memory>
#include <iomanip>
......
......@@ -524,6 +524,8 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
size_t get_cublasLt_version();
size_t get_cudnn_version();
bool userbuf_comm_available();
void placeholder();
......@@ -13,6 +13,10 @@ size_t get_cublasLt_version() {
return cublasLtGetVersion();
}
size_t get_cudnn_version() {
return cudnnGetVersion();
}
bool userbuf_comm_available() { // TODO(ksivamani) check on python side
#ifdef NVTE_WITH_USERBUFFERS
......
......@@ -77,6 +77,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version");
m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available");
// Data structures
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment