Unverified Commit 53f4bc1d authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #694 from InfiniTensor/issue/682

issue/682 Parameter supports TP, fix copy_from
parents 1f573217 ee43bda6
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
namespace infinicore { namespace infinicore {
namespace context { namespace context {
void setDevice(Device device); void setDevice(Device device, bool force_cpu = false);
Device getDevice(); Device getDevice();
size_t getDeviceCount(Device::Type type); size_t getDeviceCount(Device::Type type);
......
...@@ -39,6 +39,10 @@ public: ...@@ -39,6 +39,10 @@ public:
bool operator!=(const Device &other) const; bool operator!=(const Device &other) const;
inline static Device cpu() {
return Device(Type::CPU, 0);
}
private: private:
Type type_; Type type_;
......
...@@ -9,8 +9,19 @@ public: ...@@ -9,8 +9,19 @@ public:
Parameter(const Shape &shape, Parameter(const Shape &shape,
const DataType &dtype, const DataType &dtype,
const Device &device); const Device &device,
Size tp_dim = 0,
Size tp_rank = 0,
Size tp_size = 1);
void load_blob(const void *data); void load_blob(const void *data);
void load(const Tensor &tensor);
protected:
// Tensor parallel configs
Size tp_dim_; // dimension partitioned
Size tp_rank_; // rank of this partition among tp group
Size tp_size_; // total number of partitions
}; };
} // namespace infinicore::nn } // namespace infinicore::nn
...@@ -23,13 +23,13 @@ def get_device_count(device_type): ...@@ -23,13 +23,13 @@ def get_device_count(device_type):
return _infinicore.get_device_count(infinicore.device(device_type)._underlying.type) return _infinicore.get_device_count(infinicore.device(device_type)._underlying.type)
def set_device(device): def set_device(device, force_cpu=False):
"""Set the current active device. """Set the current active device.
Args: Args:
device: The device to set as active device: The device to set as active
""" """
_infinicore.set_device(device._underlying) _infinicore.set_device(device._underlying, force_cpu)
def sync_stream(): def sync_stream():
......
...@@ -709,9 +709,6 @@ TestResult PerformanceTest::testMemoryCopyPerformance() { ...@@ -709,9 +709,6 @@ TestResult PerformanceTest::testMemoryCopyPerformance() {
return false; return false;
} }
// Initialize source data
std::memset(src_memory->data(), 0xAB, data_size);
auto start = std::chrono::high_resolution_clock::now(); auto start = std::chrono::high_resolution_clock::now();
// Perform memory copies // Perform memory copies
......
...@@ -3,6 +3,20 @@ ...@@ -3,6 +3,20 @@
namespace infinicore::test { namespace infinicore::test {
// Helper function to format shape for logging
inline std::string formatShape(const std::vector<size_t> &shape) {
std::ostringstream oss;
oss << "[";
for (size_t i = 0; i < shape.size(); ++i) {
if (i > 0) {
oss << ", ";
}
oss << shape[i];
}
oss << "]";
return oss.str();
}
// Test 1: Basic module operations (creation, parameters, state_dict, load_state_dict) // Test 1: Basic module operations (creation, parameters, state_dict, load_state_dict)
TestResult NNModuleTest::testBasicModuleCreation() { TestResult NNModuleTest::testBasicModuleCreation() {
return measureTime("BasicModuleOperations", [this]() { return measureTime("BasicModuleOperations", [this]() {
...@@ -115,6 +129,174 @@ TestResult NNModuleTest::testBasicModuleCreation() { ...@@ -115,6 +129,174 @@ TestResult NNModuleTest::testBasicModuleCreation() {
}); });
} }
TestResult NNModuleTest::testTensorParallelParameters() {
return measureTime("TensorParallelParameters", [this]() {
try {
spdlog::info("==========================================");
spdlog::info("Testing Tensor Parallel Parameters");
spdlog::info("==========================================");
auto device = infinicore::context::getDevice();
spdlog::info("Test Tensor Parallel Parameter");
// Case 1: Partition along dimension 0 (row-wise partitioning)
infinicore::nn::Parameter param_dim0({8, 4}, infinicore::DataType::F32, device, 0, 0, 2);
if (param_dim0->shape() != std::vector<size_t>({4, 4})) {
spdlog::error("TP dim0: Expected shape [4, 4], got [{}]", formatShape(param_dim0->shape()));
return false;
}
spdlog::info("✓ TP dim0 parameter created with correct partitioned shape");
// Case 2: Partition along dimension 1 (column-wise partitioning)
infinicore::nn::Parameter param_dim1({8, 4}, infinicore::DataType::F32, device, 1, 0, 2);
if (param_dim1->shape() != std::vector<size_t>({8, 2})) {
spdlog::error("TP dim1: Expected shape [8, 2], got [{}]", formatShape(param_dim1->shape()));
return false;
}
spdlog::info("✓ TP dim1 parameter created with correct partitioned shape");
spdlog::info("✓ Parameter creation with tensor parallelism passed");
spdlog::info("Test Tensor Parallel Linear Module");
auto w_data = std::vector<float>(32 * 64);
auto b_data = std::vector<float>(32);
for (size_t i = 0; i < 32; ++i) {
for (size_t j = 0; j < 64; ++j) {
w_data[i * 64 + j] = static_cast<float>(j);
}
b_data[i] = static_cast<float>(i);
}
{
spdlog::info("Test tp_size=4 tp_dim=0");
Size tp_size = 4;
Size tp_dim = 0;
std::vector<std::unique_ptr<MockLinearModule>> tp_modules;
for (Size tp_rank = 0; tp_rank < tp_size; ++tp_rank) {
auto module = std::make_unique<MockLinearModule>(64, 32, device, tp_dim, tp_rank, tp_size);
tp_modules.push_back(std::move(module));
}
// Verify each partition has correct shape
for (size_t i = 0; i < tp_modules.size(); ++i) {
const auto &weight = tp_modules[i]->get_weight();
const auto &bias = tp_modules[i]->get_bias();
// Weight should be partitioned along output dimension (dim 0)
if (weight->shape() != std::vector<size_t>({8, 64})) { // 32/4 = 8
spdlog::error("TP rank {}: Weight shape mismatch. Expected [8, 64], got [{}]",
i, formatShape(weight->shape()));
return false;
}
// Bias should be partitioned along output dimension
if (bias->shape() != std::vector<size_t>({8})) { // 32/4 = 8
spdlog::error("TP rank {}: Bias shape mismatch. Expected [8], got [{}]",
i, formatShape(bias->shape()));
return false;
}
spdlog::debug("TP rank {}: weight shape [{}], bias shape [{}]",
i, formatShape(weight->shape()), formatShape(bias->shape()));
tp_modules[i]->load_parameter_from_blob("weight", w_data.data());
tp_modules[i]->load_parameter_from_blob("bias", b_data.data());
auto weight_loaded = infinicore::Tensor::from_blob(
w_data.data(),
{32, 64},
infinicore::DataType::F32,
infinicore::Device::cpu())
->narrow({{0, i * 8, 8}})
->to(device); // Narrow to get the partition
auto bias_loaded = infinicore::Tensor::from_blob(
b_data.data(),
{32},
infinicore::DataType::F32,
infinicore::Device::cpu())
->narrow({{0, i * 8, 8}})
->to(device); // Narrow to get the partition
if (!tensorsAllClose(tp_modules[i]->get_weight(), weight_loaded, 1e-6, 1e-6)) {
spdlog::error("TP rank {}: Weight values do not match after load_parameter_from_blob", i);
return false;
}
if (!tensorsAllClose(tp_modules[i]->get_bias(), bias_loaded, 1e-6, 1e-6)) {
spdlog::error("TP rank {}: Bias values do not match after load_parameter_from_blob", i);
return false;
}
}
}
{
spdlog::info("Test tp_size=4 tp_dim=1");
Size tp_size = 4;
Size tp_dim = 1;
std::vector<std::unique_ptr<MockLinearModule>> tp_modules;
for (Size tp_rank = 0; tp_rank < tp_size; ++tp_rank) {
auto module = std::make_unique<MockLinearModule>(64, 32, device, tp_dim, tp_rank, tp_size);
tp_modules.push_back(std::move(module));
}
// Verify each partition has correct shape
for (size_t i = 0; i < tp_modules.size(); ++i) {
const auto &weight = tp_modules[i]->get_weight();
const auto &bias = tp_modules[i]->get_bias();
// Weight should be partitioned along output dimension (dim 0)
if (weight->shape() != std::vector<size_t>({32, 16})) { // 64/4 = 16
spdlog::error("TP rank {}: Weight shape mismatch. Expected [32, 16], got [{}]",
i, formatShape(weight->shape()));
return false;
}
// Bias should be partitioned along output dimension
if (bias->shape() != std::vector<size_t>({32})) { // Bias not partitioned when tp_dim=1
spdlog::error("TP rank {}: Bias shape mismatch. Expected [32], got [{}]",
i, formatShape(bias->shape()));
return false;
}
spdlog::debug("TP rank {}: weight shape [{}], bias shape [{}]",
i, formatShape(weight->shape()), formatShape(bias->shape()));
;
tp_modules[i]->load_parameter_from_blob("weight", w_data.data());
tp_modules[i]->load_parameter_from_blob("bias", b_data.data());
auto weight_loaded = infinicore::Tensor::from_blob(
w_data.data(),
{32, 64},
infinicore::DataType::F32,
infinicore::Device::cpu())
->narrow({{1, i * 16, 16}})
->to(device); // Narrow to get the partition
auto bias_loaded = infinicore::Tensor::from_blob(
b_data.data(),
{32},
infinicore::DataType::F32,
infinicore::Device::cpu())
->to(device); // Narrow to get the partition
if (!tensorsAllClose(tp_modules[i]->get_weight(), weight_loaded, 1e-6, 1e-6)) {
spdlog::error("TP rank {}: Weight values do not match after load_parameter_from_blob", i);
return false;
}
if (!tensorsAllClose(tp_modules[i]->get_bias(), bias_loaded, 1e-6, 1e-6)) {
spdlog::error("TP rank {}: Bias values do not match after load_parameter_from_blob", i);
return false;
}
}
}
spdlog::info("=== All Tensor Parallel Parameter Tests Passed ===");
return true;
} catch (const std::exception &e) {
spdlog::error("Exception in testTensorParallelParameters: {}", e.what());
return false;
}
});
}
// Test 2: Advanced load state dict functionality (hierarchical modules) // Test 2: Advanced load state dict functionality (hierarchical modules)
TestResult NNModuleTest::testLoadStateDict() { TestResult NNModuleTest::testLoadStateDict() {
return measureTime("AdvancedLoadStateDict", [this]() { return measureTime("AdvancedLoadStateDict", [this]() {
...@@ -384,6 +566,8 @@ TestResult NNModuleTest::testParameterLoading() { ...@@ -384,6 +566,8 @@ TestResult NNModuleTest::testParameterLoading() {
return false; return false;
} }
MockLinearModule module_row_parallel(3, 2, infinicore::Device(), 0, 1, 2);
spdlog::info("Parameter loading test passed"); spdlog::info("Parameter loading test passed");
return true; return true;
} catch (const std::exception &e) { } catch (const std::exception &e) {
...@@ -1709,6 +1893,7 @@ TestResult NNModuleTest::run() { ...@@ -1709,6 +1893,7 @@ TestResult NNModuleTest::run() {
<< "==============================================" << std::endl; << "==============================================" << std::endl;
results.push_back(testBasicModuleCreation()); // Merged: creation + parameters + state_dict + load results.push_back(testBasicModuleCreation()); // Merged: creation + parameters + state_dict + load
results.push_back(testTensorParallelParameters()); // Tensor-parallel parameters
results.push_back(testLoadStateDict()); // Advanced: hierarchical modules results.push_back(testLoadStateDict()); // Advanced: hierarchical modules
results.push_back(testModuleHierarchy()); // Demonstrates hierarchical construction results.push_back(testModuleHierarchy()); // Demonstrates hierarchical construction
results.push_back(testParameterLoading()); // Blob loading results.push_back(testParameterLoading()); // Blob loading
......
...@@ -26,17 +26,25 @@ public: ...@@ -26,17 +26,25 @@ public:
INFINICORE_NN_PARAMETER(weight); INFINICORE_NN_PARAMETER(weight);
INFINICORE_NN_PARAMETER(bias); INFINICORE_NN_PARAMETER(bias);
MockLinearModule(int input_size, int output_size, const infinicore::Device &device) MockLinearModule(int input_size, int output_size, const infinicore::Device &device,
: input_size_(input_size), output_size_(output_size), device_(device) { Size tp_dim = 0, Size tp_rank = 0, Size tp_size = 1)
: input_size_(input_size), output_size_(output_size), device_(device),
tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size) {
// Initialize parameters using macros // Initialize parameters using macros
INFINICORE_NN_PARAMETER_INIT(weight, INFINICORE_NN_PARAMETER_INIT(weight,
({static_cast<size_t>(output_size), static_cast<size_t>(input_size)}, ({static_cast<size_t>(output_size), static_cast<size_t>(input_size)},
infinicore::DataType::F32, infinicore::DataType::F32,
device)); device,
tp_dim_,
tp_rank_,
tp_size_));
INFINICORE_NN_PARAMETER_INIT(bias, INFINICORE_NN_PARAMETER_INIT(bias,
({static_cast<size_t>(output_size)}, ({static_cast<size_t>(output_size)},
infinicore::DataType::F32, infinicore::DataType::F32,
device)); device,
0,
tp_dim == 0 ? tp_rank_ : 0,
tp_dim == 0 ? tp_size_ : 1));
} }
// Simple forward pass (conceptual - would need actual matrix operations) // Simple forward pass (conceptual - would need actual matrix operations)
...@@ -68,6 +76,10 @@ private: ...@@ -68,6 +76,10 @@ private:
int input_size_; int input_size_;
int output_size_; int output_size_;
infinicore::Device device_; infinicore::Device device_;
Size tp_dim_;
Size tp_rank_;
Size tp_size_;
}; };
class NNModuleTest : public TestFramework { class NNModuleTest : public TestFramework {
...@@ -77,6 +89,7 @@ public: ...@@ -77,6 +89,7 @@ public:
private: private:
TestResult testBasicModuleCreation(); // Merged: creation, parameters, state_dict, load_state_dict TestResult testBasicModuleCreation(); // Merged: creation, parameters, state_dict, load_state_dict
TestResult testTensorParallelParameters(); // Module with tensor parallel parameters
TestResult testLoadStateDict(); // Advanced: hierarchical modules TestResult testLoadStateDict(); // Advanced: hierarchical modules
TestResult testModuleHierarchy(); // Demonstrates proper hierarchical construction pattern TestResult testModuleHierarchy(); // Demonstrates proper hierarchical construction pattern
TestResult testParameterLoading(); // Test blob parameter loading TestResult testParameterLoading(); // Test blob parameter loading
......
...@@ -33,11 +33,15 @@ Runtime *ContextImpl::getCpuRuntime() { ...@@ -33,11 +33,15 @@ Runtime *ContextImpl::getCpuRuntime() {
return runtime_table_[int(Device::Type::CPU)][0].get(); return runtime_table_[int(Device::Type::CPU)][0].get();
} }
void ContextImpl::setDevice(Device device) { void ContextImpl::setDevice(Device device, bool force_cpu) {
if (device == getCurrentRuntime()->device()) { if (device == getCurrentRuntime()->device()) {
// Do nothing if the device is already set. // Do nothing if the device is already set.
return; return;
} }
if (device == Device(Device::Type::CPU, 0) && !force_cpu) {
// if not forced, no need to switch to CPU device runtime
return;
}
if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) { if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) {
// Lazy initialization of runtime if never set before. // Lazy initialization of runtime if never set before.
...@@ -83,8 +87,8 @@ ContextImpl::ContextImpl() { ...@@ -83,8 +87,8 @@ ContextImpl::ContextImpl() {
namespace context { namespace context {
void setDevice(Device device) { void setDevice(Device device, bool force_cpu) {
ContextImpl::singleton().setDevice(device); ContextImpl::singleton().setDevice(device, force_cpu);
} }
Device getDevice() { Device getDevice() {
......
...@@ -21,7 +21,7 @@ public: ...@@ -21,7 +21,7 @@ public:
Runtime *getCpuRuntime(); Runtime *getCpuRuntime();
void setDevice(Device); void setDevice(Device, bool force_cpu = false);
size_t getDeviceCount(Device::Type type); size_t getDeviceCount(Device::Type type);
......
#include "infinicore/nn/module.hpp" #include "infinicore/nn/module.hpp"
#include <spdlog/spdlog.h>
#include <stdexcept> #include <stdexcept>
namespace infinicore::nn { namespace infinicore::nn {
...@@ -21,20 +22,17 @@ void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_sta ...@@ -21,20 +22,17 @@ void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_sta
// Look up the corresponding tensor in the input state dict using the full name // Look up the corresponding tensor in the input state dict using the full name
auto it = _state_dict.find(param_full_name); auto it = _state_dict.find(param_full_name);
if (it != _state_dict.end()) { if (it != _state_dict.end()) {
// Assert dtype matches this->load_parameter(param_full_name, it->second);
if (param->dtype() != it->second->dtype()) { } else {
throw std::runtime_error( spdlog::warn("Parameter '{}' provided but not found in module.", param_full_name);
"dtype mismatch for parameter '" + param_full_name + "': "
"expected "
+ std::to_string(static_cast<int>(param->dtype())) + ", got " + std::to_string(static_cast<int>(it->second->dtype())));
}
param->copy_from(it->second);
} }
} }
} }
void Module::load_parameter(const std::string &name, const Tensor &param) { void Module::load_parameter(const std::string &name, const Tensor &param) {
auto existing_param = parameters_[name]; auto it = parameters_.find(name);
if (it != parameters_.end()) {
auto existing_param = it->second;
// Assert dtype matches // Assert dtype matches
if (existing_param->dtype() != param->dtype()) { if (existing_param->dtype() != param->dtype()) {
throw std::runtime_error( throw std::runtime_error(
...@@ -42,7 +40,10 @@ void Module::load_parameter(const std::string &name, const Tensor &param) { ...@@ -42,7 +40,10 @@ void Module::load_parameter(const std::string &name, const Tensor &param) {
"expected " "expected "
+ std::to_string(static_cast<int>(existing_param->dtype())) + ", got " + std::to_string(static_cast<int>(param->dtype()))); + std::to_string(static_cast<int>(existing_param->dtype())) + ", got " + std::to_string(static_cast<int>(param->dtype())));
} }
existing_param->copy_from(param); existing_param.load(param);
} else {
throw std::runtime_error("Parameter '" + name + "' not found in module.");
}
} }
void Module::load_parameter_from_blob(const std::string &name, const void *data) { void Module::load_parameter_from_blob(const std::string &name, const void *data) {
......
...@@ -3,29 +3,64 @@ ...@@ -3,29 +3,64 @@
#include "infinicore/context/context.hpp" #include "infinicore/context/context.hpp"
#include <cstring> #include <cstring>
#include <stdexcept>
namespace infinicore::nn { namespace infinicore::nn {
Parameter::Parameter() Parameter::Parameter()
: Tensor(Tensor::empty({}, DataType::F32, Device(Device::Type::CPU, 0), false)) { : Tensor(Tensor::empty({}, DataType::F32, Device(Device::Type::CPU, 0), false)) {
} }
inline Shape get_partipion_shape_(const Shape &shape, Size tp_dim, Size tp_size) {
if (tp_size <= 1) {
return shape;
}
Shape part_shape = shape;
if (tp_dim < shape.size()) {
if (shape[tp_dim] % tp_size != 0) {
throw std::runtime_error("Tensor dimension " + std::to_string(tp_dim) + " with size " + std::to_string(shape[tp_dim]) + " is not divisible by tensor parallel size " + std::to_string(tp_size) + ".");
}
part_shape[tp_dim] = shape[tp_dim] / tp_size;
}
return part_shape;
}
Parameter::Parameter( Parameter::Parameter(
const Shape &shape, const Shape &shape,
const DataType &dtype, const DataType &dtype,
const Device &device) const Device &device,
: Tensor(Tensor::empty(shape, dtype, device, false)) { Size tp_dim,
Size tp_rank,
Size tp_size)
: Tensor(Tensor::empty(get_partipion_shape_(shape, tp_dim, tp_size), dtype, device, false)), tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size) {
if (tp_rank_ >= tp_size_) {
throw std::runtime_error("Tensor parallel rank " + std::to_string(tp_rank_) + " must be less than tensor parallel size " + std::to_string(tp_size_) + ".");
}
} }
void Parameter::load_blob(const void *data) { void Parameter::load_blob(const void *data) {
auto buffer = Tensor::empty(impl_->shape(), impl_->dtype(), Device(Device::Type::CPU, 0), true); Shape expected_shape = Shape(impl_->shape());
expected_shape[tp_dim_] *= tp_size_;
auto buffer = Tensor::empty(expected_shape, impl_->dtype(), Device(Device::Type::CPU, 0), true);
std::memcpy(buffer->data(), data, buffer->nbytes()); std::memcpy(buffer->data(), data, buffer->nbytes());
this->load(buffer);
}
void Parameter::load(const Tensor &tensor) {
Shape expected_shape = Shape(impl_->shape());
expected_shape[tp_dim_] *= tp_size_;
if (expected_shape != tensor->shape()) {
throw std::runtime_error("Shape mismatch when loading tensor into parameter.");
}
if (impl_->dtype() != tensor->dtype()) {
throw std::runtime_error("Dtype mismatch when loading tensor into parameter.");
}
if (tp_size_ > 1) {
impl_->copy_from(tensor->narrow({{tp_dim_, tp_rank_ * impl_->size(tp_dim_), impl_->size(tp_dim_)}}));
// If parameter is on CPU, use direct memcpy; otherwise use H2D
if (impl_->device().getType() == Device::Type::CPU) {
infinicore::context::memcpyH2H(impl_->data(), buffer->data(), buffer->nbytes());
} else { } else {
infinicore::context::memcpyH2D(impl_->data(), buffer->data(), buffer->nbytes()); impl_->copy_from(tensor);
infinicore::context::syncStream();
} }
infinicore::context::syncStream();
} }
} // namespace infinicore::nn } // namespace infinicore::nn
...@@ -9,7 +9,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i ...@@ -9,7 +9,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i
) { ) {
auto input_shape = input->shape(); auto input_shape = input->shape();
auto weight_shape = weight->shape(); auto weight_shape = weight->shape();
auto vocab_size = weight_shape[0]; // auto vocab_size = weight_shape[0];
auto embedding_dim = weight_shape[1]; auto embedding_dim = weight_shape[1];
// Assign memory to out variables // Assign memory to out variables
...@@ -23,11 +23,10 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i ...@@ -23,11 +23,10 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i
void embedding_(Tensor out, Tensor input, Tensor weight) { void embedding_(Tensor out, Tensor input, Tensor weight) {
assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype())); assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype()));
assert(infinicore::Device::Type::CPU == input->device()); assert(infinicore::Device::Type::CPU == input->device().getType());
auto input_shape = input->shape(); auto input_shape = input->shape();
auto weight_shape = weight->shape(); auto weight_shape = weight->shape();
auto vocab_size = weight_shape[0];
auto embedding_dim = weight_shape[1]; auto embedding_dim = weight_shape[1];
// Calculate the number of token // Calculate the number of token
...@@ -47,7 +46,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) { ...@@ -47,7 +46,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) {
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data()); const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
for (Size i = 0; i < counts; ++i) { for (Size i = 0; i < counts; ++i) {
int64_t idx = input_arr[i]; int64_t idx = input_arr[i];
assert((idx >= 0) && (idx < vocab_size)); assert((idx >= 0) && (idx < weight_shape[0]));
std::memcpy(out_ptr + i * bytes, std::memcpy(out_ptr + i * bytes,
weight_ptr + idx * bytes, weight_ptr + idx * bytes,
bytes); bytes);
...@@ -57,7 +56,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) { ...@@ -57,7 +56,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) {
for (Size i = 0; i < counts; ++i) { for (Size i = 0; i < counts; ++i) {
int32_t idx = input_arr[i]; int32_t idx = input_arr[i];
assert((idx >= 0) && (idx < vocab_size)); assert((idx >= 0) && (idx < weight_shape[0]));
std::memcpy(out_ptr + i * bytes, std::memcpy(out_ptr + i * bytes,
weight_ptr + idx * bytes, weight_ptr + idx * bytes,
bytes); bytes);
...@@ -69,7 +68,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) { ...@@ -69,7 +68,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) {
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data()); const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
for (Size i = 0; i < counts; ++i) { for (Size i = 0; i < counts; ++i) {
int64_t idx = input_arr[i]; int64_t idx = input_arr[i];
assert((idx >= 0) && (idx < vocab_size)); assert((idx >= 0) && (idx < weight_shape[0]));
context::memcpyD2D(out_ptr + i * bytes, context::memcpyD2D(out_ptr + i * bytes,
weight_ptr + idx * bytes, weight_ptr + idx * bytes,
bytes); bytes);
...@@ -78,7 +77,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) { ...@@ -78,7 +77,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) {
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data()); const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());
for (Size i = 0; i < counts; ++i) { for (Size i = 0; i < counts; ++i) {
int32_t idx = input_arr[i]; int32_t idx = input_arr[i];
assert((idx >= 0) && (idx < vocab_size)); assert((idx >= 0) && (idx < weight_shape[0]));
context::memcpyD2D(out_ptr + i * bytes, context::memcpyD2D(out_ptr + i * bytes,
weight_ptr + idx * bytes, weight_ptr + idx * bytes,
bytes); bytes);
......
#include "infinicore/ops/rearrange.hpp" #include "infinicore/ops/rearrange.hpp"
#include "../../utils.hpp"
namespace infinicore::op { namespace infinicore::op {
...@@ -8,7 +9,9 @@ common::OpDispatcher<Rearrange::schema> &Rearrange::dispatcher() { ...@@ -8,7 +9,9 @@ common::OpDispatcher<Rearrange::schema> &Rearrange::dispatcher() {
}; };
void Rearrange::execute(Tensor y, Tensor x) { void Rearrange::execute(Tensor y, Tensor x) {
dispatcher().lookup(context::getDevice().getType())(y, x); INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x);
infinicore::context::setDevice(y->device());
dispatcher().lookup(y->device().getType())(y, x);
} }
Tensor rearrange(Tensor x) { Tensor rearrange(Tensor x) {
......
...@@ -18,8 +18,8 @@ thread_local common::OpCache<size_t, infiniopRearrangeDescriptor_t> caches( ...@@ -18,8 +18,8 @@ thread_local common::OpCache<size_t, infiniopRearrangeDescriptor_t> caches(
void calculate(Tensor y, Tensor x) { void calculate(Tensor y, Tensor x) {
size_t seed = hash_combine(y, x); size_t seed = hash_combine(y, x);
auto device_type = context::getDevice().getType(); auto device_type = y->device().getType();
auto device_index = context::getDevice().getIndex(); auto device_index = y->device().getIndex();
auto &cache = caches.getCache(device_type, device_index); auto &cache = caches.getCache(device_type, device_index);
......
...@@ -16,7 +16,8 @@ inline void bind(py::module &m) { ...@@ -16,7 +16,8 @@ inline void bind(py::module &m) {
py::arg("device_type")); py::arg("device_type"));
m.def("set_device", &setDevice, m.def("set_device", &setDevice,
"Set the current active device", "Set the current active device",
py::arg("device")); py::arg("device"),
py::arg("force_cpu"));
// Stream and handle management // Stream and handle management
m.def("get_stream", &getStream, "Get the current stream"); m.def("get_stream", &getStream, "Get the current stream");
......
...@@ -32,3 +32,15 @@ inline struct SpdlogInitializer { ...@@ -32,3 +32,15 @@ inline struct SpdlogInitializer {
throw std::runtime_error(#call " failed with error: " + std::string(infini_status_string(ret))); \ throw std::runtime_error(#call " failed with error: " + std::string(infini_status_string(ret))); \
} \ } \
} while (false) } while (false)
#define INFINICORE_ASSERT_TENSORS_SAME_DEVICE(FIRST___, ...) \
do { \
const auto &first_device___ = (FIRST___)->device(); \
for (const auto &tensor___ : {__VA_ARGS__}) { \
if (first_device___ != (tensor___)->device()) { \
throw std::runtime_error("Tensor devices mismatch " \
+ first_device___.toString() + " vs " \
+ (tensor___)->device().toString() + "."); \
} \
} \
} while (0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment