"official/vision/modeling/backbones/spinenet.py" did not exist on "889708373c44d293fa4497432ae1a459923a9b03"
LlamaLinear.h 14.2 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

gaoqiong's avatar
gaoqiong committed
5
#include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h"
lvhan028's avatar
lvhan028 committed
6
7
8
9
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
10
11
#include "src/turbomind/utils/logger.h"
#include <type_traits>
Li Zhang's avatar
Li Zhang committed
12

lvhan028's avatar
lvhan028 committed
13
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
14
15
16
17

template<typename T>
class LlamaLinear {
public:
18
19
20
21
22
23
    enum Type
    {
        kGemm,
        kFusedSiluFfn
    };

Li Zhang's avatar
Li Zhang committed
24
25
26
27
    LlamaLinear(cublasMMWrapper* cublas_wrapper, cudaStream_t stream): cublas_wrapper_(cublas_wrapper), stream_(stream)
    {
    }

28
29
    void
    forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type = kGemm)
Li Zhang's avatar
Li Zhang committed
30
31
32
33
    {
        switch (weight.type) {
            case WeightType::kFP16:
            case WeightType::kFP32:
q.yao's avatar
q.yao committed
34
            case WeightType::kBF16:
35
                forwardFp(output_data, input_data, batch_size, weight, type);
Li Zhang's avatar
Li Zhang committed
36
37
                break;
            case WeightType::kINT4:
38
                forwardInt4(output_data, input_data, batch_size, weight, type);
Li Zhang's avatar
Li Zhang committed
39
40
41
42
43
                break;
            default:
                FT_CHECK(0);
        }
    }
gaoqiong's avatar
gaoqiong committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    void forward_ffn(T* output_data,T* output_tmp, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type = kGemm)
    {
        switch (weight.type) {
            case WeightType::kFP16:
            case WeightType::kFP32:
            case WeightType::kBF16:
                forwardFp(output_data, input_data, batch_size, weight, type);
                break;
            case WeightType::kINT4:
            {
                if (type == kFusedSiluFfn)
                    forwardInt4_ffn(output_data, output_tmp,input_data, batch_size, weight, type);
                else
                    forwardInt4(output_data, input_data, batch_size, weight, type);
                break;
            }
                        
            default:
                FT_CHECK(0);
        }
    }
Li Zhang's avatar
Li Zhang committed
65
private:
66
    void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
Li Zhang's avatar
Li Zhang committed
67
    {
68
        FT_CHECK(type == kGemm);
Li Zhang's avatar
Li Zhang committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        cublas_wrapper_->Gemm(CUBLAS_OP_N,
                              CUBLAS_OP_N,
                              weight.output_dims,
                              batch_size,
                              weight.input_dims,
                              (const T*)weight.kernel,
                              weight.output_dims,
                              input_data,
                              weight.input_dims,
                              output_data,
                              weight.output_dims);
        sync_check_cuda_error();
    }

83
    void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
Li Zhang's avatar
Li Zhang committed
84
    {
gaoqiong's avatar
gaoqiong committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        if constexpr (std::is_same_v<T, half>) {

            if(weight.w4_weight_layout==0) //普通NN模式 rocblas
            {
                //检查DQweight的空间是否足够
                if(batch_size*weight.output_dims>M_max*N_max)
                {
                    FT_CHECK_WITH_INFO(0, "error! batch_size>N_max  ||weight.output_dims>N_max");
                }

                dequant_w4_gemm(stream_, reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims,weight.output_dims,weight.group_size);
                cublas_wrapper_->Gemm(CUBLAS_OP_N,
                                        CUBLAS_OP_N,
                                        weight.output_dims,//m
                                        batch_size,//n
                                        weight.input_dims,//k
                                        (const T*) cublas_wrapper_->deweight_workspace_, //[]
                                        weight.output_dims,//m
                                        input_data,
                                        weight.input_dims, //k
                                        output_data,
                                        weight.output_dims); //m
            }
            else if(weight.w4_weight_layout==1)//TN模式 padding rocblas
            {

                //检查DQweight的空间是否足够
                if(batch_size*weight.output_dims>M_max*N_max)
                {
                    FT_CHECK_WITH_INFO(0, "error! batch_size>N_max  ||weight.output_dims>N_max");
                }

                //检查xpad空间是否足够
                if(weight.input_dims%4096==0) //需要进行pad
                {
gaoqiong's avatar
gaoqiong committed
120
121
122

                    input_padding(stream_,reinterpret_cast<half*>(cublas_wrapper_->xpading_workspace_),(const T*)input_data,batch_size,weight.input_dims,weight.group_size,weight.w4_pad_size);
                    dequant_w4_gemm_colmajor(stream_,reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims+weight.w4_pad_size*weight.group_size ,weight.output_dims,weight.group_size);
gaoqiong's avatar
gaoqiong committed
123
124
125
126
127
 
                    cublas_wrapper_->Gemm(CUBLAS_OP_T,
                        CUBLAS_OP_N,
                        weight.output_dims,//m
                        batch_size,//n
gaoqiong's avatar
gaoqiong committed
128
                        weight.input_dims+weight.w4_pad_size*weight.group_size,//k
gaoqiong's avatar
gaoqiong committed
129
                        (const T*) reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_), //[]
gaoqiong's avatar
gaoqiong committed
130
                        weight.input_dims+weight.w4_pad_size*weight.group_size, //k
gaoqiong's avatar
gaoqiong committed
131
                        (const T*) cublas_wrapper_->xpading_workspace_,
gaoqiong's avatar
gaoqiong committed
132
                        weight.input_dims+weight.w4_pad_size*weight.group_size, //k
gaoqiong's avatar
gaoqiong committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
                        output_data,
                        weight.output_dims); //m 
                }
                else //不需要进行pad
                {
                    dequant_w4_gemm_colmajor(stream_,reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims,weight.output_dims,weight.group_size);
                    cublas_wrapper_->Gemm(CUBLAS_OP_T,
                        CUBLAS_OP_N,
                        weight.output_dims,//m
                        batch_size,//n
                        weight.input_dims,//k
                        (const T*) reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_), //[]
                        weight.input_dims, //k
                        input_data,
                        weight.input_dims, //k
                        output_data,
                        weight.output_dims); //m 
                }
            }
    
            else if(weight.w4_weight_layout==2) //TN 模式padding ck
            {
                //检查ck workspace 的空间是否足够
                if(weight.input_dims%4096==0)
                {
gaoqiong's avatar
gaoqiong committed
158
                    run_weight_only_gemm(reinterpret_cast<const void*>(input_data), reinterpret_cast<const void*>(weight.kernel), reinterpret_cast<const void*>(weight.scales_and_zeros), reinterpret_cast<void*> (output_data), batch_size, weight.output_dims, (weight.input_dims), (weight.input_dims),(weight.input_dims), (weight.input_dims+weight.w4_pad_size*weight.group_size), weight.output_dims, weight.group_size,reinterpret_cast<void*>(cublas_wrapper_->ck_workspace_),CK_WORKSPACE_SIZE,(hipStream_t)stream_);
gaoqiong's avatar
gaoqiong committed
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
                }
                //                                            A                                            B0                                        B1                                            C                   M                   N                K                 strideA             strideB    strideBpad        strideC           group_size                            
               else{
                    run_weight_only_gemm(reinterpret_cast<const void*>(input_data), reinterpret_cast<const void*>(weight.kernel), reinterpret_cast<const void*>(weight.scales_and_zeros), reinterpret_cast<void*> (output_data), batch_size, weight.output_dims, (weight.input_dims), (weight.input_dims),(weight.input_dims), (weight.input_dims), weight.output_dims, weight.group_size,reinterpret_cast<void*>(cublas_wrapper_->ck_workspace_),CK_WORKSPACE_SIZE,(hipStream_t)stream_);
               }
            
            }
            sync_check_cuda_error();
        }
        else {
            FT_CHECK_WITH_INFO(0, "Not implemented");
        }
    }

    void forwardInt4_ffn(T* output_data,T* output_tmp, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
    {
        if constexpr (std::is_same_v<T, half>) {
            
            if(weight.w4_weight_layout==0) //普通NN模式 rocblas
            {
                //检查DQweight的空间是否足够
                if(batch_size*weight.output_dims>M_max*N_max)
                {
                    FT_CHECK_WITH_INFO(0, "error! batch_size>N_max  ||weight.output_dims>N_max");
                }

                dequant_w4_gemm(stream_, reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims,weight.output_dims,weight.group_size);
                cublas_wrapper_->Gemm(CUBLAS_OP_N,
                                        CUBLAS_OP_N,
                                        weight.output_dims,//m
                                        batch_size,//n
                                        weight.input_dims,//k
                                        (const T*) cublas_wrapper_->deweight_workspace_, //[]
                                        weight.output_dims,//m
                                        input_data,
                                        weight.input_dims, //k
                                        output_tmp,
                                        weight.output_dims); //m
            }
            else if(weight.w4_weight_layout==1)//TN模式 padding rocblas
            {

                //检查DQweight的空间是否足够
                if(batch_size*weight.output_dims>M_max*N_max)
                {
                    FT_CHECK_WITH_INFO(0, "error! batch_size>N_max  ||weight.output_dims>N_max");
                }

                //检查xpad空间是否足够
                if(weight.input_dims%4096==0) //需要进行pad
                {
gaoqiong's avatar
gaoqiong committed
210
211
212

                    input_padding<T>(stream_,reinterpret_cast<half*>(cublas_wrapper_->xpading_workspace_),(const T*)input_data,batch_size,weight.input_dims,weight.group_size,weight.w4_pad_size);
                    dequant_w4_gemm_colmajor(stream_,reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims+weight.w4_pad_size*weight.group_size,weight.output_dims,weight.group_size);
gaoqiong's avatar
gaoqiong committed
213
214
215
216
217
 
                    cublas_wrapper_->Gemm(CUBLAS_OP_T,
                        CUBLAS_OP_N,
                        weight.output_dims,//m
                        batch_size,//n
gaoqiong's avatar
gaoqiong committed
218
                        weight.input_dims+weight.w4_pad_size*weight.group_size,//k
gaoqiong's avatar
gaoqiong committed
219
                        (const T*) reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_), //[]
gaoqiong's avatar
gaoqiong committed
220
                        weight.input_dims+weight.w4_pad_size*weight.group_size, //k
gaoqiong's avatar
gaoqiong committed
221
                        (const T*) cublas_wrapper_->xpading_workspace_,
gaoqiong's avatar
gaoqiong committed
222
                        weight.input_dims+weight.w4_pad_size*weight.group_size, //k
gaoqiong's avatar
gaoqiong committed
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
                        output_tmp,
                        weight.output_dims); //m 
                }
                else //不需要进行pad
                {
                    dequant_w4_gemm_colmajor(stream_,reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_),(const uint32_t*)weight.kernel,(const half2*)weight.scales_and_zeros,weight.input_dims,weight.output_dims,weight.group_size);
                    cublas_wrapper_->Gemm(CUBLAS_OP_T,
                        CUBLAS_OP_N,
                        weight.output_dims,//m
                        batch_size,//n
                        weight.input_dims,//k
                        (const T*) reinterpret_cast<T*>(cublas_wrapper_->deweight_workspace_), //[]
                        weight.input_dims, //k
                        input_data,
                        weight.input_dims, //k
                        output_tmp,
                        weight.output_dims); //m 
                }
            }
            else if(weight.w4_weight_layout==2) //TN 模式padding ck
            {
                //检查ck workspace 的空间是否足够

                if(weight.input_dims%4096==0)
                {
gaoqiong's avatar
gaoqiong committed
248
249
                    
                    run_weight_only_gemm(reinterpret_cast<const void*>(input_data), reinterpret_cast<const void*>(weight.kernel), reinterpret_cast<const void*>(weight.scales_and_zeros), reinterpret_cast<void*> (output_tmp), batch_size, weight.output_dims, (weight.input_dims), (weight.input_dims),(weight.input_dims), (weight.input_dims+weight.w4_pad_size*weight.group_size), weight.output_dims, weight.group_size,reinterpret_cast<void*>(cublas_wrapper_->ck_workspace_),CK_WORKSPACE_SIZE,(hipStream_t)stream_);
gaoqiong's avatar
gaoqiong committed
250
251
252
253
254
255
256
257
258
259
                }
                //                                            A                                            B0                                        B1                                            C                   M                   N                K                 strideA             strideB    strideBpad        strideC           group_size                            
               else{
                run_weight_only_gemm(reinterpret_cast<const void*>(input_data), reinterpret_cast<const void*>(weight.kernel), reinterpret_cast<const void*>(weight.scales_and_zeros), reinterpret_cast<void*> (output_tmp), batch_size, weight.output_dims, (weight.input_dims), (weight.input_dims),(weight.input_dims), (weight.input_dims), weight.output_dims, weight.group_size,reinterpret_cast<void*>(cublas_wrapper_->ck_workspace_),CK_WORKSPACE_SIZE,(hipStream_t)stream_);
               }
            }
            addFusedSiluActivation(stream_,output_data,output_tmp,batch_size,weight.output_dims,1);
            sync_check_cuda_error();
        }
        else {
260
            FT_CHECK_WITH_INFO(0, "Not implemented");
gaoqiong's avatar
gaoqiong committed
261
        }
Li Zhang's avatar
Li Zhang committed
262
263
264
265
266
    }

private:
    cublasMMWrapper* cublas_wrapper_;
    cudaStream_t     stream_{};
xiabo's avatar
xiabo committed
267
    // GemmS4F16        gemm_s4_f16_;
Li Zhang's avatar
Li Zhang committed
268
269
};

lvhan028's avatar
lvhan028 committed
270
}  // namespace turbomind