linear.hpp 4.25 KB
Newer Older
1
2
3
#pragma once

#include "../ops.hpp"
qinyiqun's avatar
qinyiqun committed
4
#include "../quantization.hpp"
5
6
#include "module.hpp"
#include <infiniccl.h>
qinyiqun's avatar
qinyiqun committed
7
#include <optional>
8
9
10

namespace infinicore::nn {

11
class BaseLinear : public Module {
12
public:
13
14
    BaseLinear(size_t in_features, size_t out_features, bool bias = true,
               const DataType &dtype = DataType::F32, const Device &device = Device());
15

qinyiqun's avatar
qinyiqun committed
16
17
18
    BaseLinear(size_t in_features, size_t out_features, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias = true,
               const DataType &dtype = DataType::F32, const Device &device = Device());

19
20
21
22
23
24
25
26
27
28
29
    // Forward pass: output = input @ weight.T + bias
    Tensor forward(Tensor &input) const;

    // Forward pass with residual connection (InfiniLM-style)
    // output = input @ weight.T + bias + residual
    Tensor forward(Tensor &input, Tensor &residual) const;

    // Module information
    size_t in_features() const { return in_features_; }
    size_t out_features() const { return out_features_; }
    bool has_bias() const { return has_bias_; }
30
    DataType dtype() const { return dtype_; }
31
32
33
34

    // Accessors for parameters
    Tensor weight() const { return weight_; }
    Tensor bias() const { return bias_; }
qinyiqun's avatar
qinyiqun committed
35
36
    Tensor weight_scale() const { return weight_scale_; }
    Tensor weight_zeros() const { return weight_zeros_; }
37
38
39

protected:
    // Parameters
40
41
    INFINICORE_NN_PARAMETER(weight);
    INFINICORE_NN_PARAMETER(bias);
42

qinyiqun's avatar
qinyiqun committed
43
44
45
    INFINICORE_NN_PARAMETER(weight_scale);
    INFINICORE_NN_PARAMETER(weight_zeros);

46
protected:
47
48
49
50
51
52
    // Helper method for common forward computation
    Tensor compute_linear(Tensor &input) const;

    size_t in_features_;
    size_t out_features_;
    bool has_bias_;
53
    DataType dtype_;
qinyiqun's avatar
qinyiqun committed
54
    std::shared_ptr<infinicore::quantization::BaseQuantization> quantization_ = std::make_shared<infinicore::quantization::NoneQuantization>(nullptr);
55
56
57
};

} // namespace infinicore::nn
58
59
60
61
62
63
64
65

namespace infinicore::nn {

class Linear : public BaseLinear {
public:
    Linear(size_t in_features, size_t out_features, bool bias = true,
           const DataType &dtype = DataType::F32, const Device &device = Device());

qinyiqun's avatar
qinyiqun committed
66
67
68
    Linear(size_t in_features, size_t out_features, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias = true,
           const DataType &dtype = DataType::F32, const Device &device = Device());

69
70
71
72
73
74
75
76
77
78
79
80
81
    // Forward pass: output = input @ weight.T + bias
    Tensor forward(Tensor &input) const;

    // String representation
    std::string extra_repr() const;
};

class ColumnParallelLinear : public BaseLinear {
public:
    ColumnParallelLinear(size_t in_features, size_t out_features, bool bias = true,
                         const DataType &dtype = DataType::F32, const Device &device = Device(),
                         Size tp_rank = 0, Size tp_size = 1);

qinyiqun's avatar
qinyiqun committed
82
83
84
85
    ColumnParallelLinear(size_t in_features, size_t out_features, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias = true,
                         const DataType &dtype = DataType::F32, const Device &device = Device(),
                         Size tp_rank = 0, Size tp_size = 1);

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    // Forward pass: output = input @ weight.T + bias
    Tensor forward(Tensor &input) const;

    // String representation
    std::string extra_repr() const;

protected:
    Size tp_rank_ = 0;
    Size tp_size_ = 1;
};

class RowParallelLinear : public BaseLinear {
public:
    RowParallelLinear(size_t in_features, size_t out_features, bool bias = true,
                      const DataType &dtype = DataType::F32, const Device &device = Device(),
                      Size tp_rank = 0, Size tp_size = 1, infinicclComm_t communicator = nullptr);

qinyiqun's avatar
qinyiqun committed
103
104
105
106
    RowParallelLinear(size_t in_features, size_t out_features, std::shared_ptr<infinicore::quantization::BaseQuantization> quantization, bool bias = true,
                      const DataType &dtype = DataType::F32, const Device &device = Device(),
                      Size tp_rank = 0, Size tp_size = 1, infinicclComm_t communicator = nullptr);

107
108
109
110
111
112
113
114
115
116
117
118
119
    // Forward pass: output = input @ weight.T + bias
    Tensor forward(Tensor &input) const;

    // String representation
    std::string extra_repr() const;

protected:
    Size tp_rank_ = 0;
    Size tp_size_ = 1;
    infinicclComm_t communicator_;
};

} // namespace infinicore::nn