Unverified Commit 15ac0191 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Issue/450 Fix Elementwise Striding Broadcasting Issue (#452)

* issue/450: change indexToReducedOffset() to indexToOffset in elementwise framework on CPU, NVIDIA, Cambricon, Metax, Moore, and Kunlun

* issue/450: remove indexToReducedOffset() in all platforms

* issue/450: add the testcases that pinpoint the issue in infiniop-test
parents 1635fd92 9db54b8f
......@@ -19,8 +19,8 @@ void calculate(op::binary::BinaryInfo info, void *c, const void *a, const void *
#pragma omp parallel for
for (ptrdiff_t i = 0; i < data_size; ++i) {
size_t a_index = info.contiguous ? i : (info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data()) : op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data()));
size_t b_index = info.contiguous ? i : (info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data()) : op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data()));
size_t a_index = info.contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data());
size_t b_index = info.contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data());
size_t c_index = info.contiguous ? i : (op::common_cpu::indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data()));
c_[c_index] = BinaryOp{}(a_[a_index], b_[b_index], std::forward<Args>(args)...);
......@@ -37,8 +37,8 @@ void calculate(op::binary::BinaryInfo info, void *c, const void *a, const void *
#pragma omp parallel for
for (ptrdiff_t i = 0; i < data_size; ++i) {
size_t a_index = info.contiguous ? i : (info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.a_strides.data()) : op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data()));
size_t b_index = info.contiguous ? i : (info.broadcasted ? op::common_cpu::indexToReducedOffset(i, info.ndim, info.c_strides.data(), info.b_strides.data()) : op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data()));
size_t a_index = info.contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.a_shape.data(), info.a_strides.data());
size_t b_index = info.contiguous ? i : op::common_cpu::indexToOffset(i, info.ndim, info.b_shape.data(), info.b_strides.data());
size_t c_index = info.contiguous ? i : (op::common_cpu::indexToOffset(i, info.ndim, info.c_shape.data(), info.c_strides.data()));
if constexpr (std::is_same_v<Tdata, fp16_t>) {
......
......@@ -22,35 +22,6 @@ __mlu_device__ half to_half(const T &v) {
return static_cast<half>(v);
}
/**
* @brief Converts a flattened index to a reduced offset considering broadcasting.
*
* This function is used when dealing with broadcasted tensors where the input
* has been broadcast to match the output shape. It calculates the offset in
* the original (non-broadcasted) tensor.
*
* @param flat_index The flattened index in the output tensor
* @param ndim Number of dimensions
* @param broadcasted_strides Strides of the broadcasted tensor
* @param target_strides Strides of the original (non-broadcasted) tensor
* @return size_t Offset in the original tensor's memory
*/
inline __mlu_device__ size_t indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
// Calculate contribution from each dimension
res += flat_index / broadcasted_strides[i] * target_strides[i];
// Remove the contribution from this dimension
flat_index %= broadcasted_strides[i];
}
return res;
}
/**
* @brief Converts a flattened index to a memory offset considering tensor striding.
*
......@@ -106,11 +77,7 @@ struct InputIndexer {
size_t global_idx = idx + element_idx;
return input_contiguous[input_id]
? global_idx // Simple case: contiguous memory
: (input_broadcasted[input_id]
// Handle broadcasted case
? indexToReducedOffset(global_idx, ndim, output_strides, input_strides + input_id * ndim)
// General non-contiguous case
: indexToOffset(global_idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
: indexToOffset(global_idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim);
}
};
......
......@@ -2,19 +2,6 @@
namespace op::common_cpu {
size_t indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i];
flat_index %= broadcasted_strides[i];
}
return res;
}
size_t indexToOffset(
size_t flat_index,
size_t ndim,
......
......@@ -15,9 +15,6 @@
namespace op::common_cpu {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
size_t indexToReducedOffset(size_t flat_index, size_t ndim, const ptrdiff_t *broadcasted_strides, const ptrdiff_t *target_strides);
// return the memory offset a tensor given flattened index
size_t indexToOffset(size_t flat_index, size_t ndim, const size_t *shape, const ptrdiff_t *strides);
......
......@@ -105,27 +105,6 @@ inline __device__ T atomicMax(__shared_ptr__ T *ptr, T value) {
return old;
}
/**
* @brief Get index of broadcasted input
* flat_index: flatten index of output tensor
* ndim: dim of output tensor
* broadcasted_strides: strides of output tensor
* target_strides: strides of input tensor
*/
inline __device__ int indexToReducedOffset(
int flat_index, // output flatten index
int ndim, // output dims
const _ptrdiff_t *broadcasted_strides, // output strides
const _ptrdiff_t *target_strides) { // strides of inputs
int res = 0;
for (int i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i].value * target_strides[i].value;
flat_index %= broadcasted_strides[i].value;
}
return res;
}
/**
* @brief Get real offset of input index
* flat_index: flatten index input
......
......@@ -12,21 +12,6 @@ using cuda_bfloat162 = hpcc_bfloat162;
namespace device::metax {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i];
flat_index %= broadcasted_strides[i];
}
return res;
}
// get the memory offset of the given element in a tensor given its flat index
__forceinline__ __device__ __host__ size_t
indexToOffset(
......
......@@ -16,21 +16,6 @@ using cuda_bfloat162 = mt_bfloat162;
namespace device::moore {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i];
flat_index %= broadcasted_strides[i];
}
return res;
}
// get the memory offset of the given element in a tensor given its flat index
__forceinline__ __device__ __host__ size_t
indexToOffset(
......
......@@ -19,20 +19,6 @@ using cuda_bfloat16 = nv_bfloat16;
using cuda_bfloat162 = nv_bfloat162;
namespace device::nvidia {
// return the memory offset of original tensor, given the flattened index of broadcasted tensor
__forceinline__ __device__ __host__ size_t
indexToReducedOffset(
size_t flat_index,
size_t ndim,
const ptrdiff_t *broadcasted_strides,
const ptrdiff_t *target_strides) {
size_t res = 0;
for (size_t i = 0; i < ndim; ++i) {
res += flat_index / broadcasted_strides[i] * target_strides[i];
flat_index %= broadcasted_strides[i];
}
return res;
}
// get the memory offset of the given element in a tensor given its flat index
__forceinline__ __device__ __host__ size_t
......
......@@ -127,9 +127,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
auto get_input_idx = [&](size_t input_id) {
return info.getInputContiguous()[input_id]
? i
: (info.getInputBroadcasted()[input_id]
? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id));
};
out[out_idx] = utils::cast<Tout>(
......@@ -162,7 +160,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
const ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for
#pragma omp parallel for if (output_size > 1024)
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.isOutputContiguous()
? i
......@@ -171,9 +169,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
auto get_input_idx = [&](size_t input_id) {
return info.getInputContiguous()[input_id]
? i
: (info.getInputBroadcasted()[input_id]
? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id));
};
if constexpr (std::is_same_v<Tdata, fp16_t> || std::is_same_v<Tdata, bf16_t>) {
......
......@@ -31,9 +31,7 @@ struct InputIndexer {
inline __device__ int operator()(int input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
: indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim);
}
};
......
......@@ -29,9 +29,7 @@ struct InputIndexer {
__device__ __forceinline__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::metax::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::metax::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
: device::metax::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim);
}
};
......
......@@ -29,9 +29,7 @@ struct InputIndexer {
__device__ __forceinline__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::moore::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::moore::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
: device::moore::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim);
}
};
......
......@@ -60,9 +60,7 @@ struct InputIndexer {
__device__ __forceinline__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::nvidia::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::nvidia::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
: device::nvidia::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim);
}
};
......
......@@ -91,6 +91,8 @@ if __name__ == "__main__":
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((13, 16, 2), (128, 4, 1), (0, 2, 1), (64, 4, 1)),
((13, 16, 2), (128, 4, 1), (2, 0, 1), (64, 4, 1)),
((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
]
......
......@@ -33,6 +33,8 @@ _TEST_CASES_ = [
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((13, 16, 2), (128, 4, 1), (0, 2, 1), (64, 4, 1)),
((13, 16, 2), (128, 4, 1), (2, 0, 1), (64, 4, 1)),
((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment