cublasFP8MMWrapper.h 7.02 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
/*
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "3rdparty/fp8_qgmma_1x1/fp8_qgmma_1x1_utils.h"
#include "cuda_utils.h"
#include "src/fastertransformer/utils/cublasAlgoMap.h"
#include "src/fastertransformer/utils/cublasMMWrapper.h"
#include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include <map>
#include <mutex>
#include <string>

#pragma once

namespace fastertransformer {

class cublasFP8MMWrapper: public cublasMMWrapper {
public:
    cublasFP8MMWrapper(cublasLtHandle_t cublaslt_handle_,
                       cudaStream_t     stream,
                       cublasAlgoMap*   map,
                       std::mutex*      mu,
                       IAllocator*      allocator);

    cublasFP8MMWrapper(cublasHandle_t   cublas_handle,
                       cublasLtHandle_t cublaslt_handle,
                       cudaStream_t     stream,
                       cublasAlgoMap*   map,
                       std::mutex*      mu,
                       IAllocator*      allocator);

    virtual ~cublasFP8MMWrapper();

    cublasFP8MMWrapper(const cublasFP8MMWrapper& wrapper);

    virtual void cublasVersionCheck() override;

    void Gemm(__nv_bfloat16*       res,
              int                  batchCount,
              int                  m,
              int                  n,
              int                  k,
              int64_t              stridea,
              int64_t              strideb,
              int64_t              stridec,
              const float*         alpha,
              const float*         beta,
              const __nv_fp8_e4m3* input,
              const __nv_fp8_e4m3* kernel,
              const float*         input_scale,
              const float*         kernel_scale);

    void Gemm(__nv_bfloat16*       res,
              int                  batchCount,
              int                  m,
              int                  n,
              int                  k,
              int64_t              stridea,
              int64_t              strideb,
              int64_t              stridec,
              const float*         alpha,
              const float*         beta,
              const __nv_fp8_e4m3* input,
              const __nv_fp8_e4m3* kernel,
              const float*         input_scale,
              const float*         kernel_scale,
              cudaStream_t         stream,
              bool                 fastAccum = true);

    void Gemm(__nv_fp8_e4m3*       res,
              int                  batchCount,
              int                  m,
              int                  n,
              int                  k,
              int64_t              stridea,
              int64_t              strideb,
              int64_t              stridec,
              const float*         alpha,
              const float*         beta,
              const __nv_fp8_e4m3* input,
              const __nv_fp8_e4m3* kernel,
              const float*         input_scale,
              const float*         kernel_scale,
              const float*         output_scale);

    void Gemm(__nv_fp8_e4m3*       res,
              int                  batchCount,
              int                  m,
              int                  n,
              int                  k,
              int64_t              stridea,
              int64_t              strideb,
              int64_t              stridec,
              const float*         alpha,
              const float*         beta,
              const __nv_fp8_e4m3* input,
              const __nv_fp8_e4m3* kernel,
              const float*         input_scale,
              const float*         kernel_scale,
              const float*         output_scale,
              cudaStream_t         stream,
              bool                 fastAccum = true);

    template<bool RELU, bool GELU>
    void Conv1x1Gemm(__nv_fp8_e4m3*       res,
                     int                  m,
                     int                  n,
                     int                  k,
                     const __nv_fp8_e4m3* input,
                     const __nv_fp8_e4m3* kernel,
                     const __nv_bfloat16* bias,
                     const float          input_scale,
                     const float          kernel_scale,
                     const float          output_scale,
                     cudaStream_t         stream);

    template<bool RELU, bool GELU>
    void Gemm_Bias_Act(__nv_bfloat16*       res,
                       int                  batchCount,
                       int                  m,
                       int                  n,
                       int                  k,
                       int64_t              stridea,
                       int64_t              strideb,
                       int64_t              stridec,
                       const float*         alpha,
                       const float*         beta,
                       const __nv_fp8_e4m3* input,
                       const __nv_fp8_e4m3* kernel,
                       const float*         input_scale,
                       const float*         kernel_scale,
                       const __nv_bfloat16* bias,
                       const float*         output_scale,
                       cudaStream_t         stream);

    template<bool RELU, bool GELU>
    void Gemm_Bias_Act(__nv_fp8_e4m3*       res,
                       int                  batchCount,
                       int                  m,
                       int                  n,
                       int                  k,
                       int64_t              stridea,
                       int64_t              strideb,
                       int64_t              stridec,
                       const float*         alpha,
                       const float*         beta,
                       const __nv_fp8_e4m3* input,
                       const __nv_fp8_e4m3* kernel,
                       const float*         input_scale,
                       const float*         kernel_scale,
                       const __nv_bfloat16* bias,
                       const float*         output_scale,
                       cudaStream_t         stream);

private:
    int                                 version_major_, version_minor_, version_patch_;
    fastertransformer::qgmma1x1Launcher qgmmaLauncher;
    void*                               cublas_workspace_qgemm_ = nullptr;
};

}  // namespace fastertransformer