Commit 420369bd authored by PanZezhong's avatar PanZezhong
Browse files

issue/719 Module 支持单个张量加载

parent e1c836b8
#pragma once #pragma once
#include "parameter.hpp"
#include "../tensor.hpp" #include "../tensor.hpp"
#include "parameter.hpp"
#include <unordered_map>
#include <type_traits> #include <type_traits>
#include <unordered_map>
#include <vector> #include <vector>
namespace infinicore::nn { namespace infinicore::nn {
...@@ -18,6 +18,8 @@ public: ...@@ -18,6 +18,8 @@ public:
void load_parameter(const std::string &name, const Tensor &param); void load_parameter(const std::string &name, const Tensor &param);
void load_parameter_(const std::string &name, const Tensor &param);
void load_parameter_from_blob(const std::string &name, const void *data); void load_parameter_from_blob(const std::string &name, const void *data);
protected: protected:
...@@ -135,7 +137,7 @@ private: ...@@ -135,7 +137,7 @@ private:
// Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device)) // Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device))
// Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device)) // Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device))
#define INFINICORE_NN_PARAMETER_INIT(name, args) \ #define INFINICORE_NN_PARAMETER_INIT(name, args) \
name##_ = infinicore::nn::Parameter args; \ name##_ = infinicore::nn::Parameter args; \
this->register_parameter(#name, name##_) this->register_parameter(#name, name##_)
// Declare a buffer member variable // Declare a buffer member variable
......
...@@ -90,18 +90,18 @@ TestResult NNModuleTest::testBasicModuleCreation() { ...@@ -90,18 +90,18 @@ TestResult NNModuleTest::testBasicModuleCreation() {
auto new_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::F32, infinicore::Device()); auto new_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::F32, infinicore::Device());
auto new_bias = infinicore::Tensor::zeros({4}, infinicore::DataType::F32, infinicore::Device()); auto new_bias = infinicore::Tensor::zeros({4}, infinicore::DataType::F32, infinicore::Device());
// Load using load_parameter // Load using load_parameter_
module.load_parameter("weight", new_weight); module.load_parameter_("weight", new_weight);
module.load_parameter("bias", new_bias); module.load_parameter_("bias", new_bias);
// Verify the parameters were updated // Verify the parameters were updated
auto updated_state_dict = module.state_dict(); auto updated_state_dict = module.state_dict();
if (!tensorsAllClose(updated_state_dict.at("weight"), new_weight, 1e-6, 1e-6)) { if (!tensorsAllClose(updated_state_dict.at("weight"), new_weight, 1e-6, 1e-6)) {
spdlog::error("Weight parameter values do not match after load_parameter"); spdlog::error("Weight parameter values do not match after load_parameter_");
return false; return false;
} }
if (!tensorsAllClose(updated_state_dict.at("bias"), new_bias, 1e-6, 1e-6)) { if (!tensorsAllClose(updated_state_dict.at("bias"), new_bias, 1e-6, 1e-6)) {
spdlog::error("Bias parameter values do not match after load_parameter"); spdlog::error("Bias parameter values do not match after load_parameter_");
return false; return false;
} }
...@@ -1493,14 +1493,14 @@ TestResult NNModuleTest::testDtypeAssertion() { ...@@ -1493,14 +1493,14 @@ TestResult NNModuleTest::testDtypeAssertion() {
linear1.load_state_dict(matching_state); linear1.load_state_dict(matching_state);
spdlog::debug("✓ Matching dtype load succeeded"); spdlog::debug("✓ Matching dtype load succeeded");
// Test 2: Failed load with mismatched dtype (load_parameter) // Test 2: Failed load with mismatched dtype (load_parameter_)
spdlog::info("Test 2: Failed load_parameter with mismatched dtype"); spdlog::info("Test 2: Failed load_parameter_ with mismatched dtype");
infinicore::nn::Linear linear2(8, 4, true); infinicore::nn::Linear linear2(8, 4, true);
auto mismatched_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::BF16, infinicore::Device()); auto mismatched_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::BF16, infinicore::Device());
bool exception_thrown = false; bool exception_thrown = false;
try { try {
linear2.load_parameter("weight", mismatched_weight); linear2.load_parameter_("weight", mismatched_weight);
} catch (const std::runtime_error &e) { } catch (const std::runtime_error &e) {
exception_thrown = true; exception_thrown = true;
std::string error_msg = e.what(); std::string error_msg = e.what();
...@@ -1512,7 +1512,7 @@ TestResult NNModuleTest::testDtypeAssertion() { ...@@ -1512,7 +1512,7 @@ TestResult NNModuleTest::testDtypeAssertion() {
} }
if (!exception_thrown) { if (!exception_thrown) {
spdlog::error("Expected exception for dtype mismatch in load_parameter"); spdlog::error("Expected exception for dtype mismatch in load_parameter_");
return false; return false;
} }
......
...@@ -13,14 +13,14 @@ Runtime *ContextImpl::getCurrentRuntime() { ...@@ -13,14 +13,14 @@ Runtime *ContextImpl::getCurrentRuntime() {
// Try to find the first non-CPU device, fallback to CPU // Try to find the first non-CPU device, fallback to CPU
for (int i = int(Device::Type::COUNT) - 1; i > 0; i--) { for (int i = int(Device::Type::COUNT) - 1; i > 0; i--) {
if (!runtime_table_[i].empty() && runtime_table_[i][0] != nullptr) { if (!runtime_table_[i].empty() && runtime_table_[i][0] != nullptr) {
current_runtime_ = runtime_table_[i][0].get(); current_runtime_ = runtime_table_[i][0].get()->activate();
spdlog::debug("Lazy init: Set current_runtime_ to {} (ptr={})", current_runtime_->device().toString(), static_cast<void *>(current_runtime_)); spdlog::debug("Lazy init: Set current_runtime_ to {} (ptr={})", current_runtime_->device().toString(), static_cast<void *>(current_runtime_));
return current_runtime_; return current_runtime_;
} }
} }
// Fallback to CPU runtime // Fallback to CPU runtime
if (!runtime_table_[0].empty() && runtime_table_[0][0] != nullptr) { if (!runtime_table_[0].empty() && runtime_table_[0][0] != nullptr) {
current_runtime_ = runtime_table_[0][0].get(); current_runtime_ = runtime_table_[0][0].get()->activate();
spdlog::debug("Lazy init: Set current_runtime_ to {} (ptr={})", current_runtime_->device().toString(), static_cast<void *>(current_runtime_)); spdlog::debug("Lazy init: Set current_runtime_ to {} (ptr={})", current_runtime_->device().toString(), static_cast<void *>(current_runtime_));
} }
} else { } else {
......
...@@ -17,6 +17,22 @@ void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_sta ...@@ -17,6 +17,22 @@ void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_sta
} }
void Module::load_parameter(const std::string &name, const Tensor &param) { void Module::load_parameter(const std::string &name, const Tensor &param) {
// This function only handles direct parameters (no hierarchical traversal)
auto all_params = state_dict();
auto it = all_params.find(name);
if (it != all_params.end()) {
auto existing_param = it->second;
existing_param.load(param);
return;
}
// Parameter not found
spdlog::debug("load_parameter_: Parameter '{}' not found. Available: {} params",
name, parameters_.size());
throw std::runtime_error("Parameter '" + name + "' not found in module.");
}
void Module::load_parameter_(const std::string &name, const Tensor &param) {
// This function only handles direct parameters (no hierarchical traversal) // This function only handles direct parameters (no hierarchical traversal)
auto it = parameters_.find(name); auto it = parameters_.find(name);
if (it != parameters_.end()) { if (it != parameters_.end()) {
...@@ -33,7 +49,7 @@ void Module::load_parameter(const std::string &name, const Tensor &param) { ...@@ -33,7 +49,7 @@ void Module::load_parameter(const std::string &name, const Tensor &param) {
} }
// Parameter not found // Parameter not found
spdlog::debug("load_parameter: Parameter '{}' not found. Available: {} params", spdlog::debug("load_parameter_: Parameter '{}' not found. Available: {} params",
name, parameters_.size()); name, parameters_.size());
throw std::runtime_error("Parameter '" + name + "' not found in module."); throw std::runtime_error("Parameter '" + name + "' not found in module.");
} }
...@@ -59,7 +75,7 @@ void Module::load_state_dict_recursively(const std::unordered_map<std::string, T ...@@ -59,7 +75,7 @@ void Module::load_state_dict_recursively(const std::unordered_map<std::string, T
std::string full_name = prefix.empty() ? param_name : prefix + "." + param_name; std::string full_name = prefix.empty() ? param_name : prefix + "." + param_name;
auto it = _state_dict.find(full_name); auto it = _state_dict.find(full_name);
if (it != _state_dict.end()) { if (it != _state_dict.end()) {
load_parameter(param_name, it->second); load_parameter_(param_name, it->second);
} }
} }
......
...@@ -12,12 +12,7 @@ RMSNorm::RMSNorm(size_t normalized_shape, double eps, const DataType &dtype, con ...@@ -12,12 +12,7 @@ RMSNorm::RMSNorm(size_t normalized_shape, double eps, const DataType &dtype, con
device_ = device; device_ = device;
// Initialize parameter using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({normalized_shape}, dtype_, device)); INFINICORE_NN_PARAMETER_INIT(weight, ({normalized_shape}, dtype_, device));
// Initialize weight to ones (standard practice for RMSNorm)
auto ones_tensor = Tensor::ones({normalized_shape}, dtype_, device);
weight_->copy_from(ones_tensor);
} }
Tensor RMSNorm::forward(const Tensor &x) const { Tensor RMSNorm::forward(const Tensor &x) const {
......
...@@ -162,6 +162,7 @@ std::string TensorImpl::info() const { ...@@ -162,6 +162,7 @@ std::string TensorImpl::info() const {
ss << s << " "; ss << s << " ";
} }
ss << "] dtype=" << toString(this->dtype()); ss << "] dtype=" << toString(this->dtype());
ss << " device=" << this->device().toString();
return ss.str(); return ss.str();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment