"graphbolt/vscode:/vscode.git/clone" did not exist on "2c03fe9952cfd3419fb4325a830a6958ed455c3d"
Unverified Commit 67b7d5b1 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[PD] Vectorise group_concurrent_contiguous in NumPy (#5834)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
parent 4322c31e
...@@ -37,25 +37,16 @@ logger = logging.getLogger(__name__) ...@@ -37,25 +37,16 @@ logger = logging.getLogger(__name__)
def group_concurrent_contiguous( def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64] src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]: ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
src_groups = [] """Vectorised NumPy implementation."""
dst_groups = [] if src_indices.size == 0:
current_src = [src_indices[0]] return [], []
current_dst = [dst_indices[0]]
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
for i in range(1, len(src_indices)): src_groups = np.split(src_indices, brk)
src_contiguous = src_indices[i] == src_indices[i - 1] + 1 dst_groups = np.split(dst_indices, brk)
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
if src_contiguous and dst_contiguous:
current_src.append(src_indices[i])
current_dst.append(dst_indices[i])
else:
src_groups.append(current_src)
dst_groups.append(current_dst)
current_src = [src_indices[i]]
current_dst = [dst_indices[i]]
src_groups.append(current_src) src_groups = [g.tolist() for g in src_groups]
dst_groups.append(current_dst) dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups return src_groups, dst_groups
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment