Unverified Commit f73d6237 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #742 from InfiniTensor/issue/741

issue/741  暴露 infinicore_cpp_api 类继承
parents 7a48b0de c74dfaea
......@@ -7,6 +7,11 @@ class Parameter : public Tensor {
public:
Parameter();
Parameter(const Tensor &tensor,
Size tp_dim = 0,
Size tp_rank = 0,
Size tp_size = 1);
Parameter(const Shape &shape,
const DataType &dtype,
const Device &device,
......
......@@ -84,6 +84,8 @@ public:
TensorImpl *operator->();
const TensorImpl *operator->() const;
operator bool() const;
protected:
Tensor(std::shared_ptr<TensorImpl> impl) : impl_(std::move(impl)) {}
std::shared_ptr<TensorImpl> impl_;
......
......@@ -22,7 +22,11 @@ void Module::load_parameter(const std::string &name, const Tensor &param) {
auto it = all_params.find(name);
if (it != all_params.end()) {
auto existing_param = it->second;
try {
existing_param.load(param);
} catch (const std::exception &e) {
throw std::runtime_error("Error loading parameter '" + name + "'. \n" + e.what());
}
return;
}
......@@ -37,14 +41,11 @@ void Module::load_parameter_(const std::string &name, const Tensor &param) {
auto it = parameters_.find(name);
if (it != parameters_.end()) {
auto existing_param = it->second;
// Assert dtype matches
if (existing_param->dtype() != param->dtype()) {
throw std::runtime_error(
"dtype mismatch for parameter '" + name + "': "
"expected "
+ std::to_string(static_cast<int>(existing_param->dtype())) + ", got " + std::to_string(static_cast<int>(param->dtype())));
}
try {
existing_param.load(param);
} catch (const std::exception &e) {
throw std::runtime_error("Error loading parameter '" + name + "'. \n" + e.what());
}
return;
}
......
......@@ -7,7 +7,7 @@
namespace infinicore::nn {
Parameter::Parameter()
: Tensor(Tensor::empty({}, DataType::F32, Device(Device::Type::CPU, 0), false)) {
: Tensor() {
}
inline Shape get_partipion_shape_(const Shape &shape, Size tp_dim, Size tp_size) {
......@@ -24,6 +24,12 @@ inline Shape get_partipion_shape_(const Shape &shape, Size tp_dim, Size tp_size)
return part_shape;
}
Parameter::Parameter(const Tensor &tensor, Size tp_dim, Size tp_rank, Size tp_size) : Tensor(tensor), tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size) {
if (tp_rank_ >= tp_size_) {
throw std::runtime_error("Tensor parallel rank " + std::to_string(tp_rank_) + " must be less than tensor parallel size " + std::to_string(tp_size_) + ".");
}
}
Parameter::Parameter(
const Shape &shape,
const DataType &dtype,
......@@ -31,10 +37,7 @@ Parameter::Parameter(
Size tp_dim,
Size tp_rank,
Size tp_size)
: Tensor(Tensor::empty(get_partipion_shape_(shape, tp_dim, tp_size), dtype, device, false)), tp_dim_(tp_dim), tp_rank_(tp_rank), tp_size_(tp_size) {
if (tp_rank_ >= tp_size_) {
throw std::runtime_error("Tensor parallel rank " + std::to_string(tp_rank_) + " must be less than tensor parallel size " + std::to_string(tp_size_) + ".");
}
: Parameter(Tensor::empty(get_partipion_shape_(shape, tp_dim, tp_size), dtype, device, false), tp_dim, tp_rank, tp_size) {
}
void Parameter::load_blob(const void *data) {
......@@ -50,10 +53,10 @@ void Parameter::load(const Tensor &tensor) {
expected_shape[tp_dim_] *= tp_size_;
if (expected_shape != tensor->shape()) {
throw std::runtime_error("Shape mismatch when loading tensor into parameter.");
throw std::runtime_error("Shape mismatch when loading tensor into parameter. Weight: " + impl_->info() + ", Tensor: " + tensor->info() + ".");
}
if (impl_->dtype() != tensor->dtype()) {
throw std::runtime_error("Dtype mismatch when loading tensor into parameter.");
throw std::runtime_error("Dtype mismatch when loading tensor into parameter. Weight: " + impl_->info() + ", Tensor: " + tensor->info() + ".");
}
if (tp_size_ > 1) {
impl_->copy_from(tensor->narrow({{tp_dim_, tp_rank_ * impl_->size(tp_dim_), impl_->size(tp_dim_)}}));
......
......@@ -60,6 +60,10 @@ Tensor Tensor::strided_from_blob(void *raw_ptr, const Shape &shape, const Stride
return Tensor{TensorImpl::strided_from_blob(raw_ptr, shape, strides, dtype, device)};
}
Tensor::operator bool() const {
return impl_ != nullptr;
}
TensorMetaData::TensorMetaData(const Shape &_shape, const Strides &_strides, const DataType &_dtype)
: shape(_shape), strides(_strides), dtype(_dtype) {
INFINICORE_CHECK_ERROR(infiniopCreateTensorDescriptor(&desc, shape.size(), shape.data(), strides.data(), (infiniDtype_t)dtype));
......
......@@ -3,6 +3,7 @@
#include "infinicore/tensor.hpp"
#include <spdlog/spdlog.h>
#include <stdexcept>
namespace infinicore {
Tensor TensorImpl::unsqueeze(size_t dim) const {
......@@ -26,8 +27,10 @@ Tensor TensorImpl::narrow(const std::vector<TensorSliceParams> &slices) const {
size_t offset = data_.offset;
for (const auto &slice : slices) {
assert(slice.len > 0);
assert(meta_.shape[slice.dim] >= slice.start + slice.len);
if (meta_.shape[slice.dim] < slice.start + slice.len) {
spdlog::error("Invalid slice [dim={}, start={}, len={}] on {}.", slice.dim, slice.start, slice.len, this->info());
throw std::runtime_error("Invalid slice on tensor.");
}
new_shape[slice.dim] = slice.len;
offset += slice.start * meta_.strides[slice.dim] * dsize(meta_.dtype);
}
......
......@@ -348,8 +348,6 @@ target("infiniccl")
set_installdir(os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini"))
target_end()
target("infinicore_c_api")
target("infinicore_c_api")
set_kind("phony")
add_deps("infiniop", "infinirt", "infiniccl")
......@@ -360,6 +358,7 @@ target("infinicore_cpp_api")
set_kind("shared")
add_deps("infiniop", "infinirt", "infiniccl")
set_languages("cxx17")
set_symbols("visibility")
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment