client.py 7.45 KB
Newer Older
zhangwq5's avatar
online  
zhangwq5 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import requests
import json
import re

API_BASE_URL = "http://localhost:8000/v1"
# MODEL_NAME = "/home/zwq/model/Qwen3-30B-A3B" 
MODEL_NAME = "/home/zwq/model/Qwen3-30B-A3B-Instruct-2507"


class Qwen3ChatClient:
    def __init__(self, api_base_url=API_BASE_URL, model_name=MODEL_NAME):
        self.api_base_url = api_base_url
        self.model_name = model_name
        self.history = [] 

    def _parse_response(self, text):
        """
        解析模型响应,分离思考内容和最终内容。
        """
        thinking_content = ""
        main_content = text

        # re.DOTALL 确保 '.' 匹配包括换行符在内的所有字符
        # 这个正则表达式尝试捕获 <think> 和 </think> 之间的内容,以及其后的所有内容
        match = re.search(r'<think>(.*?)</think>(.*)', text, re.DOTALL)
        if match:
            thinking_content = match.group(1).strip()
            main_content = match.group(2).strip()
        # 如果没有 <think> 标签,则 thinking_content 保持为空,main_content 为原始文本
        return thinking_content, main_content

    def generate_response(self, user_input, enable_thinking=True, conversation_history=None):
        """
        向Qwen3模型发送请求并获取响应。
        :param user_input: 用户输入的消息。
        :param enable_thinking: 是否启用思考模式。True为思考模式(默认),False为非思考模式。
        :param conversation_history: 可选的对话历史列表,格式为 [{"role": "...", "content": "..."}]
        :return: (full_assistant_content, list_of_rank1_logprobs)
                 full_assistant_content 是模型的完整原始回答,包含思考内容
        """
        if conversation_history is None:
            conversation_history = [{"role": "system", "content": "You are a helpful assistant."}]
        
        # 将用户输入添加到当前会话历史
        current_messages = conversation_history + [{"role": "user", "content": user_input}]

        headers = {
            "Content-Type": "application/json"
        }

        # 采样参数,固定用于确定性或近似确定性测试
        temperature, top_p, top_k = 0.0, 1.0, 1
        
        payload = {
            "model": self.model_name,
            "messages": current_messages,
            "temperature": temperature,
            "top_p": top_p,
            "top_k": top_k,
            "max_tokens": 8192,
            "stream": False,
            "logprobs": True,  
            "extra_body": { 
                "enable_reasoning": enable_thinking 
            }
        }
        
        try:
            response = requests.post(f"{self.api_base_url}/chat/completions", headers=headers, json=payload)
            response.raise_for_status() # 检查HTTP错误

            response_data = response.json()
            
            if not response_data.get("choices"):
                print("错误: 响应中未找到 choices。")
                return "", []

            full_assistant_content = response_data["choices"][0]["message"]["content"]
            
            # --- 提取 Logprobs ---
            list_of_rank1_logprobs = []
            logprobs_data = response_data["choices"][0].get("logprobs", {}).get("content", [])
            for token_logprob_info in logprobs_data:
                # 在vLLM的OpenAI API兼容响应中,顶层"logprob"字段就是Rank 1的Logprob
                list_of_rank1_logprobs.append(token_logprob_info.get("logprob"))

            return full_assistant_content, list_of_rank1_logprobs

        except requests.exceptions.HTTPError as e:
            print(f"HTTP请求失败: {e}")
            print(f"响应内容: {e.response.text}")
            return "", []
        except requests.exceptions.RequestException as e:
            print(f"请求失败: {e}")
            return "", []
        except json.JSONDecodeError as e:
            print(f"JSON解析失败: {e} - 响应文本: {response.text[:200]}...")
            return "", []
        except Exception as e:
            print(f"发生未知错误: {e}")
            return "", []

# --- 示例用法 (已修改) ---
if __name__ == "__main__":
    chatbot = Qwen3ChatClient()
    print("欢迎使用 Qwen3-30B-A3B 聊天客户端!")
    print(f"已连接到 vLLM 服务,使用模型: {MODEL_NAME}")
    print("--------------------------------------------------")
    
    # 硬编码的 10 个随机问题
    test_questions = [
        "介绍一下北京.",
        "写一首关于春天的五言绝句.",
        "请解释一下黑洞的形成原理.",
        "推荐三部值得一看的科幻电影,并简述理由.",
        "如何有效提高编程能力?",
        "给我讲一个关于人工智能的笑话.",
        "你认为未来教育会发展成什么样?",
        "如何制作一道美味的麻婆豆腐?",
        "量子计算的原理是什么?它有哪些潜在应用?",
        "请用英语介绍一下中国长城.",
    ]

    results_to_save = [] 

    for i, question in enumerate(test_questions):
        print(f"\n--- 问题 {i+1}: {question!r} ---")
        
        full_content, rank1_logprobs = chatbot.generate_response(question, enable_thinking=True)

        thinking_part, main_answer = chatbot._parse_response(full_content)

        print(f"完整回答 (包含思考): {full_content!r}")
        if thinking_part:
            print(f"【思考过程】: {thinking_part!r}")
        print(f"【主要回答】: {main_answer!r}")
        
        thinking_end_tag = '</think>'
        logprobs_after_thinking = []

        if thinking_end_tag in full_content and rank1_logprobs:
            
            end_char_idx = full_content.find(thinking_end_tag) + len(thinking_end_tag)
            current_decoded_length = 0
            for j, logprob_val in enumerate(rank1_logprobs):

                raw_logprobs_data = response_data["choices"][0].get("logprobs", {}).get("content", [])
                if j < len(raw_logprobs_data):
                    token_text_from_api = raw_logprobs_data[j].get("token", "")
                    current_decoded_length += len(token_text_from_api)
                    
                    if current_decoded_length > end_char_idx:

                        logprobs_after_thinking = rank1_logprobs[j:]
                        break
            
            if not logprobs_after_thinking and thinking_end_tag in full_content:
                print("Warning: Could not accurately find logprobs after </think> tag. Using all logprobs.")
                logprobs_after_thinking = rank1_logprobs
            elif not thinking_end_tag in full_content:
                logprobs_after_thinking = rank1_logprobs 
        else:
            logprobs_after_thinking = rank1_logprobs 


        print("\n答案部分前10个Token的Rank 1 Logprobs:")
        for k, logprob_val in enumerate(logprobs_after_thinking[:10]):
            print(f"  Step {k}: {logprob_val:.4f}")

       
        results_to_save.append({
            "input": question,
            "output": main_answer, 
            "logprobs_of_rank1_for_the_first_10_tokens": logprobs_after_thinking[:10]
        })
        print("--------------------------------------------------")


    output_filename_client_all_results = './Qwen3-30B-A3B-Instruct-2507_logprobs_K100AI_fp16.json'
    with open(output_filename_client_all_results, 'w', encoding='utf-8') as f: 
        json.dump(results_to_save, f, indent=4, ensure_ascii=False) 

    print(f"\n所有测试结果已保存到文件: {output_filename_client_all_results}")