Commit 631cfd63 authored by PanZezhong's avatar PanZezhong
Browse files

issue/161/fix 修改检查shape strides相等的宏

parent 0450fb1e
......@@ -31,32 +31,26 @@ public:
}
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
auto ndim = y_desc->ndim();
if (ndim != x_desc->ndim()) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto shape = y_desc->shape();
CHECK_SAME_SHAPE(shape, x_desc->shape());
auto ndim = y_desc->ndim();
if (ndim != 2 && ndim != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
auto shape = y_desc->shape();
if (!SAME_VEC(y_desc->shape(), x_desc->shape())) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
CHECK_STATUS(INFINI_STATUS_BAD_TENSOR_SHAPE);
}
if (shape[ndim - 1] < shape[ndim - 2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
CHECK_STATUS(INFINI_STATUS_BAD_TENSOR_SHAPE);
}
size_t batch_size = 1;
size_t seq_len = shape[ndim - 2];
size_t total_seq_len = shape[ndim - 1];
ptrdiff_t y_stride_b = 0,
x_stride_b = 0;
ptrdiff_t y_stride_i = y_desc->stride(ndim - 2),
y_stride_i = y_desc->stride(ndim - 2),
y_stride_j = y_desc->stride(ndim - 1);
ptrdiff_t x_stride_i = x_desc->stride(ndim - 2),
ptrdiff_t x_stride_b = 0,
x_stride_i = x_desc->stride(ndim - 2),
x_stride_j = x_desc->stride(ndim - 1);
if (ndim == 3) {
......
......@@ -18,9 +18,8 @@ infiniStatus_t Descriptor::create(
const auto &gate_shape = gate_desc->shape();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
if (!SAME_VEC(out_shape, up_shape, gate_shape)) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
CHECK_SAME_SHAPE(out_shape, up_shape, gate_shape);
op::binary::BinaryInfo info;
CHECK_STATUS(op::binary::createBinaryInfo(info, out_desc, up_desc, gate_desc));
......
......@@ -31,13 +31,17 @@
return INFINI_STATUS_BAD_TENSOR_DTYPE); \
} while (0)
#define SAME_VEC(...) \
[&] { \
auto &&_vec = std::forward_as_tuple(__VA_ARGS__); \
const auto &_base = std::get<0>(_vec); \
return [&_base](auto &&...args) { \
return ((args == _base) && ...); \
}(__VA_ARGS__); \
}()
#define CHECK_SAME_VEC(ERR, FIRST, ...) \
do { \
for (const auto &shape___ : {__VA_ARGS__}) { \
if (FIRST != shape___) { \
return ERR; \
} \
} \
} while (0)
#define CHECK_SAME_SHAPE(FIRST, ...) CHECK_SAME_VEC(INFINI_STATUS_BAD_TENSOR_SHAPE, FIRST, __VA_ARGS__)
#define CHECK_SAME_STRIDES(FIRST, ...) CHECK_SAME_VEC(INFINI_STATUS_BAD_TENSOR_STRIDES, FIRST, __VA_ARGS__)
#endif // INFINIUTILS_CHECK_H
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment