utils.h 4.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
/*************************************************************************
 * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_

#include <cstdint>
11
#include <numeric>
12
13
14
#include <stdexcept>
#include <string>
#include <type_traits>
Tim Moon's avatar
Tim Moon committed
15
16
17
18
19

#include <pybind11/pybind11.h>

#include "common/util/logging.h"
#include <transformer_engine/fused_attn.h>
20
21
22
23

namespace transformer_engine {
namespace jax {

24
25
26
int GetCudaRuntimeVersion();
int GetDeviceComputeCapability(int gpu_id);

27
void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
28
29
                           size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
                           cudaStream_t stream);
30

31
class WorkspaceManager {
32
 public:
33
34
    static WorkspaceManager &Instance() {
        static thread_local WorkspaceManager instance;
35
36
37
        return instance;
    }

38
39
    WorkspaceManager() {}
    ~WorkspaceManager() { Clear_(); }
40
41
42
43
44
45

    void *GetWorkspace(size_t size = 4194304) {
        ReallocateIfNeed_(size);
        return workspace_;
    }

46
47
48
49
50
51
52
53
54
55
56
57
58
59
    template <typename... Args>
    inline auto GetWorkspace(Args... args) {
        auto asks = std::array<size_t, sizeof...(Args)>{args...};
        std::array<size_t, sizeof...(Args) + 1> offsets = {0};
        std::array<void *, sizeof...(Args)> workspaces = {nullptr};
        std::transform_inclusive_scan(
            asks.cbegin(), asks.cend(), offsets.begin() + 1, std::plus<size_t>{},
            [=](auto x) { return PadSize_(x); }, 0);
        auto *workspace = GetWorkspace(offsets.back());
        std::transform(offsets.cbegin(), offsets.cend() - 1, workspaces.begin(),
                       [workspace](auto x) { return static_cast<char *>(workspace) + x; });
        return workspaces;
    }

60
61
62
63
 private:
    void *workspace_ = nullptr;
    size_t size_ = 0;

64
65
66
67
68
    size_t PadSize_(size_t size) {
        constexpr size_t alignment = 128;
        return ((size + alignment - 1) / alignment) * alignment;
    }

69
70
71
72
73
74
75
76
77
    void Clear_() {
        if (workspace_ != nullptr) {
            NVTE_CHECK_CUDA(cudaFree(workspace_));
        }
        workspace_ = nullptr;
        size_ = 0;
    }

    void Allocate_(size_t new_size) {
78
        new_size = PadSize_(new_size);
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        NVTE_CHECK_CUDA(cudaMalloc(&workspace_, new_size));
        size_ = new_size;
    }

    void ReallocateIfNeed_(size_t new_size) {
        if (new_size > size_) {
            Clear_();
            Allocate_(new_size);
        }
    }
};

class cudaDevicePropertiesManager {
 public:
    static cudaDevicePropertiesManager &Instance() {
        static thread_local cudaDevicePropertiesManager instance;
        return instance;
    }

    int GetMultiProcessorCount() {
        if (!prop_queried_) {
            int device_id;
            NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
            cudaGetDeviceProperties(&prop_, device_id);
            prop_queried_ = true;
        }
        return prop_.multiProcessorCount;
    }

108
109
110
111
112
113
114
115
116
117
    int GetMajor() {
        if (!prop_queried_) {
            int device_id;
            NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
            cudaGetDeviceProperties(&prop_, device_id);
            prop_queried_ = true;
        }
        return prop_.major;
    }

118
119
120
121
122
 private:
    bool prop_queried_ = false;
    cudaDeviceProp prop_;
};

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
class FusedAttnOffsetManager {
 public:
    static FusedAttnOffsetManager &Instance() {
        static thread_local FusedAttnOffsetManager instance;
        return instance;
    }

    size_t GetAndUpdateOffset(size_t increment) {
        size_t ret = offset_;
        offset_ += increment;
        return ret;
    }

    FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
    void operator=(FusedAttnOffsetManager const &) = delete;

 private:
    FusedAttnOffsetManager() {}
    size_t offset_ = 0;
};

144
145
146
147
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_