"docs/XcodeGuide.md" did not exist on "642acbd61235dc68f606237193cf7e7c4a61af67"
Unverified Commit f6a645a3 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

issue/296 isContigous tolerates length 1 dimension (#297)

parent f88d4ad8
...@@ -44,7 +44,7 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h ...@@ -44,7 +44,7 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
return INFINI_STATUS_BAD_TENSOR_SHAPE; return INFINI_STATUS_BAD_TENSOR_SHAPE;
} }
if (!out_desc->isContiguous(0, 2)) { if (!out_desc->isContiguous()) {
return INFINI_STATUS_BAD_TENSOR_STRIDES; return INFINI_STATUS_BAD_TENSOR_STRIDES;
} }
......
...@@ -104,11 +104,7 @@ public: ...@@ -104,11 +104,7 @@ public:
// Last dimension of x and y must be contiguous // Last dimension of x and y must be contiguous
CHECK_OR_RETURN(y_desc->stride(2) == 1 && x_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES); CHECK_OR_RETURN(y_desc->stride(2) == 1 && x_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
// sin table and cos table must be totally contiguous // sin table and cos table must be totally contiguous
CHECK_OR_RETURN(sin_desc->stride(1) == 1 CHECK_OR_RETURN(sin_desc->isContiguous() && cos_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES);
&& cos_desc->stride(1) == 1
&& sin_desc->stride(0) == ptrdiff_t(table_dim)
&& cos_desc->stride(0) == ptrdiff_t(table_dim),
INFINI_STATUS_BAD_TENSOR_STRIDES);
return utils::Result<RoPEInfo>(RoPEInfo{ return utils::Result<RoPEInfo>(RoPEInfo{
data_type, data_type,
......
...@@ -34,8 +34,14 @@ public: ...@@ -34,8 +34,14 @@ public:
std::vector<ptrdiff_t> strides() const; std::vector<ptrdiff_t> strides() const;
ptrdiff_t stride(size_t i) const; ptrdiff_t stride(size_t i) const;
std::vector<ptrdiff_t> getByteStrides() const; std::vector<ptrdiff_t> getByteStrides() const;
// Whether dimensions in [dim_start, dim_end] can be merged into a single dimension
bool isMergable(size_t dim_start, size_t dim_end) const;
bool isContiguous(size_t dim_) const;
bool isContiguous(size_t dim_start, size_t dim_end) const; bool isContiguous(size_t dim_start, size_t dim_end) const;
bool isContiguous() const; bool isContiguous() const;
// Total number of elements in the tensor
size_t numel() const; size_t numel() const;
// a dim is broadcasted if it's corresponding stride is 0 but dim > 1 // a dim is broadcasted if it's corresponding stride is 0 but dim > 1
......
...@@ -70,19 +70,63 @@ std::vector<ptrdiff_t> InfiniopTensorDescriptor::getByteStrides() const { ...@@ -70,19 +70,63 @@ std::vector<ptrdiff_t> InfiniopTensorDescriptor::getByteStrides() const {
return byte_strides; return byte_strides;
} }
bool InfiniopTensorDescriptor::isContiguous(size_t dim_start, size_t dim_end) const { bool InfiniopTensorDescriptor::isContiguous(size_t dim_) const {
if (ndim() == 0) { if (dim(dim_) == 1) {
return true;
}
return stride(dim_) == ptrdiff_t(1);
}
bool InfiniopTensorDescriptor::isMergable(size_t dim_start, size_t dim_end) const {
if (dim_start > dim_end) {
throw std::invalid_argument("Invalid input");
} else if (dim_start == dim_end) {
return true; return true;
} }
for (size_t i = dim_start + 1; i <= dim_end; i++) {
// Slice out shape and strides from dim_start to dim_end, excluding 1-sized dimensions.
// Return false at once if any effective broadcast (0-strided) dimension is found.
std::vector<size_t> shape_;
std::vector<ptrdiff_t> strides_;
for (size_t i = dim_start; i <= dim_end; i++) {
if (dim(i) != 1) {
if (stride(i) == 0) {
return false;
}
shape_.push_back(dim(i));
strides_.push_back(stride(i));
}
}
auto ndim_ = shape_.size();
for (size_t i = 1; i < ndim_; i++) {
if (stride(i - 1) != static_cast<ptrdiff_t>(dim(i)) * stride(i)) { if (stride(i - 1) != static_cast<ptrdiff_t>(dim(i)) * stride(i)) {
return false; return false;
} }
} }
return true; return true;
} }
bool InfiniopTensorDescriptor::isContiguous(size_t dim_start, size_t dim_end) const {
if (dim_start > dim_end) {
throw std::invalid_argument("Invalid input");
}
if (!isMergable(dim_start, dim_end)) {
return false;
}
return stride(dim_end) == ptrdiff_t(1);
}
bool InfiniopTensorDescriptor::isContiguous() const { bool InfiniopTensorDescriptor::isContiguous() const {
if (ndim() == 0) {
return true;
}
return isContiguous(0, ndim() - 1); return isContiguous(0, ndim() - 1);
} }
...@@ -118,7 +162,7 @@ utils::Result<infiniopTensorDescriptor_t> InfiniopTensorDescriptor::dimMerge(siz ...@@ -118,7 +162,7 @@ utils::Result<infiniopTensorDescriptor_t> InfiniopTensorDescriptor::dimMerge(siz
index++; index++;
} }
CHECK_OR_RETURN(isContiguous(dim_start, dim_end), INFINI_STATUS_BAD_PARAM); CHECK_OR_RETURN(isMergable(dim_start, dim_end), INFINI_STATUS_BAD_PARAM);
new_shape[index] = 1; new_shape[index] = 1;
for (size_t i = dim_start; i <= dim_end; i++) { for (size_t i = dim_start; i <= dim_end; i++) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment