Commit 69c1c352 authored by Ceng23333's avatar Ceng23333
Browse files

feat: implement neural network module system with PyTorch-like API

- Implement core modules: Linear, Embedding, RMSNorm
- Add PyTorch-like macros for module and parameter definition
  - INFINICORE_NN_MODULE for single module declaration
  - INFINICORE_NN_MODULE_VEC for module vectors
  - INFINICORE_NN_PARAMETER for parameter declaration
  - Corresponding INIT macros for initialization
- Implement hierarchical module system with dynamic path generation
- Add state_dict() and load_state_dict() support
- Refactor module design: protected registration methods, removed path_ member
- Add comprehensive test suite including TinyLlama integration
- All parameters are protected with public accessors
parent 99e19cc8
#pragma once
#include "infinicore/nn.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/tensor.hpp"
#pragma once
#include "nn/embedding.hpp"
#include "nn/linear.hpp"
#include "nn/rmsnorm.hpp"
#pragma once
#include "module.hpp"
#include "../ops.hpp"
#include <optional>
namespace infinicore::nn {
/**
* @brief Embedding layer that maps indices to dense vectors
*
* A simple lookup table that stores embeddings of a fixed dictionary and size.
* This module is often used to store word embeddings and retrieve them using indices.
* The input to the module is a tensor of indices, and the output is the corresponding
* embedding vectors.
*
* Similar to PyTorch's nn.Embedding:
* https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
*
* Example:
* @code
* // Create embedding: 10000 words, 300-dimensional embeddings
* auto embedding = Embedding(10000, 300);
*
* // Input: tensor of indices [batch_size, seq_len]
* auto indices = Tensor::from_data({2, 5}, {3, 5, 12, 8, 99, 0, 1, 45, 67, 23});
*
* // Output: [batch_size, seq_len, embedding_dim] = [2, 5, 300]
* auto embeddings = embedding.forward(indices);
* @endcode
*/
class Embedding : public Module {
public:
/**
* @brief Construct an Embedding layer
*
* @param num_embeddings Size of the dictionary of embeddings (vocabulary size)
* @param embedding_dim The size of each embedding vector
* @param padding_idx If specified, the entries at padding_idx do not contribute to gradient
* and the embedding vector at padding_idx is not updated during training
* @param dtype Data type for the embedding weights (default: DataType::F32)
* @param device Device to create the embedding weight on
*/
Embedding(size_t num_embeddings,
size_t embedding_dim,
std::optional<int64_t> padding_idx = std::nullopt,
const DataType &dtype = DataType::F32,
const Device &device = Device());
/**
* @brief Forward pass: lookup embeddings for given indices
*
* @param indices Tensor containing indices into the embedding matrix.
* Can be any shape (*), typically [batch_size] or [batch_size, seq_len]
* @return Tensor containing the embedding vectors.
* Shape: (*, embedding_dim) where * matches the input shape
*
* Example:
* Input shape: [2, 3] -> Output shape: [2, 3, embedding_dim]
* Input shape: [10] -> Output shape: [10, embedding_dim]
*/
Tensor forward(const Tensor &indices) const;
// Module information
size_t num_embeddings() const { return num_embeddings_; }
size_t embedding_dim() const { return embedding_dim_; }
std::optional<int64_t> padding_idx() const { return padding_idx_; }
DataType dtype() const { return dtype_; }
// String representation
std::string extra_repr() const;
// Accessors for parameters
Tensor weight() const { return weight_; }
protected:
// Parameters
Parameter weight_;
private:
size_t num_embeddings_; // Vocabulary size
size_t embedding_dim_; // Embedding dimension
std::optional<int64_t> padding_idx_; // Optional padding index
DataType dtype_; // Data type for embedding weights
};
} // namespace infinicore::nn
#pragma once
#include "module.hpp"
#include "../ops.hpp"
namespace infinicore::nn {
class Linear : public Module {
public:
Linear(size_t in_features, size_t out_features, bool bias = true, const Device &device = Device());
// Forward pass: output = input @ weight.T + bias
Tensor forward(Tensor &input) const;
// Forward pass with residual connection (InfiniLM-style)
// output = input @ weight.T + bias + residual
Tensor forward(Tensor &input, Tensor &residual) const;
// Module information
size_t in_features() const { return in_features_; }
size_t out_features() const { return out_features_; }
bool has_bias() const { return has_bias_; }
// String representation
std::string extra_repr() const;
// Accessors for parameters
Tensor weight() const { return weight_; }
Tensor bias() const { return bias_; }
protected:
// Parameters
Parameter weight_;
Parameter bias_;
private:
// Helper method for common forward computation
Tensor compute_linear(Tensor &input) const;
size_t in_features_;
size_t out_features_;
bool has_bias_;
};
} // namespace infinicore::nn
#pragma once
#include "parameter.hpp"
#include "../tensor.hpp"
#include <unordered_map>
#include <type_traits>
#include <vector>
namespace infinicore::nn {
class Module {
public:
Module() = default;
const std::unordered_map<std::string, Parameter> &state_dict() const;
void load_state_dict(const std::unordered_map<std::string, Tensor> &_state_dict);
......@@ -15,35 +20,118 @@ public:
void load_parameter_from_blob(const std::string &name, const void *data);
protected:
Tensor register_parameter(const std::string &name, Parameter param);
// Add an existing submodule to this module's hierarchy
// Template parameter M must be a type derived from Module
// Returns the submodule for convenience (allows method chaining)
template <typename M>
std::shared_ptr<M> add_module(const std::string &name, std::shared_ptr<M> submodule) {
// Ensure M is derived from Module (compile-time check)
static_assert(std::is_base_of<Module, M>::value,
"Template parameter M must be derived from infinicore::nn::Module");
// Store in the submodules map (std::shared_ptr<M> automatically converts to std::shared_ptr<Module>)
submodules_[name] = submodule;
for (auto &p : submodule->parameters_) {
parameters_[name + "." + p.first] = p.second;
}
return submodule;
}
// Create and register a new submodule by constructing it with the given arguments
// Template parameter M must be a type derived from Module
// Args are forwarded to M's constructor
template <typename M, typename... Args>
std::shared_ptr<M> register_module(const std::string &name, Args &&...args) {
// Ensure M is derived from Module (compile-time check)
static_assert(std::is_base_of<Module, M>::value,
"Template parameter M must be derived from infinicore::nn::Module");
// Construct the submodule
auto submodule = std::make_shared<M>(std::forward<Args>(args)...);
return add_module(name, submodule);
}
// Create and register multiple submodules of the same type
// Each submodule is named as "name.0", "name.1", etc.
// Template parameter M must be a type derived from Module
template <typename M, typename... Args>
std::vector<std::shared_ptr<M>> register_modules(size_t layers, const std::string &name, Args &&...args) {
auto submodules = std::vector<std::shared_ptr<M>>(layers);
for (size_t i = 0; i < layers; i++) {
register_module<M>(name + "." + std::to_string(i), std::forward<Args>(args)...);
std::vector<std::shared_ptr<M>> register_modules(size_t count, const std::string &name, Args &&...args) {
static_assert(std::is_base_of<Module, M>::value,
"Template parameter M must be derived from infinicore::nn::Module");
std::vector<std::shared_ptr<M>> modules;
modules.reserve(count);
for (size_t i = 0; i < count; i++) {
modules.push_back(register_module<M>(name + "." + std::to_string(i), std::forward<Args>(args)...));
}
return submodules;
return modules;
}
protected:
Device device_;
std::unordered_map<std::string, std::shared_ptr<Module>> submodules_;
std::unordered_map<std::string, Parameter> parameters_;
private:
void collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix = "") const;
};
// ============================================================================
// PyTorch-like Macros for Convenient Module Registration
// ============================================================================
/**
* @brief Register submodules with automatic name inference from variable name
*
* Usage:
* @code
* class MyModel : public Module {
* protected:
* INFINICORE_NN_MODULE(Linear, layer1);
* INFINICORE_NN_MODULE(Linear, layer2);
* INFINICORE_NN_MODULE_VEC(Linear, layers);
* INFINICORE_NN_PARAMETER(scaling_factor);
*
* public:
* MyModel() {
* INFINICORE_NN_MODULE_INIT(layer1, 128, 64);
* INFINICORE_NN_MODULE_INIT(layer2, 64, 32);
* INFINICORE_NN_MODULE_VEC_INIT(layers, 3, Linear, 32, 16);
* INFINICORE_NN_PARAMETER_INIT(scaling_factor, ({1}, DataType::F32, Device()));
* }
* };
* @endcode
*/
// Declare a single module member variable
#define INFINICORE_NN_MODULE(ModuleType, name) \
std::shared_ptr<ModuleType> name##_
// Declare a vector of modules member variable
#define INFINICORE_NN_MODULE_VEC(ModuleType, name) \
std::vector<std::shared_ptr<ModuleType>> name##_
// Initialize a module in constructor
#define INFINICORE_NN_MODULE_INIT(name, ...) \
name##_ = this->register_module<std::remove_reference<decltype(*name##_)>::type>(#name, ##__VA_ARGS__)
// Initialize a vector of modules in constructor
// Usage: INFINICORE_NN_MODULE_VEC_INIT(layers, count, ModuleType, ctor_args...)
// Example: INFINICORE_NN_MODULE_VEC_INIT(layers, 3, Linear, 128, 64)
#define INFINICORE_NN_MODULE_VEC_INIT(name, count, ModuleType, ...) \
name##_ = this->register_modules<ModuleType>(count, #name, ##__VA_ARGS__)
// Declare a parameter member variable
#define INFINICORE_NN_PARAMETER(name) \
Parameter name##_
// Initialize a parameter in constructor
// Usage: INFINICORE_NN_PARAMETER_INIT(name, (shape, dtype, device))
// Example: INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device))
#define INFINICORE_NN_PARAMETER_INIT(name, args) \
name##_ = Parameter args; \
this->register_parameter(#name, name##_)
} // namespace infinicore::nn
......@@ -5,6 +5,8 @@
namespace infinicore::nn {
class Parameter : public Tensor {
public:
Parameter();
Parameter(const Shape &shape,
const DataType &dtype,
const Device &device);
......
#pragma once
#include "module.hpp"
#include "../ops.hpp"
namespace infinicore::nn {
/**
* @brief Root Mean Square Layer Normalization (RMSNorm)
*
* Applies Root Mean Square Layer Normalization over the last dimension.
* Unlike LayerNorm, RMSNorm doesn't subtract mean and doesn't use bias.
*
* Formula: y = (x / RMS(x)) * weight
* where RMS(x) = sqrt(mean(x^2) + eps)
*
* Used in LLaMA, Galactica, and other modern language models as a
* simpler and faster alternative to LayerNorm.
*
* Example:
* @code
* // Create RMSNorm for hidden size 4096
* auto norm = RMSNorm(4096);
*
* // Input: [batch, seq_len, hidden_size]
* auto input = Tensor::randn({2, 10, 4096});
*
* // Output: [batch, seq_len, hidden_size]
* auto output = norm.forward(input);
* @endcode
*/
class RMSNorm : public Module {
public:
/**
* @brief Construct a RMSNorm layer
*
* @param normalized_shape Size of the feature dimension to normalize (typically hidden_size)
* @param eps Small constant for numerical stability (default: 1e-6)
* @param device Device to create the weight on
*/
RMSNorm(size_t normalized_shape,
double eps = 1e-6,
const Device &device = Device());
/**
* @brief Forward pass: apply RMSNorm
*
* @param x Input tensor of shape (*, normalized_shape) where * is any number of dimensions
* @return Normalized tensor with same shape as input
*
* The normalization is applied over the last dimension.
* For example:
* Input: [batch, seq_len, hidden_size] -> normalize over hidden_size
* Input: [batch, hidden_size] -> normalize over hidden_size
*/
Tensor forward(const Tensor &x) const;
// Module information
size_t normalized_shape() const { return normalized_shape_; }
double eps() const { return eps_; }
// String representation
std::string extra_repr() const;
// Accessors for parameters
Tensor weight() const { return weight_; }
protected:
// Parameters
Parameter weight_;
private:
size_t normalized_shape_; // Size of the feature dimension
double eps_; // Epsilon for numerical stability
};
} // namespace infinicore::nn
#include "memory_test.h"
#include "test_nn_module.h"
#include "test_runner.h"
#include "test_tensor_destructor.h"
#include <iostream>
#include <memory>
......@@ -13,6 +15,7 @@ struct ParsedArgs {
bool run_memory_leak = true;
bool run_performance = true;
bool run_stress = true;
bool run_module = false;
int num_threads = 4;
int iterations = 1000;
};
......@@ -23,7 +26,7 @@ void printUsage() {
<< std::endl
<< "Options:" << std::endl
<< " --<device> Specify the device type (default: cpu)" << std::endl
<< " --test <name> Run specific test (basic|concurrency|exception|leak|performance|stress|all)" << std::endl
<< " --test <name> Run specific test (basic|concurrency|exception|leak|performance|stress|module|all)" << std::endl
<< " --threads <num> Number of threads for concurrency tests (default: 4)" << std::endl
<< " --iterations <num> Number of iterations for stress tests (default: 1000)" << std::endl
<< " --help Show this help message" << std::endl
......@@ -46,6 +49,7 @@ void printUsage() {
<< " leak - Memory leak detection tests" << std::endl
<< " performance - Performance and benchmark tests" << std::endl
<< " stress - Stress tests with high load" << std::endl
<< " module - Neural network module tests" << std::endl
<< " all - Run all tests (default)" << std::endl
<< std::endl;
exit(EXIT_SUCCESS);
......@@ -84,7 +88,7 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
}
std::string test_name = argv[++i];
args.run_basic = args.run_concurrency = args.run_exception_safety = args.run_memory_leak = args.run_performance = args.run_stress = false;
args.run_basic = args.run_concurrency = args.run_exception_safety = args.run_memory_leak = args.run_performance = args.run_stress = args.run_module = false;
if (test_name == "basic") {
args.run_basic = true;
......@@ -98,8 +102,10 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
args.run_performance = true;
} else if (test_name == "stress") {
args.run_stress = true;
} else if (test_name == "module") {
args.run_module = true;
} else if (test_name == "all") {
args.run_basic = args.run_concurrency = args.run_exception_safety = args.run_memory_leak = args.run_performance = args.run_stress = true;
args.run_basic = args.run_concurrency = args.run_exception_safety = args.run_memory_leak = args.run_performance = args.run_stress = args.run_module = true;
} else {
std::cerr << "Error: Unknown test name: " << test_name << std::endl;
exit(EXIT_FAILURE);
......@@ -157,7 +163,7 @@ int main(int argc, char *argv[]) {
spdlog::debug("Creating test runner");
// Create test runner
infinicore::test::MemoryTestRunner runner;
infinicore::test::InfiniCoreTestRunner runner;
spdlog::debug("Test runner created successfully");
// Add tests based on arguments
......@@ -171,6 +177,12 @@ int main(int argc, char *argv[]) {
spdlog::debug("TensorDestructorTest added successfully");
}
if (args.run_module) {
spdlog::debug("Adding NNModuleTest");
runner.addTest(std::make_unique<infinicore::test::NNModuleTest>());
spdlog::debug("NNModuleTest added successfully");
}
if (args.run_concurrency) {
runner.addTest(std::make_unique<infinicore::test::ConcurrencyTest>());
}
......@@ -196,13 +208,29 @@ int main(int argc, char *argv[]) {
auto results = runner.runAllTests();
spdlog::debug("All tests completed");
// Count results
// Count results and collect failed tests
size_t passed = 0, failed = 0;
std::vector<infinicore::test::TestResult> failed_tests;
for (const auto &result : results) {
if (result.passed) {
passed++;
} else {
failed++;
failed_tests.push_back(result);
}
}
// Print list of failed tests if any
if (!failed_tests.empty()) {
std::cout << "\n==============================================\n"
<< "❌ FAILED TESTS\n"
<< "==============================================" << std::endl;
for (const auto &test : failed_tests) {
std::cout << " • " << test.test_name;
if (!test.error_message.empty()) {
std::cout << "\n Error: " << test.error_message;
}
std::cout << "\n Duration: " << test.duration.count() << "μs" << std::endl;
}
}
......@@ -217,7 +245,7 @@ int main(int argc, char *argv[]) {
// Exit with appropriate code
if (failed > 0) {
std::cout << "\n❌ Some tests failed. Please review the output above." << std::endl;
std::cout << "\n❌ Some tests failed. Please review the failed tests list above." << std::endl;
return EXIT_FAILURE;
} else {
std::cout << "\n✅ All tests passed!" << std::endl;
......
......@@ -2,72 +2,17 @@
#define __INFINICORE_MEMORY_TEST_H__
#include "../infinicore/context/allocators/memory_allocator.hpp"
#include "test_runner.h"
#include <atomic>
#include <cassert>
#include <chrono>
#include <exception>
#include <future>
#include <infinicore.hpp>
#include <iostream>
#include <memory>
#include <mutex>
#include <queue>
#include <spdlog/spdlog.h>
#include <thread>
#include <unordered_map>
#include <vector>
namespace infinicore::test {
// Test result structure
struct TestResult {
std::string test_name;
bool passed;
std::string error_message;
std::chrono::microseconds duration;
TestResult(const std::string &name, bool pass, const std::string &error = "",
std::chrono::microseconds dur = std::chrono::microseconds(0))
: test_name(name), passed(pass), error_message(error), duration(dur) {}
};
// Test framework base class
class MemoryTestFramework {
public:
virtual ~MemoryTestFramework() = default;
virtual TestResult run() = 0;
virtual std::string getName() const = 0;
protected:
void logTestStart(const std::string &test_name) {
std::cout << "[TEST] Starting: " << test_name << std::endl;
}
void logTestResult(const TestResult &result) {
std::cout << "[TEST] " << (result.passed ? "PASSED" : "FAILED")
<< ": " << result.test_name;
if (!result.passed && !result.error_message.empty()) {
std::cout << " - " << result.error_message;
}
std::cout << " (Duration: " << result.duration.count() << "μs)" << std::endl;
}
template <typename Func>
TestResult measureTime(const std::string &test_name, Func &&func) {
auto start = std::chrono::high_resolution_clock::now();
try {
bool result = func();
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
return TestResult(test_name, result, "", duration);
} catch (const std::exception &e) {
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
return TestResult(test_name, false, e.what(), duration);
}
}
};
// Mock allocator for testing exception safety
class MockAllocator : public infinicore::MemoryAllocator {
public:
......@@ -149,13 +94,13 @@ private:
};
// Test categories
class BasicMemoryTest : public MemoryTestFramework {
class BasicMemoryTest : public TestFramework {
public:
TestResult run() override;
std::string getName() const override { return "BasicMemoryTest"; }
};
class ConcurrencyTest : public MemoryTestFramework {
class ConcurrencyTest : public TestFramework {
public:
TestResult run() override;
std::string getName() const override { return "ConcurrencyTest"; }
......@@ -166,7 +111,7 @@ private:
TestResult testMemoryAllocationRace();
};
class ExceptionSafetyTest : public MemoryTestFramework {
class ExceptionSafetyTest : public TestFramework {
public:
TestResult run() override;
std::string getName() const override { return "ExceptionSafetyTest"; }
......@@ -177,7 +122,7 @@ private:
TestResult testContextSwitchException();
};
class MemoryLeakTest : public MemoryTestFramework {
class MemoryLeakTest : public TestFramework {
public:
TestResult run() override;
std::string getName() const override { return "MemoryLeakTest"; }
......@@ -188,7 +133,7 @@ private:
TestResult testExceptionLeakDetection();
};
class PerformanceTest : public MemoryTestFramework {
class PerformanceTest : public TestFramework {
public:
TestResult run() override;
std::string getName() const override { return "PerformanceTest"; }
......@@ -199,7 +144,7 @@ private:
TestResult testMemoryCopyPerformance();
};
class StressTest : public MemoryTestFramework {
class StressTest : public TestFramework {
public:
TestResult run() override;
std::string getName() const override { return "StressTest"; }
......@@ -210,67 +155,6 @@ private:
TestResult testCrossDeviceStress();
};
// Test runner
class MemoryTestRunner {
public:
void addTest(std::unique_ptr<MemoryTestFramework> test) {
tests_.push_back(std::move(test));
}
std::vector<TestResult> runAllTests() {
std::vector<TestResult> results;
std::cout << "==============================================\n"
<< "InfiniCore Memory Management Test Suite\n"
<< "==============================================" << std::endl;
for (auto &test : tests_) {
logTestStart(test->getName());
TestResult result = test->run();
logTestResult(result);
results.push_back(result);
}
printSummary(results);
return results;
}
private:
std::vector<std::unique_ptr<MemoryTestFramework>> tests_;
void logTestStart(const std::string &test_name) {
std::cout << "\n[SUITE] Running: " << test_name << std::endl;
}
void logTestResult(const TestResult &result) {
std::cout << "[SUITE] " << (result.passed ? "PASSED" : "FAILED")
<< ": " << result.test_name << std::endl;
}
void printSummary(const std::vector<TestResult> &results) {
size_t passed = 0, failed = 0;
std::chrono::microseconds total_time(0);
for (const auto &result : results) {
if (result.passed) {
passed++;
} else {
failed++;
}
total_time += result.duration;
}
std::cout << "\n==============================================\n"
<< "Test Summary\n"
<< "==============================================\n"
<< "Total Tests: " << results.size() << "\n"
<< "Passed: " << passed << "\n"
<< "Failed: " << failed << "\n"
<< "Total Time: " << total_time.count() << "μs\n"
<< "==============================================" << std::endl;
}
};
} // namespace infinicore::test
#endif // __INFINICORE_MEMORY_TEST_H__
This diff is collapsed.
#ifndef __INFINICORE_TEST_NN_MODULE_H__
#define __INFINICORE_TEST_NN_MODULE_H__
#include "infinicore/device.hpp"
#include "infinicore/nn/embedding.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/parameter.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "test_runner.h"
#include <algorithm>
#include <cmath>
#include <fstream>
#include <iostream>
#include <memory>
#include <sstream>
#include <sys/stat.h>
#include <vector>
namespace infinicore::test {
// Simple test module that mimics torch.nn.Linear
class MockLinearModule : public infinicore::nn::Module {
public:
MockLinearModule(int input_size, int output_size, const infinicore::Device &device)
: input_size_(input_size), output_size_(output_size), device_(device) {
// Initialize weight parameter (similar to torch.nn.Linear.weight)
register_parameter("weight",
infinicore::nn::Parameter({static_cast<size_t>(output_size), static_cast<size_t>(input_size)}, infinicore::DataType::F32, device));
// Initialize bias parameter (similar to torch.nn.Linear.bias)
register_parameter("bias",
infinicore::nn::Parameter({static_cast<size_t>(output_size)}, infinicore::DataType::F32, device));
}
// Simple forward pass (conceptual - would need actual matrix operations)
infinicore::Tensor forward(const infinicore::Tensor &input) {
// This is a placeholder - in a real implementation, you'd do matrix multiplication
// For testing purposes, we'll just return the input
return input;
}
infinicore::Tensor get_weight() const {
auto state_dict = this->state_dict();
auto it = state_dict.find("weight");
if (it != state_dict.end()) {
return it->second;
}
throw std::runtime_error("Weight parameter not found");
}
infinicore::Tensor get_bias() const {
auto state_dict = this->state_dict();
auto it = state_dict.find("bias");
if (it != state_dict.end()) {
return it->second;
}
throw std::runtime_error("Bias parameter not found");
}
private:
int input_size_;
int output_size_;
infinicore::Device device_;
};
class NNModuleTest : public TestFramework {
public:
TestResult run() override;
std::string getName() const override { return "NNModuleTest"; }
private:
TestResult testBasicModuleCreation(); // Merged: creation, parameters, state_dict, load_state_dict
TestResult testLoadStateDict(); // Advanced: hierarchical modules
TestResult testModuleHierarchy(); // Demonstrates proper hierarchical construction pattern
TestResult testParameterLoading(); // Test blob parameter loading
TestResult testModuleLinear(); // Comprehensive Linear module test
TestResult testModuleEmbedding(); // Embedding module test
TestResult testModuleRMSNorm(); // RMSNorm module test
TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation
};
} // namespace infinicore::test
#endif // __INFINICORE_TEST_NN_MODULE_H__
#ifndef __INFINICORE_TEST_RUNNER_H__
#define __INFINICORE_TEST_RUNNER_H__
#include <chrono>
#include <cmath>
#include <exception>
#include <infinicore.hpp>
#include <iostream>
#include <memory>
#include <spdlog/spdlog.h>
#include <sstream>
#include <string>
#include <vector>
namespace infinicore::test {
// ============================================================================
// Common Test Utilities
// ============================================================================
/**
* @brief Compare two InfiniCore tensors elementwise with tolerance
*
* Compares two tensors for approximate equality, useful for testing numerical
* computations where exact equality is not expected due to floating-point arithmetic.
*
* @param actual The actual tensor result
* @param expected The expected tensor result
* @param rtol Relative tolerance (default: 1e-5)
* @param atol Absolute tolerance (default: 1e-5)
* @return true if tensors are approximately equal, false otherwise
*
* @note Currently only supports F32 dtype
* @note Tensors are automatically moved to CPU for comparison
* @note Reports up to 10 mismatches with detailed coordinates
*/
inline bool tensorsAllClose(const infinicore::Tensor &actual,
const infinicore::Tensor &expected,
double rtol = 1e-5,
double atol = 1e-5) {
if (actual->shape() != expected->shape()) {
spdlog::error("Shape mismatch: actual vs expected");
return false;
}
auto cpu = infinicore::Device(infinicore::Device::Type::CPU, 0);
auto a_cpu = actual->to(cpu);
a_cpu = a_cpu->contiguous();
auto b_cpu = expected->to(cpu);
b_cpu = b_cpu->contiguous();
if (a_cpu->dtype() != b_cpu->dtype()) {
spdlog::error("DType mismatch");
return false;
}
// Only support F32 in this test
if (a_cpu->dtype() != infinicore::DataType::F32) {
spdlog::error("Unsupported dtype for comparison; only F32 supported in test");
return false;
}
size_t n = a_cpu->numel();
const auto &shape = a_cpu->shape();
// Precompute strides for index -> coords mapping
std::vector<size_t> stride(shape.size(), 1);
for (int i = static_cast<int>(shape.size()) - 2; i >= 0; --i) {
stride[i] = stride[i + 1] * shape[i + 1];
}
const float *ap = reinterpret_cast<const float *>(a_cpu->data());
const float *bp = reinterpret_cast<const float *>(b_cpu->data());
size_t max_diff_index = 0;
float max_diff = 0.0f;
size_t num_fail_reported = 0;
for (size_t i = 0; i < n; ++i) {
float av = ap[i];
float bv = bp[i];
float diff = std::fabs(av - bv);
if (diff > static_cast<float>(atol + rtol * std::fabs(bv))) {
if (diff > max_diff) {
max_diff = diff;
max_diff_index = i;
}
if (num_fail_reported < 10) {
// Convert linear index to coordinates
std::vector<size_t> coords(shape.size(), 0);
size_t t = i;
for (size_t d = 0; d < shape.size(); ++d) {
coords[d] = t / stride[d];
t -= coords[d] * stride[d];
}
std::stringstream ss;
ss << "[";
for (size_t d = 0; d < coords.size(); ++d) {
ss << coords[d] << (d + 1 < coords.size() ? "," : "]");
}
double tol = atol + rtol * std::fabs(bv);
spdlog::error("Mismatch at index {} coords {}: actual={} expected={} diff={} tol={}",
i, ss.str(), av, bv, diff, tol);
num_fail_reported++;
}
}
}
if (num_fail_reported > 0) {
// Report summary with max diff
std::vector<size_t> coords(shape.size(), 0);
size_t t = max_diff_index;
for (size_t d = 0; d < shape.size(); ++d) {
coords[d] = t / stride[d];
t -= coords[d] * stride[d];
}
std::stringstream ss;
ss << "[";
for (size_t d = 0; d < coords.size(); ++d) {
ss << coords[d] << (d + 1 < coords.size() ? "," : "]");
}
spdlog::error("Max diff {} at linear index {} coords {}", max_diff, max_diff_index, ss.str());
return false;
}
return true;
}
// ============================================================================
// Test Framework Classes
// ============================================================================
// Test result structure
struct TestResult {
std::string test_name;
bool passed;
std::string error_message;
std::chrono::microseconds duration;
TestResult(const std::string &name, bool pass, const std::string &error = "",
std::chrono::microseconds dur = std::chrono::microseconds(0))
: test_name(name), passed(pass), error_message(error), duration(dur) {}
};
// Test framework base class
class TestFramework {
public:
virtual ~TestFramework() = default;
virtual TestResult run() = 0;
virtual std::string getName() const = 0;
protected:
void logTestStart(const std::string &test_name) {
std::cout << "[TEST] Starting: " << test_name << std::endl;
}
void logTestResult(const TestResult &result) {
std::cout << "[TEST] " << (result.passed ? "PASSED" : "FAILED")
<< ": " << result.test_name;
if (!result.passed && !result.error_message.empty()) {
std::cout << " - " << result.error_message;
}
std::cout << " (Duration: " << result.duration.count() << "μs)" << std::endl;
}
template <typename Func>
TestResult measureTime(const std::string &test_name, Func &&func) {
auto start = std::chrono::high_resolution_clock::now();
try {
bool result = func();
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
return TestResult(test_name, result, "", duration);
} catch (const std::exception &e) {
auto end = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
return TestResult(test_name, false, e.what(), duration);
}
}
};
// Test runner
class InfiniCoreTestRunner {
public:
void addTest(std::unique_ptr<TestFramework> test) {
tests_.push_back(std::move(test));
}
std::vector<TestResult> runAllTests() {
std::vector<TestResult> results;
std::cout << "==============================================\n"
<< "InfiniCore Test Suite\n"
<< "==============================================" << std::endl;
for (auto &test : tests_) {
logTestStart(test->getName());
TestResult result = test->run();
logTestResult(result);
results.push_back(result);
}
printSummary(results);
return results;
}
private:
std::vector<std::unique_ptr<TestFramework>> tests_;
void logTestStart(const std::string &test_name) {
std::cout << "\n[SUITE] Running: " << test_name << std::endl;
}
void logTestResult(const TestResult &result) {
std::cout << "[SUITE] " << (result.passed ? "PASSED" : "FAILED")
<< ": " << result.test_name << std::endl;
}
void printSummary(const std::vector<TestResult> &results) {
size_t passed = 0, failed = 0;
std::chrono::microseconds total_time(0);
std::vector<TestResult> failed_tests;
for (const auto &result : results) {
if (result.passed) {
passed++;
} else {
failed++;
failed_tests.push_back(result);
}
total_time += result.duration;
}
// Print list of failed tests if any
if (!failed_tests.empty()) {
std::cout << "\n==============================================\n"
<< "❌ FAILED TESTS\n"
<< "==============================================" << std::endl;
for (const auto &test : failed_tests) {
std::cout << " • " << test.test_name;
if (!test.error_message.empty()) {
std::cout << "\n Error: " << test.error_message;
}
std::cout << "\n Duration: " << test.duration.count() << "μs" << std::endl;
}
}
std::cout << "\n==============================================\n"
<< "Test Summary\n"
<< "==============================================\n"
<< "Total Tests: " << results.size() << "\n"
<< "Passed: " << passed << "\n"
<< "Failed: " << failed << "\n"
<< "Total Time: " << total_time.count() << "μs\n"
<< "==============================================" << std::endl;
}
};
} // namespace infinicore::test
#endif // __INFINICORE_TEST_RUNNER_H__
......@@ -4,13 +4,14 @@
#include "infinicore/context/context.hpp"
#include "infinicore/tensor.hpp"
#include "memory_test.h"
#include "test_runner.h"
#include <iostream>
#include <memory>
#include <vector>
namespace infinicore::test {
class TensorDestructorTest : public MemoryTestFramework {
class TensorDestructorTest : public TestFramework {
public:
TestResult run() override;
std::string getName() const override { return "TensorDestructorTest"; }
......
#include "infinicore/nn/embedding.hpp"
#include "infinicore/context/context.hpp"
#include "infinicore/ops.hpp"
#include <spdlog/spdlog.h>
#include <stdexcept>
namespace infinicore::nn {
Embedding::Embedding(size_t num_embeddings,
size_t embedding_dim,
std::optional<int64_t> padding_idx,
const DataType &dtype,
const Device &device)
: num_embeddings_(num_embeddings),
embedding_dim_(embedding_dim),
padding_idx_(padding_idx),
dtype_(dtype) {
device_ = device;
// Validate padding_idx
if (padding_idx_.has_value()) {
int64_t idx = padding_idx_.value();
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings)) {
throw std::invalid_argument(
"padding_idx must be within num_embeddings range, got " + std::to_string(idx) + " for num_embeddings=" + std::to_string(num_embeddings));
}
}
// Initialize parameter using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({num_embeddings, embedding_dim}, dtype_, device));
// If padding_idx is specified, initialize that row to zeros
if (padding_idx_.has_value()) {
// TODO: Set weight[padding_idx] to zeros
// This would require a slice operation
}
spdlog::debug("Created Embedding module: num_embeddings={}, embedding_dim={}, dtype={}, padding_idx={}",
num_embeddings, embedding_dim, static_cast<int>(dtype_),
padding_idx_.has_value() ? std::to_string(padding_idx_.value()) : "None");
}
Tensor Embedding::forward(const Tensor &indices) const {
// Get the shape of indices
auto indices_shape = indices->shape();
// Output shape: indices_shape + [embedding_dim]
std::vector<size_t> output_shape = indices_shape;
output_shape.push_back(embedding_dim_);
// Create output tensor on the same device as weight
auto out = Tensor::empty(output_shape, weight_->dtype(), weight_->device());
// Flatten indices for sequential row copies
auto cpu_device = Device(Device::Type::CPU, 0);
auto indices_cpu = indices->to(cpu_device)->contiguous();
const auto *indices_data = reinterpret_cast<const int64_t *>(indices_cpu->data());
// Calculate total number of lookups
size_t num_lookups = 1;
for (auto dim : indices_shape) {
num_lookups *= dim;
}
const size_t row_bytes = embedding_dim_ * (weight_->dtype() == DataType::F32 ? sizeof(float) : weight_->dtype() == DataType::BF16 ? sizeof(uint16_t)
: sizeof(float));
// Source and destination base pointers
auto *weight_base = weight_->data();
auto *out_base = out->data();
if (weight_->device().getType() == Device::Type::CPU) {
// CPU path: memcpy row by row
for (size_t i = 0; i < num_lookups; ++i) {
int64_t idx = indices_data[i];
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings_)) {
throw std::out_of_range(
"Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")");
}
std::memcpy(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
}
} else {
// Device path: use stream-ordered D2D copies
for (size_t i = 0; i < num_lookups; ++i) {
int64_t idx = indices_data[i];
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings_)) {
throw std::out_of_range(
"Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")");
}
context::memcpyD2D(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
}
}
return out;
}
std::string Embedding::extra_repr() const {
std::string repr = "Embedding(num_embeddings=" + std::to_string(num_embeddings_) + ", embedding_dim=" + std::to_string(embedding_dim_) + ", dtype=" + std::to_string(static_cast<int>(dtype_));
if (padding_idx_.has_value()) {
repr += ", padding_idx=" + std::to_string(padding_idx_.value());
}
repr += ")";
return repr;
}
} // namespace infinicore::nn
#include "infinicore/nn/linear.hpp"
#include "infinicore/ops.hpp"
#include <spdlog/spdlog.h>
namespace infinicore::nn {
Linear::Linear(size_t in_features, size_t out_features, bool bias, const Device &device)
: in_features_(in_features),
out_features_(out_features),
has_bias_(bias) {
device_ = device;
// Initialize parameters using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({out_features, in_features}, DataType::F32, device));
// Register bias parameter if requested
if (bias) {
INFINICORE_NN_PARAMETER_INIT(bias, ({out_features}, DataType::F32, device));
} else {
bias_ = Parameter(); // Default constructed empty parameter
}
spdlog::debug("Created Linear module: in_features={}, out_features={}, bias={}",
in_features, out_features, bias);
}
Tensor Linear::compute_linear(Tensor &input) const {
// Create output tensor with shape [batch_size, out_features]
auto output_shape = input->shape();
output_shape[output_shape.size() - 1] = out_features_;
auto output = Tensor::empty(output_shape, input->dtype(), input->device());
// Transpose weight: [out_features, in_features] -> [in_features, out_features]
auto weight_t = weight_->permute({1, 0});
if (has_bias_) {
// Broadcast bias to output shape
size_t ndim_diff = output->ndim() - 1;
std::vector<Stride> strides(ndim_diff, 0);
strides.push_back(bias_->stride(0));
auto bias_view = bias_->as_strided(output->shape(), strides);
// First set output to bias (broadcasted)
infinicore::op::rearrange_(output, bias_view);
// Compute matmul result separately, then add to output
auto matmul_result = infinicore::op::matmul(input, weight_t);
infinicore::op::add_(output, output, matmul_result);
} else {
// No bias: just compute output = input @ weight_t
infinicore::op::matmul_(output, input, weight_t);
}
return output;
}
Tensor Linear::forward(Tensor &input) const {
return compute_linear(input);
}
Tensor Linear::forward(Tensor &input, Tensor &residual) const {
auto output = compute_linear(input);
// Add residual: output = output + residual
infinicore::op::add_(output, output, residual);
return output;
}
std::string Linear::extra_repr() const {
return "Linear(in_features=" + std::to_string(in_features_) + ", out_features=" + std::to_string(out_features_) + ", bias=" + (has_bias_ ? "true" : "false") + ")";
}
} // namespace infinicore::nn
......@@ -2,12 +2,26 @@
namespace infinicore::nn {
const std::unordered_map<std::string, Parameter> &Module::state_dict() const {
return parameters_;
static std::unordered_map<std::string, Parameter> result;
result.clear();
collect_all_parameters(result, "");
return result;
}
void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_state_dict) {
for (auto &p : parameters_) {
load_parameter(p.first, p.second);
// Collect all parameters from this module and its submodules with their full hierarchical names
std::unordered_map<std::string, Parameter> all_params;
collect_all_parameters(all_params, "");
// For each parameter in this module hierarchy, load from the state dict
for (auto &[param_full_name, param] : all_params) {
// Look up the corresponding tensor in the input state dict using the full name
auto it = _state_dict.find(param_full_name);
if (it != _state_dict.end()) {
param->copy_from(it->second);
}
}
}
......@@ -25,4 +39,18 @@ Tensor Module::register_parameter(const std::string &name, Parameter param) {
return param;
}
void Module::collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix) const {
// Add direct parameters with the given prefix
for (const auto &[param_name, param] : parameters_) {
std::string full_name = prefix.empty() ? param_name : prefix + "." + param_name;
all_params[full_name] = param;
}
// Recursively collect parameters from submodules with extended prefix
for (const auto &[sub_name, submodule] : submodules_) {
std::string sub_prefix = prefix.empty() ? sub_name : prefix + "." + sub_name;
submodule->collect_all_parameters(all_params, sub_prefix);
}
}
} // namespace infinicore::nn
......@@ -5,6 +5,10 @@
#include <cstring>
namespace infinicore::nn {
Parameter::Parameter()
: Tensor(Tensor::empty({}, DataType::F32, Device(Device::Type::CPU, 0), false)) {
}
Parameter::Parameter(
const Shape &shape,
const DataType &dtype,
......
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/ops.hpp"
#include <cmath>
#include <spdlog/spdlog.h>
#include <stdexcept>
namespace infinicore::nn {
RMSNorm::RMSNorm(size_t normalized_shape, double eps, const Device &device)
: normalized_shape_(normalized_shape),
eps_(eps) {
device_ = device;
// Initialize parameter using macro
INFINICORE_NN_PARAMETER_INIT(weight, ({normalized_shape}, DataType::F32, device));
// Initialize weight to ones (standard practice for RMSNorm)
auto ones_tensor = Tensor::ones({normalized_shape}, DataType::F32, device);
weight_->copy_from(ones_tensor);
spdlog::debug("Created RMSNorm module: normalized_shape={}, eps={}",
normalized_shape, eps);
}
Tensor RMSNorm::forward(const Tensor &x) const {
// Validate input shape - last dimension should match normalized_shape
auto input_shape = x->shape();
if (input_shape.empty() || input_shape.back() != normalized_shape_) {
throw std::invalid_argument(
"Input last dimension " + std::to_string(input_shape.back()) + " doesn't match normalized_shape " + std::to_string(normalized_shape_));
}
// Delegate to InfiniCore op (backed by InfiniRT/InfiniOP)
// y = RMSNorm(x, weight, eps)
return op::rms_norm(x, weight_, static_cast<float>(eps_));
}
std::string RMSNorm::extra_repr() const {
return "RMSNorm(normalized_shape=" + std::to_string(normalized_shape_) + ", eps=" + std::to_string(eps_) + ")";
}
} // namespace infinicore::nn
......@@ -86,6 +86,7 @@ target("infinicore-test")
add_files(os.projectdir().."/src/infinicore/context/*/*.cc")
add_files(os.projectdir().."/src/infinicore/tensor/*.cc")
add_files(os.projectdir().."/src/infinicore/ops/*/*.cc")
add_files(os.projectdir().."/src/infinicore/nn/*.cc")
add_files(os.projectdir().."/src/infinicore-test/*.cc")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment