session.py 2.61 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# Copyright (c) OpenMMLab. All rights reserved.
import logging

import torch
from transformers.generation.utils import ModelOutput

logger = logging.getLogger(__name__)


class BasicSessionManager:
    """Basic session manager without history."""

    def prepend_history(self, input_ids):
        return input_ids

    def add_to_history(self, output):
        pass


class BasicSessionManagerWithHistory:
    """Basic session manager with chat history.

    Args:
        max_session_len (int): Maximum number of tokens allowed for all chat sessions.
        reduce_size (int): Number of tokens to be trimmed when reaching maximum
            session length. Default: 256.
        start_ids (list[int]): Sequences of ids at the start of the chat session.
        sep_ids (list[int]): Sequences of ids separating chat sessions.
    """   # noqa: E501
    bs = 1

    def __init__(self,
                 max_session_len=2048,
                 reduce_size=256,
                 start_ids=[1],
                 sep_ids=[13]) -> None:

        self.start_ids = torch.tensor(start_ids, dtype=torch.long)
        self.sep_ids = torch.tensor(sep_ids, dtype=torch.long)

        assert self.start_ids.ndim == 1
        assert self.sep_ids.ndim == 1

        self.max_session_len = max(len(start_ids), max_session_len)
        self.reduce_size = min(reduce_size, max_session_len - len(start_ids))

        assert self.max_session_len > self.reduce_size

        self.new_session()

    def new_session(self):
        self.history_ids = self.start_ids.repeat(self.bs, 1)

    def prepend_history(self, input_ids: torch.Tensor):
        """Prepend history ids to input ids and trim if over-length."""

        input_ids = input_ids.to(self.history_ids.device).long()
        sep_ids = self.sep_ids.to(self.history_ids.device).long().repeat(1, 1)
        input_ids = torch.cat([self.history_ids, sep_ids, input_ids], dim=1)

        if input_ids.shape[1] > self.max_session_len:
            input_ids = input_ids[:,
                                  (self.reduce_size - self.max_session_len):]
            input_ids[:, :len(self.start_ids)] = self.start_ids.repeat(
                self.bs, 1)
        return input_ids

    def add_to_history(self, output):
        """Save history output ids.

        Note:
            Output returned by HuggingFace generator contains both input
            and output ids.
        """

        if isinstance(output, ModelOutput):
            self.history_ids = output.sequences
        elif isinstance(output, torch.Tensor):
            self.history_ids = output
        else:
            raise ValueError(f'Unknown output type {type(output)}')