run_client_chat.py 6.22 KB
Newer Older
zhouxiang's avatar
zhouxiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
from subprocess import PIPE, Popen

from utils.get_run_config import get_command_with_extra, get_model_name
from utils.rule_condition_assert import assert_result


def command_line_test(config,
                      case,
                      case_info,
                      model_case,
                      type,
                      extra: str = None,
                      cuda_prefix: str = None):
    dst_path = config.get('dst_path')

    if type == 'api_client':
        cmd = 'lmdeploy serve api_client ' + extra
    elif type == 'triton_client':
        cmd = 'lmdeploy serve triton_client ' + extra
    else:
        cmd = get_command_with_extra('lmdeploy chat turbomind ' + dst_path +
                                     '/workspace_' + model_case,
                                     config,
                                     model_case,
                                     cuda_prefix=cuda_prefix)
        if 'kvint8' in model_case:
            cmd += ' --quant-policy 4'
            if 'w4' in model_case or '4bits' in model_case:
                cmd += ' --model-format awq'
            else:
                cmd += ' --model-format hf'
        elif 'w4' in model_case or '4bits' in model_case:
            cmd += ' --model-format awq'
        if 'chat' not in model_case.lower():
            cmd += ' --cap completion'
    return command_test(config, [cmd], model_case, case, case_info,
                        type == 'turbomind')


def hf_command_line_test(config,
                         case,
                         case_info,
                         model_case,
                         type,
                         cuda_prefix: str = None):
    model_path = config.get('model_path') + '/' + model_case

    cmd = get_command_with_extra(' '.join(['lmdeploy chat', type, model_path]),
                                 config,
                                 model_case,
                                 need_tp=True,
                                 cuda_prefix=cuda_prefix)

    if 'kvint8' in model_case:
        cmd += ' --quant-policy 4'
        if 'w4' in model_case or '4bits' in model_case:
            cmd += ' --model-format awq'
        else:
            cmd += ' --model-format hf'
    elif 'w4' in model_case or '4bits' in model_case:
        cmd += ' --model-format awq'
    return command_test(config, [cmd], model_case,
                        '_'.join(['hf', type, case]), case_info, True)


def command_test(config, cmd, model, case, case_info, need_extract_output):
    if 'memory_test' in case and 'chat' not in model.lower():
        return True, None, 'memory case skipped for base model'

    try:
        log_path = config.get('log_path')
        model_name = get_model_name(model)

        if '/' in model:
            chat_log = os.path.join(
                log_path, 'chat_' + model.split('/')[1] + '_' + case + '.log')
        else:
            chat_log = os.path.join(log_path,
                                    'chat_' + model + '_' + case + '.log')

        file = open(chat_log, 'w')

        returncode = -1
        result = True

        print('reproduce command chat: ' + ' '.join(cmd) + '\n')
        file.writelines('reproduce command chat: ' + ' '.join(cmd) + '\n')

        spliter = '\n\n'
        if 'CodeLlama-7b-Instruct-hf' in model:
            spliter = '\n!!\n'
        # join prompt together
        prompt = ''
        for item in case_info:
            prompt += list(item.keys())[0] + spliter
        prompt += 'exit' + spliter

        msg = ''

        with Popen(cmd,
                   stdin=PIPE,
                   stdout=PIPE,
                   stderr=PIPE,
                   shell=True,
                   text=True,
                   encoding='utf-8') as proc:
            # file.writelines('prompt:' + prompt + '\n')

            outputs, errors = proc.communicate(input=prompt)
            returncode = proc.returncode
            if returncode != 0:
                file.writelines('error:' + errors + '\n')
                result = False
                return result, chat_log, errors

            outputDialogs = parse_dialogue(outputs, model)
            file.writelines('answersize:' + str(len(outputDialogs)) + '\n')

            # 结果判断
            index = 0
            for prompt_detail in case_info:
                if need_extract_output:
                    output = extract_output(outputDialogs[index], model)
                else:
                    output = outputDialogs[index]
                case_result, reason = assert_result(output,
                                                    prompt_detail.values(),
                                                    model_name)
                file.writelines('prompt:' + list(prompt_detail.keys())[0] +
                                '\n')
                file.writelines('output:' + output + '\n')
                file.writelines('result:' + str(case_result) + ',reason:' +
                                reason + '\n')
                index += 1
                if case_result is False:
                    msg = reason
                result = result & case_result

        file.close()
        return result, chat_log, msg
    except Exception as e:
        return False, None, f'Unknown error: {e}'


# 从输出中解析模型输出的对话内容
def parse_dialogue(inputs: str, model: str):
    dialogues = inputs.strip()
    if 'CodeLlama-7b-Instruct-hf' in model:
        sep = 'enter !! to end the input >>>'
    else:
        sep = 'double enter to end input >>>'
    dialogues = dialogues.strip()
    dialogues = dialogues.split(sep)
    dialogues = [d.strip() for d in dialogues]
    return dialogues[1:-1]  # 去除首尾无用字符


def extract_output(output: str, model: str):
    if 'Qwen' in model or 'internlm2' in model:
        if len(output.split('<|im_start|>assistant')) >= 2:
            return output.split('<|im_start|>assistant')[1]
    if 'Baichuan2' in model:
        if len(output.split('<reserved_107>')) >= 2:
            return output.split('<reserved_107>')[1]
    if 'internlm' in model:
        if len(output.split('<|Bot|>: ')) >= 2:
            return output.split('<|Bot|>: ')[1]
    if 'llama' in model or 'Llama' in model:
        if len(output.split('[/INST]')) >= 2:
            return output.split('[/INST]')[1]

    return output