utils.h 2.2 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_

10
11
12
#include <pybind11/pybind11.h>
#include <transformer_engine/fused_attn.h>

13
#include <cstdint>
14
#include <numeric>
15
16
17
#include <stdexcept>
#include <string>
#include <type_traits>
Tim Moon's avatar
Tim Moon committed
18
19

#include "common/util/logging.h"
20
21
22
23

namespace transformer_engine {
namespace jax {

24
25
26
int GetCudaRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id);

27
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
28
29
                           size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
                           cudaStream_t stream);
30

31
32
class cudaDevicePropertiesManager {
 public:
33
34
35
36
37
38
39
40
41
42
43
  static cudaDevicePropertiesManager &Instance() {
    static thread_local cudaDevicePropertiesManager instance;
    return instance;
  }

  int GetMultiProcessorCount() {
    if (!prop_queried_) {
      int device_id;
      NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
      cudaGetDeviceProperties(&prop_, device_id);
      prop_queried_ = true;
44
    }
45
46
47
48
49
50
51
52
53
    return prop_.multiProcessorCount;
  }

  int GetMajor() {
    if (!prop_queried_) {
      int device_id;
      NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
      cudaGetDeviceProperties(&prop_, device_id);
      prop_queried_ = true;
54
    }
55
56
    return prop_.major;
  }
57

58
 private:
59
60
  bool prop_queried_ = false;
  cudaDeviceProp prop_;
61
62
};

63
64
class FusedAttnOffsetManager {
 public:
65
66
67
68
  static FusedAttnOffsetManager &Instance() {
    static thread_local FusedAttnOffsetManager instance;
    return instance;
  }
69

70
71
72
73
74
  size_t GetAndUpdateOffset(size_t increment) {
    size_t ret = offset_;
    offset_ += increment;
    return ret;
  }
75

76
77
  FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
  void operator=(FusedAttnOffsetManager const &) = delete;
78
79

 private:
80
81
  FusedAttnOffsetManager() {}
  size_t offset_ = 0;
82
83
};

84
85
86
87
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_