LlamaFfnLayer.h 2.63 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
/*
 * Copyright (c) OpenMMLab. All rights reserved.
 * Copyright (c) 2022-2023, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

 // Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/layers/FfnLayer.cc

#pragma once

// #include "src/fastertransformer/layers/FfnLayer.h"
#include "src/fastertransformer/models/llama/LlamaDecoderLayerWeight.h"
#include "src/fastertransformer/models/llama/LlamaLinear.h"
#include "src/fastertransformer/utils/custom_ar_comm.h"
#include "src/fastertransformer/utils/nccl_utils.h"
#include <functional>

namespace fastertransformer {

template<typename T>
class LlamaFfnLayer {
public:
    LlamaFfnLayer(size_t           head_num,
                  size_t           size_per_head,
                  size_t           inter_size,
                  NcclParam        tensor_para,
                  cudaStream_t     stream,
                  cublasMMWrapper* cublas_wrapper,
                  IAllocator*      allocator,
                  bool             is_free_buffer_after_forward):
        head_num_(head_num),
        size_per_head_(size_per_head),
        inter_size_(inter_size / tensor_para.world_size_),
        hidden_units_(head_num * size_per_head),
        stream_(stream),
        linear_(cublas_wrapper, stream),
        allocator_(allocator),
        tensor_para_(tensor_para),
        is_free_buffer_after_forward_(is_free_buffer_after_forward)
    {
    }

    ~LlamaFfnLayer()
    {
        freeBuffer();
    }

    void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaFfnWeight<T>* weights);

private:
    void allocateBuffer(size_t token_num);

    void freeBuffer();

    void activation(int num_token);

    size_t         head_num_;
    size_t         size_per_head_;
    size_t         inter_size_;
    size_t         hidden_units_;
    cudaStream_t   stream_;
    LlamaLinear<T> linear_;
    IAllocator*    allocator_;
    bool           is_free_buffer_after_forward_;

    T* gating_buf_{};
    T* inter_buf_{};

    NcclParam tensor_para_;

    bool is_allocate_buffer_{};
};

}  // namespace fastertransformer