"vscode:/vscode.git/clone" did not exist on "c387d9c0923e5d380237e75d53a9d50651e3c782"
Unverified Commit f6a645a3 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

issue/296 isContigous tolerates length 1 dimension (#297)

parent f88d4ad8
......@@ -44,7 +44,7 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (!out_desc->isContiguous(0, 2)) {
if (!out_desc->isContiguous()) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
......
......@@ -104,11 +104,7 @@ public:
// Last dimension of x and y must be contiguous
CHECK_OR_RETURN(y_desc->stride(2) == 1 && x_desc->stride(2) == 1, INFINI_STATUS_BAD_TENSOR_STRIDES);
// sin table and cos table must be totally contiguous
CHECK_OR_RETURN(sin_desc->stride(1) == 1
&& cos_desc->stride(1) == 1
&& sin_desc->stride(0) == ptrdiff_t(table_dim)
&& cos_desc->stride(0) == ptrdiff_t(table_dim),
INFINI_STATUS_BAD_TENSOR_STRIDES);
CHECK_OR_RETURN(sin_desc->isContiguous() && cos_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES);
return utils::Result<RoPEInfo>(RoPEInfo{
data_type,
......
......@@ -34,8 +34,14 @@ public:
std::vector<ptrdiff_t> strides() const;
ptrdiff_t stride(size_t i) const;
std::vector<ptrdiff_t> getByteStrides() const;
// Whether dimensions in [dim_start, dim_end] can be merged into a single dimension
bool isMergable(size_t dim_start, size_t dim_end) const;
bool isContiguous(size_t dim_) const;
bool isContiguous(size_t dim_start, size_t dim_end) const;
bool isContiguous() const;
// Total number of elements in the tensor
size_t numel() const;
// a dim is broadcasted if it's corresponding stride is 0 but dim > 1
......
......@@ -70,19 +70,63 @@ std::vector<ptrdiff_t> InfiniopTensorDescriptor::getByteStrides() const {
return byte_strides;
}
bool InfiniopTensorDescriptor::isContiguous(size_t dim_start, size_t dim_end) const {
if (ndim() == 0) {
bool InfiniopTensorDescriptor::isContiguous(size_t dim_) const {
if (dim(dim_) == 1) {
return true;
}
return stride(dim_) == ptrdiff_t(1);
}
bool InfiniopTensorDescriptor::isMergable(size_t dim_start, size_t dim_end) const {
if (dim_start > dim_end) {
throw std::invalid_argument("Invalid input");
} else if (dim_start == dim_end) {
return true;
}
for (size_t i = dim_start + 1; i <= dim_end; i++) {
// Slice out shape and strides from dim_start to dim_end, excluding 1-sized dimensions.
// Return false at once if any effective broadcast (0-strided) dimension is found.
std::vector<size_t> shape_;
std::vector<ptrdiff_t> strides_;
for (size_t i = dim_start; i <= dim_end; i++) {
if (dim(i) != 1) {
if (stride(i) == 0) {
return false;
}
shape_.push_back(dim(i));
strides_.push_back(stride(i));
}
}
auto ndim_ = shape_.size();
for (size_t i = 1; i < ndim_; i++) {
if (stride(i - 1) != static_cast<ptrdiff_t>(dim(i)) * stride(i)) {
return false;
}
}
return true;
}
bool InfiniopTensorDescriptor::isContiguous(size_t dim_start, size_t dim_end) const {
if (dim_start > dim_end) {
throw std::invalid_argument("Invalid input");
}
if (!isMergable(dim_start, dim_end)) {
return false;
}
return stride(dim_end) == ptrdiff_t(1);
}
bool InfiniopTensorDescriptor::isContiguous() const {
if (ndim() == 0) {
return true;
}
return isContiguous(0, ndim() - 1);
}
......@@ -118,7 +162,7 @@ utils::Result<infiniopTensorDescriptor_t> InfiniopTensorDescriptor::dimMerge(siz
index++;
}
CHECK_OR_RETURN(isContiguous(dim_start, dim_end), INFINI_STATUS_BAD_PARAM);
CHECK_OR_RETURN(isMergable(dim_start, dim_end), INFINI_STATUS_BAD_PARAM);
new_shape[index] = 1;
for (size_t i = dim_start; i <= dim_end; i++) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment