Unverified Commit 0f5e66ce authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #721 from InfiniTensor/issue/719-a

 Module 支持单个张量加载
parents 9c4d4d1a 420369bd
#pragma once
#include "parameter.hpp"
#include "../tensor.hpp"
#include "parameter.hpp"
#include <unordered_map>
#include <type_traits>
#include <unordered_map>
#include <vector>
namespace infinicore::nn {
......@@ -18,6 +18,8 @@ public:
void load_parameter(const std::string &name, const Tensor &param);
void load_parameter_(const std::string &name, const Tensor &param);
void load_parameter_from_blob(const std::string &name, const void *data);
protected:
......@@ -135,7 +137,7 @@ private:
// Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device))
// Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device))
#define INFINICORE_NN_PARAMETER_INIT(name, args) \
name##_ = infinicore::nn::Parameter args; \
name##_ = infinicore::nn::Parameter args; \
this->register_parameter(#name, name##_)
// Declare a buffer member variable
......
......@@ -90,18 +90,18 @@ TestResult NNModuleTest::testBasicModuleCreation() {
auto new_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::F32, infinicore::Device());
auto new_bias = infinicore::Tensor::zeros({4}, infinicore::DataType::F32, infinicore::Device());
// Load using load_parameter
module.load_parameter("weight", new_weight);
module.load_parameter("bias", new_bias);
// Load using load_parameter_
module.load_parameter_("weight", new_weight);
module.load_parameter_("bias", new_bias);
// Verify the parameters were updated
auto updated_state_dict = module.state_dict();
if (!tensorsAllClose(updated_state_dict.at("weight"), new_weight, 1e-6, 1e-6)) {
spdlog::error("Weight parameter values do not match after load_parameter");
spdlog::error("Weight parameter values do not match after load_parameter_");
return false;
}
if (!tensorsAllClose(updated_state_dict.at("bias"), new_bias, 1e-6, 1e-6)) {
spdlog::error("Bias parameter values do not match after load_parameter");
spdlog::error("Bias parameter values do not match after load_parameter_");
return false;
}
......@@ -1493,14 +1493,14 @@ TestResult NNModuleTest::testDtypeAssertion() {
linear1.load_state_dict(matching_state);
spdlog::debug("✓ Matching dtype load succeeded");
// Test 2: Failed load with mismatched dtype (load_parameter)
spdlog::info("Test 2: Failed load_parameter with mismatched dtype");
// Test 2: Failed load with mismatched dtype (load_parameter_)
spdlog::info("Test 2: Failed load_parameter_ with mismatched dtype");
infinicore::nn::Linear linear2(8, 4, true);
auto mismatched_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::BF16, infinicore::Device());
bool exception_thrown = false;
try {
linear2.load_parameter("weight", mismatched_weight);
linear2.load_parameter_("weight", mismatched_weight);
} catch (const std::runtime_error &e) {
exception_thrown = true;
std::string error_msg = e.what();
......@@ -1512,7 +1512,7 @@ TestResult NNModuleTest::testDtypeAssertion() {
}
if (!exception_thrown) {
spdlog::error("Expected exception for dtype mismatch in load_parameter");
spdlog::error("Expected exception for dtype mismatch in load_parameter_");
return false;
}
......
......@@ -13,14 +13,14 @@ Runtime *ContextImpl::getCurrentRuntime() {
// Try to find the first non-CPU device, fallback to CPU
for (int i = int(Device::Type::COUNT) - 1; i > 0; i--) {
if (!runtime_table_[i].empty() && runtime_table_[i][0] != nullptr) {
current_runtime_ = runtime_table_[i][0].get();
current_runtime_ = runtime_table_[i][0].get()->activate();
spdlog::debug("Lazy init: Set current_runtime_ to {} (ptr={})", current_runtime_->device().toString(), static_cast<void *>(current_runtime_));
return current_runtime_;
}
}
// Fallback to CPU runtime
if (!runtime_table_[0].empty() && runtime_table_[0][0] != nullptr) {
current_runtime_ = runtime_table_[0][0].get();
current_runtime_ = runtime_table_[0][0].get()->activate();
spdlog::debug("Lazy init: Set current_runtime_ to {} (ptr={})", current_runtime_->device().toString(), static_cast<void *>(current_runtime_));
}
} else {
......
......@@ -17,6 +17,22 @@ void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_sta
}
void Module::load_parameter(const std::string &name, const Tensor &param) {
// This function only handles direct parameters (no hierarchical traversal)
auto all_params = state_dict();
auto it = all_params.find(name);
if (it != all_params.end()) {
auto existing_param = it->second;
existing_param.load(param);
return;
}
// Parameter not found
spdlog::debug("load_parameter_: Parameter '{}' not found. Available: {} params",
name, parameters_.size());
throw std::runtime_error("Parameter '" + name + "' not found in module.");
}
void Module::load_parameter_(const std::string &name, const Tensor &param) {
// This function only handles direct parameters (no hierarchical traversal)
auto it = parameters_.find(name);
if (it != parameters_.end()) {
......@@ -33,7 +49,7 @@ void Module::load_parameter(const std::string &name, const Tensor &param) {
}
// Parameter not found
spdlog::debug("load_parameter: Parameter '{}' not found. Available: {} params",
spdlog::debug("load_parameter_: Parameter '{}' not found. Available: {} params",
name, parameters_.size());
throw std::runtime_error("Parameter '" + name + "' not found in module.");
}
......@@ -59,7 +75,7 @@ void Module::load_state_dict_recursively(const std::unordered_map<std::string, T
std::string full_name = prefix.empty() ? param_name : prefix + "." + param_name;
auto it = _state_dict.find(full_name);
if (it != _state_dict.end()) {
load_parameter(param_name, it->second);
load_parameter_(param_name, it->second);
}
}
......
......@@ -12,12 +12,7 @@ RMSNorm::RMSNorm(size_t normalized_shape, double eps, const DataType &dtype, con
device_ = device;
// Initialize parameter using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({normalized_shape}, dtype_, device));
// Initialize weight to ones (standard practice for RMSNorm)
auto ones_tensor = Tensor::ones({normalized_shape}, dtype_, device);
weight_->copy_from(ones_tensor);
}
Tensor RMSNorm::forward(const Tensor &x) const {
......
......@@ -162,6 +162,7 @@ std::string TensorImpl::info() const {
ss << s << " ";
}
ss << "] dtype=" << toString(this->dtype());
ss << " device=" << this->device().toString();
return ss.str();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment