layers.py 381 Bytes
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
from dataclasses import dataclass


@dataclass
class AdapterMapping:
    # Per every token in input_ids:
9
    index_mapping: tuple[int, ...]
10
    # Per sampled token:
11
    prompt_mapping: tuple[int, ...]
12
13
14
15

    def __post_init__(self):
        self.index_mapping = tuple(self.index_mapping)
        self.prompt_mapping = tuple(self.prompt_mapping)