Commit baafb916 authored by Zimin Li's avatar Zimin Li
Browse files

issue/46: change BinaryInfo create method, update binary contiguous check

parent 377f6e20
#ifndef __INFINIOP_BINARY_H__ #ifndef __INFINIOP_BINARY_H__
#define __INFINIOP_BINARY_H__ #define __INFINIOP_BINARY_H__
#include "../devices/cpu/common_cpu.h"
#include "../operator.h" #include "../operator.h"
#include "../tensor.h" #include "../tensor.h"
#include <algorithm>
#include <numeric> #include <numeric>
/** /**
...@@ -52,24 +52,9 @@ namespace op::binary { ...@@ -52,24 +52,9 @@ namespace op::binary {
// Stores metadata for binary operations on CPU // Stores metadata for binary operations on CPU
struct BinaryInfo { struct BinaryInfo {
private:
BinaryInfo(infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc)
: ndim(c_desc->ndim()),
c_shape(std::move(c_desc->shape())),
a_shape(std::move(a_desc->shape())),
b_shape(std::move(b_desc->shape())),
c_strides(std::move(c_desc->strides())),
a_strides(std::move(a_desc->strides())),
b_strides(std::move(b_desc->strides())) {
this->c_data_size = std::accumulate(c_shape.begin(), c_shape.end(), size_t(1), std::multiplies<size_t>());
this->broadcasted = (a_strides != c_strides) || (b_strides != c_strides);
}
public:
size_t c_data_size; size_t c_data_size;
size_t ndim; size_t ndim;
bool contiguous;
bool broadcasted; bool broadcasted;
std::vector<size_t> c_shape; std::vector<size_t> c_shape;
std::vector<size_t> a_shape; std::vector<size_t> a_shape;
...@@ -77,20 +62,54 @@ public: ...@@ -77,20 +62,54 @@ public:
std::vector<ptrdiff_t> c_strides; std::vector<ptrdiff_t> c_strides;
std::vector<ptrdiff_t> a_strides; std::vector<ptrdiff_t> a_strides;
std::vector<ptrdiff_t> b_strides; std::vector<ptrdiff_t> b_strides;
};
static infiniStatus_t create( inline infiniStatus_t createBinaryInfo(BinaryInfo &info,
BinaryInfo **instance,
infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) { infiniopTensorDescriptor_t b_desc) {
if (!c_desc || !a_desc || !b_desc) { if (!c_desc || !a_desc || !b_desc) {
return INFINI_STATUS_BAD_PARAM; return INFINI_STATUS_BAD_PARAM;
} }
*instance = new BinaryInfo(c_desc, a_desc, b_desc); const auto &c_shape = c_desc->shape();
return INFINI_STATUS_SUCCESS; const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
const auto &c_strides = c_desc->strides();
const auto &a_strides = a_desc->strides();
const auto &b_strides = b_desc->strides();
info.c_data_size = std::accumulate(c_shape.begin(), c_shape.end(), size_t(1), std::multiplies<size_t>());
info.ndim = c_desc->ndim();
info.contiguous = c_desc->isContiguous() && a_desc->isContiguous() && b_desc->isContiguous();
// Check if a tensor is broadcasted by checking its shape and strides
auto isBroadcasted = [](const std::vector<size_t> &shape, const std::vector<ptrdiff_t> &strides) {
return std::any_of(
shape.begin(), shape.end(),
[&, i = 0](const auto &) mutable {
return shape[i] != 1 && strides[i++] == 0;
});
};
// Destination cannot have broadcast setup
if (isBroadcasted(c_shape, c_strides)) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
} }
}; const bool ndim_match = (c_desc->ndim() == a_desc->ndim()) && (c_desc->ndim() == b_desc->ndim());
info.broadcasted = !info.contiguous && (!ndim_match || isBroadcasted(a_shape, a_strides) || isBroadcasted(b_shape, b_strides));
info.c_shape = std::move(c_shape);
info.a_shape = std::move(a_shape);
info.b_shape = std::move(b_shape);
info.c_strides = std::move(c_strides);
info.a_strides = std::move(a_strides);
info.b_strides = std::move(b_strides);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::binary } // namespace op::binary
#endif // __INFINIOP_BINARY_H__ #endif // __INFINIOP_BINARY_H__
#ifndef __INFINIOP_BINARY_CPU_H__ #ifndef __INFINIOP_BINARY_CPU_H__
#define __INFINIOP_BINARY_CPU_H__ #define __INFINIOP_BINARY_CPU_H__
#include "../../devices/cpu/common_cpu.h"
#include "../binary.h" #include "../binary.h"
#include <utility> #include <utility>
...@@ -18,11 +19,9 @@ void calculate(op::binary::BinaryInfo info, void *c, const void *a, const void * ...@@ -18,11 +19,9 @@ void calculate(op::binary::BinaryInfo info, void *c, const void *a, const void *
#pragma omp parallel for #pragma omp parallel for
for (ptrdiff_t i = 0; i < data_size; ++i) { for (ptrdiff_t i = 0; i < data_size; ++i) {
size_t a_index = info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data()) size_t a_index = info.contiguous ? i : (info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data()) : op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data()));
: op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data()); size_t b_index = info.contiguous ? i : (info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data()) : op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data()));
size_t b_index = info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data()) size_t c_index = info.contiguous ? i : (op::common_cpu::indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data()));
: op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data());
size_t c_index = op::common_cpu::indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data());
c_[c_index] = BinaryOp{}(a_[a_index], b_[b_index], std::forward<Args>(args)...); c_[c_index] = BinaryOp{}(a_[a_index], b_[b_index], std::forward<Args>(args)...);
} }
...@@ -38,11 +37,9 @@ void calculate(op::binary::BinaryInfo info, void *c, const void *a, const void * ...@@ -38,11 +37,9 @@ void calculate(op::binary::BinaryInfo info, void *c, const void *a, const void *
#pragma omp parallel for #pragma omp parallel for
for (ptrdiff_t i = 0; i < data_size; ++i) { for (ptrdiff_t i = 0; i < data_size; ++i) {
size_t a_index = info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data()) size_t a_index = info.contiguous ? i : (info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data()) : op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data()));
: op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data()); size_t b_index = info.contiguous ? i : (info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data()) : op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data()));
size_t b_index = info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data()) size_t c_index = info.contiguous ? i : (op::common_cpu::indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data()));
: op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data());
size_t c_index = op::common_cpu::indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data());
if constexpr (std::is_same_v<Tdata, fp16_t>) { if constexpr (std::is_same_v<Tdata, fp16_t>) {
float a_val = utils::cast<float>(a_[a_index]); float a_val = utils::cast<float>(a_[a_index]);
......
...@@ -22,18 +22,17 @@ infiniStatus_t Descriptor::create( ...@@ -22,18 +22,17 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_SHAPE; return INFINI_STATUS_BAD_TENSOR_SHAPE;
} }
op::binary::BinaryInfo *info = nullptr; op::binary::BinaryInfo info;
CHECK_STATUS(op::binary::BinaryInfo::create(&info, out_desc, up_desc, gate_desc)); CHECK_STATUS(op::binary::createBinaryInfo(info, out_desc, up_desc, gate_desc));
// Create descriptor // Create descriptor
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
dtype, dtype,
*info, std::move(info),
nullptr, nullptr,
handle->device, handle->device,
handle->device_id); handle->device_id);
delete info;
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment