Unverified Commit fea62e92 authored by drbh's avatar drbh Committed by GitHub
Browse files

fix: improve find_segments via numpy diff (#2686)

parent 52e48739
...@@ -5,30 +5,25 @@ ...@@ -5,30 +5,25 @@
from typing import List, Tuple, Union from typing import List, Tuple, Union
import torch import torch
import numpy as np
# FIXME: this should be optimized
def find_segments( def find_segments(
adapter_indices: Union[torch.Tensor, List[int]] adapter_indices: Union[torch.Tensor, List[int]]
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
segments = [0]
segment_indices = []
if isinstance(adapter_indices, torch.Tensor): if isinstance(adapter_indices, torch.Tensor):
# Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first adapter_indices = adapter_indices.cpu().numpy()
adapter_indices = adapter_indices.cpu().tolist() elif isinstance(adapter_indices, list):
adapter_indices = np.array(adapter_indices)
start_index = 0 change_mask = np.diff(adapter_indices, prepend=adapter_indices[0] - 1)
for i in range(1, len(adapter_indices)): change_indices = np.nonzero(change_mask)[0]
if adapter_indices[i] != adapter_indices[i - 1]:
segments.append(i) segments = [0]
segment_indices.append(adapter_indices[i - 1]) segments.extend(change_indices[1:].tolist())
start_index = i segments.append(len(adapter_indices))
# Handle the last segment segment_indices = adapter_indices[change_indices].tolist()
if start_index < len(adapter_indices):
segments.append(len(adapter_indices))
segment_indices.append(adapter_indices[-1])
return segments, segment_indices return segments, segment_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment