model.py 6.03 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os.path
from typing import Type

import gradio as gr
import json

from swift.llm import MODEL_MAPPING, TEMPLATE_MAPPING, ModelType
from swift.ui.base import BaseUI
from swift.ui.llm_infer.generate import Generate


class Model(BaseUI):

    llm_train = 'llm_infer'

    sub_ui = [Generate]

    is_inference = os.environ.get('USE_INFERENCE') == '1' or os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio'

    locale_dict = {
        'checkpoint': {
            'value': {
                'zh': '训练后的模型',
                'en': 'Trained model'
            }
        },
        'model_type': {
            'label': {
                'zh': '选择模型',
                'en': 'Select Model'
            },
            'info': {
                'zh': 'SWIFT已支持的模型名称',
                'en': 'Base model supported by SWIFT'
            }
        },
        'load_checkpoint': {
            'value': {
                'zh': '加载模型' if is_inference else '部署模型',
                'en': 'Load model' if is_inference else 'Deploy model',
            }
        },
        'model_id_or_path': {
            'label': {
                'zh': '模型id或路径',
                'en': 'Model id or path'
            },
            'info': {
                'zh': '实际的模型id',
                'en': 'The actual model id or model path'
            }
        },
        'template_type': {
            'label': {
                'zh': '模型Prompt模板类型',
                'en': 'Prompt template type'
            },
            'info': {
                'zh': '选择匹配模型的Prompt模板',
                'en': 'Choose the template type of the model'
            }
        },
        'system': {
            'label': {
                'zh': 'system字段',
                'en': 'system'
            },
            'info': {
                'zh': 'system字段支持在加载模型后修改',
                'en': 'system can be modified after the model weights loaded'
            }
        },
        'more_params': {
            'label': {
                'zh': '更多参数',
                'en': 'More params'
            },
            'info': {
                'zh': '以json格式填入',
                'en': 'Fill in with json format'
            }
        },
        'reset': {
            'value': {
                'zh': '恢复初始值',
                'en': 'Reset to default'
            },
        },
    }

    @classmethod
    def do_build_ui(cls, base_tab: Type['BaseUI']):
        with gr.Row():
            model_type = gr.Dropdown(
                elem_id='model_type',
                choices=[base_tab.locale('checkpoint', cls.lang)['value']] + ModelType.get_model_name_list()
                + cls.get_custom_name_list(),
                value=base_tab.locale('checkpoint', cls.lang)['value'],
                scale=20)
            model_id_or_path = gr.Textbox(elem_id='model_id_or_path', lines=1, scale=20, interactive=True)
            template_type = gr.Dropdown(
                elem_id='template_type', choices=list(TEMPLATE_MAPPING.keys()) + ['AUTO'], scale=20)
            reset_btn = gr.Button(elem_id='reset', scale=2)
            model_state = gr.State({})
        with gr.Row():
            system = gr.Textbox(elem_id='system', lines=4, scale=20)
        Generate.build_ui(base_tab)
        with gr.Row():
            gr.Textbox(elem_id='more_params', lines=1, scale=20)
            gr.Button(elem_id='load_checkpoint', scale=2, variant='primary')

        def update_input_model(choice, model_state=None):
            if choice == base_tab.locale('checkpoint', cls.lang)['value']:
                if model_state and choice in model_state:
                    model_id_or_path = model_state[choice]
                else:
                    model_id_or_path = None
                default_system = None
                template = None
            else:
                if model_state and choice in model_state:
                    model_id_or_path = model_state[choice]
                else:
                    model_id_or_path = MODEL_MAPPING[choice]['model_id_or_path']
                default_system = getattr(TEMPLATE_MAPPING[MODEL_MAPPING[choice]['template']]['template'],
                                         'default_system', None)
                template = MODEL_MAPPING[choice]['template']
            return model_id_or_path, default_system, template

        def update_model_id_or_path(model_type, path, system, template_type, model_state):
            if not path or not os.path.exists(path):
                return gr.update(), gr.update(), gr.update()
            local_path = os.path.join(path, 'sft_args.json')
            if not os.path.exists(local_path):
                default_system = getattr(TEMPLATE_MAPPING[MODEL_MAPPING[model_type]['template']]['template'],
                                         'default_system', None)
                template = MODEL_MAPPING[model_type]['template']
                return default_system, template, model_state

            with open(local_path, 'r') as f:
                sft_args = json.load(f)
            base_model_type = sft_args['model_type']
            system = getattr(TEMPLATE_MAPPING[MODEL_MAPPING[base_model_type]['template']]['template'], 'default_system',
                             None)
            model_state[model_type] = path
            return sft_args['system'] or system, sft_args['template_type'], model_state

        model_type.change(
            update_input_model, inputs=[model_type, model_state], outputs=[model_id_or_path, system, template_type])

        model_id_or_path.change(
            update_model_id_or_path,
            inputs=[model_type, model_id_or_path, system, template_type, model_state],
            outputs=[system, template_type, model_state])

        def reset(model_type):
            model_id_or_path, default_system, template = update_input_model(model_type)
            return model_id_or_path, default_system, template, {}

        reset_btn.click(reset, inputs=[model_type], outputs=[model_id_or_path, system, template_type, model_state])