llama_mlp.cpp 4.34 KB
Newer Older
1
2
3
4
5
#include "llama_mlp.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/ops.hpp"

namespace infinilm::models::llama {
6
7
8
9
10
11
12
13
14
15
16
17
/**
 * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
 *
 * ⚠️ DEVELOPMENT POLICY:
 *   - NO new development or feature additions permitted on this interface
 *   - Only critical bug fixes (security/stability) allowed until removal
 *   - All new code MUST migrate to the polymorphic overload below
 *
 * Replacement: Use the polymorphic overload of this same function name with updated signature
 * Reason: Legacy signature lacks support for dynamic quantization modes.
 * Removal target: v0.2.0 (Q2 2026)
 */
Your Name's avatar
Your Name committed
18
19
20
LlamaMLP::LlamaMLP(const LlamaConfig &config,
                   const infinicore::Device &device,
                   engine::distributed::RankInfo rank_info)
21
22
    : hidden_size_(config.hidden_size),
      intermediate_size_(config.intermediate_size),
Your Name's avatar
Your Name committed
23
      use_bias_(config.mlp_bias), rank_info_(rank_info) {
24
    const auto &dtype{config.dtype};
Your Name's avatar
Your Name committed
25
26
27
28

    int tp_rank = rank_info.tp_rank;
    int tp_size = rank_info.tp_size;

29
    // Initialize projection layers
30
31
    INFINILM_GATE_UP_LINEAR_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, use_bias_,
                                 dtype, device, rank_info_);
32
    INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, use_bias_,
Your Name's avatar
Your Name committed
33
                              dtype, device, tp_rank, tp_size, rank_info.comm);
34
35
}

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
LlamaMLP::LlamaMLP(std::shared_ptr<infinilm::config::ModelConfig> model_config,
                   const infinicore::Device &device,
                   engine::distributed::RankInfo rank_info)
    : model_config_(model_config), hidden_size_(model_config->get<size_t>("hidden_size")),
      intermediate_size_(model_config->get<size_t>("intermediate_size")),
      use_bias_(model_config->get_or<bool>("mlp_bias", false)), rank_info_(rank_info) {

    const auto &dtype{model_config_->get_dtype()};

    int tp_rank = rank_info.tp_rank;
    int tp_size = rank_info.tp_size;

    // Initialize projection layers
    auto quant_scheme = this->model_config_->get_quant_scheme();
    switch (quant_scheme) {
    case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8:
        INFINILM_GATE_UP_LINEAR_W8A8_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, this->model_config_->get_quantization_method(), use_bias_,
                                          dtype, device, rank_info_);
        INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, this->model_config_->get_quantization_method(), use_bias_,
                                  dtype, device, tp_rank, tp_size, rank_info.comm);
        break;
    case infinicore::quantization::QuantScheme::AWQ_W4A16:
        INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, this->model_config_->get_quantization_method(), use_bias_,
                                              dtype, device, rank_info_);
        INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, this->model_config_->get_quantization_method(), use_bias_,
                                  dtype, device, tp_rank, tp_size, rank_info.comm);
        break;

    default:
        INFINILM_GATE_UP_LINEAR_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, this->model_config_->get_quantization_method(), use_bias_,
                                     dtype, device, rank_info_);
        INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, this->model_config_->get_quantization_method(), use_bias_,
                                  dtype, device, tp_rank, tp_size, rank_info.comm);
        break;
    }
}

73
74
75
infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const {
    // 1. Project to gate and up
    auto hidden_states_mutable = hidden_states;
76
    auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable);
77
78
79
80
81
82
83
84
85
86
87
88
89

    // 2. Apply SwiGLU: silu(gate) * up
    // Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up
    // So we pass (up, gate) to get the correct result: gate * sigmoid(gate) * up
    auto intermediate = infinicore::op::swiglu(up, gate);

    // 3. Project down
    auto output = down_proj_->forward(intermediate);

    return output;
}

} // namespace infinilm::models::llama