hipblas_gemm.h 4.81 KB
Newer Older
yuguo's avatar
yuguo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
/*************************************************************************
 * Copyright (c) 2022-2024, S3000 qianyj. All rights reserved.
 ************************************************************************/

/*! \file hipblas_gemmn.h
 *  \brief Functions for blas instead blaslt in pure gemm
 */

#ifndef TRANSFORMER_ENGINE_COMMON_HIPBLAS_GEMM_H_
#define TRANSFORMER_ENGINE_COMMON_HIPBLAS_GEMM_H_

#include <hip/hip_runtime.h>
#ifdef USE_HIPBLASLT
#include <hipblas/hipblas.h>
#include <mutex>
#else
#include <rocblas/rocblas.h>
#endif
#include <stdexcept>
#include "../common_hip.h"
#include <iostream>


#ifdef USE_HIPBLASLT
class HipblasHandleManager {
public:
    HipblasHandleManager() {}

    ~HipblasHandleManager() {
        // Release all handles when the manager is destroyed
        for (auto& device_pair : handles_map_) {
            hipblasDestroy(device_pair.second);  // Only one handle per device
        }
    }

    // Get a handle for the given device (creates if necessary)
    hipblasHandle_t get(int device_id) {
        std::lock_guard<std::mutex> lock(mutex_);

        // Check if the handle for this device exists
        auto device_it = handles_map_.find(device_id);
        if (device_it != handles_map_.end()) {
            return device_it->second;
        }

        // Create a new handle for this device if it doesn't exist
        hipblasHandle_t handle;
        hipblasStatus_t status = hipblasCreate(&handle);
        if (status != HIPBLAS_STATUS_SUCCESS) {
            throw std::runtime_error("Failed to create HIPBLAS handle");
        }

        // Store the handle in the map for this device
        handles_map_[device_id] = handle;
        return handle;
    }

private:
    std::unordered_map<int, hipblasHandle_t> handles_map_;  // Map from device_id to hipblasHandle
    std::mutex mutex_;
};

namespace transformer_engine {
    void hipblas_gemm(const Tensor *inputA,
                 const Tensor *inputB,
                 Tensor *outputD,
                 const Tensor *inputBias,
                 Tensor *outputPreGelu,
                 int m, int n, int k,
                 int lda, int ldb, int ldd,
                 hipblasOperation_t transa,
                 hipblasOperation_t transb,
                 bool grad,
                 void* workspace,
                 size_t workspaceSize,
                 bool accumulate,
                 bool use_split_accumulator,
                 int math_sm_count,
                 int m_split,
                 int n_split,
                 bool gemm_producer,
                 const Tensor *inputCounter,
                 hipStream_t stream);

    void hipblas_batchgemm(const Tensor *inputA,
                 const Tensor *inputB,
                 Tensor *outputD,
                 const Tensor *inputBias,
                 Tensor *outputPreGelu,
                 int m, int n, int k,
                 int lda, int ldb, int ldd,
                 hipblasOperation_t transa,
                 hipblasOperation_t transb,
                 bool grad,
                 void* workspace,
                 size_t workspaceSize,
                 bool accumulate,
                 bool use_split_accumulator,
                 int math_sm_count,
                 int m_split,
                 int n_split,
                 bool gemm_producer,
                 const Tensor *inputCounter,
                 int batch_count,
                 hipStream_t stream);
}
#else

class HipblasHandleManager {
public:
    HipblasHandleManager() : handle_(nullptr) {}

    ~HipblasHandleManager() {
        // Release the handle in the destructor to ensure cleanup when it's no longer needed
        if (handle_ != nullptr) {
            rocblas_destroy_handle(handle_);
        }
    }

    // Get a handle to make sure it's valid every time
    rocblas_handle get() {
        if (handle_ == nullptr) {
            createHandle();
        }

        // Check whether the handle is created successfully
        assert(handle_ != nullptr && "hipblasHandle should not be null after creation");
        return handle_;
    }

private:
    rocblas_handle handle_;

    // 
    void createHandle() {
        // A private method that creates a handle
        rocblas_status status = rocblas_create_handle(&handle_);
        if (status != rocblas_status_success) {
            // If initialization fails, an exception is thrown
            throw std::runtime_error("Failed to create HIPBLAS handle");
        }
    }

    // Copy construct and assignment operations are prohibited
    HipblasHandleManager(const HipblasHandleManager&) = delete;
    HipblasHandleManager& operator=(const HipblasHandleManager&) = delete;
};
#endif // #ifdef USE_HIPBLASLT
#endif // TRANSFORMER_ENGINE_COMMON_HIPBLAS_GEMM_H_