Commit baafb916 authored by Zimin Li's avatar Zimin Li
Browse files

issue/46: change BinaryInfo create method, update binary contiguous check

parent 377f6e20
#ifndef __INFINIOP_BINARY_H__
#define __INFINIOP_BINARY_H__
#include "../devices/cpu/common_cpu.h"
#include "../operator.h"
#include "../tensor.h"
#include <algorithm>
#include <numeric>
/**
......@@ -52,24 +52,9 @@ namespace op::binary {
// Stores metadata for binary operations on CPU
struct BinaryInfo {
private:
BinaryInfo(infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc)
: ndim(c_desc->ndim()),
c_shape(std::move(c_desc->shape())),
a_shape(std::move(a_desc->shape())),
b_shape(std::move(b_desc->shape())),
c_strides(std::move(c_desc->strides())),
a_strides(std::move(a_desc->strides())),
b_strides(std::move(b_desc->strides())) {
this->c_data_size = std::accumulate(c_shape.begin(), c_shape.end(), size_t(1), std::multiplies<size_t>());
this->broadcasted = (a_strides != c_strides) || (b_strides != c_strides);
}
public:
size_t c_data_size;
size_t ndim;
bool contiguous;
bool broadcasted;
std::vector<size_t> c_shape;
std::vector<size_t> a_shape;
......@@ -77,20 +62,54 @@ public:
std::vector<ptrdiff_t> c_strides;
std::vector<ptrdiff_t> a_strides;
std::vector<ptrdiff_t> b_strides;
};
static infiniStatus_t create(
BinaryInfo **instance,
inline infiniStatus_t createBinaryInfo(BinaryInfo &info,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
if (!c_desc || !a_desc || !b_desc) {
return INFINI_STATUS_BAD_PARAM;
}
*instance = new BinaryInfo(c_desc, a_desc, b_desc);
return INFINI_STATUS_SUCCESS;
const auto &c_shape = c_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
const auto &c_strides = c_desc->strides();
const auto &a_strides = a_desc->strides();
const auto &b_strides = b_desc->strides();
info.c_data_size = std::accumulate(c_shape.begin(), c_shape.end(), size_t(1), std::multiplies<size_t>());
info.ndim = c_desc->ndim();
info.contiguous = c_desc->isContiguous() && a_desc->isContiguous() && b_desc->isContiguous();
// Check if a tensor is broadcasted by checking its shape and strides
auto isBroadcasted = [](const std::vector<size_t> &shape, const std::vector<ptrdiff_t> &strides) {
return std::any_of(
shape.begin(), shape.end(),
[&, i = 0](const auto &) mutable {
return shape[i] != 1 && strides[i++] == 0;
});
};
// Destination cannot have broadcast setup
if (isBroadcasted(c_shape, c_strides)) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
};
const bool ndim_match = (c_desc->ndim() == a_desc->ndim()) && (c_desc->ndim() == b_desc->ndim());
info.broadcasted = !info.contiguous && (!ndim_match || isBroadcasted(a_shape, a_strides) || isBroadcasted(b_shape, b_strides));
info.c_shape = std::move(c_shape);
info.a_shape = std::move(a_shape);
info.b_shape = std::move(b_shape);
info.c_strides = std::move(c_strides);
info.a_strides = std::move(a_strides);
info.b_strides = std::move(b_strides);
return INFINI_STATUS_SUCCESS;
}
} // namespace op::binary
#endif // __INFINIOP_BINARY_H__
#ifndef __INFINIOP_BINARY_CPU_H__
#define __INFINIOP_BINARY_CPU_H__
#include "../../devices/cpu/common_cpu.h"
#include "../binary.h"
#include <utility>
......@@ -18,11 +19,9 @@ void calculate(op::binary::BinaryInfo info, void *c, const void *a, const void *
#pragma omp parallel for
for (ptrdiff_t i = 0; i < data_size; ++i) {
size_t a_index = info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data())
: op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data());
size_t b_index = info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data())
: op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data());
size_t c_index = op::common_cpu::indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data());
size_t a_index = info.contiguous ? i : (info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data()) : op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data()));
size_t b_index = info.contiguous ? i : (info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data()) : op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data()));
size_t c_index = info.contiguous ? i : (op::common_cpu::indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data()));
c_[c_index] = BinaryOp{}(a_[a_index], b_[b_index], std::forward<Args>(args)...);
}
......@@ -38,11 +37,9 @@ void calculate(op::binary::BinaryInfo info, void *c, const void *a, const void *
#pragma omp parallel for
for (ptrdiff_t i = 0; i < data_size; ++i) {
size_t a_index = info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data())
: op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data());
size_t b_index = info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data())
: op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data());
size_t c_index = op::common_cpu::indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data());
size_t a_index = info.contiguous ? i : (info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data()) : op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data()));
size_t b_index = info.contiguous ? i : (info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data()) : op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data()));
size_t c_index = info.contiguous ? i : (op::common_cpu::indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data()));
if constexpr (std::is_same_v<Tdata, fp16_t>) {
float a_val = utils::cast<float>(a_[a_index]);
......
......@@ -22,18 +22,17 @@ infiniStatus_t Descriptor::create(
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
op::binary::BinaryInfo *info = nullptr;
CHECK_STATUS(op::binary::BinaryInfo::create(&info, out_desc, up_desc, gate_desc));
op::binary::BinaryInfo info;
CHECK_STATUS(op::binary::createBinaryInfo(info, out_desc, up_desc, gate_desc));
// Create descriptor
*desc_ptr = new Descriptor(
dtype,
*info,
std::move(info),
nullptr,
handle->device,
handle->device_id);
delete info;
return INFINI_STATUS_SUCCESS;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment