model_manager.py 21.7 KB
Newer Older
wangkaixiong's avatar
init  
wangkaixiong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import sys
import time
import json
import shutil
import glob
import requests
import subprocess
from pathlib import Path
from typing import Optional, Callable, Dict, Any

# 尝试导入modelscope和pycsghub
try:
    from modelscope.hub.snapshot_download import snapshot_download
    has_modelscope = True
except ImportError:
    print("Warning: modelscope not installed, download functionality will be limited")
    has_modelscope = False

try:
    from pycsghub.upload_large_folder.main import upload_large_folder_internal, create_repo
    from pycsghub.csghub_api import CsgHubApi
    has_pycsghub = True
except ImportError:
    print("Warning: pycsghub not installed, upload functionality will be limited")
    has_pycsghub = False


class ModelManager:
    """模型管理器,用于下载和上传模型"""
    
    def __init__(self):
        """初始化模型管理器"""
        self.default_download_path = os.path.expanduser("~/models")
        self.csghub_config = {
            "base_url": "http://10.17.27.227:4997",
            "token": "f5dad38a9426410aa861155cd184f84a",
            "repo_type": "model",
            "revision": "main"
        }
        
        # 确保默认下载路径存在
        os.makedirs(self.default_download_path, exist_ok=True)
    
    def download_model(self, model_id: str, local_path: str = None, 
                      progress_callback: Optional[Callable] = None,
                      cancel_check: Optional[Callable] = None) -> str:
        """
        从ModelScope下载模型
        
        Args:
            model_id: 模型ID,格式为"组织/模型名"
            local_path: 本地保存路径,默认为~/models
            progress_callback: 进度回调函数,接收(progress, detail)参数
            cancel_check: 取消检查函数,返回True表示已取消
        
        Returns:
            str: 下载的模型路径
        
        Raises:
            Exception: 下载失败时抛出异常
        """
        if not model_id:
            raise ValueError("模型ID不能为空")
        
        # 设置本地路径
        if not local_path:
            local_path = self.default_download_path
        
        # 确保本地路径存在
        os.makedirs(local_path, exist_ok=True)
        
        # 模型保存的完整路径
        model_path = os.path.join(local_path, model_id)
        
        # 打印下载信息到终端
        print(f"\n{'='*50}")
        print(f"开始下载模型")
        print(f"模型ID: {model_id}")
        print(f"本地路径: {model_path}")
        print(f"{'='*50}")
        
        # 如果模型已存在,先删除
        if os.path.exists(model_path):
            print(f"模型已存在,正在删除: {model_path}")
            shutil.rmtree(model_path)
        
        # 调用回调函数
        if progress_callback:
            progress_callback(0, f"开始下载模型 {model_id}")
        
        try:
            if has_modelscope:
                # 使用modelscope下载
                print(f"使用modelscope下载模型: {model_id}")
                print(f"下载目标路径: {model_path}")
                
                # 检查是否已取消
                if cancel_check and cancel_check():
                    print(f"下载任务已取消: {model_id}")
                    if progress_callback:
                        progress_callback(-1, {"error": "下载已取消"})
                    raise Exception("下载已取消")
                
                # 注意:modelscope不直接支持进度回调,我们将在下载后计算文件数量
                # 使用进程池执行snapshot_download,以便可以强制终止
                import multiprocessing
                
                # 定义下载函数
                def download_func():
                    try:
                        return snapshot_download(
                            model_id=model_id,
                            cache_dir=local_path,
                            revision="master"
                        )
                    except Exception as e:
                        print(f"下载出错: {e}")
                        raise
                
                # 创建进程
                process = multiprocessing.Process(target=download_func)
                process.daemon = True
                process.start()
                
                # 定期检查是否已取消
                while process.is_alive():
                    if cancel_check and cancel_check():
                        print(f"下载任务已取消: {model_id}")
                        # 强制终止进程
                        process.terminate()
                        process.join(timeout=1)
                        if process.is_alive():
                            process.kill()
                        if progress_callback:
                            progress_callback(-1, {"error": "下载已取消"})
                        raise Exception("下载已取消")
                    time.sleep(0.1)
                
                # 检查进程是否正常退出
                if process.exitcode != 0:
                    raise Exception("下载进程异常退出")
                
                print(f"modelscope下载完成,正在处理文件...")
                
                # 下载完成后,计算文件数量并更新进度
                if os.path.exists(model_path):
                    # 获取文件列表
                    all_files = []
                    for root, dirs, files in os.walk(model_path):
                        # 检查是否已取消
                        if cancel_check and cancel_check():
                            print(f"下载任务已取消: {model_id}")
                            if progress_callback:
                                progress_callback(-1, {"error": "下载已取消"})
                            raise Exception("下载已取消")
                        for file in files:
                            all_files.append(os.path.join(root, file))
                    
                    file_count = len(all_files)
                    print(f"发现 {file_count} 个文件")
                    
                    # 按文件数量更新进度
                    for i, file_path in enumerate(all_files):
                        # 检查是否已取消
                        if cancel_check and cancel_check():
                            print(f"下载任务已取消: {model_id}")
                            if progress_callback:
                                progress_callback(-1, {"error": "下载已取消"})
                            raise Exception("下载已取消")
                        progress = int((i + 1) / file_count * 100)
                        rel_path = os.path.relpath(file_path, model_path)
                        file_size = os.path.getsize(file_path)
                        print(f"[{progress}%] 已下载: {rel_path} ({self.get_dir_size(file_path)})")
                        if progress_callback:
                            progress_callback(progress, {
                                "file_count": i + 1,
                                "total_files": file_count,
                                "current_file": rel_path,
                                "file_size": file_size
                            })
                        time.sleep(0.05)  # 减少延迟
            else:
                # 直接使用modelscope下载(不使用模拟模式)
                print(f"modelscope未安装,无法下载模型: {model_id}")
                raise Exception("modelscope未安装,无法下载模型")
            
            if progress_callback:
                progress_callback(100, {
                    "file_count": file_count,
                    "total_files": file_count,
                    "current_file": "完成",
                    "message": f"模型 {model_id} 下载完成"
                })
            
            # 打印下载完成信息
            print(f"\n{'='*50}")
            print(f"模型下载完成!")
            print(f"模型ID: {model_id}")
            print(f"下载路径: {model_path}")
            print(f"文件数量: {file_count}")
            print(f"{'='*50}")
            
            return model_path
            
        except Exception as e:
            error_msg = str(e)
            if progress_callback:
                progress_callback(-1, {"error": error_msg})
            
            # 打印错误信息
            print(f"\n{'='*50}")
            print(f"下载失败!")
            print(f"模型ID: {model_id}")
            print(f"错误信息: {error_msg}")
            print(f"{'='*50}")
            
            raise Exception(f"下载模型 {model_id} 失败: {error_msg}")
    
    def upload_model(self, local_path: str, repo_id: str, 
                    create_repo_flag: bool = True,
                    progress_callback: Optional[Callable] = None) -> Dict[str, Any]:
        """
        上传模型到CsgHub
        
        Args:
            local_path: 本地模型路径
            repo_id: 仓库ID
            create_repo_flag: 是否创建仓库
            progress_callback: 进度回调函数,接收(progress, detail)参数
        
        Returns:
            Dict[str, Any]: 上传结果
        
        Raises:
            Exception: 上传失败时抛出异常
        """
        if not local_path or not os.path.exists(local_path):
            raise ValueError(f"本地路径 {local_path} 不存在")
        
        if not repo_id:
            raise ValueError("仓库ID不能为空")
        
        # 调用回调函数
        if progress_callback:
            progress_callback(0, f"开始上传模型到仓库 {repo_id}")
        
        try:
            # 首先获取所有文件列表,用于计算进度
            all_files = []
            for root, dirs, files in os.walk(local_path):
                for file in files:
                    file_path = os.path.join(root, file)
                    all_files.append(file_path)
            
            file_count = len(all_files)
            
            if file_count == 0:
                raise ValueError(f"本地路径 {local_path} 中没有文件")
            
            if has_pycsghub:
                # 使用pycsghub上传
                csg_api = CsgHubApi()
                use_full_repo_id = f"root/{repo_id}"
                
                # 创建仓库
                if create_repo_flag:
                    if progress_callback:
                        progress_callback(5, "正在创建仓库...")
                    
                    create_repo(
                        api=csg_api,
                        repo_id=use_full_repo_id,
                        repo_type=self.csghub_config["repo_type"],
                        revision=self.csghub_config["revision"],
                        endpoint=self.csghub_config["base_url"],
                        token=self.csghub_config["token"]
                    )
                
                # 上传模型
                if progress_callback:
                    progress_callback(10, f"准备上传 {file_count} 个文件...")
                
                # 创建一个自定义的进度回调函数
                def custom_upload_callback(current_file_index, current_file_path, total_files):
                    """自定义上传进度回调"""
                    progress = int((current_file_index + 1) / total_files * 90) + 10  # 10% - 100%
                    rel_path = os.path.relpath(current_file_path, local_path)
                    if progress_callback:
                        progress_callback(progress, f"上传中 {current_file_index + 1}/{total_files}: {rel_path}")
                
                # 执行上传 - 注意:pycsghub可能不直接支持文件级别的进度回调
                # 这里我们将在上传完成后模拟文件级别的进度
                upload_large_folder_internal(
                    repo_id=use_full_repo_id,
                    local_path=local_path,
                    repo_type=self.csghub_config["repo_type"],
                    revision=self.csghub_config["revision"],
                    endpoint=self.csghub_config["base_url"],
                    token=self.csghub_config["token"],
                    allow_patterns=None,
                    ignore_patterns=None,
                    num_workers=1,
                    print_report=False,
                    print_report_every=1,
                )
                
                # 上传完成后,模拟文件级别的进度更新
                for i, file_path in enumerate(all_files):
                    progress = int((i + 1) / file_count * 90) + 10  # 10% - 100%
                    rel_path = os.path.relpath(file_path, local_path)
                    if progress_callback:
                        progress_callback(progress, f"已上传 {i + 1}/{file_count}: {rel_path}")
                    time.sleep(0.05)  # 模拟处理延迟
            else:
                # 直接使用pycsghub上传(不使用模拟模式)
                print(f"pycsghub未安装,无法上传模型: {repo_id}")
                raise Exception("pycsghub未安装,无法上传模型")
            
            if progress_callback:
                progress_callback(100, f"模型上传完成,仓库ID: {repo_id},共上传 {file_count} 个文件")
            
            return {
                "success": True,
                "repo_id": repo_id,
                "file_count": file_count,
                "message": f"模型上传成功,共上传 {file_count} 个文件"
            }
            
        except Exception as e:
            if progress_callback:
                progress_callback(-1, f"上传失败: {str(e)}")
            raise Exception(f"上传模型失败: {str(e)}")
    


    def list_models(self, local_path: str = None) -> list:
        """
        列出本地模型
        
        Args:
            local_path: 本地模型路径,默认为~/models
        
        Returns:
            list: 模型列表
        """
        if not local_path:
            local_path = self.default_download_path
        
        if not os.path.exists(local_path):
            print(f"Model path does not exist: {local_path}")
            return []
        
        models = []
        
        try:
            print(f"Listing models from: {local_path}")
            items = os.listdir(local_path)
            print(f"Found {len(items)} items in directory")
            
            # 遍历一级目录
            for item in items:
                item_path = os.path.join(local_path, item)
                print(f"Checking item: {item} (type: {'dir' if os.path.isdir(item_path) else 'file'})")
                
                if os.path.isdir(item_path):
                    # 检查一级目录下的二级子目录
                    try:
                        sub_items = os.listdir(item_path)
                        print(f"  Found {len(sub_items)} sub-items in {item}")
                        
                        for sub_item in sub_items:
                            sub_item_path = os.path.join(item_path, sub_item)
                            
                            if os.path.isdir(sub_item_path):
                                print(f"  Checking sub-directory: {sub_item}")
                                
                                # 检查是否有 README.md
                                has_readme = os.path.exists(os.path.join(sub_item_path, "README.md"))
                                
                                # 检查是否有 .safetensors 或 .bin 文件
                                has_safetensors_or_bin = False
                                try:
                                    for file in os.listdir(sub_item_path):
                                        if file.endswith('.safetensors') or file.endswith('.bin'):
                                            has_safetensors_or_bin = True
                                            break
                                except Exception as e:
                                    print(f"    Error checking files in {sub_item_path}: {e}")
                                    continue
                                
                                print(f"    - README.md: {has_readme}")
                                print(f"    - has .safetensors or .bin: {has_safetensors_or_bin}")
                                
                                # 判断是否为模型目录(必须有README.md和.safetensors/.bin文件)
                                if has_readme and has_safetensors_or_bin:
                                    # 获取模型信息,使用前端期望的字段名
                                    model_info = {
                                        "id": sub_item,  # 使用二级目录名作为id
                                        "path": sub_item_path,
                                        "size": self.get_dir_size(sub_item_path),
                                        "status": "downloaded",  # 默认状态
                                        "downloadTime": self.get_dir_creation_time(sub_item_path),
                                        "uploadTime": None,
                                        "upload_repo_id": None,
                                        "file_count": 0  # 计算文件数量
                                    }
                                    
                                    # 计算文件数量
                                    file_count = 0
                                    for root, dirs, files in os.walk(sub_item_path):
                                        file_count += len(files)
                                    model_info["file_count"] = file_count
                                    
                                    models.append(model_info)
                                    print(f"    + Added as model: {sub_item} (in {item}/{sub_item})")
                                else:
                                    print(f"    - Skipped (missing required files)")
                            else:
                                print(f"    - Skipped (not a directory): {sub_item}")
                                
                    except Exception as e:
                        print(f"  Error processing sub-directories in {item_path}: {e}")
                        continue
                else:
                    print(f"  - Skipped (not a directory): {item}")
            
            print(f"Total models found: {len(models)}")
            
        except Exception as e:
            print(f"Error listing models: {e}")
        
        return models
    
    def get_dir_size(self, path: str) -> str:
        """
        获取目录大小
        
        Args:
            path: 目录路径
        
        Returns:
            str: 格式化的大小字符串
        """
        total_size = 0
        
        for root, dirs, files in os.walk(path):
            for file in files:
                file_path = os.path.join(root, file)
                total_size += os.path.getsize(file_path)
        
        # 格式化大小
        if total_size < 1024:
            return f"{total_size}B"
        elif total_size < 1024 * 1024:
            return f"{total_size / 1024:.1f}KB"
        elif total_size < 1024 * 1024 * 1024:
            return f"{total_size / (1024 * 1024):.1f}MB"
        else:
            return f"{total_size / (1024 * 1024 * 1024):.1f}GB"
    
    def get_dir_creation_time(self, path: str) -> str:
        """
        获取目录创建时间
        
        Args:
            path: 目录路径
        
        Returns:
            str: 格式化的时间字符串
        """
        try:
            # 获取目录创建时间
            stat = os.stat(path)
            
            # 尝试获取创建时间,不同系统可能有不同的属性
            if hasattr(stat, 'st_birthtime'):  # macOS
                creation_time = stat.st_birthtime
            else:  # Linux
                creation_time = stat.st_mtime  # 使用修改时间作为创建时间
            
            return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(creation_time))
        except Exception:
            return "未知"
    
    def delete_model(self, model_path: str) -> bool:
        """
        删除本地模型
        
        Args:
            model_path: 模型完整路径
        
        Returns:
            bool: 是否删除成功
        """
        if not model_path or not os.path.exists(model_path):
            print(f"[DEBUG] 模型路径不存在: {model_path}")
            return False
        
        try:
            print(f"[DEBUG] 开始删除模型目录: {model_path}")
            shutil.rmtree(model_path)
            print(f"[DEBUG] 模型目录删除成功: {model_path}")
            return True
        except Exception as e:
            print(f"[DEBUG] 删除模型失败: {str(e)}")
            return False


# 测试代码
if __name__ == "__main__":
    # 直接在这里创建ModelManager实例,避免循环导入
    class TestModelManager:
        """测试用的模型管理器"""
        
        def list_models(self):
            """列出模型"""
            return [
                {"model_id": "test-model-1", "size": "1.2GB", "created_at": "2024-01-01 10:00:00"},
                {"model_id": "test-model-2", "size": "800MB", "created_at": "2024-01-02 14:30:00"}
            ]
    
    # 使用测试类
    manager = TestModelManager()
    
    # 测试列出模型
    print("测试列出模型...")
    models = manager.list_models()
    for model in models:
        print(f"模型: {model['model_id']}, 大小: {model['size']}, 创建时间: {model['created_at']}")
    
    print("\n注意: 这是一个简化的测试模式,完整功能需要通过app.py运行")