model.py 4.29 KB
Newer Older
lvhan028's avatar
lvhan028 committed
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import Registry

lvhan028's avatar
lvhan028 committed
4
MODELS = Registry('model', locations=['lmdeploy.model'])
lvhan028's avatar
lvhan028 committed
5
6
7
8


@MODELS.register_module(name='vicuna')
class Vicuna:
lvhan028's avatar
lvhan028 committed
9
    """Chat template of vicuna model."""
lvhan028's avatar
lvhan028 committed
10
11
12
13
14
15
16

    def __init__(self):
        self.system = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """  # noqa: E501
        self.user = 'USER'
        self.assistant = 'ASSISTANT'

    def get_prompt(self, prompt, sequence_start=True):
lvhan028's avatar
lvhan028 committed
17
18
19
20
21
22
23
24
25
26
        """Return the prompt that is concatenated with other elements in the
        chat template.

        Args:
            prompt (str): user's input prompt
            sequence_start (bool): indicator for the first round chat of a
               session sequence
        Returns:
            str: the concatenated prompt
        """
lvhan028's avatar
lvhan028 committed
27
28
29
30
31
32
33
        if sequence_start:
            return f'{self.system} {self.user}: {prompt} {self.assistant}:'
        else:
            return f'</s>{self.user}: {prompt} {self.assistant}:'

    @property
    def stop_words(self):
lvhan028's avatar
lvhan028 committed
34
        """Return the stop-words' token ids."""
lvhan028's avatar
lvhan028 committed
35
36
37
        return None


38
39
@MODELS.register_module(name='internlm')
class InternLM:
lvhan028's avatar
lvhan028 committed
40
    """Chat template of InternLM model."""
lvhan028's avatar
lvhan028 committed
41
42

    def __init__(self):
43
44
45
46
47
        self.system = ''
        self.user = '<|User|>'
        self.eoh = '<eoh>'
        self.eoa = '<eoa>'
        self.assistant = '<|Bot|>'
lvhan028's avatar
lvhan028 committed
48
49

    def get_prompt(self, prompt, sequence_start=True):
lvhan028's avatar
lvhan028 committed
50
51
52
53
54
55
56
57
58
59
        """Return the prompt that is concatenated with other elements in the
        chat template.

        Args:
            prompt (str): user's input prompt
            sequence_start (bool): indicator for the first round chat of a
               session sequence
        Returns:
            str: the concatenated prompt
        """
lvhan028's avatar
lvhan028 committed
60
        if sequence_start:
61
            return f'<BOS>{self.user}:{prompt}{self.eoh}\n' \
lvhan028's avatar
lvhan028 committed
62
63
                   f'{self.assistant}:'
        else:
64
65
            return f'\n{self.user}:{prompt}{self.eoh}\n' \
                   f'{self.assistant}:'
lvhan028's avatar
lvhan028 committed
66
67
68

    @property
    def stop_words(self):
lvhan028's avatar
lvhan028 committed
69
        """Return the stop-words' token ids."""
70
71
72
73
74
        return [103027, 103028]


@MODELS.register_module(name='llama')
class Llama:
lvhan028's avatar
lvhan028 committed
75
    """Chat template of LLaMA model."""
76
77
78
79
80

    def __init__(self):
        pass

    def get_prompt(self, prompt, sequence_start=True):
lvhan028's avatar
lvhan028 committed
81
82
83
84
85
86
87
88
89
90
        """Return the prompt that is concatenated with other elements in the
        chat template.

        Args:
            prompt (str): user's input prompt
            sequence_start (bool): indicator for the first round chat of a
               session sequence
        Returns:
            str: the concatenated prompt
        """
91
92
93
94
        return prompt

    @property
    def stop_words(self):
lvhan028's avatar
lvhan028 committed
95
        """Return the stop-words' token ids."""
96
        return None
lvhan028's avatar
lvhan028 committed
97
98


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
@MODELS.register_module(name='puyu')
class Puyu:
    """Chat template of puyu model.This is only for internal usage in Shanghai
    AI Laboratory."""

    def __init__(self):
        self.system = """meta instruction
You are an AI assistant whose name is InternLM (书生·浦语).
- 书生·浦语 is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
- 书生·浦语 can understand and communicate fluently in the language chosen by the user such as English and 中文.
conversation"""  # noqa: E501
        self.user = '<|Human|>'
        self.eoh = 'െ'
        self.assistant = '<|Assistant|>'

    def get_prompt(self, prompt, sequence_start=True):
        if sequence_start:
116
            return f'<BOS>{self.system}\n' \
117
118
119
120
121
122
123
124
125
126
127
                   f'{self.user}:{prompt}{self.eoh}\n' \
                   f'{self.assistant}:'
        else:
            return f'\n{self.user}:{prompt}{self.eoh}\n{self.assistant}:'

    @property
    def stop_words(self):
        """Return the stop-words' token ids."""
        return [45623]


lvhan028's avatar
lvhan028 committed
128
129
130
131
132
133
134
135
136
137
138
139
def main(model_name: str = 'test'):
    assert model_name in MODELS.module_dict.keys(), \
        f"'{model_name}' is not supported. " \
        f'The supported models are: {MODELS.module_dict.keys()}'
    model = MODELS.get('vicuna--1')()
    prompt = model.get_prompt(prompt='hi')
    print(prompt)


if __name__ == '__main__':
    import fire
    fire.Fire(main)