Commit 289d4002 authored by Ceng23333's avatar Ceng23333
Browse files

issue/545 nn::module::Rope


Signed-off-by: default avatarCeng23333 <441651826@qq.com>
parent 2e5b2342
......@@ -23,6 +23,8 @@ public:
protected:
Tensor register_parameter(const std::string &name, Parameter param);
Tensor register_buffer(const std::string &name, Parameter buffer);
// Add an existing submodule to this module's hierarchy
// Template parameter M must be a type derived from Module
// Returns the submodule for convenience (allows method chaining)
......@@ -72,6 +74,7 @@ protected:
protected:
Device device_;
std::unordered_map<std::string, std::shared_ptr<Module>> submodules_;
std::unordered_map<std::string, Parameter> buffers_;
std::unordered_map<std::string, Parameter> parameters_;
private:
......@@ -134,4 +137,15 @@ private:
name##_ = infinicore::nn::Parameter args; \
this->register_parameter(#name, name##_)
// Declare a buffer member variable
#define INFINICORE_NN_BUFFER(name) \
infinicore::nn::Parameter name##_
// Initialize a buffer in constructor
// Usage: INFINICORE_NN_BUFFER_INIT(name, (shape, dtype, device))
// Example: INFINICORE_NN_BUFFER_INIT(cache, ({max_seq_len, head_dim}, DataType::F32, device))
#define INFINICORE_NN_BUFFER_INIT(name, args) \
name##_ = infinicore::nn::Parameter args; \
this->register_buffer(#name, name##_)
} // namespace infinicore::nn
#pragma once
#include "module.hpp"
#include "../context/context.hpp"
#include "../tensor.hpp"
#include <memory>
namespace infinicore::nn {
class RoPE : public Module {
public:
/**
* @brief RoPE algorithm type
*/
enum class Algo {
GPT_J = 0, // GPT-J style RoPE algorithm (Interleave even and odd dimensions)
GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
};
/**
* @brief Construct a RoPE layer
*
* @param head_dim Dimension of each attention head (must be even)
* @param max_seq_len Maximum sequence length for pre-computed cache
* @param theta Base frequency for rotary embeddings (default: 10000.0)
* @param algo RoPE algorithm type (default: Algo::GPT_J)
* @param dtype Data type for sin/cos cache (default: DataType::F32)
* @param device Device to create the cache on
*/
RoPE(size_t head_dim,
size_t max_seq_len,
double theta = 10000.0,
Algo algo = Algo::GPT_J,
const DataType &dtype = DataType::F32,
const Device &device = Device());
/**
* @brief Forward pass: apply RoPE to a tensor
*
* @param x Input tensor of shape (..., head_dim) where ... is any number of dimensions
* @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len]
* @return Rotated tensor with same shape as input
*
* Applies rotary position embeddings to the input tensor.
* For attention mechanisms, call this method separately for query and key tensors.
*
* Common input shapes:
* - [batch, num_heads, seq_len, head_dim]
* - [batch, seq_len, num_heads, head_dim]
* - [seq_len, head_dim]
*/
Tensor forward(const Tensor &x, const Tensor &pos) const;
// Module information
size_t head_dim() const { return head_dim_; }
size_t max_seq_len() const { return max_seq_len_; }
double theta() const { return theta_; }
Algo algo() const { return algo_; }
DataType dtype() const { return dtype_; }
// String representation
std::string extra_repr() const;
protected:
// Buffers (sin and cos cache tables) - not exposed in state_dict
INFINICORE_NN_BUFFER(sin_cache);
INFINICORE_NN_BUFFER(cos_cache);
private:
void initialize_cache();
size_t head_dim_; // Dimension of each attention head
size_t max_seq_len_; // Maximum sequence length
double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type
DataType dtype_; // Data type for cache tables
};
} // namespace infinicore::nn
......@@ -7,5 +7,6 @@
#include "ops/ones.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
#include "ops/rope.hpp"
#include "ops/silu.hpp"
#include "ops/swiglu.hpp"
#pragma once
#include "../device.hpp"
#include "../tensor.hpp"
#include "../nn/rope.hpp"
#include "common/op.hpp"
namespace infinicore::op {
class RoPE {
public:
using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo);
static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
static common::OpDispatcher<schema> &dispatcher();
};
// Internal function
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
// Public API that uses infinicore::nn::RoPE::Algo
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
} // namespace infinicore::op
......@@ -141,12 +141,8 @@ ParsedArgs parseArgs(int argc, char *argv[]) {
int main(int argc, char *argv[]) {
try {
// Initialize spdlog for debugging
spdlog::set_level(spdlog::level::debug);
spdlog::info("Starting InfiniCore Memory Management Test Suite");
ParsedArgs args = parseArgs(argc, argv);
spdlog::debug("Arguments parsed successfully");
spdlog::info("Arguments parsed successfully");
std::cout << "==============================================\n"
<< "InfiniCore Memory Management Test Suite\n"
......@@ -156,31 +152,25 @@ int main(int argc, char *argv[]) {
<< "Iterations: " << args.iterations << "\n"
<< "==============================================" << std::endl;
spdlog::debug("About to initialize InfiniCore context");
spdlog::info("About to initialize InfiniCore context");
// Initialize InfiniCore context
infinicore::context::setDevice(infinicore::Device(static_cast<infinicore::Device::Type>(args.device_type), 0));
spdlog::debug("InfiniCore context initialized successfully");
spdlog::info("InfiniCore context initialized successfully");
spdlog::debug("Creating test runner");
spdlog::info("Creating test runner");
// Create test runner
infinicore::test::InfiniCoreTestRunner runner;
spdlog::debug("Test runner created successfully");
spdlog::info("Test runner created successfully");
// Add tests based on arguments
if (args.run_basic) {
spdlog::debug("Adding BasicMemoryTest");
runner.addTest(std::make_unique<infinicore::test::BasicMemoryTest>());
spdlog::debug("BasicMemoryTest added successfully");
spdlog::debug("Adding TensorDestructorTest");
runner.addTest(std::make_unique<infinicore::test::TensorDestructorTest>());
spdlog::debug("TensorDestructorTest added successfully");
}
if (args.run_module) {
spdlog::debug("Adding NNModuleTest");
runner.addTest(std::make_unique<infinicore::test::NNModuleTest>());
spdlog::debug("NNModuleTest added successfully");
}
if (args.run_concurrency) {
......@@ -203,10 +193,10 @@ int main(int argc, char *argv[]) {
runner.addTest(std::make_unique<infinicore::test::StressTest>());
}
spdlog::debug("About to run all tests");
spdlog::info("About to run all tests");
// Run all tests
auto results = runner.runAllTests();
spdlog::debug("All tests completed");
spdlog::info("All tests completed");
// Count results and collect failed tests
size_t passed = 0, failed = 0;
......
......@@ -7,7 +7,9 @@ namespace infinicore::test {
TestResult NNModuleTest::testBasicModuleCreation() {
return measureTime("BasicModuleOperations", [this]() {
try {
spdlog::info("=== Testing Basic Module Operations ===");
spdlog::info("==========================================");
spdlog::info("Testing Basic Module Operations");
spdlog::info("==========================================");
// Test 1a: Module creation and parameter registration
spdlog::info("Test 1a: Module creation and parameter registration");
......@@ -117,7 +119,9 @@ TestResult NNModuleTest::testBasicModuleCreation() {
TestResult NNModuleTest::testLoadStateDict() {
return measureTime("AdvancedLoadStateDict", [this]() {
try {
spdlog::info("=== Testing Advanced load_state_dict with Hierarchical Modules ===");
spdlog::info("==========================================");
spdlog::info("Testing Advanced load_state_dict with Hierarchical Modules");
spdlog::info("==========================================");
// Test: Deep nesting (2-level hierarchy)
spdlog::info("Test 4: Testing load_state_dict with 2-level deep nesting");
......@@ -304,13 +308,11 @@ TestResult NNModuleTest::testModuleHierarchy() {
// layer1.sublayer.weight, layer1.sublayer.bias,
// layer1.layer2.sublayer.weight, layer1.layer2.sublayer.bias
if (root_state_dict.size() < 6) {
std::cout << "Error: Expected at least 6 parameters in hierarchy, got "
<< root_state_dict.size() << std::endl;
spdlog::error("Error: Expected at least 6 parameters in hierarchy, got {}", root_state_dict.size());
return false;
}
std::cout << "Module hierarchy test passed. Root state dict has "
<< root_state_dict.size() << " parameters" << std::endl;
spdlog::info("Module hierarchy test passed. Root state dict has {} parameters", root_state_dict.size());
// Print the hierarchy
std::cout << "Module hierarchy:" << std::endl;
......@@ -350,7 +352,7 @@ TestResult NNModuleTest::testModuleHierarchy() {
return true;
} catch (const std::exception &e) {
std::cout << "Exception in testModuleHierarchy: " << e.what() << std::endl;
spdlog::error("Exception in testModuleHierarchy: {}", e.what());
return false;
}
});
......@@ -360,6 +362,9 @@ TestResult NNModuleTest::testModuleHierarchy() {
TestResult NNModuleTest::testParameterLoading() {
return measureTime("ParameterLoading", [this]() {
try {
spdlog::info("==========================================");
spdlog::info("Testing Parameter loading from blob");
spdlog::info("==========================================");
MockLinearModule module(3, 2, infinicore::Device());
// Create test data
......@@ -370,19 +375,19 @@ TestResult NNModuleTest::testParameterLoading() {
module.load_parameter_from_blob("weight", weight_data.data());
module.load_parameter_from_blob("bias", bias_data.data());
std::cout << "Successfully loaded parameters from blob data" << std::endl;
spdlog::info("Successfully loaded parameters from blob data");
// Verify parameters exist
auto state_dict = module.state_dict();
if (state_dict.find("weight") == state_dict.end() || state_dict.find("bias") == state_dict.end()) {
std::cout << "Error: Parameters not found after loading" << std::endl;
spdlog::error("Error: Parameters not found after loading");
return false;
}
std::cout << "Parameter loading test passed" << std::endl;
spdlog::info("Parameter loading test passed");
return true;
} catch (const std::exception &e) {
std::cout << "Exception in testParameterLoading: " << e.what() << std::endl;
spdlog::error("Exception in testParameterLoading: {}", e.what());
return false;
}
});
......@@ -393,7 +398,9 @@ TestResult NNModuleTest::testModuleLinear() {
return measureTime("ModuleLinear", [this]() {
try {
// Test with bias
spdlog::info("==========================================");
spdlog::info("Testing Linear module with bias (8->4 features)");
spdlog::info("==========================================");
infinicore::nn::Linear m1(8, 4, true);
auto sd1 = m1.state_dict();
if (sd1.find("weight") == sd1.end()) {
......@@ -679,7 +686,9 @@ TestResult NNModuleTest::testModuleLinear() {
TestResult NNModuleTest::testModuleEmbedding() {
return measureTime("ModuleEmbedding", [this]() {
try {
spdlog::info("==========================================");
spdlog::info("Testing Embedding module implementation");
spdlog::info("==========================================");
// Test 1: Basic embedding creation
spdlog::info("Test 1: Basic embedding creation (vocab=100, dim=64)");
......@@ -830,7 +839,9 @@ TestResult NNModuleTest::testModuleEmbedding() {
TestResult NNModuleTest::testModuleRMSNorm() {
return measureTime("ModuleRMSNorm", [this]() {
try {
spdlog::info("==========================================");
spdlog::info("Testing RMSNorm module implementation");
spdlog::info("==========================================");
// Test 1: Basic RMSNorm creation
spdlog::info("Test 1: Basic RMSNorm creation (hidden_size=768)");
......@@ -923,8 +934,25 @@ TestResult NNModuleTest::testModuleRMSNorm() {
spdlog::debug("extra_repr test passed");
// Test 7: Different hidden sizes
spdlog::info("Test 7: Testing different hidden sizes");
// Test 7: Input validation - normalized_shape mismatch (op layer handles this)
spdlog::info("Test 7: Testing input validation - normalized_shape mismatch");
auto input_wrong_shape = infinicore::Tensor::ones({4, 512}, infinicore::DataType::F32, infinicore::Device()); // normalized_shape=512, expected 768
try {
norm1.forward(input_wrong_shape);
spdlog::error("Should have thrown exception for normalized_shape mismatch");
return false;
} catch (const std::exception &e) {
spdlog::debug("Correctly caught exception for normalized_shape mismatch (handled by op layer): {}", e.what());
} catch (...) {
spdlog::error("Caught unexpected exception type");
return false;
}
spdlog::debug("Normalized_shape mismatch validation test passed");
// Test 8: Different hidden sizes
spdlog::info("Test 8: Testing different hidden sizes");
infinicore::nn::RMSNorm norm_small(128, 1e-5);
infinicore::nn::RMSNorm norm_large(4096);
......@@ -956,11 +984,316 @@ TestResult NNModuleTest::testModuleRMSNorm() {
});
}
// Test 7.5: RoPE module test
TestResult NNModuleTest::testModuleRoPE() {
return measureTime("ModuleRoPE", [this]() {
try {
spdlog::info("==========================================");
spdlog::info("Testing RoPE module implementation");
spdlog::info("==========================================");
// Test 1: Basic RoPE creation
spdlog::info("Test 1: Basic RoPE creation (head_dim=128, max_seq_len=2048)");
infinicore::nn::RoPE rope1(128, 2048);
auto state1 = rope1.state_dict();
if (rope1.head_dim() != 128) {
spdlog::error("head_dim mismatch. Expected 128, got {}", rope1.head_dim());
return false;
}
if (rope1.max_seq_len() != 2048) {
spdlog::error("max_seq_len mismatch. Expected 2048, got {}", rope1.max_seq_len());
return false;
}
spdlog::debug("Basic RoPE creation passed");
// Test 2: Forward pass - 3D input [seq_len, n_head, head_dim]
spdlog::info("Test 2: Forward pass with 3D input [seq_len, n_head, head_dim]");
auto x_3d = infinicore::Tensor::ones({32, 32, 128}, infinicore::DataType::F32, infinicore::Device());
// Create position tensor [0, 1, 2, ..., 31]
std::vector<int32_t> pos_data(32);
for (size_t i = 0; i < 32; i++) {
pos_data[i] = static_cast<int32_t>(i);
}
auto pos = infinicore::Tensor::from_blob(pos_data.data(), {32}, infinicore::DataType::I32, infinicore::Device());
auto x_out = rope1.forward(x_3d, pos);
if (x_out->shape() != std::vector<size_t>({32, 32, 128})) {
spdlog::error("3D output shape mismatch. Expected {{32, 32, 128}}");
return false;
}
spdlog::debug("3D forward pass passed. Output shape: {{32, 32, 128}}");
// Test 3: Different algorithms
spdlog::info("Test 3: Testing different algorithms");
infinicore::nn::RoPE rope_gptj(64, 1024, 10000.0, infinicore::nn::RoPE::Algo::GPT_J);
infinicore::nn::RoPE rope_gptneox(64, 1024, 10000.0, infinicore::nn::RoPE::Algo::GPT_NEOX);
if (rope_gptj.algo() != infinicore::nn::RoPE::Algo::GPT_J) {
spdlog::error("GPT_J algorithm not set correctly");
return false;
}
if (rope_gptneox.algo() != infinicore::nn::RoPE::Algo::GPT_NEOX) {
spdlog::error("GPT_NEOX algorithm not set correctly");
return false;
}
auto x_test = infinicore::Tensor::ones({10, 32, 64}, infinicore::DataType::F32, infinicore::Device());
std::vector<int32_t> pos_test_data(10);
for (size_t i = 0; i < 10; i++) {
pos_test_data[i] = static_cast<int32_t>(i);
}
auto pos_test = infinicore::Tensor::from_blob(pos_test_data.data(), {10}, infinicore::DataType::I32, infinicore::Device());
auto x_gptj = rope_gptj.forward(x_test, pos_test);
auto x_neox = rope_gptneox.forward(x_test, pos_test);
if (x_gptj->shape() != x_test->shape()) {
spdlog::error("GPT_J forward pass shape mismatch");
return false;
}
if (x_neox->shape() != x_test->shape()) {
spdlog::error("GPT_NEOX forward pass shape mismatch");
return false;
}
spdlog::debug("Different algorithms test passed");
// Test 4: Different theta values
spdlog::info("Test 4: Testing different theta values");
infinicore::nn::RoPE rope_theta1(128, 2048, 1e5);
infinicore::nn::RoPE rope_theta2(128, 2048, 1e4);
if (rope_theta1.theta() != 1e5) {
spdlog::error("theta mismatch. Expected 1e5, got {}", rope_theta1.theta());
return false;
}
if (rope_theta2.theta() != 1e4) {
spdlog::error("theta mismatch. Expected 1e4, got {}", rope_theta2.theta());
return false;
}
spdlog::debug("Different theta values test passed");
// Test 5: load_state_dict
std::unordered_map<std::string, infinicore::Tensor> new_state;
rope1.load_state_dict(new_state);
spdlog::debug("load_state_dict for RoPE passed (no parameters to load)");
// Test 6: extra_repr
spdlog::info("Test 6: Testing extra_repr");
std::string repr = rope1.extra_repr();
spdlog::debug("RoPE repr: {}", repr);
if (repr.find("head_dim=128") == std::string::npos) {
spdlog::error("extra_repr should contain head_dim");
return false;
}
if (repr.find("max_seq_len=2048") == std::string::npos) {
spdlog::error("extra_repr should contain max_seq_len");
return false;
}
if (repr.find("theta=") == std::string::npos) {
spdlog::error("extra_repr should contain theta");
return false;
}
spdlog::debug("extra_repr test passed");
// Test 7: Different head dimensions
spdlog::info("Test 7: Testing different head dimensions");
infinicore::nn::RoPE rope_small(64, 1024);
infinicore::nn::RoPE rope_large(256, 4096);
auto x_small = infinicore::Tensor::ones({10, 32, 64}, infinicore::DataType::F32, infinicore::Device());
std::vector<int32_t> pos_small_data(10);
for (size_t i = 0; i < 10; i++) {
pos_small_data[i] = static_cast<int32_t>(i);
}
auto pos_small = infinicore::Tensor::from_blob(pos_small_data.data(), {10}, infinicore::DataType::I32, infinicore::Device());
auto x_small_out = rope_small.forward(x_small, pos_small);
if (x_small_out->shape() != std::vector<size_t>({10, 32, 64})) {
spdlog::error("Small RoPE output shape mismatch");
return false;
}
auto x_large = infinicore::Tensor::ones({20, 32, 256}, infinicore::DataType::F32, infinicore::Device());
std::vector<int32_t> pos_large_data(20);
for (size_t i = 0; i < 20; i++) {
pos_large_data[i] = static_cast<int32_t>(i);
}
auto pos_large = infinicore::Tensor::from_blob(pos_large_data.data(), {20}, infinicore::DataType::I32, infinicore::Device());
auto x_large_out = rope_large.forward(x_large, pos_large);
if (x_large_out->shape() != std::vector<size_t>({20, 32, 256})) {
spdlog::error("Large RoPE output shape mismatch");
return false;
}
spdlog::debug("Different head dimensions test passed");
// Test 8: Invalid head_dim (odd number)
spdlog::info("Test 8: Testing invalid head_dim (odd number)");
try {
infinicore::nn::RoPE rope_invalid(127, 2048);
spdlog::error("Should have thrown exception for odd head_dim");
return false;
} catch (const std::invalid_argument &e) {
spdlog::debug("Correctly caught exception for odd head_dim: {}", e.what());
} catch (...) {
spdlog::error("Caught unexpected exception type");
return false;
}
spdlog::debug("Invalid head_dim test passed");
// Test 9: Input validation - empty tensor (op layer handles this)
spdlog::info("Test 9: Testing input validation - empty tensor");
auto x_empty = infinicore::Tensor::ones({}, infinicore::DataType::F32, infinicore::Device());
std::vector<int32_t> pos_empty_data(1);
pos_empty_data[0] = 0;
auto pos_empty = infinicore::Tensor::from_blob(pos_empty_data.data(), {1}, infinicore::DataType::I32, infinicore::Device());
try {
rope1.forward(x_empty, pos_empty);
spdlog::error("Should have thrown exception for empty input tensor");
return false;
} catch (const std::exception &e) {
spdlog::debug("Correctly caught exception for empty input (handled by op layer): {}", e.what());
} catch (...) {
spdlog::error("Caught unexpected exception type");
return false;
}
spdlog::debug("Empty tensor validation test passed");
// Test 10: Input validation - head_dim mismatch (op layer handles this)
spdlog::info("Test 10: Testing input validation - head_dim mismatch");
auto x_wrong_dim = infinicore::Tensor::ones({32, 32, 64}, infinicore::DataType::F32, infinicore::Device()); // head_dim=64, expected 128
std::vector<int32_t> pos_wrong_data(32);
for (size_t i = 0; i < 32; i++) {
pos_wrong_data[i] = static_cast<int32_t>(i);
}
auto pos_wrong = infinicore::Tensor::from_blob(pos_wrong_data.data(), {32}, infinicore::DataType::I32, infinicore::Device());
try {
rope1.forward(x_wrong_dim, pos_wrong);
spdlog::error("Should have thrown exception for head_dim mismatch");
return false;
} catch (const std::exception &e) {
spdlog::debug("Correctly caught exception for head_dim mismatch (handled by op layer): {}", e.what());
} catch (...) {
spdlog::error("Caught unexpected exception type");
return false;
}
spdlog::debug("Head_dim mismatch validation test passed");
// Test 11: Different input shapes (from reference test cases)
spdlog::info("Test 11: Testing different input shapes");
// Test shape (1, 32, 128) - single sequence
auto x_single = infinicore::Tensor::ones({1, 32, 128}, infinicore::DataType::F32, infinicore::Device());
std::vector<int32_t> pos_single_data(1);
pos_single_data[0] = 0;
auto pos_single = infinicore::Tensor::from_blob(pos_single_data.data(), {1}, infinicore::DataType::I32, infinicore::Device());
auto x_single_out = rope1.forward(x_single, pos_single);
if (x_single_out->shape() != std::vector<size_t>({1, 32, 128})) {
spdlog::error("Single sequence output shape mismatch");
return false;
}
// Test shape (10, 32, 64) - different head_dim
auto rope_64 = infinicore::nn::RoPE(64, 1024);
auto x_64 = infinicore::Tensor::ones({10, 32, 64}, infinicore::DataType::F32, infinicore::Device());
std::vector<int32_t> pos_64_data(10);
for (size_t i = 0; i < 10; i++) {
pos_64_data[i] = static_cast<int32_t>(i);
}
auto pos_64 = infinicore::Tensor::from_blob(pos_64_data.data(), {10}, infinicore::DataType::I32, infinicore::Device());
auto x_64_out = rope_64.forward(x_64, pos_64);
if (x_64_out->shape() != std::vector<size_t>({10, 32, 64})) {
spdlog::error("Shape (10, 32, 64) output mismatch");
return false;
}
// Test shape (4, 1, 32) - single head
auto rope_32 = infinicore::nn::RoPE(32, 1024);
auto x_1head = infinicore::Tensor::ones({4, 1, 32}, infinicore::DataType::F32, infinicore::Device());
std::vector<int32_t> pos_1head_data(4);
for (size_t i = 0; i < 4; i++) {
pos_1head_data[i] = static_cast<int32_t>(i);
}
auto pos_1head = infinicore::Tensor::from_blob(pos_1head_data.data(), {4}, infinicore::DataType::I32, infinicore::Device());
auto x_1head_out = rope_32.forward(x_1head, pos_1head);
if (x_1head_out->shape() != std::vector<size_t>({4, 1, 32})) {
spdlog::error("Shape (4, 1, 32) output mismatch");
return false;
}
spdlog::debug("Different input shapes test passed");
// Test 12: Position tensor validation
spdlog::info("Test 12: Testing position tensor edge cases");
// Test with seq_len less than max_seq_len
auto x_short = infinicore::Tensor::ones({10, 32, 128}, infinicore::DataType::F32, infinicore::Device());
std::vector<int32_t> pos_short_data(10);
for (size_t i = 0; i < 10; i++) {
pos_short_data[i] = static_cast<int32_t>(i);
}
auto pos_short = infinicore::Tensor::from_blob(pos_short_data.data(), {10}, infinicore::DataType::I32, infinicore::Device());
auto x_short_out = rope1.forward(x_short, pos_short);
if (x_short_out->shape() != std::vector<size_t>({10, 32, 128})) {
spdlog::error("Short sequence output shape mismatch");
return false;
}
spdlog::debug("Position tensor edge cases test passed");
// Test 13: Test that outputs are on the same device as inputs
spdlog::info("Test 13: Testing device consistency");
auto device = x_3d->device();
if (x_out->device() != device) {
spdlog::error("Output tensor not on the same device as input");
return false;
}
spdlog::debug("Device consistency test passed");
spdlog::info("All RoPE module tests passed!");
return true;
} catch (const std::exception &e) {
spdlog::error("Exception in testModuleRoPE: {}", e.what());
return false;
}
});
}
// Test 8: Dtype assertion test
TestResult NNModuleTest::testDtypeAssertion() {
return measureTime("DtypeAssertionTest", [this]() {
try {
spdlog::info("==========================================");
spdlog::info("Testing dtype assertions when loading parameters");
spdlog::info("==========================================");
// Test 1: Successful load with matching dtype
spdlog::info("Test 1: Successful load with matching dtype (F32)");
......@@ -1382,6 +1715,7 @@ TestResult NNModuleTest::run() {
results.push_back(testModuleLinear()); // Linear module comprehensive test
results.push_back(testModuleEmbedding()); // Embedding module test
results.push_back(testModuleRMSNorm()); // RMSNorm module test
results.push_back(testModuleRoPE()); // RoPE module test
results.push_back(testDtypeAssertion()); // Dtype assertion test
results.push_back(testTinyLlamaConstruction()); // Comprehensive: TinyLlama model test
......
......@@ -6,6 +6,7 @@
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/parameter.hpp"
#include "infinicore/nn/rmsnorm.hpp"
#include "infinicore/nn/rope.hpp"
#include "test_runner.h"
#include <algorithm>
#include <cmath>
......@@ -82,6 +83,7 @@ private:
TestResult testModuleLinear(); // Comprehensive Linear module test
TestResult testModuleEmbedding(); // Embedding module test
TestResult testModuleRMSNorm(); // RMSNorm module test
TestResult testModuleRoPE(); // RoPE module test
TestResult testDtypeAssertion(); // Test dtype assertions when loading parameters
TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation
};
......
......@@ -55,6 +55,11 @@ Tensor Module::register_parameter(const std::string &name, Parameter param) {
return param;
}
Tensor Module::register_buffer(const std::string &name, Parameter buffer) {
buffers_[name] = buffer;
return buffer;
}
void Module::collect_all_parameters(std::unordered_map<std::string, Parameter> &all_params, const std::string &prefix) const {
// Add direct parameters with the given prefix
for (const auto &[param_name, param] : parameters_) {
......
......@@ -25,15 +25,8 @@ RMSNorm::RMSNorm(size_t normalized_shape, double eps, const DataType &dtype, con
}
Tensor RMSNorm::forward(const Tensor &x) const {
// Validate input shape - last dimension should match normalized_shape
auto input_shape = x->shape();
if (input_shape.empty() || input_shape.back() != normalized_shape_) {
throw std::invalid_argument(
"Input last dimension " + std::to_string(input_shape.back()) + " doesn't match normalized_shape " + std::to_string(normalized_shape_));
}
// Delegate to InfiniCore op (backed by InfiniRT/InfiniOP)
// y = RMSNorm(x, weight, eps)
// Validation is handled by the op layer
return op::rms_norm(x, weight_, static_cast<float>(eps_));
}
......
#include "infinicore/nn/rope.hpp"
#include "../utils.hpp"
#include "infinicore/ops.hpp"
#include <algorithm>
#include <cmath>
#include <functional>
#include <spdlog/spdlog.h>
#include <stdexcept>
namespace infinicore::nn {
RoPE::RoPE(size_t head_dim,
size_t max_seq_len,
double theta,
Algo algo,
const DataType &dtype,
const Device &device)
: head_dim_(head_dim),
max_seq_len_(max_seq_len),
theta_(theta),
algo_(algo),
dtype_(dtype) {
if (head_dim % 2 != 0) {
throw std::invalid_argument("head_dim must be even for RoPE, got " + std::to_string(head_dim));
}
device_ = device;
// Initialize cache tables
initialize_cache();
spdlog::debug("Created RoPE module: head_dim={}, max_seq_len={}, theta={}, algo={}, dtype={}",
head_dim, max_seq_len, theta, static_cast<int>(algo), static_cast<int>(dtype_));
}
void RoPE::initialize_cache() {
size_t cache_dim = head_dim_ / 2;
// Create sin and cos cache tables: [max_seq_len, cache_dim]
INFINICORE_NN_BUFFER_INIT(sin_cache, ({max_seq_len_, cache_dim}, dtype_, device_));
INFINICORE_NN_BUFFER_INIT(cos_cache, ({max_seq_len_, cache_dim}, dtype_, device_));
// Pre-compute sin and cos values
// The frequency calculation differs based on algorithm:
// - GPT_J: pairs are (2j, 2j+1) for cache entry j, frequency for dimension 2j is theta^(-2j/head_dim)
// - GPT_NEOX: pairs are (j, j+head_dim/2) for cache entry j, frequency for dimension j is theta^(-j/head_dim)
// Compute on CPU first, then copy to device
auto cpu_device = Device(Device::Type::CPU, 0);
// Allocate CPU buffers
std::vector<float> sin_data(max_seq_len_ * cache_dim);
std::vector<float> cos_data(max_seq_len_ * cache_dim);
for (size_t pos = 0; pos < max_seq_len_; pos++) {
for (size_t j = 0; j < cache_dim; j++) {
// Compute inverse frequency based on algorithm
double inv_freq;
if (algo_ == Algo::GPT_J) {
// GPT_J: pairs are (2j, 2j+1) for cache entry j
// Frequency for pair j: theta^(-2j/head_dim)
inv_freq = 1.0 / std::pow(theta_, 2.0 * static_cast<double>(j) / static_cast<double>(head_dim_));
} else if (algo_ == Algo::GPT_NEOX) {
// GPT_NEOX: pairs are (j, j+head_dim/2) for cache entry j
// Frequency for pair j (corresponding to dimension j): theta^(-j/head_dim)
inv_freq = 1.0 / std::pow(theta_, static_cast<double>(j) / static_cast<double>(head_dim_));
} else {
throw std::runtime_error("Unsupported RoPE algorithm: " + std::to_string(static_cast<int>(algo_)));
}
// Compute angle: position * inverse_frequency
double angle = static_cast<double>(pos) * inv_freq;
// Compute sin and cos
sin_data[pos * cache_dim + j] = static_cast<float>(std::sin(angle));
cos_data[pos * cache_dim + j] = static_cast<float>(std::cos(angle));
}
}
// Create CPU tensors and copy data
auto sin_cpu = Tensor::from_blob(sin_data.data(), {max_seq_len_, cache_dim}, DataType::F32, cpu_device);
auto cos_cpu = Tensor::from_blob(cos_data.data(), {max_seq_len_, cache_dim}, DataType::F32, cpu_device);
// Copy to device
// Note: Cache is created with dtype_, but we compute in F32 for precision.
// If dtype_ != F32, copy_from will fail. For now, we only support F32 cache.
// TODO: Add dtype conversion support when cast operation is available
if (dtype_ != DataType::F32) {
throw std::runtime_error(
"RoPE cache dtype conversion not yet supported. Please use DataType::F32 for cache. "
"Requested dtype: "
+ std::to_string(static_cast<int>(dtype_)));
}
// copy_from handles cross-device copying automatically
// Direct copy from CPU to target device avoids double copying
sin_cache_->copy_from(sin_cpu);
cos_cache_->copy_from(cos_cpu);
}
Tensor RoPE::forward(const Tensor &x, const Tensor &pos) const {
// Delegate to InfiniCore op (backed by InfiniRT/InfiniOP)
// Validation is handled by the op layer
return op::rope(x, pos, sin_cache_, cos_cache_, algo_);
}
std::string RoPE::extra_repr() const {
std::string algo_str = (algo_ == Algo::GPT_J) ? "GPT_J" : "GPT_NEOX";
return "RoPE(head_dim=" + std::to_string(head_dim_) + ", max_seq_len=" + std::to_string(max_seq_len_) + ", theta=" + std::to_string(theta_) + ", algo=" + algo_str + ", dtype=" + std::to_string(static_cast<int>(dtype_)) + ")";
}
} // namespace infinicore::nn
#include "infinicore/ops/rope.hpp"
#include "infinicore/context/context.hpp"
#include <stdexcept>
namespace infinicore::op {
common::OpDispatcher<RoPE::schema> &RoPE::dispatcher() {
static common::OpDispatcher<RoPE::schema> dispatcher_;
return dispatcher_;
};
void RoPE::execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) {
auto device_type = context::getDevice().getType();
auto func = dispatcher().lookup(device_type);
if (func == nullptr) {
throw std::runtime_error("No RoPE implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
}
func(x_out, x, pos, sin_cache, cos_cache, algo);
}
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) {
RoPE::execute(x_out, x, pos, sin_cache, cos_cache, algo);
}
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) {
Shape shape = x->shape();
auto x_out = Tensor::empty(shape, x->dtype(), x->device());
rope_(x_out, x, pos, sin_cache, cos_cache, algo);
return x_out;
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/rope.hpp"
#include <infiniop.h>
namespace infinicore::op::rope_impl::infiniop {
thread_local common::OpCache<size_t, infiniopRoPEDescriptor_t> caches(
100, // capacity
[](infiniopRoPEDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyRoPEDescriptor(desc));
desc = nullptr;
}
});
void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) {
// Convert infinicore::nn::RoPE::Algo to infiniopRoPEAlgo_t
infiniopRoPEAlgo_t infiniop_algo;
switch (algo) {
case infinicore::nn::RoPE::Algo::GPT_J:
infiniop_algo = INFINIOP_ROPE_ALGO_GPT_J;
break;
case infinicore::nn::RoPE::Algo::GPT_NEOX:
infiniop_algo = INFINIOP_ROPE_ALGO_GPT_NEOX;
break;
default:
throw std::runtime_error("Unsupported RoPE algorithm: " + std::to_string(static_cast<int>(algo)));
}
// Create hash key for descriptor caching
size_t key = hash_combine(x_out, x, pos, sin_cache, cos_cache);
hash_combine(key, std::hash<int>()(static_cast<int>(infiniop_algo)));
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(key);
infiniopRoPEDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor(
context::getInfiniopHandle(), &desc,
x_out->desc(), x->desc(),
pos->desc(), sin_cache->desc(), cos_cache->desc(),
infiniop_algo));
cache.put(key, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetRoPEWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
// InfiniOP reads from x and writes to x_out (handles copying internally)
INFINICORE_CHECK_ERROR(infiniopRoPE(
desc, workspace->data(), workspace_size,
x_out->data(), x->data(), pos->data(),
sin_cache->data(), cos_cache->data(), context::getStream()));
}
static bool registered = []() {
RoPE::dispatcher().registerAll(&calculate, false);
return true;
}();
} // namespace infinicore::op::rope_impl::infiniop
......@@ -9,7 +9,7 @@
inline struct SpdlogInitializer {
SpdlogInitializer() {
if (!std::getenv("INFINICORE_LOG_LEVEL")) {
spdlog::set_level(spdlog::level::off);
spdlog::set_level(spdlog::level::info);
} else {
spdlog::cfg::load_env_levels("INFINICORE_LOG_LEVEL");
}
......@@ -21,9 +21,9 @@ inline struct SpdlogInitializer {
#define INFINICORE_CHECK_ERROR(call) \
do { \
spdlog::info("Entering `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \
spdlog::debug("Entering `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \
infiniStatus_t ret = (call); \
spdlog::info("Exiting `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \
spdlog::debug("Exiting `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \
if (ret != INFINI_STATUS_SUCCESS) { \
throw std::runtime_error(#call " failed with error: " + std::string(infini_status_string(ret))); \
} \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment