LlamaLinear.h 2.96 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

xiabo's avatar
xiabo committed
5
// #include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h"
lvhan028's avatar
lvhan028 committed
6
7
8
9
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
10
11
#include "src/turbomind/utils/logger.h"
#include <type_traits>
Li Zhang's avatar
Li Zhang committed
12

lvhan028's avatar
lvhan028 committed
13
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
14
15
16
17

template<typename T>
class LlamaLinear {
public:
18
19
20
21
22
23
    enum Type
    {
        kGemm,
        kFusedSiluFfn
    };

Li Zhang's avatar
Li Zhang committed
24
25
26
27
    LlamaLinear(cublasMMWrapper* cublas_wrapper, cudaStream_t stream): cublas_wrapper_(cublas_wrapper), stream_(stream)
    {
    }

28
29
    void
    forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type = kGemm)
Li Zhang's avatar
Li Zhang committed
30
31
32
33
    {
        switch (weight.type) {
            case WeightType::kFP16:
            case WeightType::kFP32:
q.yao's avatar
q.yao committed
34
            case WeightType::kBF16:
35
                forwardFp(output_data, input_data, batch_size, weight, type);
Li Zhang's avatar
Li Zhang committed
36
37
                break;
            case WeightType::kINT4:
38
                forwardInt4(output_data, input_data, batch_size, weight, type);
Li Zhang's avatar
Li Zhang committed
39
40
41
42
43
44
45
                break;
            default:
                FT_CHECK(0);
        }
    }

private:
46
    void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
Li Zhang's avatar
Li Zhang committed
47
    {
48
        FT_CHECK(type == kGemm);
Li Zhang's avatar
Li Zhang committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
        cublas_wrapper_->Gemm(CUBLAS_OP_N,
                              CUBLAS_OP_N,
                              weight.output_dims,
                              batch_size,
                              weight.input_dims,
                              (const T*)weight.kernel,
                              weight.output_dims,
                              input_data,
                              weight.input_dims,
                              output_data,
                              weight.output_dims);
        sync_check_cuda_error();
    }

63
    void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
Li Zhang's avatar
Li Zhang committed
64
    {
xiabo's avatar
xiabo committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        // if constexpr (std::is_same_v<T, half>) {
        //     gemm_s4_f16_.Run(output_data,
        //                      (const uint*)weight.kernel,
        //                      input_data,
        //                      (const half2*)weight.scales_and_zeros,
        //                      weight.output_dims,
        //                      batch_size,
        //                      weight.input_dims,
        //                      weight.group_size,
        //                      type == kFusedSiluFfn ? GemmS4F16::kFusedSiluFfn : GemmS4F16::kGemm,
        //                      -1,
        //                      stream_);
        //     sync_check_cuda_error();
        // }
        // else {
80
            FT_CHECK_WITH_INFO(0, "Not implemented");
xiabo's avatar
xiabo committed
81
        // }
Li Zhang's avatar
Li Zhang committed
82
83
84
85
86
    }

private:
    cublasMMWrapper* cublas_wrapper_;
    cudaStream_t     stream_{};
xiabo's avatar
xiabo committed
87
    // GemmS4F16        gemm_s4_f16_;
Li Zhang's avatar
Li Zhang committed
88
89
};

lvhan028's avatar
lvhan028 committed
90
}  // namespace turbomind