Commit 777b3233 authored by Ceng23333's avatar Ceng23333
Browse files

do assertion at load_parameter && update Module definition with macros


Signed-off-by: default avatarCeng23333 <441651826@qq.com>
parent 69c1c352
...@@ -75,7 +75,7 @@ public: ...@@ -75,7 +75,7 @@ public:
protected: protected:
// Parameters // Parameters
Parameter weight_; INFINICORE_NN_PARAMETER(weight);
private: private:
size_t num_embeddings_; // Vocabulary size size_t num_embeddings_; // Vocabulary size
......
...@@ -7,7 +7,7 @@ namespace infinicore::nn { ...@@ -7,7 +7,7 @@ namespace infinicore::nn {
class Linear : public Module { class Linear : public Module {
public: public:
Linear(size_t in_features, size_t out_features, bool bias = true, const Device &device = Device()); Linear(size_t in_features, size_t out_features, bool bias = true, const DataType &dtype = DataType::F32, const Device &device = Device());
// Forward pass: output = input @ weight.T + bias // Forward pass: output = input @ weight.T + bias
Tensor forward(Tensor &input) const; Tensor forward(Tensor &input) const;
...@@ -20,6 +20,7 @@ public: ...@@ -20,6 +20,7 @@ public:
size_t in_features() const { return in_features_; } size_t in_features() const { return in_features_; }
size_t out_features() const { return out_features_; } size_t out_features() const { return out_features_; }
bool has_bias() const { return has_bias_; } bool has_bias() const { return has_bias_; }
DataType dtype() const { return dtype_; }
// String representation // String representation
std::string extra_repr() const; std::string extra_repr() const;
...@@ -30,8 +31,8 @@ public: ...@@ -30,8 +31,8 @@ public:
protected: protected:
// Parameters // Parameters
Parameter weight_; INFINICORE_NN_PARAMETER(weight);
Parameter bias_; INFINICORE_NN_PARAMETER(bias);
private: private:
// Helper method for common forward computation // Helper method for common forward computation
...@@ -40,6 +41,7 @@ private: ...@@ -40,6 +41,7 @@ private:
size_t in_features_; size_t in_features_;
size_t out_features_; size_t out_features_;
bool has_bias_; bool has_bias_;
DataType dtype_;
}; };
} // namespace infinicore::nn } // namespace infinicore::nn
...@@ -125,13 +125,13 @@ private: ...@@ -125,13 +125,13 @@ private:
// Declare a parameter member variable // Declare a parameter member variable
#define INFINICORE_NN_PARAMETER(name) \ #define INFINICORE_NN_PARAMETER(name) \
Parameter name##_ infinicore::nn::Parameter name##_
// Initialize a parameter in constructor // Initialize a parameter in constructor
// Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device)) // Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device))
// Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device)) // Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device))
#define INFINICORE_NN_PARAMETER_INIT(name, args) \ #define INFINICORE_NN_PARAMETER_INIT(name, args) \
name##_ = Parameter args; \ name##_ = infinicore::nn::Parameter args; \
this->register_parameter(#name, name##_) this->register_parameter(#name, name##_)
} // namespace infinicore::nn } // namespace infinicore::nn
...@@ -36,10 +36,12 @@ public: ...@@ -36,10 +36,12 @@ public:
* *
* @param normalized_shape Size of the feature dimension to normalize (typically hidden_size) * @param normalized_shape Size of the feature dimension to normalize (typically hidden_size)
* @param eps Small constant for numerical stability (default: 1e-6) * @param eps Small constant for numerical stability (default: 1e-6)
* @param dtype Data type for the weight (default: DataType::F32)
* @param device Device to create the weight on * @param device Device to create the weight on
*/ */
RMSNorm(size_t normalized_shape, RMSNorm(size_t normalized_shape,
double eps = 1e-6, double eps = 1e-6,
const DataType &dtype = DataType::F32,
const Device &device = Device()); const Device &device = Device());
/** /**
...@@ -58,6 +60,7 @@ public: ...@@ -58,6 +60,7 @@ public:
// Module information // Module information
size_t normalized_shape() const { return normalized_shape_; } size_t normalized_shape() const { return normalized_shape_; }
double eps() const { return eps_; } double eps() const { return eps_; }
DataType dtype() const { return dtype_; }
// String representation // String representation
std::string extra_repr() const; std::string extra_repr() const;
...@@ -67,11 +70,12 @@ public: ...@@ -67,11 +70,12 @@ public:
protected: protected:
// Parameters // Parameters
Parameter weight_; INFINICORE_NN_PARAMETER(weight);
private: private:
size_t normalized_shape_; // Size of the feature dimension size_t normalized_shape_; // Size of the feature dimension
double eps_; // Epsilon for numerical stability double eps_; // Epsilon for numerical stability
DataType dtype_; // Data type for weight
}; };
} // namespace infinicore::nn } // namespace infinicore::nn
...@@ -394,7 +394,7 @@ TestResult NNModuleTest::testModuleLinear() { ...@@ -394,7 +394,7 @@ TestResult NNModuleTest::testModuleLinear() {
try { try {
// Test with bias // Test with bias
spdlog::info("Testing Linear module with bias (8->4 features)"); spdlog::info("Testing Linear module with bias (8->4 features)");
infinicore::nn::Linear m1(8, 4, true, infinicore::Device()); infinicore::nn::Linear m1(8, 4, true);
auto sd1 = m1.state_dict(); auto sd1 = m1.state_dict();
if (sd1.find("weight") == sd1.end()) { if (sd1.find("weight") == sd1.end()) {
spdlog::error("weight missing"); spdlog::error("weight missing");
...@@ -440,7 +440,7 @@ TestResult NNModuleTest::testModuleLinear() { ...@@ -440,7 +440,7 @@ TestResult NNModuleTest::testModuleLinear() {
// Test without bias // Test without bias
spdlog::info("Testing Linear module without bias (16->3 features)"); spdlog::info("Testing Linear module without bias (16->3 features)");
infinicore::nn::Linear m2(16, 3, false, infinicore::Device()); infinicore::nn::Linear m2(16, 3, false);
auto sd2 = m2.state_dict(); auto sd2 = m2.state_dict();
if (sd2.find("weight") == sd2.end()) { if (sd2.find("weight") == sd2.end()) {
spdlog::error("weight missing (no-bias)"); spdlog::error("weight missing (no-bias)");
...@@ -834,7 +834,7 @@ TestResult NNModuleTest::testModuleRMSNorm() { ...@@ -834,7 +834,7 @@ TestResult NNModuleTest::testModuleRMSNorm() {
// Test 1: Basic RMSNorm creation // Test 1: Basic RMSNorm creation
spdlog::info("Test 1: Basic RMSNorm creation (hidden_size=768)"); spdlog::info("Test 1: Basic RMSNorm creation (hidden_size=768)");
infinicore::nn::RMSNorm norm1(768, 1e-6, infinicore::Device()); infinicore::nn::RMSNorm norm1(768);
auto state1 = norm1.state_dict(); auto state1 = norm1.state_dict();
if (state1.find("weight") == state1.end()) { if (state1.find("weight") == state1.end()) {
...@@ -925,8 +925,8 @@ TestResult NNModuleTest::testModuleRMSNorm() { ...@@ -925,8 +925,8 @@ TestResult NNModuleTest::testModuleRMSNorm() {
// Test 7: Different hidden sizes // Test 7: Different hidden sizes
spdlog::info("Test 7: Testing different hidden sizes"); spdlog::info("Test 7: Testing different hidden sizes");
infinicore::nn::RMSNorm norm_small(128, 1e-5, infinicore::Device()); infinicore::nn::RMSNorm norm_small(128, 1e-5);
infinicore::nn::RMSNorm norm_large(4096, 1e-6, infinicore::Device()); infinicore::nn::RMSNorm norm_large(4096);
auto input_small = infinicore::Tensor::ones({2, 128}, infinicore::DataType::F32, infinicore::Device()); auto input_small = infinicore::Tensor::ones({2, 128}, infinicore::DataType::F32, infinicore::Device());
auto output_small = norm_small.forward(input_small); auto output_small = norm_small.forward(input_small);
...@@ -956,7 +956,130 @@ TestResult NNModuleTest::testModuleRMSNorm() { ...@@ -956,7 +956,130 @@ TestResult NNModuleTest::testModuleRMSNorm() {
}); });
} }
// Test 8: Comprehensive Tiny-Llama model test (construction + weight loading + validation) // Test 8: Dtype assertion test
TestResult NNModuleTest::testDtypeAssertion() {
return measureTime("DtypeAssertionTest", [this]() {
try {
spdlog::info("Testing dtype assertions when loading parameters");
// Test 1: Successful load with matching dtype
spdlog::info("Test 1: Successful load with matching dtype (F32)");
infinicore::nn::Linear linear1(8, 4, true);
auto matching_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::F32, infinicore::Device());
auto matching_bias = infinicore::Tensor::ones({4}, infinicore::DataType::F32, infinicore::Device());
std::unordered_map<std::string, infinicore::Tensor> matching_state;
matching_state.emplace("weight", matching_weight);
matching_state.emplace("bias", matching_bias);
// This should succeed without throwing
linear1.load_state_dict(matching_state);
spdlog::debug("✓ Matching dtype load succeeded");
// Test 2: Failed load with mismatched dtype (load_parameter)
spdlog::info("Test 2: Failed load_parameter with mismatched dtype");
infinicore::nn::Linear linear2(8, 4, true);
auto mismatched_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::BF16, infinicore::Device());
bool exception_thrown = false;
try {
linear2.load_parameter("weight", mismatched_weight);
} catch (const std::runtime_error &e) {
exception_thrown = true;
std::string error_msg = e.what();
if (error_msg.find("dtype mismatch") == std::string::npos) {
spdlog::error("Exception message doesn't contain 'dtype mismatch'");
return false;
}
spdlog::debug("✓ Mismatched dtype exception caught: {}", error_msg);
}
if (!exception_thrown) {
spdlog::error("Expected exception for dtype mismatch in load_parameter");
return false;
}
// Test 3: Failed load with mismatched dtype (load_state_dict)
spdlog::info("Test 3: Failed load_state_dict with mismatched dtype");
infinicore::nn::Embedding embedding1(100, 64);
auto mismatched_embed_weight = infinicore::Tensor::ones({100, 64}, infinicore::DataType::BF16, infinicore::Device());
std::unordered_map<std::string, infinicore::Tensor> mismatched_state;
mismatched_state.emplace("weight", mismatched_embed_weight);
exception_thrown = false;
try {
embedding1.load_state_dict(mismatched_state);
} catch (const std::runtime_error &e) {
exception_thrown = true;
std::string error_msg = e.what();
if (error_msg.find("dtype mismatch") == std::string::npos) {
spdlog::error("Exception message doesn't contain 'dtype mismatch'");
return false;
}
if (error_msg.find("weight") == std::string::npos) {
spdlog::error("Exception message doesn't contain parameter name 'weight'");
return false;
}
spdlog::debug("✓ Mismatched dtype exception caught: {}", error_msg);
}
if (!exception_thrown) {
spdlog::error("Expected exception for dtype mismatch in load_state_dict");
return false;
}
// Test 4: Failed load with mismatched dtype (RMSNorm)
spdlog::info("Test 4: Failed load_state_dict with mismatched dtype (RMSNorm)");
infinicore::nn::RMSNorm norm1(768);
auto mismatched_norm_weight = infinicore::Tensor::ones({768}, infinicore::DataType::BF16, infinicore::Device());
std::unordered_map<std::string, infinicore::Tensor> mismatched_norm_state;
mismatched_norm_state.emplace("weight", mismatched_norm_weight);
exception_thrown = false;
try {
norm1.load_state_dict(mismatched_norm_state);
} catch (const std::runtime_error &e) {
exception_thrown = true;
std::string error_msg = e.what();
if (error_msg.find("dtype mismatch") == std::string::npos) {
spdlog::error("Exception message doesn't contain 'dtype mismatch'");
return false;
}
spdlog::debug("✓ Mismatched dtype exception caught for RMSNorm: {}", error_msg);
}
if (!exception_thrown) {
spdlog::error("Expected exception for dtype mismatch in RMSNorm load_state_dict");
return false;
}
// Test 5: Successful load with different module dtypes
spdlog::info("Test 5: Successful load with BF16 dtype (module created with BF16)");
infinicore::nn::Linear linear3(8, 4, true, infinicore::DataType::BF16);
auto bf16_weight = infinicore::Tensor::ones({4, 8}, infinicore::DataType::BF16, infinicore::Device());
auto bf16_bias = infinicore::Tensor::ones({4}, infinicore::DataType::BF16, infinicore::Device());
std::unordered_map<std::string, infinicore::Tensor> bf16_state;
bf16_state.emplace("weight", bf16_weight);
bf16_state.emplace("bias", bf16_bias);
// This should succeed
linear3.load_state_dict(bf16_state);
spdlog::debug("✓ BF16 dtype load succeeded");
spdlog::info("All dtype assertion tests passed!");
return true;
} catch (const std::exception &e) {
spdlog::error("Exception in testDtypeAssertion: {}", e.what());
return false;
}
});
}
// Test 9: Comprehensive Tiny-Llama model test (construction + weight loading + validation)
TestResult NNModuleTest::testTinyLlamaConstruction() { TestResult NNModuleTest::testTinyLlamaConstruction() {
return measureTime("TinyLlamaModelTest", [this]() { return measureTime("TinyLlamaModelTest", [this]() {
try { try {
...@@ -1007,10 +1130,10 @@ TestResult NNModuleTest::testTinyLlamaConstruction() { ...@@ -1007,10 +1130,10 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
INFINICORE_NN_MODULE(infinicore::nn::Linear, o_proj); INFINICORE_NN_MODULE(infinicore::nn::Linear, o_proj);
SelfAttn(size_t hidden_size, size_t kv_dim, const infinicore::Device &device) { SelfAttn(size_t hidden_size, size_t kv_dim, const infinicore::Device &device) {
INFINICORE_NN_MODULE_INIT(q_proj, hidden_size, hidden_size, false, device); INFINICORE_NN_MODULE_INIT(q_proj, hidden_size, hidden_size, false, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(k_proj, hidden_size, kv_dim, false, device); INFINICORE_NN_MODULE_INIT(k_proj, hidden_size, kv_dim, false, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(v_proj, hidden_size, kv_dim, false, device); INFINICORE_NN_MODULE_INIT(v_proj, hidden_size, kv_dim, false, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(o_proj, hidden_size, hidden_size, false, device); INFINICORE_NN_MODULE_INIT(o_proj, hidden_size, hidden_size, false, infinicore::DataType::F32, device);
} }
}; };
...@@ -1021,9 +1144,9 @@ TestResult NNModuleTest::testTinyLlamaConstruction() { ...@@ -1021,9 +1144,9 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
INFINICORE_NN_MODULE(infinicore::nn::Linear, down_proj); INFINICORE_NN_MODULE(infinicore::nn::Linear, down_proj);
MLP(size_t hidden_size, size_t intermediate_size, const infinicore::Device &device) { MLP(size_t hidden_size, size_t intermediate_size, const infinicore::Device &device) {
INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size, intermediate_size, false, device); INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size, intermediate_size, false, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(up_proj, hidden_size, intermediate_size, false, device); INFINICORE_NN_MODULE_INIT(up_proj, hidden_size, intermediate_size, false, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size, hidden_size, false, device); INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size, hidden_size, false, infinicore::DataType::F32, device);
} }
}; };
...@@ -1036,9 +1159,9 @@ TestResult NNModuleTest::testTinyLlamaConstruction() { ...@@ -1036,9 +1159,9 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
Block(const TinyLlamaConfig &cfg, const infinicore::Device &device) { Block(const TinyLlamaConfig &cfg, const infinicore::Device &device) {
size_t kv_dim = cfg.hidden_size * cfg.num_key_value_heads / cfg.num_attention_heads; size_t kv_dim = cfg.hidden_size * cfg.num_key_value_heads / cfg.num_attention_heads;
INFINICORE_NN_MODULE_INIT(input_layernorm, cfg.hidden_size, cfg.rms_norm_eps, device); INFINICORE_NN_MODULE_INIT(input_layernorm, cfg.hidden_size, cfg.rms_norm_eps, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(self_attn, cfg.hidden_size, kv_dim, device); INFINICORE_NN_MODULE_INIT(self_attn, cfg.hidden_size, kv_dim, device);
INFINICORE_NN_MODULE_INIT(post_attention_layernorm, cfg.hidden_size, cfg.rms_norm_eps, device); INFINICORE_NN_MODULE_INIT(post_attention_layernorm, cfg.hidden_size, cfg.rms_norm_eps, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_INIT(mlp, cfg.hidden_size, cfg.intermediate_size, device); INFINICORE_NN_MODULE_INIT(mlp, cfg.hidden_size, cfg.intermediate_size, device);
} }
}; };
...@@ -1051,7 +1174,7 @@ TestResult NNModuleTest::testTinyLlamaConstruction() { ...@@ -1051,7 +1174,7 @@ TestResult NNModuleTest::testTinyLlamaConstruction() {
TinyLlamaModel(const TinyLlamaConfig &config, const infinicore::Device &device) { TinyLlamaModel(const TinyLlamaConfig &config, const infinicore::Device &device) {
INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size, std::nullopt, infinicore::DataType::F32, device); INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size, std::nullopt, infinicore::DataType::F32, device);
INFINICORE_NN_MODULE_VEC_INIT(layers, config.num_hidden_layers, Block, config, device); INFINICORE_NN_MODULE_VEC_INIT(layers, config.num_hidden_layers, Block, config, device);
INFINICORE_NN_MODULE_INIT(norm, config.hidden_size, config.rms_norm_eps, device); INFINICORE_NN_MODULE_INIT(norm, config.hidden_size, config.rms_norm_eps, infinicore::DataType::F32, device);
} }
}; };
...@@ -1259,6 +1382,7 @@ TestResult NNModuleTest::run() { ...@@ -1259,6 +1382,7 @@ TestResult NNModuleTest::run() {
results.push_back(testModuleLinear()); // Linear module comprehensive test results.push_back(testModuleLinear()); // Linear module comprehensive test
results.push_back(testModuleEmbedding()); // Embedding module test results.push_back(testModuleEmbedding()); // Embedding module test
results.push_back(testModuleRMSNorm()); // RMSNorm module test results.push_back(testModuleRMSNorm()); // RMSNorm module test
results.push_back(testDtypeAssertion()); // Dtype assertion test
results.push_back(testTinyLlamaConstruction()); // Comprehensive: TinyLlama model test results.push_back(testTinyLlamaConstruction()); // Comprehensive: TinyLlama model test
// Check if all tests passed // Check if all tests passed
......
...@@ -21,16 +21,21 @@ namespace infinicore::test { ...@@ -21,16 +21,21 @@ namespace infinicore::test {
// Simple test module that mimics torch.nn.Linear // Simple test module that mimics torch.nn.Linear
class MockLinearModule : public infinicore::nn::Module { class MockLinearModule : public infinicore::nn::Module {
public: public:
// Declare parameters using macros (torch-like style)
INFINICORE_NN_PARAMETER(weight);
INFINICORE_NN_PARAMETER(bias);
MockLinearModule(int input_size, int output_size, const infinicore::Device &device) MockLinearModule(int input_size, int output_size, const infinicore::Device &device)
: input_size_(input_size), output_size_(output_size), device_(device) { : input_size_(input_size), output_size_(output_size), device_(device) {
// Initialize parameters using macros
// Initialize weight parameter (similar to torch.nn.Linear.weight) INFINICORE_NN_PARAMETER_INIT(weight,
register_parameter("weight", ({static_cast<size_t>(output_size), static_cast<size_t>(input_size)},
infinicore::nn::Parameter({static_cast<size_t>(output_size), static_cast<size_t>(input_size)}, infinicore::DataType::F32, device)); infinicore::DataType::F32,
device));
// Initialize bias parameter (similar to torch.nn.Linear.bias) INFINICORE_NN_PARAMETER_INIT(bias,
register_parameter("bias", ({static_cast<size_t>(output_size)},
infinicore::nn::Parameter({static_cast<size_t>(output_size)}, infinicore::DataType::F32, device)); infinicore::DataType::F32,
device));
} }
// Simple forward pass (conceptual - would need actual matrix operations) // Simple forward pass (conceptual - would need actual matrix operations)
...@@ -77,6 +82,7 @@ private: ...@@ -77,6 +82,7 @@ private:
TestResult testModuleLinear(); // Comprehensive Linear module test TestResult testModuleLinear(); // Comprehensive Linear module test
TestResult testModuleEmbedding(); // Embedding module test TestResult testModuleEmbedding(); // Embedding module test
TestResult testModuleRMSNorm(); // RMSNorm module test TestResult testModuleRMSNorm(); // RMSNorm module test
TestResult testDtypeAssertion(); // Test dtype assertions when loading parameters
TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation
}; };
......
...@@ -4,25 +4,26 @@ ...@@ -4,25 +4,26 @@
namespace infinicore::nn { namespace infinicore::nn {
Linear::Linear(size_t in_features, size_t out_features, bool bias, const Device &device) Linear::Linear(size_t in_features, size_t out_features, bool bias, const DataType &dtype, const Device &device)
: in_features_(in_features), : in_features_(in_features),
out_features_(out_features), out_features_(out_features),
has_bias_(bias) { has_bias_(bias),
dtype_(dtype) {
device_ = device; device_ = device;
// Initialize parameters using macro // Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device)); INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, dtype_, device));
// Register bias parameter if requested // Register bias parameter if requested
if (bias) { if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, DataType::F32, device)); INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, dtype_, device));
} else { } else {
bias_ = Parameter(); // Default constructed empty parameter bias_ = Parameter(); // Default constructed empty parameter
} }
spdlog::debug("Created Linear module: in_features={}, out_features={}, bias={}", spdlog::debug("Created Linear module: in_features={}, out_features={}, bias={}, dtype={}",
in_features, out_features, bias); in_features, out_features, bias, static_cast<int>(dtype_));
} }
Tensor Linear::compute_linear(Tensor &input) const { Tensor Linear::compute_linear(Tensor &input) const {
...@@ -41,12 +42,9 @@ Tensor Linear::compute_linear(Tensor &input) const { ...@@ -41,12 +42,9 @@ Tensor Linear::compute_linear(Tensor &input) const {
strides.push_back(bias_->stride(0)); strides.push_back(bias_->stride(0));
auto bias_view = bias_->as_strided(output->shape(), strides); auto bias_view = bias_->as_strided(output->shape(), strides);
// First set output to bias (broadcasted)
infinicore::op::rearrange_(output, bias_view);
// Compute matmul result separately, then add to output // Compute matmul result separately, then add to output
auto matmul_result = infinicore::op::matmul(input, weight_t); infinicore::op::matmul_(output, input, weight_t);
infinicore::op::add_(output, output, matmul_result); infinicore::op::add_(output, output, bias_view);
} else { } else {
// No bias: just compute output = input @ weight_t // No bias: just compute output = input @ weight_t
infinicore::op::matmul_(output, input, weight_t); infinicore::op::matmul_(output, input, weight_t);
...@@ -69,7 +67,7 @@ Tensor Linear::forward(Tensor &input, Tensor &residual) const { ...@@ -69,7 +67,7 @@ Tensor Linear::forward(Tensor &input, Tensor &residual) const {
} }
std::string Linear::extra_repr() const { std::string Linear::extra_repr() const {
return "Linear(in_features=" + std::to_string(in_features_) + ", out_features=" + std::to_string(out_features_) + ", bias=" + (has_bias_ ? "true" : "false") + ")"; return "Linear(in_features=" + std::to_string(in_features_) + ", out_features=" + std::to_string(out_features_) + ", bias=" + (has_bias_ ? "true" : "false") + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
} }
} // namespace infinicore::nn } // namespace infinicore::nn
#include "infinicore/nn/module.hpp" #include "infinicore/nn/module.hpp"
#include <stdexcept>
namespace infinicore::nn { namespace infinicore::nn {
const std::unordered_map<std::string, Parameter> &Module::state_dict() const { const std::unordered_map<std::string, Parameter> &Module::state_dict() const {
...@@ -20,13 +21,28 @@ void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_sta ...@@ -20,13 +21,28 @@ void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_sta
// Look up the corresponding tensor in the input state dict using the full name // Look up the corresponding tensor in the input state dict using the full name
auto it = _state_dict.find(param_full_name); auto it = _state_dict.find(param_full_name);
if (it != _state_dict.end()) { if (it != _state_dict.end()) {
// Assert dtype matches
if (param->dtype() != it->second->dtype()) {
throw std::runtime_error(
"dtype mismatch for parameter '" + param_full_name + "': "
"expected "
+ std::to_string(static_cast<int>(param->dtype())) + ", got " + std::to_string(static_cast<int>(it->second->dtype())));
}
param->copy_from(it->second); param->copy_from(it->second);
} }
} }
} }
void Module::load_parameter(const std::string &name, const Tensor &param) { void Module::load_parameter(const std::string &name, const Tensor &param) {
parameters_[name]->copy_from(param); auto existing_param = parameters_[name];
// Assert dtype matches
if (existing_param->dtype() != param->dtype()) {
throw std::runtime_error(
"dtype mismatch for parameter '" + name + "': "
"expected "
+ std::to_string(static_cast<int>(existing_param->dtype())) + ", got " + std::to_string(static_cast<int>(param->dtype())));
}
existing_param->copy_from(param);
} }
void Module::load_parameter_from_blob(const std::string &name, const void *data) { void Module::load_parameter_from_blob(const std::string &name, const void *data) {
......
...@@ -6,21 +6,22 @@ ...@@ -6,21 +6,22 @@
namespace infinicore::nn { namespace infinicore::nn {
RMSNorm::RMSNorm(size_t normalized_shape, double eps, const Device &device) RMSNorm::RMSNorm(size_t normalized_shape, double eps, const DataType &dtype, const Device &device)
: normalized_shape_(normalized_shape), : normalized_shape_(normalized_shape),
eps_(eps) { eps_(eps),
dtype_(dtype) {
device_ = device; device_ = device;
// Initialize parameter using macro // Initialize parameter using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({normalized_shape}, DataType::F32, device)); INFINICORE_NN_PARAMETER_INIT(weight, ({normalized_shape}, dtype_, device));
// Initialize weight to ones (standard practice for RMSNorm) // Initialize weight to ones (standard practice for RMSNorm)
auto ones_tensor = Tensor::ones({normalized_shape}, DataType::F32, device); auto ones_tensor = Tensor::ones({normalized_shape}, dtype_, device);
weight_->copy_from(ones_tensor); weight_->copy_from(ones_tensor);
spdlog::debug("Created RMSNorm module: normalized_shape={}, eps={}", spdlog::debug("Created RMSNorm module: normalized_shape={}, eps={}, dtype={}",
normalized_shape, eps); normalized_shape, eps, static_cast<int>(dtype_));
} }
Tensor RMSNorm::forward(const Tensor &x) const { Tensor RMSNorm::forward(const Tensor &x) const {
...@@ -37,7 +38,7 @@ Tensor RMSNorm::forward(const Tensor &x) const { ...@@ -37,7 +38,7 @@ Tensor RMSNorm::forward(const Tensor &x) const {
} }
std::string RMSNorm::extra_repr() const { std::string RMSNorm::extra_repr() const {
return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ")"; return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
} }
} // namespace infinicore::nn } // namespace infinicore::nn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment