utils.h 2.33 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_

10
11
12
#include <pybind11/pybind11.h>
#include <transformer_engine/fused_attn.h>

13
#include <cstdint>
14
#include <numeric>
15
16
17
#include <stdexcept>
#include <string>
#include <type_traits>
Tim Moon's avatar
Tim Moon committed
18
19

#include "common/util/logging.h"
20
21
22
23

namespace transformer_engine {
namespace jax {

24
int GetCudaRuntimeVersion();
25
size_t GetCudnnRuntimeVersion();
26
27
int GetDeviceComputeCapability(int gpu_id);

28
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
29
30
                           size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
                           cudaStream_t stream);
31

32
33
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);

34
35
class cudaDevicePropertiesManager {
 public:
36
37
38
39
40
41
42
43
44
45
46
  static cudaDevicePropertiesManager &Instance() {
    static thread_local cudaDevicePropertiesManager instance;
    return instance;
  }

  int GetMultiProcessorCount() {
    if (!prop_queried_) {
      int device_id;
      NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
      cudaGetDeviceProperties(&prop_, device_id);
      prop_queried_ = true;
47
    }
48
49
50
51
52
53
54
55
56
    return prop_.multiProcessorCount;
  }

  int GetMajor() {
    if (!prop_queried_) {
      int device_id;
      NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
      cudaGetDeviceProperties(&prop_, device_id);
      prop_queried_ = true;
57
    }
58
59
    return prop_.major;
  }
60

61
 private:
62
63
  bool prop_queried_ = false;
  cudaDeviceProp prop_;
64
65
};

66
67
class FusedAttnOffsetManager {
 public:
68
69
70
71
  static FusedAttnOffsetManager &Instance() {
    static thread_local FusedAttnOffsetManager instance;
    return instance;
  }
72

73
74
75
76
77
  size_t GetAndUpdateOffset(size_t increment) {
    size_t ret = offset_;
    offset_ += increment;
    return ret;
  }
78

79
80
  FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
  void operator=(FusedAttnOffsetManager const &) = delete;
81
82

 private:
83
84
  FusedAttnOffsetManager() {}
  size_t offset_ = 0;
85
86
};

87
88
89
90
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_