utils.py 2.24 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) Alibaba, Inc. and its affiliates.

import time
import subprocess
from modelscope import snapshot_download as ms_snapshot_download
import multiprocessing as mp
import os
project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


def max_retries(max_attempts):
    def decorator(func):
        def wrapper(*args, **kwargs):
            attempts = 0
            while attempts < max_attempts:
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    attempts += 1
                    print(f"Retry {attempts}/{max_attempts}: {e}")
                    # wait 1 sec
                    time.sleep(1)
            raise Exception(f"Max retries ({max_attempts}) exceeded.")
        return wrapper
    return decorator


@max_retries(3)
def snapshot_download(*args, **kwargs):
    return ms_snapshot_download(*args, **kwargs)


def pre_download_models():
    snapshot_download('ly261666/cv_portrait_model', revision='v4.0')
    snapshot_download('YorickHe/majicmixRealistic_v6', revision='v1.0.0')
    snapshot_download('damo/face_chain_control_model', revision='v1.0.1')
    snapshot_download('ly261666/cv_wanx_style_model', revision='v1.0.3')
    snapshot_download('damo/face_chain_control_model', revision='v1.0.1')
    snapshot_download('Cherrytest/zjz_mj_jiyi_small_addtxt_fromleo', revision='v1.0.0')
    snapshot_download('Cherrytest/rot_bgr', revision='v1.0.0')
    snapshot_download('damo/face_frombase_c4', revision='v1.0.0')


def set_spawn_method():
    try:
        mp.set_start_method('spawn', force=True)
    except RuntimeError:
        print("spawn method already set")

def check_install(*args):
    try:
        subprocess.check_output(args, stderr=subprocess.STDOUT)
        return True
    except OSError as e:
        return False

def check_ffmpeg():
    """
    Check if ffmpeg is installed.
    """
    return check_install("ffmpeg", "-version")


def get_worker_data_dir() -> str:
    """
    Get the worker data directory.
    """
    return os.path.join(project_dir, "worker_data")


def join_worker_data_dir(*kwargs) -> str:
    """
    Join the worker data directory with the specified sub directory.
    """
    return os.path.join(get_worker_data_dir(), *kwargs)