"src/git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "7a8339879aacf15138361f77e9f7d7caca2078ab"
utils.h 2.33 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_

#include <cstdint>
11
#include <numeric>
12
13
14
#include <stdexcept>
#include <string>
#include <type_traits>
Tim Moon's avatar
Tim Moon committed
15
16
17
18
19

#include <pybind11/pybind11.h>

#include "common/util/logging.h"
#include <transformer_engine/fused_attn.h>
20
21
22
23

namespace transformer_engine {
namespace jax {

24
25
26
int GetCudaRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id);

27
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
28
29
                           size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
                           cudaStream_t stream);
30

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class cudaDevicePropertiesManager {
 public:
    static cudaDevicePropertiesManager &Instance() {
        static thread_local cudaDevicePropertiesManager instance;
        return instance;
    }

    int GetMultiProcessorCount() {
        if (!prop_queried_) {
            int device_id;
            NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
            cudaGetDeviceProperties(&prop_, device_id);
            prop_queried_ = true;
        }
        return prop_.multiProcessorCount;
    }

48
49
50
51
52
53
54
55
56
57
    int GetMajor() {
        if (!prop_queried_) {
            int device_id;
            NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
            cudaGetDeviceProperties(&prop_, device_id);
            prop_queried_ = true;
        }
        return prop_.major;
    }

58
59
60
61
62
 private:
    bool prop_queried_ = false;
    cudaDeviceProp prop_;
};

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class FusedAttnOffsetManager {
 public:
    static FusedAttnOffsetManager &Instance() {
        static thread_local FusedAttnOffsetManager instance;
        return instance;
    }

    size_t GetAndUpdateOffset(size_t increment) {
        size_t ret = offset_;
        offset_ += increment;
        return ret;
    }

    FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
    void operator=(FusedAttnOffsetManager const &) = delete;

 private:
    FusedAttnOffsetManager() {}
    size_t offset_ = 0;
};

84
85
86
87
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_