Unverified Commit 67b7d5b1 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[PD] Vectorise group_concurrent_contiguous in NumPy (#5834)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 4322c31e
......@@ -37,25 +37,16 @@ logger = logging.getLogger(__name__)
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
src_groups = []
dst_groups = []
current_src = [src_indices[0]]
current_dst = [dst_indices[0]]
for i in range(1, len(src_indices)):
src_contiguous = src_indices[i] == src_indices[i - 1] + 1
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
if src_contiguous and dst_contiguous:
current_src.append(src_indices[i])
current_dst.append(dst_indices[i])
else:
src_groups.append(current_src)
dst_groups.append(current_dst)
current_src = [src_indices[i]]
current_dst = [dst_indices[i]]
"""Vectorised NumPy implementation."""
if src_indices.size == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups.append(current_src)
dst_groups.append(current_dst)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment