llama_decoder_layer.hpp 4.27 KB
Newer Older
1
2
#pragma once

Your Name's avatar
Your Name committed
3
#include "infinicore/device.hpp"
4
5
6
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/tensor.hpp"
Your Name's avatar
Your Name committed
7
8
9
10
11
#include "llama_attention.hpp"
#include "llama_config.hpp"
#include "llama_mlp.hpp"

#include "../../engine/distributed/distributed.hpp"
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

namespace infinilm::models::llama {

/**
 * @brief Single decoder layer (transformer block) for Llama
 *
 * Each decoder layer consists of:
 * - Input layer normalization (RMSNorm)
 * - Self-attention mechanism
 * - Post-attention layer normalization (RMSNorm)
 * - MLP feed-forward network
 *
 * Residual connections are applied around both attention and MLP blocks.
 */
class LlamaDecoderLayer : public infinicore::nn::Module {
public:
    /**
     * @brief Construct LlamaDecoderLayer module
     *
     * @param config Model configuration
     * @param device Device to create tensors on
Ceng's avatar
Ceng committed
33
     * @param layer_idx Layer index for cache management and debugging
34
35
     * @param dtype Optional data type for model parameters (defaults to F32)
     */
36
37
38
39
40
41
42
43
44
45
46
47
    /**
     * @deprecated This function is deprecated and will be REMOVED in the next major release (v0.2.0).
     *
     * ⚠️ DEVELOPMENT POLICY:
     *   - NO new development or feature additions permitted on this interface
     *   - Only critical bug fixes (security/stability) allowed until removal
     *   - All new code MUST migrate to the polymorphic overload below
     *
     * Replacement: Use the polymorphic overload of this same function name with updated signature
     * Reason: Legacy signature lacks support for dynamic quantization modes.
     * Removal target: v0.2.0 (Q2 2026)
     */
Your Name's avatar
Your Name committed
48
49
50
    LlamaDecoderLayer(const LlamaConfig &config,
                      const infinicore::Device &device,
                      size_t layer_idx,
51
52
                      engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
                      backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
53

54
55
56
    LlamaDecoderLayer(std::shared_ptr<infinilm::config::ModelConfig> model_config,
                      const infinicore::Device &device,
                      size_t layer_idx,
57
58
                      engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
                      backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
59

60
61
62
    /**
     * @brief Forward pass: process one decoder layer
     *
63
64
     * @param hidden_states [batch, seq_len, hidden_size], will be modified
     * @param residual [batch, seq_len, hidden_size], will be modified
65
66
67
     * @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
     * @param kv_cache Optional KV cache for incremental decoding
     * @return Output tensor of shape [batch, seq_len, hidden_size]
68
     *         Updated residual tensor of shape [batch, seq_len, hidden_size]
69
     */
70
71
72
73
74
75
76
77
    std::tuple<infinicore::Tensor, infinicore::Tensor>
    forward(infinicore::Tensor &hidden_states,
            infinicore::Tensor &residual,
            const infinicore::Tensor &position_ids,
            std::shared_ptr<infinilm::cache::Cache> kv_cache,
            std::optional<infinicore::Tensor> past_sequence_lengths,
            std::optional<infinicore::Tensor> total_sequence_lengths,
            std::optional<infinicore::Tensor> input_offsets,
78
            std::optional<infinicore::Tensor> cu_seqlens,
79
80
            std::optional<infinicore::Tensor> block_tables,
            std::optional<infinicore::Tensor> slot_mappin) const;
81

Ceng's avatar
Ceng committed
82
83
84
85
86
    /**
     * @brief Get the layer index
     */
    size_t layer_idx() const { return layer_idx_; }

87
88
89
90
91
92
93
94
95
96
97
98
99
100
    void set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb) {
        if (self_attn_) {
            self_attn_->set_rotary_emb(rotary_emb);
        }
    }

protected:
    // Layer normalization
    INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm);
    INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm);

    // Attention and MLP
    INFINICORE_NN_MODULE(LlamaAttention, self_attn);
    INFINICORE_NN_MODULE(LlamaMLP, mlp);
Your Name's avatar
Your Name committed
101
    engine::distributed::RankInfo rank_info_;
102
    std::shared_ptr<infinilm::config::ModelConfig> model_config_;
Ceng's avatar
Ceng committed
103
104

private:
Your Name's avatar
Your Name committed
105
    size_t layer_idx_; // Layer index for cache management and debugging
106
107
108
};

} // namespace infinilm::models::llama