Unverified Commit 53f4bc1d authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #694 from InfiniTensor/issue/682

issue/682 Parameter supports TP, fix copy_from
parents 1f573217 ee43bda6
......@@ -11,7 +11,7 @@
namespace infinicore {
namespace context {
void setDevice(Device device);
void setDevice(Device device, bool force_cpu = false);
Device getDevice();
size_t getDeviceCount(Device::Type type);
......
......@@ -39,6 +39,10 @@ public:
bool operator!=(const Device &other) const;
inline static Device cpu() {
return Device(Type::CPU, 0);
}
private:
Type type_;
......
......@@ -9,8 +9,19 @@ public:
Parameter(const Shape &shape,
const DataType &dtype,
const Device &device);
const Device &device,
Size tp_dim = 0,
Size tp_rank = 0,
Size tp_size = 1);
void load_blob(const void *data);
void load(const Tensor &tensor);
protected:
// Tensor parallel configs
Size tp_dim_; // dimension partitioned
Size tp_rank_; // rank of this partition among tp group
Size tp_size_; // total number of partitions
};
} // namespace infinicore::nn
......@@ -23,13 +23,13 @@ def get_device_count(device_type):
return _infinicore.get_device_count(infinicore.device(device_type)._underlying.type)
def set_device(device):
def set_device(device, force_cpu=False):
"""Set the current active device.
Args:
device: The device to set as active
"""
_infinicore.set_device(device._underlying)
_infinicore.set_device(device._underlying, force_cpu)
def sync_stream():
......
......@@ -709,9 +709,6 @@ TestResult PerformanceTest::testMemoryCopyPerformance() {
return false;
}
// Initialize source data
std::memset(src_memory->data(), 0xAB, data_size);
auto start = std::chrono::high_resolution_clock::now();
// Perform memory copies
......
......@@ -3,6 +3,20 @@
namespace infinicore::test {
// Helper function to format shape for logging
inline std::string formatShape(const std::vector<size_t> &shape) {
std::ostringstream oss;
oss << "[";
for (size_t i = 0; i < shape.size(); ++i) {
if (i > 0) {
oss << ", ";
}
oss << shape[i];
}
oss << "]";
return oss.str();
}
// Test 1: Basic module operations (creation, parameters, state_dict, load_state_dict)
TestResult NNModuleTest::testBasicModuleCreation() {
return measureTime("BasicModuleOperations", [this]() {
......@@ -115,6 +129,174 @@ TestResult NNModuleTest::testBasicModuleCreation() {
});
}
TestResult NNModuleTest::testTensorParallelParameters() {
return measureTime("TensorParallelParameters", [this]() {
try {
spdlog::info("==========================================");
spdlog::info("Testing Tensor Parallel Parameters");
spdlog::info("==========================================");
auto device = infinicore::context::getDevice();
spdlog::info("Test Tensor Parallel Parameter");
// Case 1: Partition along dimension 0 (row-wise partitioning)
infinicore::nn::Parameter param_dim0({8, 4}, infinicore::DataType::F32, device, 0, 0, 2);
if (param_dim0->shape() != std::vector<size_t>({4, 4})) {
spdlog::error("TP dim0: Expected shape [4, 4], got [{}]", formatShape(param_dim0->shape()));
return false;
}
spdlog::info("✓ TP dim0 parameter created with correct partitioned shape");
// Case 2: Partition along dimension 1 (column-wise partitioning)
infinicore::nn::Parameter param_dim1({8, 4}, infinicore::DataType::F32, device, 1, 0, 2);
if (param_dim1->shape() != std::vector<size_t>({8, 2})) {
spdlog::error("TP dim1: Expected shape [8, 2], got [{}]", formatShape(param_dim1->shape()));
return false;
}
spdlog::info("✓ TP dim1 parameter created with correct partitioned shape");
spdlog::info("✓ Parameter creation with tensor parallelism passed");
spdlog::info("Test Tensor Parallel Linear Module");
auto w_data = std::vector<float>(32 * 64);
auto b_data = std::vector<float>(32);
for (size_t i = 0; i < 32; ++i) {
for (size_t j = 0; j < 64; ++j) {
w_data[i * 64 + j] = static_cast<float>(j);
}
b_data[i] = static_cast<float>(i);
}
{
spdlog::info("Test tp_size=4 tp_dim=0");
Size tp_size = 4;
Size tp_dim = 0;
std::vector<std::unique_ptr<MockLinearModule>> tp_modules;
for (Size tp_rank = 0; tp_rank < tp_size; ++tp_rank) {
auto module = std::make_unique<MockLinearModule>(64, 32, device, tp_dim, tp_rank, tp_size);
tp_modules.push_back(std::move(module));
}
// Verify each partition has correct shape
for (size_t i = 0; i < tp_modules.size(); ++i) {
const auto &weight = tp_modules[i]->get_weight();
const auto &bias = tp_modules[i]->get_bias();
// Weight should be partitioned along output dimension (dim 0)
if (weight->shape() != std::vector<size_t>({8, 64})) { // 32/4 = 8
spdlog::error("TP rank {}: Weight shape mismatch. Expected [8, 64], got [{}]",
i, formatShape(weight->shape()));
return false;
}
// Bias should be partitioned along output dimension
if (bias->shape() != std::vector<size_t>({8})) { // 32/4 = 8
spdlog::error("TP rank {}: Bias shape mismatch. Expected [8], got [{}]",
i, formatShape(bias->shape()));
return false;
}
spdlog::debug("TP rank {}: weight shape [{}], bias shape [{}]",
i, formatShape(weight->shape()), formatShape(bias->shape()));
tp_modules[i]->load_parameter_from_blob("weight", w_data.data());
tp_modules[i]->load_parameter_from_blob("bias", b_data.data());
auto weight_loaded = infinicore::Tensor::from_blob(
w_data.data(),
{32, 64},
infinicore::DataType::F32,
infinicore::Device::cpu())
->narrow({{0, i * 8, 8}})
->to(device); // Narrow to get the partition
auto bias_loaded = infinicore::Tensor::from_blob(
b_data.data(),
{32},
infinicore::DataType::F32,
infinicore::Device::cpu())
->narrow({{0, i * 8, 8}})
->to(device); // Narrow to get the partition
if (!tensorsAllClose(tp_modules[i]->get_weight(), weight_loaded, 1e-6, 1e-6)) {
spdlog::error("TP rank {}: Weight values do not match after load_parameter_from_blob", i);
return false;
}
if (!tensorsAllClose(tp_modules[i]->get_bias(), bias_loaded, 1e-6, 1e-6)) {
spdlog::error("TP rank {}: Bias values do not match after load_parameter_from_blob", i);
return false;
}
}
}
{
spdlog::info("Test tp_size=4 tp_dim=1");
Size tp_size = 4;
Size tp_dim = 1;
std::vector<std::unique_ptr<MockLinearModule>> tp_modules;
for (Size tp_rank = 0; tp_rank < tp_size; ++tp_rank) {
auto module = std::make_unique<MockLinearModule>(64, 32, device, tp_dim, tp_rank, tp_size);
tp_modules.push_back(std::move(module));
}
// Verify each partition has correct shape
for (size_t i = 0; i < tp_modules.size(); ++i) {
const auto &weight = tp_modules[i]->get_weight();
const auto &bias = tp_modules[i]->get_bias();
// Weight should be partitioned along output dimension (dim 0)
if (weight->shape() != std::vector<size_t>({32, 16})) { // 64/4 = 16
spdlog::error("TP rank {}: Weight shape mismatch. Expected [32, 16], got [{}]",
i, formatShape(weight->shape()));
return false;
}
// Bias should be partitioned along output dimension
if (bias->shape() != std::vector<size_t>({32})) { // Bias not partitioned when tp_dim=1
spdlog::error("TP rank {}: Bias shape mismatch. Expected [32], got [{}]",
i, formatShape(bias->shape()));
return false;
}
spdlog::debug("TP rank {}: weight shape [{}], bias shape [{}]",
i, formatShape(weight->shape()), formatShape(bias->shape()));
;
tp_modules[i]->load_parameter_from_blob("weight", w_data.data());
tp_modules[i]->load_parameter_from_blob("bias", b_data.data());
auto weight_loaded = infinicore::Tensor::from_blob(
w_data.data(),
{32, 64},
infinicore::DataType::F32,
infinicore::Device::cpu())
->narrow({{1, i * 16, 16}})
->to(device); // Narrow to get the partition
auto bias_loaded = infinicore::Tensor::from_blob(
b_data.data(),
{32},
infinicore::DataType::F32,
infinicore::Device::cpu())
->to(device); // Narrow to get the partition
if (!tensorsAllClose(tp_modules[i]->get_weight(), weight_loaded, 1e-6, 1e-6)) {
spdlog::error("TP rank {}: Weight values do not match after load_parameter_from_blob", i);
return false;
}
if (!tensorsAllClose(tp_modules[i]->get_bias(), bias_loaded, 1e-6, 1e-6)) {
spdlog::error("TP rank {}: Bias values do not match after load_parameter_from_blob", i);
return false;
}
}
}
spdlog::info("=== All Tensor Parallel Parameter Tests Passed ===");
return true;
} catch (const std::exception &e) {
spdlog::error("Exception in testTensorParallelParameters: {}", e.what());
return false;
}
});
}
// Test 2: Advanced load state dict functionality (hierarchical modules)
TestResult NNModuleTest::testLoadStateDict() {
return measureTime("AdvancedLoadStateDict", [this]() {
......@@ -384,6 +566,8 @@ TestResult NNModuleTest::testParameterLoading() {
return false;
}
MockLinearModule module_row_parallel(3, 2, infinicore::Device(), 0, 1, 2);
spdlog::info("Parameter loading test passed");
return true;
} catch (const std::exception &e) {
......@@ -1708,16 +1892,17 @@ TestResult NNModuleTest::run() {
<< "InfiniCore nn::Module Test Suite\n"
<< "==============================================" << std::endl;
results.push_back(testBasicModuleCreation()); // Merged: creation + parameters + state_dict + load
results.push_back(testLoadStateDict()); // Advanced: hierarchical modules
results.push_back(testModuleHierarchy()); // Demonstrates hierarchical construction
results.push_back(testParameterLoading()); // Blob loading
results.push_back(testModuleLinear()); // Linear module comprehensive test
results.push_back(testModuleEmbedding()); // Embedding module test
results.push_back(testModuleRMSNorm()); // RMSNorm module test
results.push_back(testModuleRoPE()); // RoPE module test
results.push_back(testDtypeAssertion()); // Dtype assertion test
results.push_back(testTinyLlamaConstruction()); // Comprehensive: TinyLlama model test
results.push_back(testBasicModuleCreation()); // Merged: creation + parameters + state_dict + load
results.push_back(testTensorParallelParameters()); // Tensor-parallel parameters
results.push_back(testLoadStateDict()); // Advanced: hierarchical modules
results.push_back(testModuleHierarchy()); // Demonstrates hierarchical construction
results.push_back(testParameterLoading()); // Blob loading
results.push_back(testModuleLinear()); // Linear module comprehensive test
results.push_back(testModuleEmbedding()); // Embedding module test
results.push_back(testModuleRMSNorm()); // RMSNorm module test
results.push_back(testModuleRoPE()); // RoPE module test
results.push_back(testDtypeAssertion()); // Dtype assertion test
results.push_back(testTinyLlamaConstruction()); // Comprehensive: TinyLlama model test
// Check if all tests passed
bool all_passed = true;
......
......@@ -26,17 +26,25 @@ public:
INFINICORE_NN_PARAMETER(weight);
INFINICORE_NN_PARAMETER(bias);
MockLinearModule(int input_size, int output_size, const infinicore::Device &device)
: input_size_(input_size), output_size_(output_size), device_(device) {
MockLinearModule(int input_size, int output_size, const infinicore::Device &device,
Size tp_dim = 0, Size tp_rank = 0, Size tp_size = 1)
: input_size_(input_size), output_size_(output_size), device_(device),
tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size) {
// Initialize parameters using macros
INFINICORE_NN_PARAMETER_INIT(weight,
({static_cast<size_t>(output_size), static_cast<size_t>(input_size)},
infinicore::DataType::F32,
device));
device,
tp_dim_,
tp_rank_,
tp_size_));
INFINICORE_NN_PARAMETER_INIT(bias,
({static_cast<size_t>(output_size)},
infinicore::DataType::F32,
device));
device,
0,
tp_dim == 0 ? tp_rank_ : 0,
tp_dim == 0 ? tp_size_ : 1));
}
// Simple forward pass (conceptual - would need actual matrix operations)
......@@ -68,6 +76,10 @@ private:
int input_size_;
int output_size_;
infinicore::Device device_;
Size tp_dim_;
Size tp_rank_;
Size tp_size_;
};
class NNModuleTest : public TestFramework {
......@@ -76,16 +88,17 @@ public:
std::string getName() const override { return "NNModuleTest"; }
private:
TestResult testBasicModuleCreation(); // Merged: creation, parameters, state_dict, load_state_dict
TestResult testLoadStateDict(); // Advanced: hierarchical modules
TestResult testModuleHierarchy(); // Demonstrates proper hierarchical construction pattern
TestResult testParameterLoading(); // Test blob parameter loading
TestResult testModuleLinear(); // Comprehensive Linear module test
TestResult testModuleEmbedding(); // Embedding module test
TestResult testModuleRMSNorm(); // RMSNorm module test
TestResult testModuleRoPE(); // RoPE module test
TestResult testDtypeAssertion(); // Test dtype assertions when loading parameters
TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation
TestResult testBasicModuleCreation(); // Merged: creation, parameters, state_dict, load_state_dict
TestResult testTensorParallelParameters(); // Module with tensor parallel parameters
TestResult testLoadStateDict(); // Advanced: hierarchical modules
TestResult testModuleHierarchy(); // Demonstrates proper hierarchical construction pattern
TestResult testParameterLoading(); // Test blob parameter loading
TestResult testModuleLinear(); // Comprehensive Linear module test
TestResult testModuleEmbedding(); // Embedding module test
TestResult testModuleRMSNorm(); // RMSNorm module test
TestResult testModuleRoPE(); // RoPE module test
TestResult testDtypeAssertion(); // Test dtype assertions when loading parameters
TestResult testTinyLlamaConstruction(); // Comprehensive: construction + weight loading + validation
};
} // namespace infinicore::test
......
......@@ -33,11 +33,15 @@ Runtime *ContextImpl::getCpuRuntime() {
return runtime_table_[int(Device::Type::CPU)][0].get();
}
void ContextImpl::setDevice(Device device) {
void ContextImpl::setDevice(Device device, bool force_cpu) {
if (device == getCurrentRuntime()->device()) {
// Do nothing if the device is already set.
return;
}
if (device == Device(Device::Type::CPU, 0) && !force_cpu) {
// if not forced, no need to switch to CPU device runtime
return;
}
if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) {
// Lazy initialization of runtime if never set before.
......@@ -83,8 +87,8 @@ ContextImpl::ContextImpl() {
namespace context {
void setDevice(Device device) {
ContextImpl::singleton().setDevice(device);
void setDevice(Device device, bool force_cpu) {
ContextImpl::singleton().setDevice(device, force_cpu);
}
Device getDevice() {
......
......@@ -21,7 +21,7 @@ public:
Runtime *getCpuRuntime();
void setDevice(Device);
void setDevice(Device, bool force_cpu = false);
size_t getDeviceCount(Device::Type type);
......
#include "infinicore/nn/module.hpp"
#include <spdlog/spdlog.h>
#include <stdexcept>
namespace infinicore::nn {
......@@ -21,28 +22,28 @@ void Module::load_state_dict(const std::unordered_map<std::string, Tensor> &_sta
// Look up the corresponding tensor in the input state dict using the full name
auto it = _state_dict.find(param_full_name);
if (it != _state_dict.end()) {
// Assert dtype matches
if (param->dtype() != it->second->dtype()) {
throw std::runtime_error(
"dtype mismatch for parameter '" + param_full_name + "': "
"expected "
+ std::to_string(static_cast<int>(param->dtype())) + ", got " + std::to_string(static_cast<int>(it->second->dtype())));
}
param->copy_from(it->second);
this->load_parameter(param_full_name, it->second);
} else {
spdlog::warn("Parameter '{}' provided but not found in module.", param_full_name);
}
}
}
void Module::load_parameter(const std::string &name, const Tensor &param) {
auto existing_param = parameters_[name];
// Assert dtype matches
if (existing_param->dtype() != param->dtype()) {
throw std::runtime_error(
"dtype mismatch for parameter '" + name + "': "
"expected "
+ std::to_string(static_cast<int>(existing_param->dtype())) + ", got " + std::to_string(static_cast<int>(param->dtype())));
auto it = parameters_.find(name);
if (it != parameters_.end()) {
auto existing_param = it->second;
// Assert dtype matches
if (existing_param->dtype() != param->dtype()) {
throw std::runtime_error(
"dtype mismatch for parameter '" + name + "': "
"expected "
+ std::to_string(static_cast<int>(existing_param->dtype())) + ", got " + std::to_string(static_cast<int>(param->dtype())));
}
existing_param.load(param);
} else {
throw std::runtime_error("Parameter '" + name + "' not found in module.");
}
existing_param->copy_from(param);
}
void Module::load_parameter_from_blob(const std::string &name, const void *data) {
......
......@@ -3,29 +3,64 @@
#include "infinicore/context/context.hpp"
#include <cstring>
#include <stdexcept>
namespace infinicore::nn {
Parameter::Parameter()
: Tensor(Tensor::empty({}, DataType::F32, Device(Device::Type::CPU, 0), false)) {
}
inline Shape get_partipion_shape_(const Shape &shape, Size tp_dim, Size tp_size) {
if (tp_size <= 1) {
return shape;
}
Shape part_shape = shape;
if (tp_dim < shape.size()) {
if (shape[tp_dim] % tp_size != 0) {
throw std::runtime_error("Tensor dimension " + std::to_string(tp_dim) + " with size " + std::to_string(shape[tp_dim]) + " is not divisible by tensor parallel size " + std::to_string(tp_size) + ".");
}
part_shape[tp_dim] = shape[tp_dim] / tp_size;
}
return part_shape;
}
Parameter::Parameter(
const Shape &shape,
const DataType &dtype,
const Device &device)
: Tensor(Tensor::empty(shape, dtype, device, false)) {
const Device &device,
Size tp_dim,
Size tp_rank,
Size tp_size)
: Tensor(Tensor::empty(get_partipion_shape_(shape, tp_dim, tp_size), dtype, device, false)), tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size) {
if (tp_rank_ >= tp_size_) {
throw std::runtime_error("Tensor parallel rank " + std::to_string(tp_rank_) + " must be less than tensor parallel size " + std::to_string(tp_size_) + ".");
}
}
void Parameter::load_blob(const void *data) {
auto buffer = Tensor::empty(impl_->shape(), impl_->dtype(), Device(Device::Type::CPU, 0), true);
Shape expected_shape = Shape(impl_->shape());
expected_shape[tp_dim_] *= tp_size_;
auto buffer = Tensor::empty(expected_shape, impl_->dtype(), Device(Device::Type::CPU, 0), true);
std::memcpy(buffer->data(), data, buffer->nbytes());
this->load(buffer);
}
void Parameter::load(const Tensor &tensor) {
Shape expected_shape = Shape(impl_->shape());
expected_shape[tp_dim_] *= tp_size_;
if (expected_shape != tensor->shape()) {
throw std::runtime_error("Shape mismatch when loading tensor into parameter.");
}
if (impl_->dtype() != tensor->dtype()) {
throw std::runtime_error("Dtype mismatch when loading tensor into parameter.");
}
if (tp_size_ > 1) {
impl_->copy_from(tensor->narrow({{tp_dim_, tp_rank_ * impl_->size(tp_dim_), impl_->size(tp_dim_)}}));
// If parameter is on CPU, use direct memcpy; otherwise use H2D
if (impl_->device().getType() == Device::Type::CPU) {
infinicore::context::memcpyH2H(impl_->data(), buffer->data(), buffer->nbytes());
} else {
infinicore::context::memcpyH2D(impl_->data(), buffer->data(), buffer->nbytes());
infinicore::context::syncStream();
impl_->copy_from(tensor);
}
infinicore::context::syncStream();
}
} // namespace infinicore::nn
......@@ -9,7 +9,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i
) {
auto input_shape = input->shape();
auto weight_shape = weight->shape();
auto vocab_size = weight_shape[0];
// auto vocab_size = weight_shape[0];
auto embedding_dim = weight_shape[1];
// Assign memory to out variables
......@@ -23,11 +23,10 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i
void embedding_(Tensor out, Tensor input, Tensor weight) {
assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype()));
assert(infinicore::Device::Type::CPU == input->device());
assert(infinicore::Device::Type::CPU == input->device().getType());
auto input_shape = input->shape();
auto weight_shape = weight->shape();
auto vocab_size = weight_shape[0];
auto embedding_dim = weight_shape[1];
// Calculate the number of token
......@@ -47,7 +46,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) {
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int64_t idx = input_arr[i];
assert((idx >= 0) && (idx < vocab_size));
assert((idx >= 0) && (idx < weight_shape[0]));
std::memcpy(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
......@@ -57,7 +56,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) {
for (Size i = 0; i < counts; ++i) {
int32_t idx = input_arr[i];
assert((idx >= 0) && (idx < vocab_size));
assert((idx >= 0) && (idx < weight_shape[0]));
std::memcpy(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
......@@ -69,7 +68,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) {
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int64_t idx = input_arr[i];
assert((idx >= 0) && (idx < vocab_size));
assert((idx >= 0) && (idx < weight_shape[0]));
context::memcpyD2D(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
......@@ -78,7 +77,7 @@ void embedding_(Tensor out, Tensor input, Tensor weight) {
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());
for (Size i = 0; i < counts; ++i) {
int32_t idx = input_arr[i];
assert((idx >= 0) && (idx < vocab_size));
assert((idx >= 0) && (idx < weight_shape[0]));
context::memcpyD2D(out_ptr + i * bytes,
weight_ptr + idx * bytes,
bytes);
......
#include "infinicore/ops/rearrange.hpp"
#include "../../utils.hpp"
namespace infinicore::op {
......@@ -8,7 +9,9 @@ common::OpDispatcher<Rearrange::schema> &Rearrange::dispatcher() {
};
void Rearrange::execute(Tensor y, Tensor x) {
dispatcher().lookup(context::getDevice().getType())(y, x);
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x);
infinicore::context::setDevice(y->device());
dispatcher().lookup(y->device().getType())(y, x);
}
Tensor rearrange(Tensor x) {
......
......@@ -18,8 +18,8 @@ thread_local common::OpCache<size_t, infiniopRearrangeDescriptor_t> caches(
void calculate(Tensor y, Tensor x) {
size_t seed = hash_combine(y, x);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto device_type = y->device().getType();
auto device_index = y->device().getIndex();
auto &cache = caches.getCache(device_type, device_index);
......
......@@ -16,7 +16,8 @@ inline void bind(py::module &m) {
py::arg("device_type"));
m.def("set_device", &setDevice,
"Set the current active device",
py::arg("device"));
py::arg("device"),
py::arg("force_cpu"));
// Stream and handle management
m.def("get_stream", &getStream, "Get the current stream");
......
......@@ -32,3 +32,15 @@ inline struct SpdlogInitializer {
throw std::runtime_error(#call " failed with error: " + std::string(infini_status_string(ret))); \
} \
} while (false)
#define INFINICORE_ASSERT_TENSORS_SAME_DEVICE(FIRST___, ...) \
do { \
const auto &first_device___ = (FIRST___)->device(); \
for (const auto &tensor___ : {__VA_ARGS__}) { \
if (first_device___ != (tensor___)->device()) { \
throw std::runtime_error("Tensor devices mismatch " \
+ first_device___.toString() + " vs " \
+ (tensor___)->device().toString() + "."); \
} \
} \
} while (0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment