LlamaLinear.h 1.94 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// Copyright (c) OpenMMLab. All rights reserved.

#pragma once

#include "src/fastertransformer/models/llama/LlamaDenseWeight.h"
#include "src/fastertransformer/models/llama/llama_kernels.h"
#include "src/fastertransformer/utils/cublasMMWrapper.h"
#include "src/fastertransformer/utils/cuda_utils.h"

namespace fastertransformer {

template<typename T>
class LlamaLinear {
public:
    LlamaLinear(cublasMMWrapper* cublas_wrapper, cudaStream_t stream): cublas_wrapper_(cublas_wrapper), stream_(stream)
    {
    }

    void forward(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight)
    {
        switch (weight.type) {
            case WeightType::kFP16:
            case WeightType::kFP32:
                forwardFp(output_data, input_data, batch_size, weight);
                break;
            case WeightType::kINT4:
                forwardInt4(output_data, input_data, batch_size, weight);
                break;
            default:
                FT_CHECK(0);
        }
    }

private:
    void forwardFp(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight)
    {
        cublas_wrapper_->Gemm(CUBLAS_OP_N,
                              CUBLAS_OP_N,
                              weight.output_dims,
                              batch_size,
                              weight.input_dims,
                              (const T*)weight.kernel,
                              weight.output_dims,
                              input_data,
                              weight.input_dims,
                              output_data,
                              weight.output_dims);
        sync_check_cuda_error();
    }

    void forwardInt4(T* output_data, const T* input_data, int batch_size, const LlamaDenseWeight<T>& weight)
    {
        FT_CHECK_WITH_INFO(0, "Not implemented");
    }

private:
    cublasMMWrapper* cublas_wrapper_;
    cudaStream_t     stream_{};
};

AllentDan's avatar
AllentDan committed
61
}  // namespace fastertransformer