llama_decoder_layer.cpp 2.21 KB
Newer Older
1
2
3
4
5
6
#include "llama_decoder_layer.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/ops.hpp"

namespace infinilm::models::llama {

Your Name's avatar
Your Name committed
7
8
LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
                                     const infinicore::Device &device,
Ceng's avatar
Ceng committed
9
                                     size_t layer_idx,
10
11
12
13
                                     engine::distributed::RankInfo rank_info) : layer_idx_(layer_idx), rank_info_(rank_info) {
    const auto &dtype{config.dtype};

    // Initialize layer normalization layers
14
15
16
17
18
19
    INFINICORE_NN_MODULE_INIT(input_layernorm, config.hidden_size, config.rms_norm_eps,
                              dtype, device);
    INFINICORE_NN_MODULE_INIT(post_attention_layernorm, config.hidden_size, config.rms_norm_eps,
                              dtype, device);

    // Initialize attention and MLP modules
20
21
    INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, rank_info_);
    INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_);
22
23
24
}

infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,
Your Name's avatar
Your Name committed
25
                                              const infinicore::Tensor &position_ids,
PanZezhong's avatar
PanZezhong committed
26
27
                                              std::shared_ptr<infinilm::cache::Cache> kv_cache,
                                              const infinicore::Tensor &cache_positions) const {
28
29
30
31
32
33
34
    // Save residual for attention
    auto residual = hidden_states;

    // 1. Pre-attention layer normalization
    auto normed_states = input_layernorm_->forward(hidden_states);

    // 2. Self-attention with residual connection
PanZezhong's avatar
PanZezhong committed
35
    auto attn_output = self_attn_->forward(normed_states, position_ids, kv_cache, cache_positions);
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

    // Add residual: hidden_states = hidden_states + attn_output
    auto output = infinicore::op::add(residual, attn_output);
    // Save residual for MLP
    residual = output;

    // 3. Post-attention layer normalization
    normed_states = post_attention_layernorm_->forward(output);

    // 4. MLP with residual connection
    auto mlp_output = mlp_->forward(normed_states);

    // Add residual: output = output + mlp_output
    output = infinicore::op::add(residual, mlp_output);

    return output;
}

} // namespace infinilm::models::llama