utils.h 2.29 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_

10
11
12
#include <pybind11/pybind11.h>
#include <transformer_engine/fused_attn.h>

13
#include <cstdint>
14
#include <numeric>
15
16
17
#include <stdexcept>
#include <string>
#include <type_traits>
Tim Moon's avatar
Tim Moon committed
18
19

#include "common/util/logging.h"
20
21
22
23

namespace transformer_engine {
namespace jax {

24
25
26
int GetCudaRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id);

27
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
28
29
                           size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
                           cudaStream_t stream);
30

31
32
uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);

33
34
class cudaDevicePropertiesManager {
 public:
35
36
37
38
39
40
41
42
43
44
45
  static cudaDevicePropertiesManager &Instance() {
    static thread_local cudaDevicePropertiesManager instance;
    return instance;
  }

  int GetMultiProcessorCount() {
    if (!prop_queried_) {
      int device_id;
      NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
      cudaGetDeviceProperties(&prop_, device_id);
      prop_queried_ = true;
46
    }
47
48
49
50
51
52
53
54
55
    return prop_.multiProcessorCount;
  }

  int GetMajor() {
    if (!prop_queried_) {
      int device_id;
      NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
      cudaGetDeviceProperties(&prop_, device_id);
      prop_queried_ = true;
56
    }
57
58
    return prop_.major;
  }
59

60
 private:
61
62
  bool prop_queried_ = false;
  cudaDeviceProp prop_;
63
64
};

65
66
class FusedAttnOffsetManager {
 public:
67
68
69
70
  static FusedAttnOffsetManager &Instance() {
    static thread_local FusedAttnOffsetManager instance;
    return instance;
  }
71

72
73
74
75
76
  size_t GetAndUpdateOffset(size_t increment) {
    size_t ret = offset_;
    offset_ += increment;
    return ret;
  }
77

78
79
  FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
  void operator=(FusedAttnOffsetManager const &) = delete;
80
81

 private:
82
83
  FusedAttnOffsetManager() {}
  size_t offset_ = 0;
84
85
};

86
87
88
89
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_