segments.py 2.45 KB
Newer Older
drbh's avatar
drbh committed
1
2
3
4
5
6
7
# Origin:   https://github.com/predibase/lorax
# Path:     lorax/server/lorax_server/utils/segments.py
# License:  Apache License Version 2.0, January 2004

from typing import List, Tuple, Union

import torch
8
import numpy as np
drbh's avatar
drbh committed
9
10
11
12
13


def find_segments(
    adapter_indices: Union[torch.Tensor, List[int]]
) -> Tuple[List[int], List[int]]:
14
15
16
17
18
19
20
21
    if isinstance(adapter_indices, torch.Tensor):
        adapter_indices = adapter_indices.cpu().numpy()
    elif isinstance(adapter_indices, list):
        adapter_indices = np.array(adapter_indices)

    change_mask = np.diff(adapter_indices, prepend=adapter_indices[0] - 1)
    change_indices = np.nonzero(change_mask)[0]

drbh's avatar
drbh committed
22
    segments = [0]
23
24
    segments.extend(change_indices[1:].tolist())
    segments.append(len(adapter_indices))
drbh's avatar
drbh committed
25

26
    segment_indices = adapter_indices[change_indices].tolist()
drbh's avatar
drbh committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62

    return segments, segment_indices


class SegmentConcatBuilder:
    def __init__(self):
        self.adapter_segment_indices = []
        self.adapter_segment_tensors = []

    def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
        # Update adapter segments
        if self.adapter_segment_tensors:
            # Because we have already processed at least one batch, remove the 0 start index
            # from this batch denoting the beginning of the segment, then offset all segment
            # positions by the value of the last segment in the previous batch to account for
            # the concatenation.
            adapter_segments = (
                adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]
            )

        if (
            self.adapter_segment_indices
            and self.adapter_segment_indices[-1] == segment_indices[0]
        ):
            # If the last segment in the previous batch is the same as the first segment in this batch,
            # then we merge them together into a single segment. In effect, this means removing it from
            # the segment indices of this batch, and extending the segment span by removing the segment
            # end index from the previous batch.
            segment_indices = segment_indices[1:]
            self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]

        self.adapter_segment_indices.extend(segment_indices)
        self.adapter_segment_tensors.append(adapter_segments)

    def build(self) -> Tuple[torch.Tensor, List[int]]:
        return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices