llama_config.hpp 3.39 KB
Newer Older
1
2
3
4
5
#pragma once

#include <cstddef>
#include <cstdint>
#include <string>
Ceng's avatar
Ceng committed
6
#include <vector>
7

Jiacheng Huang's avatar
Jiacheng Huang committed
8
9
#include "../infinilm_model.hpp"

PanZezhong's avatar
PanZezhong committed
10
11
#include <infinicore/nn/rope.hpp>

12
13
14
15
16
17
18
19
namespace infinilm::models::llama {

/**
 * @brief Configuration structure for Llama model architecture
 *
 * This struct holds all hyperparameters needed to construct a Llama model.
 * It follows the same structure as HuggingFace's LlamaConfig.
 */
Jiacheng Huang's avatar
Jiacheng Huang committed
20
struct LlamaConfig : public InfinilmModel::Config {
21
22
23
    // Data type
    infinicore::DataType dtype = infinicore::DataType::F32;

24
    // Vocabulary and embedding
PanZezhong's avatar
PanZezhong committed
25
26
27
    size_t vocab_size = 32000;        // Vocabulary size
    size_t hidden_size = 4096;        // Hidden dimension size
    size_t intermediate_size = 11008; // MLP intermediate dimension
28
29

    // Architecture
PanZezhong's avatar
PanZezhong committed
30
31
32
33
    size_t num_hidden_layers = 32;   // Number of decoder layers
    size_t num_attention_heads = 32; // Number of attention heads
    size_t num_key_value_heads = 32; // Number of key-value heads (for GQA)
    size_t head_dim = 128;           // Attention head dimension (hidden_size / num_attention_heads)
34
35

    // Position embeddings
PanZezhong's avatar
PanZezhong committed
36
37
38
39
    size_t max_position_embeddings = 2048; // Maximum sequence length
    double rope_theta = 10000.0;           // RoPE base frequency

    std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> rope_scaling = nullptr; // RoPE scaling type
40
41

    // Normalization
PanZezhong's avatar
PanZezhong committed
42
    double rms_norm_eps = 1e-6; // RMSNorm epsilon
43
44

    // Activation
PanZezhong's avatar
PanZezhong committed
45
46
    std::string hidden_act = "silu";  // Activation function (typically "silu")
    std::string model_type = "llama"; // Model type identifier (matches HF configs)
47
48

    // Optional features
PanZezhong's avatar
PanZezhong committed
49
50
51
52
53
    bool use_cache = true;              // Whether to use KV cache
    bool attention_bias = true;         // Whether to use bias in Q/K/V projections (default true for 9G7B compatibility)
    bool attention_output_bias = false; // Whether to use bias in output projection (o_proj)
    bool mlp_bias = false;              // Whether to use bias in MLP projections
    bool tie_word_embeddings = false;   // Whether to tie input/output embeddings
54

Ceng's avatar
Ceng committed
55
    // Training/initialization parameters
PanZezhong's avatar
PanZezhong committed
56
57
58
    double attention_dropout = 0.0;  // Dropout ratio for attention probabilities
    double initializer_range = 0.02; // Standard deviation for weight initialization
    size_t pretraining_tp = 1;       // Tensor parallelism rank used during pretraining
Ceng's avatar
Ceng committed
59
60

    // Model metadata
PanZezhong's avatar
PanZezhong committed
61
    std::string name_or_path = ""; // Model name or path identifier
Ceng's avatar
Ceng committed
62

63
64
    // Token IDs
    int64_t pad_token_id = -1;               // Padding token ID (optional)
Ceng's avatar
Ceng committed
65
66
    std::vector<int64_t> bos_token_id = {1}; // Beginning of sequence token ID(s)
    std::vector<int64_t> eos_token_id = {2}; // End of sequence token ID(s)
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

    /**
     * @brief Compute key-value dimension for Grouped Query Attention (GQA)
     * @return The dimension for key/value projections
     */
    size_t kv_dim() const {
        return hidden_size * num_key_value_heads / num_attention_heads;
    }

    /**
     * @brief Validate configuration parameters
     * @return true if configuration is valid
     */
    bool validate() const {
        if (hidden_size % num_attention_heads != 0) {
            return false;
        }
        if (num_attention_heads % num_key_value_heads != 0) {
            return false;
        }
        if (head_dim != hidden_size / num_attention_heads) {
            return false;
        }
        return true;
    }
};

} // namespace infinilm::models::llama