hipblas_gemm.cu 8.26 KB
Newer Older
yuguo's avatar
yuguo committed
1
2
3
4
5
6
7
8
9
10
11
/*************************************************************************
 * Copyright (c) 2022-2024, S3000 qianyj. All rights reserved.
 ************************************************************************/

#include <hip/hip_runtime.h>
#include "hipblas_gemm.h"
#include "../common_hip.h"
#include "../util/logging.h"

namespace {

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
template<typename T>
void printTensor(const std::string& str, const T* devTensor, size_t size) {
  T* hostTensor;
  hostTensor = (T*)malloc(size * sizeof(T));
  hipMemcpy(hostTensor, devTensor, size * sizeof(T), hipMemcpyDeviceToHost);
  std::cout << str << ": ";
  for(int i; i<size; i++) {
    if (i % 16 == 0) {
      std::cout << std::endl;
    }
    std::cout << static_cast<float>(hostTensor[i]) << ", ";
  }
  std::cout << str << ": finish" << std::endl;
  free(hostTensor);
}

yuguo's avatar
yuguo committed
28
29
30
31
32
33
34
35
hipblasDatatype_t get_hip_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kFloat16:
      return HIPBLAS_R_16F;
    case DType::kFloat32:
      return HIPBLAS_R_32F;
    case DType::kBFloat16:
36
37
38
39
40
      return HIPBLAS_R_16B;
    case DType::kInt8:
      return HIPBLAS_R_8I; 
    case DType::kInt32:
      return HIPBLAS_R_32I;      
yuguo's avatar
yuguo committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    default:
      NVTE_ERROR("Invalid type");
  }
}

}  // namespace

// Define a static handle manager
static HipblasHandleManager handleManager;

namespace transformer_engine {

void hipblas_gemm(const Tensor *inputA,
                 const Tensor *inputB,
                 Tensor *outputD,
                 const Tensor *inputBias,
                 Tensor *outputPreGelu,
                 int m, int n, int k,
                 int lda, int ldb, int ldd,
                 hipblasOperation_t transa,
                 hipblasOperation_t transb,
                 bool grad,
                 void* workspace,
                 size_t workspaceSize,
                 bool accumulate,
                 bool use_split_accumulator,
                 int math_sm_count,
                 int m_split,
                 int n_split,
                 bool gemm_producer,
                 const Tensor *inputCounter,
                 hipStream_t stream) {
    // Use static handles
    int device_id;
    hipGetDevice(&device_id);
    hipblasHandle_t handle = handleManager.get(device_id);
    void *A = inputA->data.dptr;
    // void *A_scale_inverse = inputA->scale_inv.dptr;
    void *B = inputB->data.dptr;
    // void *B_scale_inverse = inputB->scale_inv.dptr;
    void *C = outputD->data.dptr;
    void *D = outputD->data.dptr;


    // Select the calculation accuracy
    hipblasDatatype_t A_type = get_hip_dtype(inputA->data.dtype);
    hipblasDatatype_t B_type = get_hip_dtype(inputB->data.dtype);
    hipblasDatatype_t D_type = get_hip_dtype(outputD->data.dtype);
    hipblasDatatype_t computeType = HIPBLAS_R_32F; // default acc is float32

    // setting computetype
    // if (/* condition for mixed precision */) {
    //     computeType = HIPBLAS_R_16F; // 
    // }
    // hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
    // const char *env_tf32 = std::getenv("NVTE_BLASLT_TF32");
    // if (env_tf32 != nullptr && env_tf32[0] == '1') {
    // if (A_type == HIPBLAS_R_32F && B_type == HIPBLAS_R_32F && D_type == HIPBLAS_R_32F) {
    //     gemm_compute_type = HIPBLAS_COMPUTE_32F_FAST_TF32;
    // }

    float one = 1.0f;
    float zero = 0.0f;
    float beta = accumulate ? one : zero;
105
106
107
108
109
110
111
112
113
114
    int int_one = 1;
    int int_zero = 0;
    int int_beta = int_zero;
    bool use_int8 = false;
    
    if ((A_type == HIPBLAS_R_8I) && (B_type == HIPBLAS_R_8I) && (D_type == HIPBLAS_R_32I)) {
      NVTE_CHECK(!accumulate, "Int8 gemm not support accumulate."); 
      use_int8 = true;
      computeType = HIPBLAS_R_32I;
    }
yuguo's avatar
yuguo committed
115
116
117
118
119
120
121
122
123
124
  
    hipblasSetStream(handle, stream);
    // execute multiply
    hipblasStatus_t status = hipblasGemmEx(
                                       handle,
                                       transa,   // transa
                                       transb,   // transb
                                       m,
                                       n,
                                       k,
125
                                       use_int8 ? static_cast<const void*>(&int_one) : static_cast<const void*>(&one), 
yuguo's avatar
yuguo committed
126
127
128
129
130
131
                                       A,
                                       A_type,
                                       lda,
                                       B,
                                       B_type,
                                       ldb,
132
                                       use_int8 ? static_cast<const void*>(&int_beta) : static_cast<const void*>(&beta), 
yuguo's avatar
yuguo committed
133
134
135
136
137
                                       D,
                                       D_type,
                                       ldd,
                                       computeType,
                                       HIPBLAS_GEMM_DEFAULT);
138
    // printTensor<int32_t>("D_tensor: ", reinterpret_cast<int32_t*>(D), 10);
yuguo's avatar
yuguo committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    if (status != HIPBLAS_STATUS_SUCCESS) {
        NVTE_ERROR("hipblasGemmEx execution failed");
    }
}

void hipblas_batchgemm(const Tensor *inputA,
                 const Tensor *inputB,
                 Tensor *outputD,
                 const Tensor *inputBias,
                 Tensor *outputPreGelu,
                 int m, int n, int k,
                 int lda, int ldb, int ldd,
                 hipblasOperation_t transa,
                 hipblasOperation_t transb,
                 bool grad,
                 void* workspace,
                 size_t workspaceSize,
                 bool accumulate,
                 bool use_split_accumulator,
                 int math_sm_count,
                 int m_split,
                 int n_split,
                 bool gemm_producer,
                 const Tensor *inputCounter,
                 int batch_count,
                 hipStream_t stream) {
    // Use static handles
    int device_id;
    hipGetDevice(&device_id);
    hipblasHandle_t handle = handleManager.get(device_id);
    void *A = inputA->data.dptr;
    // void *A_scale_inverse = inputA->scale_inv.dptr;
    void *B = inputB->data.dptr;
    // void *B_scale_inverse = inputB->scale_inv.dptr;
    void *C = outputD->data.dptr;
    void *D = outputD->data.dptr;

    // Select the calculation accuracy
    hipblasDatatype_t A_type = get_hip_dtype(inputA->data.dtype);
    hipblasDatatype_t B_type = get_hip_dtype(inputB->data.dtype);
    hipblasDatatype_t D_type = get_hip_dtype(outputD->data.dtype);
    hipblasDatatype_t computeType = HIPBLAS_R_32F; // default acc is float32

    float one = 1.0f;
    float zero = 0.0f;
    float beta = accumulate ? one : zero;
  
    hipblasSetStream(handle, stream);
    // execute multiply
    // calculate stride

    const long long int strideA = m*k;
    const long long int strideB = k*n;
    const long long int strideD = m*n;
    hipblasStatus_t status = hipblasGemmStridedBatchedEx(
                                       handle,
                                       transa,   // transa
                                       transb,   // transb
                                       m,
                                       n,
                                       k,
                                       static_cast<const void*>(&one), 
                                       A,
                                       A_type,
                                       lda,
                                       strideA,
                                       B,
                                       B_type,
                                       ldb,
                                       strideB,
                                       static_cast<const void*>(&beta), 
                                       D,
                                       D_type,
                                       ldd,
                                       strideD,
                                       batch_count,
                                       computeType,
                                       HIPBLAS_GEMM_DEFAULT);
  
    if (status != HIPBLAS_STATUS_SUCCESS) {
        NVTE_ERROR("hipblasGemmEx execution failed");
    }
}

}  // namespace transformer_engine