old_main.py 3.82 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
2
3
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
4
5
6

import requests
import json
Timothy J. Baek's avatar
Timothy J. Baek committed
7
from pydantic import BaseModel
8
9
10

from apps.web.models.users import Users
from constants import ERROR_MESSAGES
Timothy J. Baek's avatar
Timothy J. Baek committed
11
from utils.utils import decode_token, get_current_user
12
13
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH

Timothy J. Baek's avatar
Timothy J. Baek committed
14
import aiohttp
15

Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
18
19
20
21
22
23
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
24

Timothy J. Baek's avatar
Timothy J. Baek committed
25
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
26

Timothy J. Baek's avatar
Timothy J. Baek committed
27
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
28
29


Timothy J. Baek's avatar
Timothy J. Baek committed
30
31
32
33
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
    if user and user.role == "admin":
        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
34
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
35
36
37
38
39
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


class UrlUpdateForm(BaseModel):
    url: str
40
41


Timothy J. Baek's avatar
Timothy J. Baek committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@app.post("/url/update")
async def update_ollama_api_url(
    form_data: UrlUpdateForm, user=Depends(get_current_user)
):
    if user and user.role == "admin":
        app.state.OLLAMA_API_BASE_URL = form_data.url
        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
    else:
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


# async def fetch_sse(method, target_url, body, headers):
#     async with aiohttp.ClientSession() as session:
#         try:
#             async with session.request(
#                 method, target_url, data=body, headers=headers
#             ) as response:
#                 print(response.status)
#                 async for line in response.content:
#                     yield line
#         except Exception as e:
#             print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
64
#             error_detail = "Open WebUI: Server Connection Error"
Timothy J. Baek's avatar
Timothy J. Baek committed
65
66
67
68
69
70
#             yield json.dumps({"error": error_detail, "message": str(e)}).encode()


@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
    target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
71
72
    print(target_url)

Timothy J. Baek's avatar
Timothy J. Baek committed
73
    body = await request.body()
74
75
    headers = dict(request.headers)

Timothy J. Baek's avatar
Timothy J. Baek committed
76
77
78
79
80
81
    if user.role in ["user", "admin"]:
        if path in ["pull", "delete", "push", "copy", "create"]:
            if user.role != "admin":
                raise HTTPException(
                    status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
                )
82
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
83
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
84
85
86
87
88
89

    headers.pop("Host", None)
    headers.pop("Authorization", None)
    headers.pop("Origin", None)
    headers.pop("Referer", None)

Timothy J. Baek's avatar
Timothy J. Baek committed
90
91
    session = aiohttp.ClientSession()
    response = None
92
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
        response = await session.request(
            request.method, target_url, data=body, headers=headers
95
96
        )

Timothy J. Baek's avatar
Timothy J. Baek committed
97
98
99
100
101
        print(response)
        if not response.ok:
            data = await response.json()
            print(data)
            response.raise_for_status()
102

Timothy J. Baek's avatar
Timothy J. Baek committed
103
104
105
106
107
        async def generate():
            async for line in response.content:
                print(line)
                yield line
            await session.close()
108

Timothy J. Baek's avatar
Timothy J. Baek committed
109
        return StreamingResponse(generate(), response.status)
110
111
112

    except Exception as e:
        print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
113
        error_detail = "Open WebUI: Server Connection Error"
114

Timothy J. Baek's avatar
Timothy J. Baek committed
115
116
117
118
119
120
121
122
123
124
125
126
127
        if response is not None:
            try:
                res = await response.json()
                if "error" in res:
                    error_detail = f"Ollama: {res['error']}"
            except:
                error_detail = f"Ollama: {e}"

        await session.close()
        raise HTTPException(
            status_code=response.status if response else 500,
            detail=error_detail,
        )