model_config.h 3.27 KB
Newer Older
1
2
3
4
#ifndef __MODEL_CONFIG_HPP_
#define __MODEL_CONFIG_HPP_

#include "nlohmann/json.hpp"
5
#include <iostream>
6
7
8
9
10
11
12
13
14
15

#include <filesystem>
#include <fstream>

using DimSize = size_t;
using URL = std::string;
using ModelName = std::string;

// We must assure this can be load by config.json
class ModelConfig {
16
public:
17
18
19
20
21
22
23
24
25
  DimSize hidden_size;
  DimSize intermediate_size;
  size_t max_position_embeddings;
  std::string model_type;
  size_t num_attention_heads;
  size_t num_hidden_layers;
  size_t num_key_value_heads;
  size_t vocab_size;

26
27
28
29
  NLOHMANN_DEFINE_TYPE_INTRUSIVE(ModelConfig, hidden_size, intermediate_size,
                                 max_position_embeddings, model_type,
                                 num_attention_heads, num_hidden_layers,
                                 num_key_value_heads, vocab_size);
30
31

  void load_from(std::filesystem::path path) {
32
    std::cout << "Load from " << path << std::endl;
33
34
35
36
37
38
39
40
41
42
43
    std::ifstream i(path);
    nlohmann::json j;
    i >> j;
    *this = j.get<ModelConfig>();
  }
};

using QuantType = std::string;
static const QuantType NoQuantType = "";

class QuantConfig {
44
public:
45
46
47
48
  QuantType name;

  // For GEMV
  QuantType type_of_dot_vector = NoQuantType;
49
50
51
  inline bool can_be_used_as_matrix() {
    return type_of_dot_vector != NoQuantType;
  }
52
53
54
55
56
57
58
59
60
61
62
63

  bool can_be_used_as_vector;

  double bytes_per_element;
  bool has_scale;
  bool has_min;

  size_t block_element_count;
  size_t block_element_size;

  URL reference = "";

64
65
66
67
68
  NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(QuantConfig, name,
                                              type_of_dot_vector,
                                              can_be_used_as_vector,
                                              bytes_per_element, has_scale,
                                              has_min, block_element_count,
69
70
71
72
73
74
75
76
                                              block_element_size, reference);
};

inline std::map<QuantType, QuantConfig> quant_configs;
inline std::map<ModelName, ModelConfig> model_configs;

inline void load_quant_configs(std::filesystem::path path) {
  nlohmann::json j;
77
78
79
80
81
82
83
84
85
86
87
  if (std::filesystem::exists(path)) {
    std::cout << __FUNCTION__ << " from " << path << std::endl;
    std::ifstream i(path);
    i >> j;
    quant_configs = j.get<std::map<QuantType, QuantConfig>>();
    std::cout << "Loaded Quant Configs" << std::endl;
    for (auto &[k, v] : quant_configs) {
      std::cout << " - " << k << std::endl;
    }
  } else {
    std::cout << __FUNCTION__ << " no file at " << path << std::endl;
88
89
90
91
92
93
94
95
96
97
98
  }
}

inline void dump_quant_configs(std::filesystem::path path) {
  std::ofstream o(path);
  nlohmann::json j = quant_configs;
  o << j.dump(4);
}

inline void load_model_configs(std::filesystem::path path) {
  nlohmann::json j;
99
100
101
102
103
104
105
106
107
108
109
  if (std::filesystem::exists(path)) {
    std::cout << __FUNCTION__ << " from " << path << std::endl;
    std::ifstream i(path);
    i >> j;
    model_configs = j.get<std::map<ModelName, ModelConfig>>();
    std::cout << "Loaded Model Configs" << std::endl;
    for (auto &[k, v] : model_configs) {
      std::cout << " - " << k << std::endl;
    }
  } else {
    std::cout << __FUNCTION__ << " no file at " << path << std::endl;
110
111
112
113
114
115
116
117
118
119
  }
}

inline void dump_model_configs(std::filesystem::path path) {
  std::ofstream o(path);
  nlohmann::json j = model_configs;
  o << j.dump(4);
}

#endif