llama_mlp.cpp 1.71 KB
Newer Older
1
2
3
4
5
6
#include "llama_mlp.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/ops.hpp"

namespace infinilm::models::llama {

Your Name's avatar
Your Name committed
7
8
9
10
LlamaMLP::LlamaMLP(const LlamaConfig &config,
                   const infinicore::Device &device,
                   infinicore::DataType dtype,
                   engine::distributed::RankInfo rank_info)
11
12
    : hidden_size_(config.hidden_size),
      intermediate_size_(config.intermediate_size),
Your Name's avatar
Your Name committed
13
14
15
16
17
      use_bias_(config.mlp_bias), rank_info_(rank_info) {

    int tp_rank = rank_info.tp_rank;
    int tp_size = rank_info.tp_size;

18
19
    // Initialize projection layers
    INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size_, intermediate_size_, use_bias_,
Your Name's avatar
Your Name committed
20
                              dtype, device, tp_rank, tp_size);
21
    INFINICORE_NN_MODULE_INIT(up_proj, hidden_size_, intermediate_size_, use_bias_,
Your Name's avatar
Your Name committed
22
                              dtype, device, tp_rank, tp_size);
23
    INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, use_bias_,
Your Name's avatar
Your Name committed
24
                              dtype, device, tp_rank, tp_size, rank_info.comm);
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
}

infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const {
    // 1. Project to gate and up
    auto hidden_states_mutable = hidden_states;
    auto gate = gate_proj_->forward(hidden_states_mutable);

    auto up = up_proj_->forward(hidden_states_mutable);

    // 2. Apply SwiGLU: silu(gate) * up
    // Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up
    // So we pass (up, gate) to get the correct result: gate * sigmoid(gate) * up
    auto intermediate = infinicore::op::swiglu(up, gate);

    // 3. Project down
    auto output = down_proj_->forward(intermediate);

    return output;
}

} // namespace infinilm::models::llama