utils.h 2.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
/*************************************************************************
 * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_

#include <pybind11/pybind11.h>

#include <cstdint>
#include <stdexcept>
#include <string>
#include <type_traits>
#include "transformer_engine/logging.h"

namespace transformer_engine {
namespace jax {

class cublasLtMetaManager {
 public:
    static cublasLtMetaManager &Instance() {
        static thread_local cublasLtMetaManager instance;
        return instance;
    }

    cublasLtMetaManager() {}
    ~cublasLtMetaManager() { Clear_(); }

    void *GetWorkspace(size_t size = 4194304) {
        ReallocateIfNeed_(size);
        return workspace_;
    }

 private:
    void *workspace_ = nullptr;
    size_t size_ = 0;

    void Clear_() {
        if (workspace_ != nullptr) {
            NVTE_CHECK_CUDA(cudaFree(workspace_));
        }
        workspace_ = nullptr;
        size_ = 0;
    }

    void Allocate_(size_t new_size) {
        NVTE_CHECK_CUDA(cudaMalloc(&workspace_, new_size));
        size_ = new_size;
    }

    void ReallocateIfNeed_(size_t new_size) {
        if (new_size > size_) {
            Clear_();
            Allocate_(new_size);
        }
    }
};

class cudaDevicePropertiesManager {
 public:
    static cudaDevicePropertiesManager &Instance() {
        static thread_local cudaDevicePropertiesManager instance;
        return instance;
    }

    int GetMultiProcessorCount() {
        if (!prop_queried_) {
            int device_id;
            NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
            cudaGetDeviceProperties(&prop_, device_id);
            prop_queried_ = true;
        }
        return prop_.multiProcessorCount;
    }

78
79
80
81
82
83
84
85
86
87
    int GetMajor() {
        if (!prop_queried_) {
            int device_id;
            NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
            cudaGetDeviceProperties(&prop_, device_id);
            prop_queried_ = true;
        }
        return prop_.major;
    }

88
89
90
91
92
93
94
95
96
 private:
    bool prop_queried_ = false;
    cudaDeviceProp prop_;
};

}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_