fused_linear.hpp 5.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#pragma once
#include "infinicore/nn/linear.hpp"

#include "../engine/distributed/communication_group.hpp"

namespace infinilm::layers {
class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear {
public:
    explicit QKVParallelLinear(size_t hidden_size,
                               size_t q_dim, size_t k_dim, size_t v_dim,
                               size_t num_q_head, size_t num_k_head, size_t num_v_head,
                               bool q_bias, bool k_bias, bool v_bias,
                               const infinicore::DataType &dtype = infinicore::DataType::F32,
                               const infinicore::Device &device = infinicore::Device(),
                               engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

    // A more common case where all heads have the same dimension
    explicit QKVParallelLinear(size_t hidden_size,
                               size_t head_dim,
                               size_t num_q_head, size_t num_kv_head,
                               bool bias = false,
                               const infinicore::DataType &dtype = infinicore::DataType::F32,
                               const infinicore::Device &device = infinicore::Device(),
                               engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

    std::tuple<infinicore::Tensor, infinicore::Tensor, infinicore::Tensor>
    forward_split(infinicore::Tensor &input);

    infinicore::nn::Parameter get_q_weight() const;
    infinicore::nn::Parameter get_k_weight() const;
    infinicore::nn::Parameter get_v_weight() const;

    infinicore::nn::Parameter get_q_bias() const;
    infinicore::nn::Parameter get_k_bias() const;
    infinicore::nn::Parameter get_v_bias() const;

    bool has_q_bias() const;
    bool has_k_bias() const;
    bool has_v_bias() const;

private:
    size_t q_dim_;
    size_t k_dim_;
    size_t v_dim_;
    size_t num_q_head_;
    size_t num_k_head_;
    size_t num_v_head_;
    bool q_bias_;
    bool k_bias_;
    bool v_bias_;
    size_t q_out_size_; // num_q_head * q_dim / tp_size
    size_t k_out_size_; // num_k_head * k_dim / tp_size
    size_t v_out_size_; // num_v_head * v_dim / tp_size
};

class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear {
public:
    GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool bias = false,
                         const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
                         engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

    GateUpParallelLinear(size_t hidden_size, size_t intermediate_size, bool gate_bias, bool up_bias,
                         const infinicore::DataType &dtype = infinicore::DataType::F32, const infinicore::Device &device = infinicore::Device(),
                         engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

    std::tuple<infinicore::Tensor, infinicore::Tensor> forward_split(infinicore::Tensor &input);

    infinicore::nn::Parameter get_gate_weight() const;

    infinicore::nn::Parameter get_gate_bias() const;

    infinicore::nn::Parameter get_up_weight() const;

    infinicore::nn::Parameter get_up_bias() const;

    bool has_gate_bias() const;

    bool has_up_bias() const;

private:
    bool gate_bias_;
    bool up_bias_;
};

#define INFINILM_QKV_LINEAR_INIT(name, q_name, k_name, v_name, ...)                     \
    name##_ = std::make_shared<layers::QKVParallelLinear>(__VA_ARGS__);                 \
    this->register_parameter(std::string(q_name) + ".weight", name##_->get_q_weight()); \
    this->register_parameter(std::string(k_name) + ".weight", name##_->get_k_weight()); \
    this->register_parameter(std::string(v_name) + ".weight", name##_->get_v_weight()); \
    if (name##_->has_q_bias())                                                          \
        this->register_parameter(std::string(q_name) + ".bias", name##_->get_q_bias()); \
    if (name##_->has_k_bias())                                                          \
        this->register_parameter(std::string(k_name) + ".bias", name##_->get_k_bias()); \
    if (name##_->has_v_bias())                                                          \
        this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias());

#define INFINILM_GATE_UP_LINEAR_INIT(name, gate_name, up_name, ...)                           \
    name##_ = std::make_shared<layers::GateUpParallelLinear>(__VA_ARGS__);                    \
    this->register_parameter(std::string(gate_name) + ".weight", name##_->get_gate_weight()); \
    this->register_parameter(std::string(up_name) + ".weight", name##_->get_up_weight());     \
    if (name##_->has_gate_bias())                                                             \
        this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \
    if (name##_->has_up_bias())                                                               \
        this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias());

} // namespace infinilm::layers