main.py 3.75 KB
Newer Older
1
2
3
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6

import requests
import json
7
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
8

9
10
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
11
from utils.utils import decode_token, get_current_user
12
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
Timothy J. Baek's avatar
Timothy J. Baek committed
13

Timothy J. Baek's avatar
Timothy J. Baek committed
14
15
import aiohttp

16
17
18
19
20
21
22
23
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
Timothy J. Baek's avatar
Timothy J. Baek committed
24

25
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
26

27
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
28
29


30
31
32
33
34
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
    if user and user.role == "admin":
        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
35
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
36

Timothy J. Baek's avatar
Timothy J. Baek committed
37

38
39
40
41
42
class UrlUpdateForm(BaseModel):
    url: str


@app.post("/url/update")
Timothy J. Baek's avatar
Timothy J. Baek committed
43
44
45
async def update_ollama_api_url(
    form_data: UrlUpdateForm, user=Depends(get_current_user)
):
46
47
48
    if user and user.role == "admin":
        app.state.OLLAMA_API_BASE_URL = form_data.url
        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
49
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


# async def fetch_sse(method, target_url, body, headers):
#     async with aiohttp.ClientSession() as session:
#         try:
#             async with session.request(
#                 method, target_url, data=body, headers=headers
#             ) as response:
#                 print(response.status)
#                 async for line in response.content:
#                     yield line
#         except Exception as e:
#             print(e)
#             error_detail = "Ollama WebUI: Server Connection Error"
#             yield json.dumps({"error": error_detail, "message": str(e)}).encode()
66

67

68
69
70
71
72
73
74
75
76
77
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
    target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"

    body = await request.body()
    headers = dict(request.headers)

    if user.role in ["user", "admin"]:
        if path in ["pull", "delete", "push", "copy", "create"]:
            if user.role != "admin":
Timothy J. Baek's avatar
Timothy J. Baek committed
78
79
80
                raise HTTPException(
                    status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
                )
81
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
82
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
83

Timothy J. Baek's avatar
Timothy J. Baek committed
84
85
86
87
    headers.pop("Host", None)
    headers.pop("Authorization", None)
    headers.pop("Origin", None)
    headers.pop("Referer", None)
Timothy J. Baek's avatar
Timothy J. Baek committed
88

Timothy J. Baek's avatar
Timothy J. Baek committed
89
90
    session = aiohttp.ClientSession()
    response = None
91
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
92
93
        response = await session.request(
            request.method, target_url, data=body, headers=headers
94
95
        )

Timothy J. Baek's avatar
Timothy J. Baek committed
96
97
98
99
100
        if not response.ok:
            data = await response.json()
            print(data)
            response.raise_for_status()

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
101
        async def generate():
Timothy J. Baek's avatar
Timothy J. Baek committed
102
103
104
105
            async for line in response.content:
                yield line
            await session.close()

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
106
        return StreamingResponse(generate(), response.status)
107
108

    except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
109
        print(e)
110
        error_detail = "Ollama WebUI: Server Connection Error"
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
111

Timothy J. Baek's avatar
Timothy J. Baek committed
112
        if response is not None:
113
            try:
Timothy J. Baek's avatar
Timothy J. Baek committed
114
                res = await response.json()
115
116
117
118
119
                if "error" in res:
                    error_detail = f"Ollama: {res['error']}"
            except:
                error_detail = f"Ollama: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
120
        await session.close()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
121

Timothy J. Baek's avatar
Timothy J. Baek committed
122
123
124
125
        raise HTTPException(
            status_code=response.status if response else 500,
            detail=error_detail,
        )