segments.py 2.62 KB
Newer Older
drbh's avatar
drbh committed
1
2
3
4
5
6
7
8
9
# Origin:   https://github.com/predibase/lorax
# Path:     lorax/server/lorax_server/utils/segments.py
# License:  Apache License Version 2.0, January 2004

from typing import List, Tuple, Union

import torch


10
# FIXME: this should be optimized
drbh's avatar
drbh committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def find_segments(
    adapter_indices: Union[torch.Tensor, List[int]]
) -> Tuple[List[int], List[int]]:
    segments = [0]
    segment_indices = []

    if isinstance(adapter_indices, torch.Tensor):
        # Calling .item() repeatedly on CUDA tensor is very slow, so we move it to CPU first
        adapter_indices = adapter_indices.cpu().tolist()

    start_index = 0
    for i in range(1, len(adapter_indices)):
        if adapter_indices[i] != adapter_indices[i - 1]:
            segments.append(i)
            segment_indices.append(adapter_indices[i - 1])
            start_index = i

    # Handle the last segment
    if start_index < len(adapter_indices):
        segments.append(len(adapter_indices))
        segment_indices.append(adapter_indices[-1])

    return segments, segment_indices


class SegmentConcatBuilder:
    def __init__(self):
        self.adapter_segment_indices = []
        self.adapter_segment_tensors = []

    def concat(self, adapter_segments: torch.Tensor, segment_indices: List[int]):
        # Update adapter segments
        if self.adapter_segment_tensors:
            # Because we have already processed at least one batch, remove the 0 start index
            # from this batch denoting the beginning of the segment, then offset all segment
            # positions by the value of the last segment in the previous batch to account for
            # the concatenation.
            adapter_segments = (
                adapter_segments[1:] + self.adapter_segment_tensors[-1][-1]
            )

        if (
            self.adapter_segment_indices
            and self.adapter_segment_indices[-1] == segment_indices[0]
        ):
            # If the last segment in the previous batch is the same as the first segment in this batch,
            # then we merge them together into a single segment. In effect, this means removing it from
            # the segment indices of this batch, and extending the segment span by removing the segment
            # end index from the previous batch.
            segment_indices = segment_indices[1:]
            self.adapter_segment_tensors[-1] = self.adapter_segment_tensors[-1][:-1]

        self.adapter_segment_indices.extend(segment_indices)
        self.adapter_segment_tensors.append(adapter_segments)

    def build(self) -> Tuple[torch.Tensor, List[int]]:
        return torch.concat(self.adapter_segment_tensors), self.adapter_segment_indices