utils.py 5.42 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
from fastapi import APIRouter, UploadFile, File, BackgroundTasks
from fastapi import Depends, HTTPException, status
3
4
from starlette.responses import StreamingResponse, FileResponse

Timothy J. Baek's avatar
Timothy J. Baek committed
5
6
7
8
9

from pydantic import BaseModel

import requests
import os
10
import aiohttp
Timothy J. Baek's avatar
Timothy J. Baek committed
11
import json
12

13
14

from utils.utils import get_admin_user
15
from utils.misc import calculate_sha256, get_gravatar_url
16

17
from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR
Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
from constants import ERROR_MESSAGES

Timothy J. Baek's avatar
Timothy J. Baek committed
20
21
22
23
24
25
26
27

router = APIRouter()


class UploadBlobForm(BaseModel):
    filename: str


Timothy J. Baek's avatar
Timothy J. Baek committed
28
29
30
31
from urllib.parse import urlparse


def parse_huggingface_url(hf_url):
32
33
34
    try:
        # Parse the URL
        parsed_url = urlparse(hf_url)
Timothy J. Baek's avatar
Timothy J. Baek committed
35

36
37
        # Get the path and split it into components
        path_components = parsed_url.path.split("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
38

39
40
41
        # Extract the desired output
        user_repo = "/".join(path_components[1:3])
        model_file = path_components[-1]
Timothy J. Baek's avatar
Timothy J. Baek committed
42

43
44
45
        return model_file
    except ValueError:
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
46
47


Timothy J. Baek's avatar
Timothy J. Baek committed
48
async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
Timothy J. Baek's avatar
Timothy J. Baek committed
49
50
51
52
53
54
55
56
57
    done = False

    if os.path.exists(file_path):
        current_size = os.path.getsize(file_path)
    else:
        current_size = 0

    headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}

58
    timeout = aiohttp.ClientTimeout(total=600)  # Set the timeout
59
60
61

    async with aiohttp.ClientSession(timeout=timeout) as session:
        async with session.get(url, headers=headers) as response:
Timothy J. Baek's avatar
Timothy J. Baek committed
62
            total_size = int(response.headers.get("content-length", 0)) + current_size
63
64
65
66
67
68
69
70

            with open(file_path, "ab+") as file:
                async for data in response.content.iter_chunked(chunk_size):
                    current_size += len(data)
                    file.write(data)

                    done = current_size == total_size
                    progress = round((current_size / total_size) * 100, 2)
71
                    yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
72
73
74
75
76

                if done:
                    file.seek(0)
                    hashed = calculate_sha256(file)
                    file.seek(0)
Timothy J. Baek's avatar
Timothy J. Baek committed
77

78
                    url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
79
                    response = requests.post(url, data=file)
Timothy J. Baek's avatar
Timothy J. Baek committed
80

81
82
83
84
                    if response.ok:
                        res = {
                            "done": done,
                            "blob": f"sha256:{hashed}",
85
                            "name": file_name,
86
87
                        }
                        os.remove(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
88

89
90
91
                        yield f"data: {json.dumps(res)}\n\n"
                    else:
                        raise "Ollama: Could not create blob, Please try again."
Timothy J. Baek's avatar
Timothy J. Baek committed
92
93
94


@router.get("/download")
Timothy J. Baek's avatar
Timothy J. Baek committed
95
96
97
async def download(
    url: str,
):
98
    # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
99
    file_name = parse_huggingface_url(url)
Timothy J. Baek's avatar
Timothy J. Baek committed
100

101
    if file_name:
102
        file_path = f"{UPLOAD_DIR}/{file_name}"
Timothy J. Baek's avatar
Timothy J. Baek committed
103

104
        return StreamingResponse(
105
106
            download_file_stream(url, file_path, file_name),
            media_type="text/event-stream",
107
108
109
        )
    else:
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
110
111


Timothy J. Baek's avatar
Timothy J. Baek committed
112
@router.post("/upload")
Timothy J. Baek's avatar
Timothy J. Baek committed
113
def upload(file: UploadFile = File(...)):
114
    file_path = f"{UPLOAD_DIR}/{file.filename}"
Timothy J. Baek's avatar
Timothy J. Baek committed
115

Timothy J. Baek's avatar
Timothy J. Baek committed
116
    # Save file in chunks
lucasew's avatar
lucasew committed
117
    with open(file_path, "wb+") as f:
Timothy J. Baek's avatar
Timothy J. Baek committed
118
119
        for chunk in file.file:
            f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
120

Timothy J. Baek's avatar
Timothy J. Baek committed
121
    def file_process_stream():
lucasew's avatar
lucasew committed
122
        total_size = os.path.getsize(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
123
        chunk_size = 1024 * 1024
Timothy J. Baek's avatar
Timothy J. Baek committed
124
        try:
Timothy J. Baek's avatar
Timothy J. Baek committed
125
126
127
128
129
130
            with open(file_path, "rb") as f:
                total = 0
                done = False

                while not done:
                    chunk = f.read(chunk_size)
Timothy J. Baek's avatar
Timothy J. Baek committed
131
                    if not chunk:
Timothy J. Baek's avatar
Timothy J. Baek committed
132
133
134
                        done = True
                        continue

Timothy J. Baek's avatar
Timothy J. Baek committed
135
                    total += len(chunk)
136
                    progress = round((total / total_size) * 100, 2)
Timothy J. Baek's avatar
Timothy J. Baek committed
137
138

                    res = {
139
                        "progress": progress,
Timothy J. Baek's avatar
Timothy J. Baek committed
140
                        "total": total_size,
141
                        "completed": total,
Timothy J. Baek's avatar
Timothy J. Baek committed
142
143
144
145
                    }
                    yield f"data: {json.dumps(res)}\n\n"

                if done:
Timothy J. Baek's avatar
Timothy J. Baek committed
146
147
148
149
                    f.seek(0)
                    hashed = calculate_sha256(f)
                    f.seek(0)

150
                    url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
Timothy J. Baek's avatar
Timothy J. Baek committed
151
152
153
154
155
156
                    response = requests.post(url, data=f)

                    if response.ok:
                        res = {
                            "done": done,
                            "blob": f"sha256:{hashed}",
157
                            "name": file.filename,
Timothy J. Baek's avatar
Timothy J. Baek committed
158
159
160
161
                        }
                        os.remove(file_path)
                        yield f"data: {json.dumps(res)}\n\n"
                    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
162
163
164
                        raise Exception(
                            "Ollama: Could not create blob, Please try again."
                        )
Timothy J. Baek's avatar
Timothy J. Baek committed
165
166
167
168
169

        except Exception as e:
            res = {"error": str(e)}
            yield f"data: {json.dumps(res)}\n\n"

Timothy J. Baek's avatar
Timothy J. Baek committed
170
    return StreamingResponse(file_process_stream(), media_type="text/event-stream")
171
172
173
174
175
176
177


@router.get("/gravatar")
async def get_gravatar(
    email: str,
):
    return get_gravatar_url(email)
178
179
180
181
182
183
184
185
186
187


@router.get("/db/download")
async def download_db(user=Depends(get_admin_user)):

    return FileResponse(
        f"{DATA_DIR}/webui.db",
        media_type="application/octet-stream",
        filename="webui.db",
    )