utils.h 1.4 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
9
#include <pybind11/pybind11.h>
#include <transformer_engine/fused_attn.h>

10
#include <cstdint>
11
#include <numeric>
12
13
14
#include <stdexcept>
#include <string>
#include <type_traits>
Tim Moon's avatar
Tim Moon committed
15
16

#include "common/util/logging.h"
17
18
19
20

namespace transformer_engine {
namespace jax {

21
int GetCudaRuntimeVersion();
22
size_t GetCudnnRuntimeVersion();
23
24
int GetDeviceComputeCapability(int gpu_id);

25
26
class cudaDevicePropertiesManager {
 public:
27
28
29
30
31
32
33
34
35
36
37
  static cudaDevicePropertiesManager &Instance() {
    static thread_local cudaDevicePropertiesManager instance;
    return instance;
  }

  int GetMultiProcessorCount() {
    if (!prop_queried_) {
      int device_id;
      NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
      cudaGetDeviceProperties(&prop_, device_id);
      prop_queried_ = true;
38
    }
39
40
41
42
43
44
45
46
47
    return prop_.multiProcessorCount;
  }

  int GetMajor() {
    if (!prop_queried_) {
      int device_id;
      NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
      cudaGetDeviceProperties(&prop_, device_id);
      prop_queried_ = true;
48
    }
49
50
    return prop_.major;
  }
51

52
 private:
53
54
  bool prop_queried_ = false;
  cudaDeviceProp prop_;
55
56
57
58
};

}  // namespace jax
}  // namespace transformer_engine