llama_mlp.hpp 2.69 KB
Newer Older
1
2
#pragma once

3
4
5
#include "../../layers/fused_linear.hpp"
#include "llama_config.hpp"

6
#include "../../config/model_config.hpp"
Your Name's avatar
Your Name committed
7
#include "infinicore/device.hpp"
8
#include "infinicore/nn/linear.hpp"
Your Name's avatar
Your Name committed
9
#include "infinicore/nn/module.hpp"
10
#include "infinicore/tensor.hpp"
Your Name's avatar
Your Name committed
11
12
13
#include "llama_config.hpp"

#include "../../engine/distributed/distributed.hpp"
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

namespace infinilm::models::llama {

/**
 * @brief MLP (Feed-Forward Network) module for Llama
 *
 * Implements the MLP block with:
 * - Gate projection
 * - Up projection
 * - Down projection
 * - SiLU activation function
 *
 * Formula: down_proj(SiLU(gate_proj(x)) * up_proj(x))
 */
class LlamaMLP : public infinicore::nn::Module {
public:
    /**
     * @brief Construct LlamaMLP module
     *
     * @param config Model configuration
     * @param device Device to create tensors on
     * @param dtype Optional data type for model parameters (defaults to F32)
     */
37
38
39
40
41
42
43
44
45
46
47
48
    /**
     * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
     *
     * ⚠️ DEVELOPMENT POLICY:
     *   - NO new development or feature additions permitted on this interface
     *   - Only critical bug fixes (security/stability) allowed until removal
     *   - All new code MUST migrate to the polymorphic overload below
     *
     * Replacement: Use the polymorphic overload of this same function name with updated signature
     * Reason: Legacy signature lacks support for dynamic quantization modes.
     * Removal target: v0.2.0 (Q2 2026)
     */
Your Name's avatar
Your Name committed
49
50
51
    LlamaMLP(const LlamaConfig &config,
             const infinicore::Device &device,
             engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
52

53
54
55
56
    LlamaMLP(std::shared_ptr<infinilm::config::ModelConfig> model_config,
             const infinicore::Device &device,
             engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());

57
58
59
60
61
62
63
64
65
66
67
68
69
    /**
     * @brief Forward pass: compute MLP output
     *
     * @param hidden_states Input tensor of shape [batch, seq_len, hidden_size]
     * @return Output tensor of shape [batch, seq_len, hidden_size]
     */
    infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const;

    // Module information
    size_t hidden_size() const { return hidden_size_; }
    size_t intermediate_size() const { return intermediate_size_; }

protected:
70
    INFINICORE_NN_MODULE(layers::GateUpParallelLinear, gate_up_proj);
Your Name's avatar
Your Name committed
71
72
73
    INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, down_proj);

    engine::distributed::RankInfo rank_info_;
74
75
76
    size_t hidden_size_;
    size_t intermediate_size_;
    bool use_bias_;
77
78

    std::shared_ptr<infinilm::config::ModelConfig> model_config_;
79
80
81
};

} // namespace infinilm::models::llama