utils.py 2.58 KB
Newer Older
WRH's avatar
WRH committed
1
2
# Copyright (c) OpenMMLab. All rights reserved.

3
import logging
WRH's avatar
WRH committed
4
5
6

from transformers.generation.streamers import BaseStreamer

7
from .dist import get_rank, master_only, master_only_and_broadcast_general
WRH's avatar
WRH committed
8

9
10
11
12
try:
    import readline  # To support command line history # noqa: F401
except ImportError:  # readline not available
    pass
13

14
logger = logging.getLogger(__name__)
WRH's avatar
WRH committed
15
16


17
18
class TerminalIO:
    """Terminal input and output."""
WRH's avatar
WRH committed
19

20
    end_of_output = '\n'
WRH's avatar
WRH committed
21

22
23
24
    @master_only_and_broadcast_general
    def input(self):
        """Read input from terminal."""
WRH's avatar
WRH committed
25

26
27
28
29
30
31
32
        print('\ndouble enter to end input >>> ', end='')
        sentinel = ''  # ends when this string is seen
        try:
            return '\n'.join(iter(input, sentinel))
        except EOFError:
            print('Detect EOF, exit')
            exit()
33

34
35
36
37
38
    @master_only
    def output(self, string):
        """Output to terminal with flush."""

        print(string, end='', flush=True)
WRH's avatar
WRH committed
39

40

41
42
43
44
45
46
47
48
49
50
51
52
53
class BasicStreamer(BaseStreamer):
    """Basic streamer for HuggingFace models."""

    def __init__(self,
                 decode_func,
                 output_func,
                 end_of_output='\n',
                 skip_prompt=True):
        self.decode = decode_func
        self.output = output_func
        self.end_of_output = end_of_output
        self.skip_prompt = skip_prompt
        self.gen_len = 0
WRH's avatar
WRH committed
54
55

    def put(self, value):
56
        """Callback before forwarding current token id to model."""
57

WRH's avatar
WRH committed
58
59
60
        if self.gen_len == 0 and self.skip_prompt:
            pass
        else:
61
62
            token = self.decode(value)
            self.output(token)
WRH's avatar
WRH committed
63
64
65
66

        self.gen_len += 1

    def end(self):
67
68
        """Callback at the end of generation."""
        self.output(self.end_of_output)
WRH's avatar
WRH committed
69
70
71
        self.gen_len = 0


72
73
def control(prompt, gen_config, sm):
    """Allow user to control generation config and session manager.
74

75
76
77
    Return:
        True if control command applied, False otherwise.
    """
WRH's avatar
WRH committed
78

79
80
    if prompt == 'exit':
        exit(0)
WRH's avatar
WRH committed
81

82
83
84
85
    if prompt == 'clear':
        sm.new_session()
        logger.info('Session cleared')
        return True
WRH's avatar
WRH committed
86

87
88
89
90
91
92
93
94
95
    # Re-config during runtime
    if prompt.startswith('config set'):
        try:
            keqv = prompt.split()[-1]
            k, v = keqv.split('=')
            v = eval(v)
            gen_config.__setattr__(k, v)
            logger.info(f'Worker {get_rank()} set {k} to {repr(v)}')
            logger.info(f'Generator config changed to: {gen_config}')
WRH's avatar
WRH committed
96

97
98
99
100
            return True
        except:  # noqa
            logger.info(
                'illegal instruction, treated as normal conversation. ')
WRH's avatar
WRH committed
101

102
    return False