utils.py 5.03 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
4
5
6
7
8
from fastapi import APIRouter, UploadFile, File, BackgroundTasks
from fastapi import Depends, HTTPException, status
from starlette.responses import StreamingResponse

from pydantic import BaseModel

import requests
import os
9
import aiohttp
Timothy J. Baek's avatar
Timothy J. Baek committed
10
import json
11
12
13

from utils.misc import calculate_sha256

14
from config import OLLAMA_API_BASE_URL, DATA_DIR, UPLOAD_DIR
Timothy J. Baek's avatar
Timothy J. Baek committed
15
16
from constants import ERROR_MESSAGES

Timothy J. Baek's avatar
Timothy J. Baek committed
17
18
19
20
21
22
23
24

router = APIRouter()


class UploadBlobForm(BaseModel):
    filename: str


Timothy J. Baek's avatar
Timothy J. Baek committed
25
26
27
28
from urllib.parse import urlparse


def parse_huggingface_url(hf_url):
29
30
31
    try:
        # Parse the URL
        parsed_url = urlparse(hf_url)
Timothy J. Baek's avatar
Timothy J. Baek committed
32

33
34
        # Get the path and split it into components
        path_components = parsed_url.path.split("/")
Timothy J. Baek's avatar
Timothy J. Baek committed
35

36
37
38
        # Extract the desired output
        user_repo = "/".join(path_components[1:3])
        model_file = path_components[-1]
Timothy J. Baek's avatar
Timothy J. Baek committed
39

40
41
42
        return model_file
    except ValueError:
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
43
44


Timothy J. Baek's avatar
Timothy J. Baek committed
45
async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
Timothy J. Baek's avatar
Timothy J. Baek committed
46
47
48
49
50
51
52
53
54
    done = False

    if os.path.exists(file_path):
        current_size = os.path.getsize(file_path)
    else:
        current_size = 0

    headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}

55
    timeout = aiohttp.ClientTimeout(total=600)  # Set the timeout
56
57
58

    async with aiohttp.ClientSession(timeout=timeout) as session:
        async with session.get(url, headers=headers) as response:
Timothy J. Baek's avatar
Timothy J. Baek committed
59
            total_size = int(response.headers.get("content-length", 0)) + current_size
60
61
62
63
64
65
66
67

            with open(file_path, "ab+") as file:
                async for data in response.content.iter_chunked(chunk_size):
                    current_size += len(data)
                    file.write(data)

                    done = current_size == total_size
                    progress = round((current_size / total_size) * 100, 2)
68
                    yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
69
70
71
72
73

                if done:
                    file.seek(0)
                    hashed = calculate_sha256(file)
                    file.seek(0)
Timothy J. Baek's avatar
Timothy J. Baek committed
74

75
76
                    url = f"{OLLAMA_API_BASE_URL}/blobs/sha256:{hashed}"
                    response = requests.post(url, data=file)
Timothy J. Baek's avatar
Timothy J. Baek committed
77

78
79
80
81
                    if response.ok:
                        res = {
                            "done": done,
                            "blob": f"sha256:{hashed}",
82
                            "name": file_name,
83
84
                        }
                        os.remove(file_path)
Timothy J. Baek's avatar
Timothy J. Baek committed
85

86
87
88
                        yield f"data: {json.dumps(res)}\n\n"
                    else:
                        raise "Ollama: Could not create blob, Please try again."
Timothy J. Baek's avatar
Timothy J. Baek committed
89
90
91


@router.get("/download")
Timothy J. Baek's avatar
Timothy J. Baek committed
92
93
94
async def download(
    url: str,
):
95
    # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
96
    file_name = parse_huggingface_url(url)
Timothy J. Baek's avatar
Timothy J. Baek committed
97

98
    if file_name:
99
        file_path = str(UPLOAD_DIR / file_name)
Timothy J. Baek's avatar
Timothy J. Baek committed
100

101
        return StreamingResponse(
102
103
            download_file_stream(url, file_path, file_name),
            media_type="text/event-stream",
104
105
106
        )
    else:
        return None
Timothy J. Baek's avatar
Timothy J. Baek committed
107
108


Timothy J. Baek's avatar
Timothy J. Baek committed
109
@router.post("/upload")
Timothy J. Baek's avatar
Timothy J. Baek committed
110
def upload(file: UploadFile = File(...)):
111
    file_path = UPLOAD_DIR / file.filename
Timothy J. Baek's avatar
Timothy J. Baek committed
112

Timothy J. Baek's avatar
Timothy J. Baek committed
113
    # Save file in chunks
114
    with file_path.open("wb+") as f:
Timothy J. Baek's avatar
Timothy J. Baek committed
115
116
        for chunk in file.file:
            f.write(chunk)
Timothy J. Baek's avatar
Timothy J. Baek committed
117

Timothy J. Baek's avatar
Timothy J. Baek committed
118
    def file_process_stream():
119
        total_size = os.path.getsize(str(file_path))
Timothy J. Baek's avatar
Timothy J. Baek committed
120
        chunk_size = 1024 * 1024
Timothy J. Baek's avatar
Timothy J. Baek committed
121
        try:
Timothy J. Baek's avatar
Timothy J. Baek committed
122
123
124
125
126
127
            with open(file_path, "rb") as f:
                total = 0
                done = False

                while not done:
                    chunk = f.read(chunk_size)
Timothy J. Baek's avatar
Timothy J. Baek committed
128
                    if not chunk:
Timothy J. Baek's avatar
Timothy J. Baek committed
129
130
131
                        done = True
                        continue

Timothy J. Baek's avatar
Timothy J. Baek committed
132
                    total += len(chunk)
133
                    progress = round((total / total_size) * 100, 2)
Timothy J. Baek's avatar
Timothy J. Baek committed
134
135

                    res = {
136
                        "progress": progress,
Timothy J. Baek's avatar
Timothy J. Baek committed
137
                        "total": total_size,
138
                        "completed": total,
Timothy J. Baek's avatar
Timothy J. Baek committed
139
140
141
142
                    }
                    yield f"data: {json.dumps(res)}\n\n"

                if done:
Timothy J. Baek's avatar
Timothy J. Baek committed
143
144
145
146
147
148
149
150
151
152
153
                    f.seek(0)
                    hashed = calculate_sha256(f)
                    f.seek(0)

                    url = f"{OLLAMA_API_BASE_URL}/blobs/sha256:{hashed}"
                    response = requests.post(url, data=f)

                    if response.ok:
                        res = {
                            "done": done,
                            "blob": f"sha256:{hashed}",
154
                            "name": file.filename,
Timothy J. Baek's avatar
Timothy J. Baek committed
155
156
157
158
                        }
                        os.remove(file_path)
                        yield f"data: {json.dumps(res)}\n\n"
                    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
159
160
161
                        raise Exception(
                            "Ollama: Could not create blob, Please try again."
                        )
Timothy J. Baek's avatar
Timothy J. Baek committed
162
163
164
165
166

        except Exception as e:
            res = {"error": str(e)}
            yield f"data: {json.dumps(res)}\n\n"

Timothy J. Baek's avatar
Timothy J. Baek committed
167
    return StreamingResponse(file_process_stream(), media_type="text/event-stream")