model_factory.cpp 855 Bytes
Newer Older
1
2
3
4
#include "model_factory.hpp"
#include "llama/llama.hpp"

namespace infinilm {
5
6
7
8
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
    const std::any &config,
    engine::distributed::RankInfo rank_info,
    std::shared_ptr<cache::DynamicCache> cache_ptr) {
9
10
11

    if (config.type() == typeid(models::llama::LlamaConfig)) {
        const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config);
12
13
14
15
16
17
18
19
        auto model = std::make_shared<models::llama::LlamaForCausalLM>(
            llama_config, rank_info.device, infinicore::DataType::BF16, rank_info);

        if (cache_ptr != nullptr) {
            model->model().set_external_cache(cache_ptr);
        }

        return model;
20
21
22
23
24
    } else {
        throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model config type");
    }
}
} // namespace infinilm