FfnLayerINT8.h 4.81 KB
Newer Older
AllentDan's avatar
AllentDan committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/*
 * Copyright (c) 2019-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include "FfnINT8Weight.h"
#include "src/fastertransformer/kernels/activation_int8_kernels.h"
#include "src/fastertransformer/layers/BaseLayer.h"
#include "src/fastertransformer/utils/ScaleList.h"
#include "src/fastertransformer/utils/Tensor.h"
#include "src/fastertransformer/utils/allocator.h"
#include "src/fastertransformer/utils/cublasINT8MMWrapper.h"
#include "src/fastertransformer/utils/memory_utils.h"
#include <vector>

namespace fastertransformer {

template<typename T>
class GeluFfnLayerINT8;

template<typename T>
class ReluFfnLayerINT8;

template<typename T>
class FfnLayerINT8: public BaseLayer {
private:
    // buffer handling
    size_t max_token_num_ = 0;

    // meta data
    size_t head_num_;
    size_t size_per_head_;

    // calculated data
    size_t hidden_units_;

    void allocateBuffer() override;
    void freeBuffer() override;
    bool isValidTokenNum(size_t token_num);

protected:
    size_t inter_size_;
    int    int8_mode_;
    bool   sparse_;

    int*         inter_int_buf_;
    int8_t*      inter_buf_;
    virtual void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) = 0;

public:
    FfnLayerINT8(size_t           max_batch_size,
                 size_t           max_seq_len,
                 size_t           head_num,
                 size_t           size_per_head,
                 size_t           inter_size,
                 int              int8_mode,
                 cudaStream_t     stream,
                 cublasMMWrapper* cublas_wrapper,
                 IAllocator*      allocator,
                 bool             is_free_buffer_after_forward,
                 bool             sparse = false);

    FfnLayerINT8(FfnLayerINT8<T> const& ffn_layer);

    ~FfnLayerINT8();

    void forward(std::vector<fastertransformer::Tensor>*       output_tensors,
                 const std::vector<fastertransformer::Tensor>* input_tensors,
                 const FfnWeight<T>*                           ffn_weights);

    friend GeluFfnLayerINT8<T>;
    friend ReluFfnLayerINT8<T>;
};

template<typename T>
class GeluFfnLayerINT8: public FfnLayerINT8<T> {
public:
    GeluFfnLayerINT8(size_t           max_batch_size,
                     size_t           max_seq_len,
                     size_t           head_num,
                     size_t           size_per_head,
                     size_t           inter_size,
                     int              int8_mode,
                     cudaStream_t     stream,
                     cublasMMWrapper* cublas_wrapper,
                     IAllocator*      allocator,
                     bool             is_free_buffer_after_forward,
                     bool             sparse = false);

    GeluFfnLayerINT8(GeluFfnLayerINT8<T> const& ffn_layer);

    ~GeluFfnLayerINT8() = default;

private:
    using FfnLayerINT8<T>::inter_int_buf_;
    using FfnLayerINT8<T>::inter_buf_;
    using FfnLayerINT8<T>::inter_size_;
    using FfnLayerINT8<T>::stream_;
    using FfnLayerINT8<T>::int8_mode_;
    using FfnLayerINT8<T>::sparse_;
    using FfnLayerINT8<T>::hidden_units_;
    void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) override;
};

template<typename T>
class ReluFfnLayerINT8: public FfnLayerINT8<T> {
public:
    ReluFfnLayerINT8(size_t           max_batch_size,
                     size_t           max_seq_len,
                     size_t           head_num,
                     size_t           size_per_head,
                     size_t           inter_size,
                     int              int8_mode,
                     cudaStream_t     stream,
                     cublasMMWrapper* cublas_wrapper,
                     IAllocator*      allocator,
                     bool             is_free_buffer_after_forward);

    ReluFfnLayerINT8(ReluFfnLayerINT8<T> const& ffn_layer);

    ~ReluFfnLayerINT8() = default;

private:
    using FfnLayerINT8<T>::inter_int_buf_;
    using FfnLayerINT8<T>::inter_buf_;
    using FfnLayerINT8<T>::inter_size_;
    using FfnLayerINT8<T>::stream_;
    using FfnLayerINT8<T>::int8_mode_;
    using FfnLayerINT8<T>::hidden_units_;
    void invokeAddBiasActivation(const int m, const T* bias, ScaleList* scale_list) override;
};

}  // namespace fastertransformer