model_factory.cpp 816 Bytes
Newer Older
1
2
3
4
#include "model_factory.hpp"
#include "llama/llama.hpp"

namespace infinilm {
5
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
Jiacheng Huang's avatar
Jiacheng Huang committed
6
    const InfinilmModel::Config &config,
7
    engine::distributed::RankInfo rank_info,
PanZezhong's avatar
PanZezhong committed
8
    const cache::CacheConfig *cache) {
9

PanZezhong's avatar
PanZezhong committed
10
    std::shared_ptr<InfinilmModel> model;
Jiacheng Huang's avatar
Jiacheng Huang committed
11
12
    if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
        const auto &llama_config = *llama_config_ptr;
PanZezhong's avatar
PanZezhong committed
13
        model = std::make_shared<models::llama::LlamaForCausalLM>(
14
            llama_config, rank_info.device, rank_info);
15
16
17
    } else {
        throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
    }
PanZezhong's avatar
PanZezhong committed
18
19
20
21
22
23

    if (cache) {
        model->reset_cache(cache);
    }

    return model;
24
25
}
} // namespace infinilm