".github/vscode:/vscode.git/clone" did not exist on "c9e6f0542df66301cc1bf77aef6f60a342bd32a8"
model.py 4.29 KB
Newer Older
lvhan028's avatar
lvhan028 committed
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import Registry

lvhan028's avatar
lvhan028 committed
4
MODELS = Registry('model', locations=['lmdeploy.model'])
lvhan028's avatar
lvhan028 committed
5
6
7
8


@MODELS.register_module(name='vicuna')
class Vicuna:
lvhan028's avatar
lvhan028 committed
9
    """Chat template of vicuna model."""
lvhan028's avatar
lvhan028 committed
10
11
12
13
14
15
16

    def __init__(self):
        self.system = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """  # noqa: E501
        self.user = 'USER'
        self.assistant = 'ASSISTANT'

    def get_prompt(self, prompt, sequence_start=True):
lvhan028's avatar
lvhan028 committed
17
18
19
20
21
22
23
24
25
26
        """Return the prompt that is concatenated with other elements in the
        chat template.

        Args:
            prompt (str): user's input prompt
            sequence_start (bool): indicator for the first round chat of a
               session sequence
        Returns:
            str: the concatenated prompt
        """
lvhan028's avatar
lvhan028 committed
27
28
29
30
31
32
33
        if sequence_start:
            return f'{self.system} {self.user}: {prompt} {self.assistant}:'
        else:
            return f'</s>{self.user}: {prompt} {self.assistant}:'

    @property
    def stop_words(self):
lvhan028's avatar
lvhan028 committed
34
        """Return the stop-words' token ids."""
lvhan028's avatar
lvhan028 committed
35
36
37
        return None


38
39
@MODELS.register_module(name='internlm')
class InternLM:
lvhan028's avatar
lvhan028 committed
40
    """Chat template of InternLM model."""
lvhan028's avatar
lvhan028 committed
41
42

    def __init__(self):
43
44
45
46
47
        self.system = ''
        self.user = '<|User|>'
        self.eoh = '<eoh>'
        self.eoa = '<eoa>'
        self.assistant = '<|Bot|>'
lvhan028's avatar
lvhan028 committed
48
49

    def get_prompt(self, prompt, sequence_start=True):
lvhan028's avatar
lvhan028 committed
50
51
52
53
54
55
56
57
58
59
        """Return the prompt that is concatenated with other elements in the
        chat template.

        Args:
            prompt (str): user's input prompt
            sequence_start (bool): indicator for the first round chat of a
               session sequence
        Returns:
            str: the concatenated prompt
        """
lvhan028's avatar
lvhan028 committed
60
        if sequence_start:
61
            return f'<bos>{self.user}:{prompt}{self.eoh}\n' \
lvhan028's avatar
lvhan028 committed
62
63
                   f'{self.assistant}:'
        else:
64
65
            return f'\n{self.user}:{prompt}{self.eoh}\n' \
                   f'{self.assistant}:'
lvhan028's avatar
lvhan028 committed
66
67
68

    @property
    def stop_words(self):
lvhan028's avatar
lvhan028 committed
69
        """Return the stop-words' token ids."""
70
71
72
73
74
        return [103027, 103028]


@MODELS.register_module(name='llama')
class Llama:
lvhan028's avatar
lvhan028 committed
75
    """Chat template of LLaMA model."""
76
77
78
79
80

    def __init__(self):
        pass

    def get_prompt(self, prompt, sequence_start=True):
lvhan028's avatar
lvhan028 committed
81
82
83
84
85
86
87
88
89
90
        """Return the prompt that is concatenated with other elements in the
        chat template.

        Args:
            prompt (str): user's input prompt
            sequence_start (bool): indicator for the first round chat of a
               session sequence
        Returns:
            str: the concatenated prompt
        """
91
92
93
94
        return prompt

    @property
    def stop_words(self):
lvhan028's avatar
lvhan028 committed
95
        """Return the stop-words' token ids."""
96
        return None
lvhan028's avatar
lvhan028 committed
97
98


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
@MODELS.register_module(name='puyu')
class Puyu:
    """Chat template of puyu model.This is only for internal usage in Shanghai
    AI Laboratory."""

    def __init__(self):
        self.system = """meta instruction
You are an AI assistant whose name is InternLM (书生·浦语).
- 书生·浦语 is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
- 书生·浦语 can understand and communicate fluently in the language chosen by the user such as English and 中文.
conversation"""  # noqa: E501
        self.user = '<|Human|>'
        self.eoh = 'െ'
        self.assistant = '<|Assistant|>'

    def get_prompt(self, prompt, sequence_start=True):
        if sequence_start:
            return f'<bos>{self.system}\n' \
                   f'{self.user}:{prompt}{self.eoh}\n' \
                   f'{self.assistant}:'
        else:
            return f'\n{self.user}:{prompt}{self.eoh}\n{self.assistant}:'

    @property
    def stop_words(self):
        """Return the stop-words' token ids."""
        return [45623]


lvhan028's avatar
lvhan028 committed
128
129
130
131
132
133
134
135
136
137
138
139
def main(model_name: str = 'test'):
    assert model_name in MODELS.module_dict.keys(), \
        f"'{model_name}' is not supported. " \
        f'The supported models are: {MODELS.module_dict.keys()}'
    model = MODELS.get('vicuna--1')()
    prompt = model.get_prompt(prompt='hi')
    print(prompt)


if __name__ == '__main__':
    import fire
    fire.Fire(main)