model_factory.cpp 872 Bytes
Newer Older
1
2
3
4
#include "model_factory.hpp"
#include "llama/llama.hpp"

namespace infinilm {
5
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
Jiacheng Huang's avatar
Jiacheng Huang committed
6
    const InfinilmModel::Config &config,
7
8
    engine::distributed::RankInfo rank_info,
    std::shared_ptr<cache::DynamicCache> cache_ptr) {
9

Jiacheng Huang's avatar
Jiacheng Huang committed
10
11
    if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
        const auto &llama_config = *llama_config_ptr;
12
13
14
15
16
17
18
19
        auto model = std::make_shared<models::llama::LlamaForCausalLM>(
            llama_config, rank_info.device, infinicore::DataType::BF16, rank_info);

        if (cache_ptr != nullptr) {
            model->model().set_external_cache(cache_ptr);
        }

        return model;
20
21
22
23
24
    } else {
        throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
    }
}
} // namespace infinilm