download_rife.py 3.56 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python3
# coding: utf-8

import os
import sys
import requests
import zipfile
import shutil
import argparse
from pathlib import Path


def get_base_dir():
    """Get project root directory"""
    return Path(__file__).parent.parent


def download_file(url, save_path):
    """Download file"""
    print(f"Starting download: {url}")
    response = requests.get(url, stream=True)
    response.raise_for_status()

    total_size = int(response.headers.get("content-length", 0))
    downloaded_size = 0

    with open(save_path, "wb") as f:
        for chunk in response.iter_content(chunk_size=8192):
            if chunk:
                f.write(chunk)
                downloaded_size += len(chunk)
                if total_size > 0:
                    progress = (downloaded_size / total_size) * 100
                    print(f"\rDownload progress: {progress:.1f}%", end="", flush=True)

    print(f"\nDownload completed: {save_path}")


def extract_zip(zip_path, extract_to):
    """Extract zip file"""
    print(f"Starting extraction: {zip_path}")
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extract_to)
    print(f"Extraction completed: {extract_to}")


def find_flownet_pkl(extract_dir):
    """Find flownet.pkl file in extracted directory"""
    for root, dirs, files in os.walk(extract_dir):
        for file in files:
            if file == "flownet.pkl":
                return os.path.join(root, file)
    return None


def main():
    parser = argparse.ArgumentParser(description="Download RIFE model to specified directory")
    parser.add_argument("target_directory", help="Target directory path")

    args = parser.parse_args()

    target_dir = Path(args.target_directory)
    if not target_dir.is_absolute():
        target_dir = Path.cwd() / target_dir

    base_dir = get_base_dir()
    temp_dir = base_dir / "_temp"

    # Create temporary directory
    temp_dir.mkdir(exist_ok=True)

    target_dir.mkdir(parents=True, exist_ok=True)

    zip_url = "https://huggingface.co/hzwer/RIFE/resolve/main/RIFEv4.26_0921.zip"
    zip_path = temp_dir / "RIFEv4.26_0921.zip"

    try:
        # Download zip file
        download_file(zip_url, zip_path)

        # Extract file
        extract_zip(zip_path, temp_dir)

        # Find flownet.pkl file
        flownet_pkl = find_flownet_pkl(temp_dir)
        if flownet_pkl:
            # Copy flownet.pkl to target directory
            target_file = target_dir / "flownet.pkl"
            shutil.copy2(flownet_pkl, target_file)
            print(f"flownet.pkl copied to: {target_file}")
        else:
            print("Error: flownet.pkl file not found")
            return 1

        # Clean up temporary files
        print("Cleaning up temporary files...")
        if zip_path.exists():
            zip_path.unlink()
            print(f"Deleted: {zip_path}")

            # Delete extracted folders
        for item in temp_dir.iterdir():
            if item.is_dir():
                shutil.rmtree(item)
                print(f"Deleted directory: {item}")

        # Delete the temp directory itself if empty
        if temp_dir.exists() and not any(temp_dir.iterdir()):
            temp_dir.rmdir()
            print(f"Deleted temp directory: {temp_dir}")

        print("RIFE model download and installation completed!")
        return 0

    except Exception as e:
        print(f"Error: {e}")
        return 1
    finally:
        if zip_path.exists():
            try:
                zip_path.unlink()
            except Exception as e:
                print(f"Error: {e}")


if __name__ == "__main__":
    sys.exit(main())