Commit e2ca3642 authored by Zimin Li's avatar Zimin Li
Browse files

issue/46: abstract out binaryInfo, remove binary check, add SAME_VEC macro, fix misc., etc.

parent c686f0e8
#ifndef __INFINIOP_BINARY_H__
#define __INFINIOP_BINARY_H__
#include "../tensor.h"
#include <numeric>
namespace op::binary {
// Stores metadata for binary operations on CPU
struct BinaryInfo {
private:
BinaryInfo(infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc)
: ndim(c_desc->ndim()),
c_shape(std::move(c_desc->shape())),
a_shape(std::move(a_desc->shape())),
b_shape(std::move(b_desc->shape())),
c_strides(std::move(c_desc->strides())),
a_strides(std::move(a_desc->strides())),
b_strides(std::move(b_desc->strides())) {
this->c_data_size = std::accumulate(c_shape.begin(), c_shape.end(), 1ULL, std::multiplies<size_t>());
this->broadcasted = (a_strides != c_strides) || (b_strides != c_strides);
}
public:
size_t c_data_size;
size_t ndim;
bool broadcasted;
std::vector<size_t> c_shape;
std::vector<size_t> a_shape;
std::vector<size_t> b_shape;
std::vector<ptrdiff_t> c_strides;
std::vector<ptrdiff_t> a_strides;
std::vector<ptrdiff_t> b_strides;
static infiniStatus_t create(
BinaryInfo **instance,
infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) {
if (!c_desc || !a_desc || !b_desc) {
return INFINI_STATUS_BAD_PARAM;
}
try {
*instance = new BinaryInfo(c_desc, a_desc, b_desc);
return INFINI_STATUS_SUCCESS;
} catch (const std::exception &) {
return INFINI_STATUS_INTERNAL_ERROR;
}
}
};
} // namespace op::binary
#endif // __INFINIOP_BINARY_H__
#ifndef __INFINIOP_BINARY_CPU_H__ #ifndef __INFINIOP_BINARY_H__
#define __INFINIOP_BINARY_CPU_H__ #define __INFINIOP_BINARY_H__
#include "../../devices/cpu/common_cpu.h" #include "../tensor.h"
#include "../../operator.h"
#include "../../tensor.h"
#include <array>
#include <numeric> #include <numeric>
#include <utility>
/** namespace op::binary {
* 该类的设计基于 matmul.h 中 YdrMaster 设计的 DESCRIPTOR 宏。
*/
#define BINARY_DESCRIPTOR(OP, NAMESPACE) \
\
namespace op::OP::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
infiniDtype_t _dtype; \
op::common_cpu::binary_op::BinaryCpuInfo _info; \
\
Descriptor( \
infiniDtype_t dtype, \
op::common_cpu::binary_op::BinaryCpuInfo info, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_dtype(dtype), \
_info(info) {} \
\
public: \
~Descriptor(); \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t c_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc); \
\
infiniStatus_t calculate( \
void *c, \
const void *a, \
const void *b, \
void *stream) const; \
}; \
}
namespace op::common_cpu {
namespace binary_op {
// Stores metadata for binary operations on CPU // Stores metadata for binary operations on CPU
struct BinaryCpuInfo { struct BinaryInfo {
size_t c_data_size; private:
size_t ndim; BinaryInfo(infiniopTensorDescriptor_t c_desc,
bool broadcasted;
std::vector<size_t> c_shape;
std::vector<size_t> a_shape;
std::vector<size_t> b_shape;
std::vector<ptrdiff_t> c_strides;
std::vector<ptrdiff_t> a_strides;
std::vector<ptrdiff_t> b_strides;
BinaryCpuInfo(size_t c_data_size,
size_t ndim,
bool broadcasted,
std::vector<size_t> c_shape,
std::vector<size_t> a_shape,
std::vector<size_t> b_shape,
std::vector<ptrdiff_t> c_strides,
std::vector<ptrdiff_t> a_strides,
std::vector<ptrdiff_t> b_strides)
: c_data_size(c_data_size),
ndim(ndim),
broadcasted(broadcasted),
c_shape(std::move(c_shape)),
a_shape(std::move(a_shape)),
b_shape(std::move(b_shape)),
c_strides(std::move(c_strides)),
a_strides(std::move(a_strides)),
b_strides(std::move(b_strides)) {}
BinaryCpuInfo(infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc) infiniopTensorDescriptor_t b_desc)
: ndim(c_desc->ndim()), : ndim(c_desc->ndim()),
...@@ -98,85 +22,35 @@ struct BinaryCpuInfo { ...@@ -98,85 +22,35 @@ struct BinaryCpuInfo {
this->c_data_size = std::accumulate(c_shape.begin(), c_shape.end(), 1ULL, std::multiplies<size_t>()); this->c_data_size = std::accumulate(c_shape.begin(), c_shape.end(), 1ULL, std::multiplies<size_t>());
this->broadcasted = (a_strides != c_strides) || (b_strides != c_strides); this->broadcasted = (a_strides != c_strides) || (b_strides != c_strides);
} }
};
// Helper function for compile-time optimized checks public:
template <size_t N> size_t c_data_size;
bool isDtypeSupported(infiniDtype_t dtype, const std::array<infiniDtype_t, N> &supported_dtypes) { size_t ndim;
for (size_t i = 0; i < N; ++i) { bool broadcasted;
if (dtype == supported_dtypes[i]) { std::vector<size_t> c_shape;
return true; std::vector<size_t> a_shape;
} std::vector<size_t> b_shape;
} std::vector<ptrdiff_t> c_strides;
return false; std::vector<ptrdiff_t> a_strides;
} std::vector<ptrdiff_t> b_strides;
// Checks if the tensors are compatible for binary operations based on dtype and shape requirements. static infiniStatus_t create(
template <size_t N> BinaryInfo **instance,
infiniStatus_t check(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t c_desc,
infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t b_desc) {
const std::array<infiniDtype_t, N> &supported_dtypes, if (!c_desc || !a_desc || !b_desc) {
bool require_same_dtype, return INFINI_STATUS_BAD_PARAM;
bool require_same_shape) {
const auto dtype = c_desc->dtype();
if (!isDtypeSupported(dtype, supported_dtypes)) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// check dtype match if required
if (require_same_dtype && (a_desc->dtype() != dtype || b_desc->dtype() != dtype)) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
// check shape compatibility if required
if (require_same_shape) {
const auto ndim = c_desc->ndim();
if (a_desc->ndim() != ndim || b_desc->ndim() != ndim) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
const auto &c_shape = c_desc->shape();
const auto &a_shape = a_desc->shape();
const auto &b_shape = b_desc->shape();
for (size_t i = 0; i < ndim; ++i) {
if (c_shape[i] != a_shape[i] || c_shape[i] != b_shape[i]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
}
} }
try {
*instance = new BinaryInfo(c_desc, a_desc, b_desc);
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} } catch (const std::exception &e) {
return INFINI_STATUS_INTERNAL_ERROR;
// Perform binary computation
template <typename Tdata, typename BinaryOp, typename... Args>
void calculate(BinaryCpuInfo info, void *c, const void *a, const void *b, Args &&...args) {
auto a_ = reinterpret_cast<const Tdata *>(a);
auto b_ = reinterpret_cast<const Tdata *>(b);
auto c_ = reinterpret_cast<Tdata *>(c);
ptrdiff_t data_size = info.c_data_size;
#pragma omp parallel for
for (ptrdiff_t i = 0; i < data_size; ++i) {
size_t a_index = info.broadcasted ? indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data())
: indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data());
size_t b_index = info.broadcasted ? indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data())
: indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data());
size_t c_index = indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data());
if constexpr (std::is_same_v<Tdata, fp16_t>) {
float a_val = utils::cast<float>(a_[a_index]);
float b_val = utils::cast<float>(b_[b_index]);
c_[c_index] = utils::cast<fp16_t>(BinaryOp{}(a_val, b_val, std::forward<Args>(args)...));
} else {
c_[c_index] = BinaryOp{}(a_[a_index], b_[b_index], std::forward<Args>(args)...);
} }
} }
} };
} // namespace op::binary
} // namespace binary_op
} // namespace op::common_cpu
#endif // __INFINIOP_BINARY_CPU_H__ #endif // __INFINIOP_BINARY_H__
#ifndef __INFINIOP_BINARY_CPU_H__
#define __INFINIOP_BINARY_CPU_H__
#include "../../devices/cpu/common_cpu.h"
#include "../../operator.h"
#include "../../tensor.h"
#include "../binary.h"
#include <utility>
/**
* 该类的设计基于 matmul.h 中 YdrMaster 设计的 DESCRIPTOR 宏。
*/
#define BINARY_DESCRIPTOR(OP, NAMESPACE) \
\
namespace op::OP::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
infiniDtype_t _dtype; \
op::binary::BinaryInfo _info; \
\
Descriptor( \
infiniDtype_t dtype, \
op::binary::BinaryInfo info, \
Opaque *opaque, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_dtype(dtype), \
_info(info) {} \
\
public: \
~Descriptor(); \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t c_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t b_desc); \
\
infiniStatus_t calculate( \
void *c, \
const void *a, \
const void *b, \
void *stream) const; \
}; \
}
namespace op::common_cpu {
namespace binary_op {
// Perform binary computation when inputs and the output can have different dtypes
template <typename Tc, typename Ta, typename Tb, typename BinaryOp, typename... Args>
void calculate(op::binary::BinaryInfo info, void *c, const void *a, const void *b, Args &&...args) {
auto a_ = reinterpret_cast<const Ta *>(a);
auto b_ = reinterpret_cast<const Tb *>(b);
auto c_ = reinterpret_cast<Tc *>(c);
ptrdiff_t data_size = info.c_data_size;
#pragma omp parallel for
for (ptrdiff_t i = 0; i < data_size; ++i) {
size_t a_index = info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data())
: op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data());
size_t b_index = info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data())
: op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data());
size_t c_index = op::common_cpu::indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data());
c_[c_index] = BinaryOp{}(a_[a_index], b_[b_index], std::forward<Args>(args)...);
}
}
// Perform binary computation when all inputs and the output share the same dtype
template <typename Tdata, typename BinaryOp, typename... Args>
void calculate(op::binary::BinaryInfo info, void *c, const void *a, const void *b, Args &&...args) {
auto a_ = reinterpret_cast<const Tdata *>(a);
auto b_ = reinterpret_cast<const Tdata *>(b);
auto c_ = reinterpret_cast<Tdata *>(c);
ptrdiff_t data_size = info.c_data_size;
#pragma omp parallel for
for (ptrdiff_t i = 0; i < data_size; ++i) {
size_t a_index = info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data())
: op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data());
size_t b_index = info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data())
: op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data());
size_t c_index = op::common_cpu::indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data());
if constexpr (std::is_same_v<Tdata, fp16_t>) {
float a_val = utils::cast<float>(a_[a_index]);
float b_val = utils::cast<float>(b_[b_index]);
c_[c_index] = utils::cast<fp16_t>(BinaryOp{}(a_val, b_val, std::forward<Args>(args)...));
} else {
c_[c_index] = BinaryOp{}(a_[a_index], b_[b_index], std::forward<Args>(args)...);
}
}
}
} // namespace binary_op
} // namespace op::common_cpu
#endif // __INFINIOP_BINARY_CPU_H__
#include "common_cpu.h" #include "common_cpu.h"
namespace op::common_cpu {
size_t indexToReducedOffset( size_t indexToReducedOffset(
size_t flat_index, size_t flat_index,
size_t ndim, size_t ndim,
...@@ -48,3 +50,5 @@ std::vector<size_t> getPaddedShape( ...@@ -48,3 +50,5 @@ std::vector<size_t> getPaddedShape(
} }
return padded_shape; return padded_shape;
} }
} // namespace op::common_cpu
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#include <omp.h> #include <omp.h>
#endif #endif
namespace op::common_cpu {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor // return the memory offset of original tensor, given the flattened index of broadcasted tensor
size_t indexToReducedOffset(size_t flat_index, size_t ndim, const ptrdiff_t *broadcasted_strides, const ptrdiff_t *target_strides); size_t indexToReducedOffset(size_t flat_index, size_t ndim, const ptrdiff_t *broadcasted_strides, const ptrdiff_t *target_strides);
...@@ -28,4 +30,6 @@ size_t getPaddedSize(size_t ndim, size_t *shape, const size_t *pads); ...@@ -28,4 +30,6 @@ size_t getPaddedSize(size_t ndim, size_t *shape, const size_t *pads);
// calculate the padded shape and store the result in padded_shape // calculate the padded shape and store the result in padded_shape
std::vector<size_t> getPaddedShape(size_t ndim, const size_t *shape, const size_t *pads); std::vector<size_t> getPaddedShape(size_t ndim, const size_t *shape, const size_t *pads);
} // namespace op::common_cpu
#endif // __INFINIOP__COMMON_CPU_H__ #endif // __INFINIOP__COMMON_CPU_H__
...@@ -12,23 +12,28 @@ infiniStatus_t Descriptor::create( ...@@ -12,23 +12,28 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t gate_desc) { infiniopTensorDescriptor_t gate_desc) {
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_); auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
constexpr std::array<infiniDtype_t, 3> SUPPORTED_DTYPES = { auto dtype = out_desc->dtype();
INFINI_DTYPE_F16, const auto &out_shape = out_desc->shape();
INFINI_DTYPE_F32, const auto &up_shape = up_desc->shape();
INFINI_DTYPE_F64, const auto &gate_shape = gate_desc->shape();
};
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
if (!SAME_VEC(out_shape, up_shape, gate_shape)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
// Perform generic binary operator check op::binary::BinaryInfo *info = nullptr;
CHECK_STATUS(op::common_cpu::binary_op::check(out_desc, up_desc, gate_desc, SUPPORTED_DTYPES, true, true)); CHECK_STATUS(op::binary::BinaryInfo::create(&info, out_desc, up_desc, gate_desc));
// Create descriptor // Create descriptor
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
out_desc->dtype(), dtype,
{out_desc, up_desc, gate_desc}, *info,
nullptr, nullptr,
handle->device, handle->device,
handle->device_id); handle->device_id);
delete info;
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
#ifndef __SWIGLU_CPU_H__ #ifndef __SWIGLU_CPU_H__
#define __SWIGLU_CPU_H__ #define __SWIGLU_CPU_H__
#include "../../../binary/cpu/binary.h" #include "../../../binary/cpu/binary_cpu.h"
BINARY_DESCRIPTOR(swiglu, cpu) BINARY_DESCRIPTOR(swiglu, cpu)
......
#ifndef INFINIUTILS_CHECK_H #ifndef INFINIUTILS_CHECK_H
#define INFINIUTILS_CHECK_H #define INFINIUTILS_CHECK_H
#include <iostream> #include <iostream>
#include <tuple>
#define CHECK_API_OR(API, EXPECT, ACTION) \ #define CHECK_API_OR(API, EXPECT, ACTION) \
do { \ do { \
...@@ -30,4 +31,13 @@ ...@@ -30,4 +31,13 @@
return INFINI_STATUS_BAD_TENSOR_DTYPE); \ return INFINI_STATUS_BAD_TENSOR_DTYPE); \
} while (0) } while (0)
#define SAME_VEC(...) \
[&] { \
auto &&_vec = std::forward_as_tuple(__VA_ARGS__); \
const auto &_base = std::get<0>(_vec); \
return [&_base](auto &&...args) { \
return ((args == _base) && ...); \
}(__VA_ARGS__); \
}()
#endif // INFINIUTILS_CHECK_H #endif // INFINIUTILS_CHECK_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment