simple_test.py 5.45 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import argparse
import threading
import time
import requests
from collections import defaultdict
import random

class TTSStressTester:
    def __init__(self, urls, data, concurrency, requests_per_thread):
        self.urls = urls
        self.data = data
        self.concurrency = concurrency
        self.requests_per_thread = requests_per_thread
        self.stats = {
            'total': 0,
            'success': 0,
            'fail': 0,
            'durations': [],
            'status_codes': defaultdict(int),
            'errors': defaultdict(int)
        }
        self.lock = threading.Lock()
        self.current_url_index = 0
        self.url_lock = threading.Lock()  # 用于轮询URL的锁

    def _get_next_url(self):
        with self.url_lock:
            url = self.urls[self.current_url_index]
            self.current_url_index = (self.current_url_index + 1) % len(self.urls)
        return url

    def _send_request(self):
        start_time = time.time()
        try:
            # 生成随机数字符串,确保不触发 vllm 的 cache
            self.data["text"] = ",".join(["".join([str(random.randint(0, 9)) for _ in range(5)]) for _ in range(5)])
            target_url = self._get_next_url()  # 获取轮询后的URL
            response = requests.post(target_url, json=self.data, timeout=10)
            elapsed = time.time() - start_time
            
            with self.lock:
                self.stats['durations'].append(elapsed)
                self.stats['status_codes'][response.status_code] += 1
                self.stats['total'] += 1
                if response.status_code == 200:
                    content_type = response.headers.get('Content-Type', '')
                    if 'audio' in content_type:
                        self.stats['success'] += 1
                    else:
                        self.stats['fail'] += 1
                        self.stats['errors']['invalid_content_type'] += 1
                else:
                    self.stats['fail'] += 1
                    
        except Exception as e:
            with self.lock:
                self.stats['fail'] += 1
                self.stats['errors'][str(type(e).__name__)] += 1
                self.stats['durations'].append(time.time() - start_time)

    def _worker(self):
        for _ in range(self.requests_per_thread):
            self._send_request()

    def run(self):
        threads = []
        start_time = time.time()
        
        for _ in range(self.concurrency):
            thread = threading.Thread(target=self._worker)
            thread.start()
            threads.append(thread)
        
        for thread in threads:
            thread.join()

        total_time = time.time() - start_time
        self._generate_report(total_time)

    def _generate_report(self, total_time):
        durations = self.stats['durations']
        total_requests = self.stats['total']
        
        print(f"\n{' 测试报告 ':=^40}")
        print(f"总请求时间: {total_time:.2f}秒")
        print(f"总请求量: {total_requests}")
        print(f"成功请求: {self.stats['success']}")
        print(f"失败请求: {self.stats['fail']}")
        
        if durations:
            avg_duration = sum(durations) / len(durations)
            max_duration = max(durations)
            min_duration = min(durations)
            print(f"\n响应时间统计:")
            print(f"平均: {avg_duration:.3f}秒")
            print(f"最大: {max_duration:.3f}秒")
            print(f"最小: {min_duration:.3f}秒")
            
            sorted_durations = sorted(durations)
            for p in [50, 90, 95, 99]:
                index = int(p / 100 * len(sorted_durations))
                print(f"P{p}: {sorted_durations[index]:.3f}秒")

        print("\n状态码分布:")
        for code, count in self.stats['status_codes'].items():
            print(f"HTTP {code}: {count}次")

        if self.stats['errors']:
            print("\n错误统计:")
            for error, count in self.stats['errors'].items():
                print(f"{error}: {count}次")

        print(f"\n吞吐量: {total_requests / total_time:.2f} 请求/秒")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='TTS服务压力测试脚本')
    parser.add_argument('--urls', nargs='+', 
                        default=['http://localhost:6006/tts'],
                        help='TTS服务地址列表(多个用空格分隔)')
    parser.add_argument('--text', type=str, default='测试文本', help='需要合成的文本内容')
    parser.add_argument('--character', type=str, default='jay_klee', help='合成角色名称')
    parser.add_argument('--concurrency', type=int, default=16, help='并发线程数')
    parser.add_argument('--requests', type=int, default=5, help='每个线程的请求数')
    
    args = parser.parse_args()
    
    test_data = {
        "text": args.text,
        "character": args.character
    }
    
    tester = TTSStressTester(
        urls=args.urls,
        data=test_data,
        concurrency=args.concurrency,
        requests_per_thread=args.requests
    )
    
    print(f"开始压力测试,配置参数:")
    print(f"目标服务: {', '.join(args.urls)}")
    print(f"并发线程: {args.concurrency}")
    print(f"单线程请求数: {args.requests}")
    print(f"总预计请求量: {args.concurrency * args.requests}")
    print(f"{' 测试启动 ':=^40}")
    
    try:
        tester.run()
    except KeyboardInterrupt:
        print("\n测试被用户中断")