LlamaLinear.h 2.87 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

5
#include "src/turbomind/kernels/gemm_s_f16/gemm_s4_f16.h"
lvhan028's avatar
lvhan028 committed
6
7
8
9
#include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
10
11
#include "src/turbomind/utils/logger.h"
#include <type_traits>
Li Zhang's avatar
Li Zhang committed
12

lvhan028's avatar
lvhan028 committed
13
namespace turbomind {
Li Zhang's avatar
Li Zhang committed
14
15
16
17

template<typename T>
class LlamaLinear {
public:
18
19
20
21
22
23
    enum Type
    {
        kGemm,
        kFusedSiluFfn
    };

Li Zhang's avatar
Li Zhang committed
24
25
26
27
    LlamaLinear(cublasMMWrapper* cublas_wrapper, cudaStream_t stream): cublas_wrapper_(cublas_wrapper), stream_(stream)
    {
    }

28
29
    void
    forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type = kGemm)
Li Zhang's avatar
Li Zhang committed
30
31
32
33
    {
        switch (weight.type) {
            case WeightType::kFP16:
            case WeightType::kFP32:
34
                forwardFp(output_data, input_data, batch_size, weight, type);
Li Zhang's avatar
Li Zhang committed
35
36
                break;
            case WeightType::kINT4:
37
                forwardInt4(output_data, input_data, batch_size, weight, type);
Li Zhang's avatar
Li Zhang committed
38
39
40
41
42
43
44
                break;
            default:
                FT_CHECK(0);
        }
    }

private:
45
    void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
Li Zhang's avatar
Li Zhang committed
46
    {
47
        FT_CHECK(type == kGemm);
Li Zhang's avatar
Li Zhang committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        cublas_wrapper_->Gemm(CUBLAS_OP_N,
                              CUBLAS_OP_N,
                              weight.output_dims,
                              batch_size,
                              weight.input_dims,
                              (const T*)weight.kernel,
                              weight.output_dims,
                              input_data,
                              weight.input_dims,
                              output_data,
                              weight.output_dims);
        sync_check_cuda_error();
    }

62
    void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight, Type type)
Li Zhang's avatar
Li Zhang committed
63
    {
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        if constexpr (std::is_same_v<T, half>) {
            gemm_s4_f16_.Run(output_data,
                             (const uint*)weight.kernel,
                             input_data,
                             (const half2*)weight.scales_and_zeros,
                             weight.output_dims,
                             batch_size,
                             weight.input_dims,
                             weight.group_size,
                             type == kFusedSiluFfn ? GemmS4F16::kFusedSiluFfn : GemmS4F16::kGemm,
                             -1,
                             stream_);
            sync_check_cuda_error();
        }
        else {
            FT_CHECK_WITH_INFO(0, "Not implemented");
        }
Li Zhang's avatar
Li Zhang committed
81
82
83
84
85
    }

private:
    cublasMMWrapper* cublas_wrapper_;
    cudaStream_t     stream_{};
86
    GemmS4F16        gemm_s4_f16_;
Li Zhang's avatar
Li Zhang committed
87
88
};

lvhan028's avatar
lvhan028 committed
89
}  // namespace turbomind