hyper.py 4.79 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import Type

import gradio as gr

from swift.ui.base import BaseUI


class Hyper(BaseUI):

    group = 'llm_train'

    locale_dict = {
        'hyper_param': {
            'label': {
                'zh': '超参数',
                'en': 'Hyper settings',
            },
        },
        'batch_size': {
            'label': {
                'zh': '训练batch size',
                'en': 'Train batch size',
            },
            'info': {
                'zh': '训练的batch size',
                'en': 'Set the train batch size',
            }
        },
        'eval_batch_size': {
            'label': {
                'zh': '验证batch size',
                'en': 'Val batch size',
            },
            'info': {
                'zh': '验证的batch size',
                'en': 'Set the val batch size',
            }
        },
        'learning_rate': {
            'label': {
                'zh': '学习率',
                'en': 'Learning rate',
            },
            'info': {
                'zh': '设置学习率',
                'en': 'Set the learning rate',
            }
        },
        'eval_steps': {
            'label': {
                'zh': '交叉验证步数',
                'en': 'Eval steps',
            },
            'info': {
                'zh': '设置每隔多少步数进行一次验证',
                'en': 'Set the step interval to validate',
            }
        },
        'num_train_epochs': {
            'label': {
                'zh': '数据集迭代轮次',
                'en': 'Train epoch',
            },
            'info': {
                'zh': '设置对数据集训练多少轮次',
                'en': 'Set the max train epoch',
            }
        },
        'max_steps': {
            'label': {
                'zh': '最大迭代步数',
                'en': 'Max steps',
            },
            'info': {
                'zh': '设置最大迭代步数,该值如果大于零则数据集迭代次数不生效',
                'en': 'Set the max steps, if the value > 0 then num_train_epochs has no effects',
            }
        },
        'gradient_accumulation_steps': {
            'label': {
                'zh': '梯度累计步数',
                'en': 'Gradient accumulation steps',
            },
            'info': {
                'zh': '设置梯度累计步数以减小显存占用',
                'en': 'Set the gradient accumulation steps',
            }
        },
        'max_grad_norm': {
            'label': {
                'zh': '梯度裁剪',
                'en': 'Max grad norm',
            },
            'info': {
                'zh': '设置梯度裁剪',
                'en': 'Set the max grad norm',
            }
        },
        'predict_with_generate': {
            'label': {
                'zh': '使用生成指标代替loss',
                'en': 'Use generate metric instead of loss',
            },
            'info': {
                'zh': '验证时使用generate/Rouge代替loss',
                'en': 'Use model.generate/Rouge instead of loss',
            }
        },
        'use_flash_attn': {
            'label': {
                'zh': '使用Flash Attention',
                'en': 'Use Flash Attention',
            },
            'info': {
                'zh': '使用Flash Attention减小显存占用',
                'en': 'Use Flash Attention to reduce memory',
            }
        },
    }

    @classmethod
    def do_build_ui(cls, base_tab: Type['BaseUI']):
        with gr.Accordion(elem_id='hyper_param', open=True):
            with gr.Blocks():
                with gr.Row():
                    gr.Slider(elem_id='batch_size', minimum=1, maximum=256, step=2, scale=20)
                    learning_rate = gr.Textbox(elem_id='learning_rate', value='1e-4', lines=1, scale=20)
                    gr.Textbox(elem_id='num_train_epochs', lines=1, scale=20)
                    gr.Textbox(elem_id='max_steps', lines=1, scale=20)
                    gr.Slider(elem_id='gradient_accumulation_steps', minimum=1, maximum=256, step=2, value=16, scale=20)
                with gr.Row():
                    gr.Slider(elem_id='eval_batch_size', minimum=1, maximum=256, step=2, scale=20)
                    gr.Textbox(elem_id='eval_steps', lines=1, value='500', scale=20)
                    gr.Textbox(elem_id='max_grad_norm', lines=1, scale=20)
                    gr.Checkbox(elem_id='predict_with_generate', scale=20)
                    gr.Checkbox(elem_id='use_flash_attn', scale=20)

            def update_lr(sft_type):
                if sft_type == 'full':
                    return 1e-5
                else:
                    return 1e-4

            base_tab.element('sft_type').change(
                update_lr, inputs=[base_tab.element('sft_type')], outputs=[learning_rate])