model.py 3.03 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os.path
from typing import Type

import gradio as gr

from swift.llm import MODEL_MAPPING, ModelType
from swift.ui.base import BaseUI


class Model(BaseUI):

    group = 'llm_eval'

    locale_dict = {
        'checkpoint': {
            'value': {
                'zh': '训练后的模型',
                'en': 'Trained model'
            }
        },
        'model_type': {
            'label': {
                'zh': '选择模型',
                'en': 'Select Model'
            },
            'info': {
                'zh': 'SWIFT已支持的模型名称',
                'en': 'Base model supported by SWIFT'
            }
        },
        'model_id_or_path': {
            'label': {
                'zh': '模型id或路径',
                'en': 'Model id or path'
            },
            'info': {
                'zh': '实际的模型id',
                'en': 'The actual model id or model path'
            }
        },
        'reset': {
            'value': {
                'zh': '恢复初始值',
                'en': 'Reset to default'
            },
        },
    }

    @classmethod
    def do_build_ui(cls, base_tab: Type['BaseUI']):
        with gr.Row():
            model_type = gr.Dropdown(
                elem_id='model_type',
                choices=[base_tab.locale('checkpoint', cls.lang)['value']] + ModelType.get_model_name_list()
                + cls.get_custom_name_list(),
                value=base_tab.locale('checkpoint', cls.lang)['value'],
                allow_custom_value=True,
                scale=20)
            model_id_or_path = gr.Textbox(elem_id='model_id_or_path', lines=1, scale=20, interactive=True)
            reset_btn = gr.Button(elem_id='reset', scale=2)
            model_state = gr.State({})

        def update_input_model(choice, model_state=None):
            if choice == base_tab.locale('checkpoint', cls.lang)['value']:
                if model_state and choice in model_state:
                    model_id_or_path = model_state[choice]
                else:
                    model_id_or_path = None
            else:
                if model_state and choice in model_state:
                    model_id_or_path = model_state[choice]
                else:
                    model_id_or_path = MODEL_MAPPING.get(choice, {}).get('model_id_or_path')
            return model_id_or_path

        def update_model_id_or_path(model_type, path, model_state):
            if not path or not os.path.exists(path):
                return gr.update()
            model_state[model_type] = path
            return model_state

        model_type.change(update_input_model, inputs=[model_type, model_state], outputs=[model_id_or_path])

        model_id_or_path.change(
            update_model_id_or_path, inputs=[model_type, model_id_or_path, model_state], outputs=[model_state])

        def reset(model_type):
            model_id_or_path = update_input_model(model_type)
            return model_id_or_path, {}

        reset_btn.click(reset, inputs=[model_type], outputs=[model_id_or_path, model_state])