#ifndef WEIGHTS_LOADER_HPP #define WEIGHTS_LOADER_HPP #include "../tensor.hpp" #include #include namespace infinicore { class WeightsLoader { protected: std::vector>> _weights; infiniDevice_t _device; std::vector _dev_ids; std::vector _streams; public: WeightsLoader(infiniDevice_t, const std::vector &dev_ids); void resigter(const std::string &name, std::shared_ptr tensor, int rank = 0); void load_weight(const std::string &name, const void *host_data); void load_distributed_weight(const std::string &name, const void *host_data, const std::vector &ranks); void load_rank_weight(const std::string &name, const void *host_data, int rank); void finalize(); std::shared_ptr get(const std::string &name, int rank = 0); const std::vector &dev_ids() const { return _dev_ids; } infiniDevice_t device() const { return _device; } }; } // namespace infinicore #endif // WEIGHTS_LOADER_HPP