Commit ee68b55d authored by Zimin Li's avatar Zimin Li
Browse files

issue/46: Move tensor broadcast check functions to infiniopTensorDescriptor_t

parent baafb916
......@@ -3,7 +3,6 @@
#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
#include <numeric>
/**
......@@ -73,39 +72,23 @@ inline infiniStatus_t createBinaryInfo(BinaryInfo &info,
return INFINI_STATUS_BAD_PARAM;
}
const auto &c_shape = c_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
const auto &c_strides = c_desc->strides();
const auto &a_strides = a_desc->strides();
const auto &b_strides = b_desc->strides();
info.c_data_size = std::accumulate(c_shape.begin(), c_shape.end(), size_t(1), std::multiplies<size_t>());
info.c_data_size = c_desc->numel();
info.ndim = c_desc->ndim();
info.contiguous = c_desc->isContiguous() && a_desc->isContiguous() && b_desc->isContiguous();
// Check if a tensor is broadcasted by checking its shape and strides
auto isBroadcasted = [](const std::vector<size_t> &shape, const std::vector<ptrdiff_t> &strides) {
return std::any_of(
shape.begin(), shape.end(),
[&, i = 0](const auto &) mutable {
return shape[i] != 1 && strides[i++] == 0;
});
};
// Destination cannot have broadcast setup
if (isBroadcasted(c_shape, c_strides)) {
if (c_desc->hasBroadcastDim()) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
const bool ndim_match = (c_desc->ndim() == a_desc->ndim()) && (c_desc->ndim() == b_desc->ndim());
info.broadcasted = !info.contiguous && (!ndim_match || isBroadcasted(a_shape, a_strides) || isBroadcasted(b_shape, b_strides));
info.broadcasted = !info.contiguous && (!ndim_match || a_desc->hasBroadcastDim() || b_desc->hasBroadcastDim());
info.c_shape = std::move(c_shape);
info.a_shape = std::move(a_shape);
info.b_shape = std::move(b_shape);
info.c_strides = std::move(c_strides);
info.a_strides = std::move(a_strides);
info.b_strides = std::move(b_strides);
info.c_shape = std::move(c_desc->shape());
info.a_shape = std::move(a_desc->shape());
info.b_shape = std::move(b_desc->shape());
info.c_strides = std::move(c_desc->strides());
info.a_strides = std::move(a_desc->strides());
info.b_strides = std::move(b_desc->strides());
return INFINI_STATUS_SUCCESS;
}
......
......@@ -28,6 +28,10 @@ public:
bool isContiguous() const;
size_t numel() const;
// a dim is broadcasted if it's corresponding stride is 0 but dim > 1
bool hasBroadcastDim() const;
std::vector<size_t> getBroadcastDim() const;
infiniopTensorDescriptor_t dimMerge(size_t dim_start, size_t dim_end) const;
infiniopTensorDescriptor_t dimSplit(size_t axis, const std::vector<size_t> &dims) const;
infiniopTensorDescriptor_t dimPermute(const std::vector<size_t> &order) const;
......
#include "../utils.h"
#include "tensor.h"
#include <algorithm>
#include <cstring>
#include <functional>
#include <numeric>
......@@ -85,6 +86,24 @@ bool InfiniopTensorDescriptor::isContiguous() const {
return isContiguous(0, ndim() - 1);
}
bool InfiniopTensorDescriptor::hasBroadcastDim() const {
return std::any_of(
_shape.begin(), _shape.end(),
[&, i = 0](const auto &) mutable {
return _shape[i] != 1 && _strides[i++] == 0;
});
}
std::vector<size_t> InfiniopTensorDescriptor::getBroadcastDim() const {
std::vector<size_t> res;
for (size_t i = 0; i < ndim(); ++i) {
if (_shape[i] != 1 && _strides[i] == 0) {
res.push_back(i);
}
}
return res;
}
infiniopTensorDescriptor_t InfiniopTensorDescriptor::dimMerge(size_t dim_start, size_t dim_end) const {
if (dim_start > dim_end || dim_end >= ndim()) {
return nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment