"docs/source/en/_toctree.yml" did not exist on "a127363dcabdc4c0625ef24be0e0d8143de18af2"
model.py 5.44 KB
Newer Older
lvhan028's avatar
lvhan028 committed
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import Registry

lvhan028's avatar
lvhan028 committed
4
MODELS = Registry('model', locations=['lmdeploy.model'])
lvhan028's avatar
lvhan028 committed
5
6


7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@MODELS.register_module(name='llama')
class BaseModel:
    """Base model."""

    def __init__(self):
        self.session_len = 2048
        self.top_p = 0.8
        self.top_k = None
        self.temperature = 0.8
        self.repetition_penalty = 1.0

    @staticmethod
    def get_prompt(prompt, sequence_start=True):
        """Return the prompt that is concatenated with other elements in the
        chat template.

        Args:
            prompt (str): user's input prompt
            sequence_start (bool): indicator for the first round chat of a
               session sequence
        Returns:
            str: the concatenated prompt
        """
        return prompt

    @property
    def stop_words(self):
        """Return the stop-words' token ids."""
        return None


lvhan028's avatar
lvhan028 committed
38
@MODELS.register_module(name='vicuna')
39
class Vicuna(BaseModel):
lvhan028's avatar
lvhan028 committed
40
    """Chat template of vicuna model."""
lvhan028's avatar
lvhan028 committed
41
42

    def __init__(self):
43
        super().__init__()
lvhan028's avatar
lvhan028 committed
44
45
46
47
48
        self.system = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """  # noqa: E501
        self.user = 'USER'
        self.assistant = 'ASSISTANT'

    def get_prompt(self, prompt, sequence_start=True):
lvhan028's avatar
lvhan028 committed
49
50
51
52
53
54
55
56
57
58
        """Return the prompt that is concatenated with other elements in the
        chat template.

        Args:
            prompt (str): user's input prompt
            sequence_start (bool): indicator for the first round chat of a
               session sequence
        Returns:
            str: the concatenated prompt
        """
lvhan028's avatar
lvhan028 committed
59
60
61
62
63
64
        if sequence_start:
            return f'{self.system} {self.user}: {prompt} {self.assistant}:'
        else:
            return f'</s>{self.user}: {prompt} {self.assistant}:'


65
@MODELS.register_module(name='internlm')
66
67
68
69
70
71
72
73
class InternLM(BaseModel):

    def __init__(self):
        super().__init__()


@MODELS.register_module(name='internlm-chat-7b')
class InternLMChat7B(BaseModel):
lvhan028's avatar
lvhan028 committed
74
    """Chat template of InternLM model."""
lvhan028's avatar
lvhan028 committed
75
76

    def __init__(self):
77
        super().__init__()
78
79
80
81
82
        self.system = ''
        self.user = '<|User|>'
        self.eoh = '<eoh>'
        self.eoa = '<eoa>'
        self.assistant = '<|Bot|>'
lvhan028's avatar
lvhan028 committed
83
84

    def get_prompt(self, prompt, sequence_start=True):
lvhan028's avatar
lvhan028 committed
85
86
87
88
89
90
91
92
93
94
        """Return the prompt that is concatenated with other elements in the
        chat template.

        Args:
            prompt (str): user's input prompt
            sequence_start (bool): indicator for the first round chat of a
               session sequence
        Returns:
            str: the concatenated prompt
        """
lvhan028's avatar
lvhan028 committed
95
        if sequence_start:
96
            return f'<BOS>{self.user}:{prompt}{self.eoh}\n' \
lvhan028's avatar
lvhan028 committed
97
98
                   f'{self.assistant}:'
        else:
99
100
            return f'\n{self.user}:{prompt}{self.eoh}\n' \
                   f'{self.assistant}:'
lvhan028's avatar
lvhan028 committed
101
102
103

    @property
    def stop_words(self):
lvhan028's avatar
lvhan028 committed
104
        """Return the stop-words' token ids."""
105
106
107
        return [103027, 103028]


108
109
@MODELS.register_module(name='internlm-chat-7b-8k')
class InternLMChat7B8K(InternLMChat7B):
110
111

    def __init__(self):
112
113
        super(InternLMChat7B8K, self).__init__()
        self.session_len = 8192
lvhan028's avatar
lvhan028 committed
114
115


116
117
118
119
120
121
122
123
@MODELS.register_module(name='baichuan-7b')
class Baichuan7B(BaseModel):

    def __init__(self):
        super().__init__()
        self.repetition_penalty = 1.1


q.yao's avatar
q.yao committed
124
@MODELS.register_module(name='llama2')
125
class Llama2(BaseModel):
q.yao's avatar
q.yao committed
126
127
128
    """Chat template of LLaMA2 model."""

    def __init__(self):
129
        super().__init__()
q.yao's avatar
q.yao committed
130
131
132
133
134
135
136
137
138
139
140
141
142
        B_INST, E_INST = '[INST]', '[/INST]'
        B_SYS, E_SYS = '<<SYS>>\n', '\n<</SYS>>\n\n'

        DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""  # noqa: E501

        self.b_inst = B_INST
        self.e_inst = E_INST
        self.b_sys = B_SYS
        self.e_sys = E_SYS
        self.default_sys_prompt = DEFAULT_SYSTEM_PROMPT
143
        self.session_len = 4096
q.yao's avatar
q.yao committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

    def get_prompt(self, prompt, sequence_start=True):
        """Return the prompt that is concatenated with other elements in the
        chat template.

        Args:
            prompt (str): user's input prompt
            sequence_start (bool): indicator for the first round chat of a
               session sequence
        Returns:
            str: the concatenated prompt
        """
        if sequence_start:
            return f'<BOS>{self.b_inst} ' \
                   f'{self.b_sys} {self.default_sys_prompt} {self.e_sys}' \
                   f'{prompt} {self.e_inst} '

        return f'{self.b_inst} {prompt} {self.e_inst} '


lvhan028's avatar
lvhan028 committed
164
165
166
167
def main(model_name: str = 'test'):
    assert model_name in MODELS.module_dict.keys(), \
        f"'{model_name}' is not supported. " \
        f'The supported models are: {MODELS.module_dict.keys()}'
168
    model = MODELS.get(model_name)()
lvhan028's avatar
lvhan028 committed
169
170
    prompt = model.get_prompt(prompt='hi')
    print(prompt)
171
    print(f'session_len: {model.session_len}')
lvhan028's avatar
lvhan028 committed
172
173
174
175
176


if __name__ == '__main__':
    import fire
    fire.Fire(main)