Commit 777b3233 authored by Ceng23333's avatar Ceng23333
Browse files

do assertion at load_parameter && update Module definition with macros


Signed-off-by: default avatarCeng23333 <441651826@qq.com>
parent 69c1c352
......@@ -75,7 +75,7 @@ public:
protected:
// Parameters
Parameter weight_;
INFINICORE_NN_PARAMETER(weight);
private:
size_t num_embeddings_; // Vocabulary size
......
......@@ -7,7 +7,7 @@ namespace infinicore::nn {
class Linear : public Module {
public:
Linear(size_t in_features, size_t out_features, bool bias = true, const Device &device = Device());
Linear(size_t in_features, size_t out_features, bool bias = true, const DataType &dtype = DataType::F32, const Device &device = Device());
// Forward pass: output = input @ weight.T + bias
Tensor forward(Tensor &input) const;
......@@ -20,6 +20,7 @@ public:
size_t in_features() const { return in_features_; }
size_t out_features() const { return out_features_; }
bool has_bias() const { return has_bias_; }
DataType dtype() const { return dtype_; }
// String representation
std::string extra_repr() const;
......@@ -30,8 +31,8 @@ public:
protected:
// Parameters
Parameter weight_;
Parameter bias_;
INFINICORE_NN_PARAMETER(weight);
INFINICORE_NN_PARAMETER(bias);
private:
// Helper method for common forward computation
......@@ -40,6 +41,7 @@ private:
size_t in_features_;
size_t out_features_;
bool has_bias_;
DataType dtype_;
};
} // namespace infinicore::nn
......@@ -125,13 +125,13 @@ private:
// Declare a parameter member variable
#define INFINICORE_NN_PARAMETER(name) \
Parameter name##_
infinicore::nn::Parameter name##_
// Initialize a parameter in constructor
// Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device))
// Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device))
#define INFINICORE_NN_PARAMETER_INIT(name, args) \
name##_ = Parameter args; \
name##_ = infinicore::nn::Parameter args; \
this->register_parameter(#name, name##_)
} // namespace infinicore::nn
......@@ -36,10 +36,12 @@ public:
*
* @param normalized_shape Size of the feature dimension to normalize (typically hidden_size)
* @param eps Small constant for numerical stability (default: 1e-6)
* @param dtype Data type for the weight (default: DataType::F32)
* @param device Device to create the weight on
*/
RMSNorm(size_t normalized_shape,
double eps = 1e-6,
const DataType &dtype = DataType::F32,
const Device &device = Device());
/**
......@@ -58,6 +60,7 @@ public:
// Module information
size_t normalized_shape() const { return normalized_shape_; }
double eps() const { return eps_; }
DataType dtype() const { return dtype_; }
// String representation
std::string extra_repr() const;
......@@ -67,11 +70,12 @@ public:
protected:
// Parameters
Parameter weight_;
INFINICORE_NN_PARAMETER(weight);
private:
size_t normalized_shape_; // Size of the feature dimension
double eps_; // Epsilon for numerical stability
DataType dtype_; // Data type for weight
};
} // namespace infinicore::nn
......@@ -394,7 +394,7 @@ TestResult NNModuleTest::testModuleLinear() {
try {
// Test with bias
spdlog::info("Testing Linear module with bias (8->4 features)");
infinicore::nn::Linear m1(8, 4, true, infinicore::Device());
infinicore::nn::Linear m1(8, 4, true);
auto sd1 = m1.state_dict();
if (sd1.find("weight") == sd1.end()) {
spdlog::error("weight missing");
......@@ -440,7 +440,7 @@ TestResult NNModuleTest::testModuleLinear() {
// Test without bias
spdlog::info("Testing Linear module without bias (16->3 features)");
infinicore::nn::Linear m2(16, 3, false, infinicore::Device());
infinicore::nn::Linear m2(16, 3, false);
auto sd2 = m2.state_dict();
if (sd2.find("weight") == sd2.end()) {
spdlog::error("weight missing (no-bias)");
......@@ -834,7 +834,7 @@ TestResult NNModuleTest::testModuleRMSNorm() {
// Test 1: Basic RMSNorm creation
spdlog::info("Test 1: Basic RMSNorm creation (hidden_size=768)");
infinicore::nn::RMSNorm norm1(768, 1e-6, infinicore::Device());
infinicore::nn::RMSNorm norm1(768);
auto state1 = norm1.state_dict();
if (state1.find("weight") == state1.end()) {
......@@ -925,8 +925,8 @@ TestResult NNModuleTest::testModuleRMSNorm() {
// Test 7: Different hidden sizes
spdlog::info("Test 7: Testing different hidden sizes");
infinicore::nn::RMSNorm norm_small(128, 1e-5, infinicore::Device());
infinicore::nn::RMSNorm norm_large(4096, 1e-6, infinicore::Device());
infinicore::nn::RMSNorm norm_small(128, 1e-5);
infinicore::nn::RMSNorm norm_large(4096);
auto input_small = infinicore::Tensor::ones({2, 128}, infinicore::DataType::F32, infinicore::Device());
auto output_small = norm_small.forward(input_small);
......@@ -956,7 +956,130 @@ TestResult NNModuleTest::testModuleRMSNorm() {
});
}
// Test 8: Comprehensive Tiny-Llama model test (construction + weight loading + validation)
// Test 8: Dtype assertion test
TestResult NNModuleTest::testDtypeAssertion() {
return measureTime("DtypeAssertionTest", [this]() {
try {
spdlog::info("Testing dtype assertions when loading parameters");
// Test 1: Successful load with matching dtype
spdlog::info("Test 1: Successful load with matching dtype (F32)");
infinicore::nn::Linear linear1(8, 4, true);
auto matching_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::F32, infinicore::Device());
auto matching_bias = infinicore::Tensor::ones({4}, infinicore::DataType::F32, infinicore::Device());
std::unordered_map<std::string, infinicore::Tensor> matching_state;
matching_state.emplace("weight", matching_weight);
matching_state.emplace("bias", matching_bias);
// This should succeed without throwing
linear1.load_state_dict(matching_state);
spdlog::debug("✓ Matching dtype load succeeded");
// Test 2: Failed load with mismatched dtype (load_parameter)
spdlog::info("Test 2: Failed load_parameter with mismatched dtype");
infinicore::nn::Linear linear2(8, 4, true);
auto mismatched_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::BF16, infinicore::Device());
bool exception_thrown = false;
try {
linear2.load_parameter("weight", mismatched_weight);
} catch (const std::runtime_error &e) {
exception_thrown = true;
std::string error_msg = e.what();
if (error_msg.find("dtype mismatch") == std::string::npos) {
spdlog::error("Exception message doesn't contain 'dtype mismatch'");
return false;
}
spdlog::debug("✓ Mismatched dtype exception caught: {}", error_msg);
}
if (!exception_thrown) {
spdlog::error("Expected exception for dtype mismatch in load_parameter");
return false;
}
// Test 3: Failed load with mismatched dtype (load_state_dict)
spdlog::info("Test 3: Failed load_state_dict with mismatched dtype");
infinicore::nn::Embedding embedding1(100, 64);
auto mismatched_embed_weight = infinicore::Tensor::ones({100, 64}, infinicore::DataType::BF16, infinicore::Device());
std::unordered_map<std::string, infinicore::Tensor> mismatched_state;
mismatched_state.emplace("weight", mismatched_embed_weight);
exception_thrown = false;
try {
embedding1.load_state_dict(mismatched_state);
} catch (const std::runtime_error &e) {
exception_thrown = true;
std::string error_msg = e.what();
if (error_msg.find("dtype mismatch") == std::string::npos) {
spdlog::error("Exception message doesn't contain 'dtype mismatch'");
return false;
}
if (error_msg.find("weight") == std::string::npos) {
spdlog::error("Exception message doesn't contain parameter name 'weight'");
return false;
}
spdlog::debug("✓ Mismatched dtype exception caught: {}", error_msg);
}
if (!exception_thrown) {
spdlog::error("Expected exception for dtype mismatch in load_state_dict");
return false;
}
// Test 4: Failed load with mismatched dtype (RMSNorm)
spdlog::info("Test 4: Failed load_state_dict with mismatched dtype (RMSNorm)");
infinicore::nn::RMSNorm norm1(768);
auto mismatched_norm_weight = infinicore::Tensor::ones({768}, infinicore::DataType::BF16, infinicore::Device());
std::unordered_map<std::string, infinicore::Tensor> mismatched_norm_state;
mismatched_norm_state.emplace("weight", mismatched_norm_weight);
exception_thrown = false;
try {
norm1.load_state_dict(mismatched_norm_state);
} catch (const std::runtime_error &e) {
exception_thrown = true;
std::string error_msg = e.what();
if (error_msg.find("dtype mismatch") == std::string::npos) {
spdlog::error("Exception message doesn't contain 'dtype mismatch'");
return false;
}
spdlog::debug("✓ Mismatched dtype exception caught for RMSNorm: {}", error_msg);
}
if (!exception_thrown) {
spdlog::error("Expected exception for dtype mismatch in RMSNorm load_state_dict");
return false;
}
// Test 5: Successful load with different module dtypes
spdlog::info("Test 5: Successful load with BF16 dtype (module created with BF16)");
infinicore::nn::Linear linear3(8, 4, true, infinicore::DataType::BF16);
auto bf16_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::BF16, infinicore::Device());
auto bf16_bias = infinicore::Tensor::ones({4}, infinicore::DataType::BF16, infinicore::Device());
std::unordered_map<std::string, infinicore::Tensor> bf16_state;
bf16_state.emplace("weight", bf16_weight);
bf16_state.emplace("bias", bf16_bias);
// This should succeed
linear3.load_state_dict(bf16_state);
spdlog::debug("✓ BF16 dtype load succeeded");
spdlog::info("All dtype assertion tests passed!");
return true;
} catch (const std::exception &e) {
spdlog::error("Exception in testDtypeAssertion: {}", e.what());
return false;
}
});
}
// Test 9: Comprehensive Tiny-Llama model test (construction + weight loading + validation)
TestResult NNModuleTest::testTinyLlamaConstruction() {
return measureTime("TinyLlamaModelTest", [this]() {
try {
......@@ -1007,10 +1130,10 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
INFINICORE_NN_MODULE(infinicore::nn::Linear, o_proj);
SelfAttn(size_t hidden_size, size_t kv_dim, const infinicore::Device &device) {
INFINICORE_NN_MODULE_INIT(q_proj, hidden_size, hidden_size, false, device);
INFINICORE_NN_MODULE_INIT(k_proj, hidden_size, kv_dim, false, device);
INFINICORE_NN_MODULE_INIT(v_proj, hidden_size, kv_dim, false, device);
INFINICORE_NN_MODULE_INIT(o_proj, hidden_size, hidden_size, false, device);
INFINICORE_NN_MODULE_INIT(q_proj, hidden_size, hidden_size, false, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(k_proj, hidden_size, kv_dim, false, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(v_proj, hidden_size, kv_dim, false, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(o_proj, hidden_size, hidden_size, false, infinicore::DataType::F32, device);
}
};
......@@ -1021,9 +1144,9 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
INFINICORE_NN_MODULE(infinicore::nn::Linear, down_proj);
MLP(size_t hidden_size, size_t intermediate_size, const infinicore::Device &device) {
INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size, intermediate_size, false, device);
INFINICORE_NN_MODULE_INIT(up_proj, hidden_size, intermediate_size, false, device);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size, hidden_size, false, device);
INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size, intermediate_size, false, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(up_proj, hidden_size, intermediate_size, false, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size, hidden_size, false, infinicore::DataType::F32, device);
}
};
......@@ -1036,9 +1159,9 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
Block(const TinyLlamaConfig &cfg, const infinicore::Device &device) {
size_t kv_dim = cfg.hidden_size * cfg.num_key_value_heads / cfg.num_attention_heads;
INFINICORE_NN_MODULE_INIT(input_layernorm, cfg.hidden_size, cfg.rms_norm_eps, device);
INFINICORE_NN_MODULE_INIT(input_layernorm, cfg.hidden_size, cfg.rms_norm_eps, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(self_attn, cfg.hidden_size, kv_dim, device);
INFINICORE_NN_MODULE_INIT(post_attention_layernorm, cfg.hidden_size, cfg.rms_norm_eps, device);
INFINICORE_NN_MODULE_INIT(post_attention_layernorm, cfg.hidden_size, cfg.rms_norm_eps, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(mlp, cfg.hidden_size, cfg.intermediate_size, device);
}
};
......@@ -1051,7 +1174,7 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
TinyLlamaModel(const TinyLlamaConfig &config, const infinicore::Device &device) {
INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size, std::nullopt, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_VEC_INIT(layers, config.num_hidden_layers, Block, config, device);
INFINICORE_NN_MODULE_INIT(norm, config.hidden_size, config.rms_norm_eps, device);
INFINICORE_NN_MODULE_INIT(norm, config.hidden_size, config.rms_norm_eps, infinicore::DataType::F32, device);
}
};
......@@ -1259,6 +1382,7 @@ TestResult NNModuleTest::run() {
results.push_back(testModuleLinear()); // Linear module comprehensive test
results.push_back(testModuleEmbedding()); // Embedding module test
results.push_back(testModuleRMSNorm()); // RMSNorm module test
results.push_back(testDtypeAssertion()); // Dtype assertion test
results.push_back(testTinyLlamaConstruction()); // Comprehensive: TinyLlama model test
// Check if all tests passed
......
......@@ -21,16 +21,21 @@ namespace infinicore::test {
// Simple test module that mimics torch.nn.Linear
class MockLinearModule : public infinicore::nn::Module {
public:
// Declare parameters using macros (torch-like style)
INFINICORE_NN_PARAMETER(weight);
INFINICORE_NN_PARAMETER(bias);
MockLinearModule(int input_size, int output_size, const infinicore::Device &device)
: input_size_(input_size), output_size_(output_size), device_(device) {
// Initialize weight parameter (similar to torch.nn.Linear.weight)
register_parameter("weight",
infinicore::nn::Parameter({static_cast<size_t>(output_size), static_cast<size_t>(input_size)}, infinicore::DataType::F32, device));
// Initialize bias parameter (similar to torch.nn.Linear.bias)
register_parameter("bias",
infinicore::nn::Parameter({static_cast<size_t>(output_size)}, infinicore::DataType::F32, device));
// Initialize parameters using macros
INFINICORE_NN_PARAMETER_INIT(weight,
({static_cast<size_t>(output_size), static_cast<size_t>(input_size)},
infinicore::DataType::F32,
device));
INFINICORE_NN_PARAMETER_INIT(bias,
({static_cast<size_t>(output_size)},
infinicore::DataType::F32,
device));
}
// Simple forward pass (conceptual - would need actual matrix operations)
......@@ -77,6 +82,7 @@ private:
TestResult testModuleLinear(); // Comprehensive Linear module test
TestResult testModuleEmbedding(); // Embedding module test
TestResult testModuleRMSNorm(); // RMSNorm module test
TestResult testDtypeAssertion(); // Test dtype assertions when loading parameters
TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation
};
......
......@@ -4,25 +4,26 @@
namespace infinicore::nn {
Linear::Linear(size_t in_features, size_t out_features, bool bias, const Device &device)
Linear::Linear(size_t in_features, size_t out_features, bool bias, const DataType &dtype, const Device &device)
: in_features_(in_features),
out_features_(out_features),
has_bias_(bias) {
has_bias_(bias),
dtype_(dtype) {
device_ = device;
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device));
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device));
// Register bias parameter if requested
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, DataType::F32, device));
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device));
} else {
bias_ = Parameter(); // Default constructed empty parameter
}
spdlog::debug("Created Linear module: in_features={}, out_features={}, bias={}",
in_features, out_features, bias);
spdlog::debug("Created Linear module: in_features={}, out_features={}, bias={}, dtype={}",
in_features, out_features, bias, static_cast<int>(dtype_));
}
Tensor Linear::compute_linear(Tensor &input) const {
......@@ -41,12 +42,9 @@ Tensor Linear::compute_linear(Tensor &input) const {
strides.push_back(bias_->stride(0));
auto bias_view = bias_->as_strided(output->shape(), strides);
// First set output to bias (broadcasted)
infinicore::op::rearrange_(output, bias_view);
// Compute matmul result separately, then add to output
auto matmul_result = infinicore::op::matmul(input, weight_t);
infinicore::op::add_(output, output, matmul_result);
infinicore::op::matmul_(output, input, weight_t);
infinicore::op::add_(output, output, bias_view);
} else {
// No bias: just compute output = input @ weight_t
infinicore::op::matmul_(output, input, weight_t);
......@@ -69,7 +67,7 @@ Tensor Linear::forward(Tensor &input, Tensor &residual) const {
}
std::string Linear::extra_repr() const {
return "Linear(in_features=" + std::to_string(in_features_) + ", out_features=" + std::to_string(out_features_) + ", bias=" + (has_bias_ ? "true" : "false") + ")";
return "Linear(in_features=" + std::to_string(in_features_) + ", out_features=" + std::to_string(out_features_) + ", bias=" + (has_bias_ ? "true" : "false") + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
}
} // namespace infinicore::nn
#include "infinicore/nn/module.hpp"
#include <stdexcept>
namespace infinicore::nn {
const std::unordered_map<std::string, Parameter> &Module::state_dict() const {
......@@ -20,13 +21,28 @@ void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_sta
// Look up the corresponding tensor in the input state dict using the full name
auto it = _state_dict.find(param_full_name);
if (it != _state_dict.end()) {
// Assert dtype matches
if (param->dtype() != it->second->dtype()) {
throw std::runtime_error(
"dtype mismatch for parameter '" + param_full_name + "': "
"expected "
+ std::to_string(static_cast<int>(param->dtype())) + ", got " + std::to_string(static_cast<int>(it->second->dtype())));
}
param->copy_from(it->second);
}
}
}
void Module::load_parameter(const std::string &name, const Tensor &param) {
parameters_[name]->copy_from(param);
auto existing_param = parameters_[name];
// Assert dtype matches
if (existing_param->dtype() != param->dtype()) {
throw std::runtime_error(
"dtype mismatch for parameter '" + name + "': "
"expected "
+ std::to_string(static_cast<int>(existing_param->dtype())) + ", got " + std::to_string(static_cast<int>(param->dtype())));
}
existing_param->copy_from(param);
}
void Module::load_parameter_from_blob(const std::string &name, const void *data) {
......
......@@ -6,21 +6,22 @@
namespace infinicore::nn {
RMSNorm::RMSNorm(size_t normalized_shape, double eps, const Device &device)
RMSNorm::RMSNorm(size_t normalized_shape, double eps, const DataType &dtype, const Device &device)
: normalized_shape_(normalized_shape),
eps_(eps) {
eps_(eps),
dtype_(dtype) {
device_ = device;
// Initialize parameter using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({normalized_shape}, DataType::F32, device));
INFINICORE_NN_PARAMETER_INIT(weight, ({normalized_shape}, dtype_, device));
// Initialize weight to ones (standard practice for RMSNorm)
auto ones_tensor = Tensor::ones({normalized_shape}, DataType::F32, device);
auto ones_tensor = Tensor::ones({normalized_shape}, dtype_, device);
weight_->copy_from(ones_tensor);
spdlog::debug("Created RMSNorm module: normalized_shape={}, eps={}",
normalized_shape, eps);
spdlog::debug("Created RMSNorm module: normalized_shape={}, eps={}, dtype={}",
normalized_shape, eps, static_cast<int>(dtype_));
}
Tensor RMSNorm::forward(const Tensor &x) const {
......@@ -37,7 +38,7 @@ Tensor RMSNorm::forward(const Tensor &x) const {
}
std::string RMSNorm::extra_repr() const {
return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ")";
return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
}
} // namespace infinicore::nn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment