model.py 1.96 KB
Newer Older
lvhan028's avatar
lvhan028 committed
1
2
3
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import Registry

lvhan028's avatar
lvhan028 committed
4
MODELS = Registry('model', locations=['lmdeploy.model'])
lvhan028's avatar
lvhan028 committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25


@MODELS.register_module(name='vicuna')
class Vicuna:

    def __init__(self):
        self.system = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """  # noqa: E501
        self.user = 'USER'
        self.assistant = 'ASSISTANT'

    def get_prompt(self, prompt, sequence_start=True):
        if sequence_start:
            return f'{self.system} {self.user}: {prompt} {self.assistant}:'
        else:
            return f'</s>{self.user}: {prompt} {self.assistant}:'

    @property
    def stop_words(self):
        return None


26
27
@MODELS.register_module(name='internlm')
class InternLM:
lvhan028's avatar
lvhan028 committed
28
29

    def __init__(self):
30
31
32
33
34
        self.system = ''
        self.user = '<|User|>'
        self.eoh = '<eoh>'
        self.eoa = '<eoa>'
        self.assistant = '<|Bot|>'
lvhan028's avatar
lvhan028 committed
35
36
37
38

    def get_prompt(self, prompt, sequence_start=True):
        if sequence_start:
            return f'{self.system}\n' \
39
                   f'{self.user}:{prompt}{self.eoh}\n' \
lvhan028's avatar
lvhan028 committed
40
41
                   f'{self.assistant}:'
        else:
42
43
            return f'\n{self.user}:{prompt}{self.eoh}\n' \
                   f'{self.assistant}:'
lvhan028's avatar
lvhan028 committed
44
45
46

    @property
    def stop_words(self):
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        return [103027, 103028]


@MODELS.register_module(name='llama')
class Llama:

    def __init__(self):
        pass

    def get_prompt(self, prompt, sequence_start=True):
        return prompt

    @property
    def stop_words(self):
        return None
lvhan028's avatar
lvhan028 committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75


def main(model_name: str = 'test'):
    assert model_name in MODELS.module_dict.keys(), \
        f"'{model_name}' is not supported. " \
        f'The supported models are: {MODELS.module_dict.keys()}'
    model = MODELS.get('vicuna--1')()
    prompt = model.get_prompt(prompt='hi')
    print(prompt)


if __name__ == '__main__':
    import fire
    fire.Fire(main)