Commit 5e581b8e authored by Ziminli's avatar Ziminli
Browse files

issue/450: change indexToReducedOffset() to indexToOffset in elementwise...

issue/450: change indexToReducedOffset() to indexToOffset in elementwise framework on CPU, NVIDIA, Cambricon, Metax, Moore, and Kunlun
parent 1635fd92
......@@ -106,11 +106,7 @@ struct InputIndexer {
size_t global_idx = idx + element_idx;
return input_contiguous[input_id]
? global_idx // Simple case: contiguous memory
: (input_broadcasted[input_id]
// Handle broadcasted case
? indexToReducedOffset(global_idx, ndim, output_strides, input_strides + input_id * ndim)
// General non-contiguous case
: indexToOffset(global_idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
: indexToOffset(global_idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim);
}
};
......
......@@ -127,9 +127,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
auto get_input_idx = [&](size_t input_id) {
return info.getInputContiguous()[input_id]
? i
: (info.getInputBroadcasted()[input_id]
? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id));
};
out[out_idx] = utils::cast<Tout>(
......@@ -162,7 +160,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
const ptrdiff_t output_size = info.getOutputSize();
#pragma omp parallel for
#pragma omp parallel for if (output_size > 1024)
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.isOutputContiguous()
? i
......@@ -171,9 +169,7 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,
auto get_input_idx = [&](size_t input_id) {
return info.getInputContiguous()[input_id]
? i
: (info.getInputBroadcasted()[input_id]
? op::common_cpu::indexToReducedOffset(i, info.getNdim(), info.getOutputStrides(), info.getInputStrides(input_id))
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)));
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id));
};
if constexpr (std::is_same_v<Tdata, fp16_t> || std::is_same_v<Tdata, bf16_t>) {
......
......@@ -31,9 +31,7 @@ struct InputIndexer {
inline __device__ int operator()(int input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
: indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim);
}
};
......
......@@ -29,9 +29,7 @@ struct InputIndexer {
__device__ __forceinline__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::metax::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::metax::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
: device::metax::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim);
}
};
......
......@@ -29,9 +29,7 @@ struct InputIndexer {
__device__ __forceinline__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::moore::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::moore::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
: device::moore::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim);
}
};
......
......@@ -60,9 +60,7 @@ struct InputIndexer {
__device__ __forceinline__ size_t operator()(size_t input_id) const {
return input_contiguous[input_id]
? idx
: (input_broadcasted[input_id]
? device::nvidia::indexToReducedOffset(idx, ndim, output_strides, input_strides + input_id * ndim)
: device::nvidia::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim));
: device::nvidia::indexToOffset(idx, ndim, input_shapes + input_id * ndim, input_strides + input_id * ndim);
}
};
......
......@@ -33,6 +33,8 @@ _TEST_CASES_ = [
((13, 4, 4), (4, 0, 1), (0, 4, 1), None),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((13, 16, 2), (128, 4, 1), (0, 2, 1), (64, 4, 1)),
((13, 16, 2), (128, 4, 1), (2, 0, 1), (64, 4, 1)),
((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment