Commit 0c6803c6 authored by Ceng23333's avatar Ceng23333
Browse files

issue/697: fix load_state_dict


Signed-off-by: default avatarCeng23333 <441651826@qq.com>
parent 986bb179
...@@ -78,6 +78,7 @@ protected: ...@@ -78,6 +78,7 @@ protected:
std::unordered_map<std::string, Parameter> parameters_; std::unordered_map<std::string, Parameter> parameters_;
private: private:
void load_state_dict_recursively(const std::unordered_map<std::string, Tensor> &_state_dict, const std::string &prefix = "");
void collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix = "") const; void collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix = "") const;
}; };
......
...@@ -13,23 +13,11 @@ const std::unordered_map<std::string, Parameter> &Module::state_dict() const { ...@@ -13,23 +13,11 @@ const std::unordered_map<std::string, Parameter> &Module::state_dict() const {
} }
void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_state_dict) { void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_state_dict) {
// Collect all parameters from this module and its submodules with their full hierarchical names load_state_dict_recursively(_state_dict, "");
std::unordered_map<std::string, Parameter> all_params;
collect_all_parameters(all_params, "");
// For each parameter in this module hierarchy, load from the state dict
for (auto &[param_full_name, param] : all_params) {
// Look up the corresponding tensor in the input state dict using the full name
auto it = _state_dict.find(param_full_name);
if (it != _state_dict.end()) {
this->load_parameter(param_full_name, it->second);
} else {
spdlog::warn("Parameter '{}' provided but not found in module.", param_full_name);
}
}
} }
void Module::load_parameter(const std::string &name, const Tensor &param) { void Module::load_parameter(const std::string &name, const Tensor &param) {
// This function only handles direct parameters (no hierarchical traversal)
auto it = parameters_.find(name); auto it = parameters_.find(name);
if (it != parameters_.end()) { if (it != parameters_.end()) {
auto existing_param = it->second; auto existing_param = it->second;
...@@ -41,9 +29,13 @@ void Module::load_parameter(const std::string &name, const Tensor &param) { ...@@ -41,9 +29,13 @@ void Module::load_parameter(const std::string &name, const Tensor &param) {
+ std::to_string(static_cast<int>(existing_param->dtype())) + ", got " + std::to_string(static_cast<int>(param->dtype()))); + std::to_string(static_cast<int>(existing_param->dtype())) + ", got " + std::to_string(static_cast<int>(param->dtype())));
} }
existing_param.load(param); existing_param.load(param);
} else { return;
throw std::runtime_error("Parameter '" + name + "' not found in module.");
} }
// Parameter not found
spdlog::debug("load_parameter: Parameter '{}' not found. Available: {} params",
name, parameters_.size());
throw std::runtime_error("Parameter '" + name + "' not found in module.");
} }
void Module::load_parameter_from_blob(const std::string &name, const void *data) { void Module::load_parameter_from_blob(const std::string &name, const void *data) {
...@@ -61,6 +53,23 @@ Tensor Module::register_buffer(const std::string &name, Parameter buffer) { ...@@ -61,6 +53,23 @@ Tensor Module::register_buffer(const std::string &name, Parameter buffer) {
return buffer; return buffer;
} }
void Module::load_state_dict_recursively(const std::unordered_map<std::string, Tensor> &_state_dict, const std::string &prefix) {
// Load direct parameters with the given prefix
for (const auto &[param_name, param] : parameters_) {
std::string full_name = prefix.empty() ? param_name : prefix + "." + param_name;
auto it = _state_dict.find(full_name);
if (it != _state_dict.end()) {
load_parameter(param_name, it->second);
}
}
// Recursively load parameters from submodules with extended prefix
for (const auto &[sub_name, submodule] : submodules_) {
std::string sub_prefix = prefix.empty() ? sub_name : prefix + "." + sub_name;
submodule->load_state_dict_recursively(_state_dict, sub_prefix);
}
}
void Module::collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix) const { void Module::collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix) const {
// Add direct parameters with the given prefix // Add direct parameters with the given prefix
for (const auto &[param_name, param] : parameters_) { for (const auto &[param_name, param] : parameters_) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment