main.py 3.83 KB
Newer Older
1
2
3
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6

import requests
import json
7
from pydantic import BaseModel
Timothy J. Baek's avatar
Timothy J. Baek committed
8

9
10
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
11
from utils.utils import decode_token, get_current_user
12
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
Timothy J. Baek's avatar
Timothy J. Baek committed
13

Timothy J. Baek's avatar
Timothy J. Baek committed
14
15
import aiohttp

16
17
18
19
20
21
22
23
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
Timothy J. Baek's avatar
Timothy J. Baek committed
24

25
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
26

27
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
Timothy J. Baek's avatar
Timothy J. Baek committed
28
29


30
31
32
33
34
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
    if user and user.role == "admin":
        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
35
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
36

Timothy J. Baek's avatar
Timothy J. Baek committed
37

38
39
40
41
42
class UrlUpdateForm(BaseModel):
    url: str


@app.post("/url/update")
Timothy J. Baek's avatar
Timothy J. Baek committed
43
44
45
async def update_ollama_api_url(
    form_data: UrlUpdateForm, user=Depends(get_current_user)
):
46
47
48
    if user and user.role == "admin":
        app.state.OLLAMA_API_BASE_URL = form_data.url
        return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
49
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)


# async def fetch_sse(method, target_url, body, headers):
#     async with aiohttp.ClientSession() as session:
#         try:
#             async with session.request(
#                 method, target_url, data=body, headers=headers
#             ) as response:
#                 print(response.status)
#                 async for line in response.content:
#                     yield line
#         except Exception as e:
#             print(e)
#             error_detail = "Ollama WebUI: Server Connection Error"
#             yield json.dumps({"error": error_detail, "message": str(e)}).encode()
66

67

68
69
70
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
    target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
Timothy J. Baek's avatar
Timothy J. Baek committed
71
    print(target_url)
72
73
74
75
76
77
78

    body = await request.body()
    headers = dict(request.headers)

    if user.role in ["user", "admin"]:
        if path in ["pull", "delete", "push", "copy", "create"]:
            if user.role != "admin":
Timothy J. Baek's avatar
Timothy J. Baek committed
79
80
81
                raise HTTPException(
                    status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
                )
82
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
83
        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
84

Timothy J. Baek's avatar
Timothy J. Baek committed
85
86
87
88
    headers.pop("Host", None)
    headers.pop("Authorization", None)
    headers.pop("Origin", None)
    headers.pop("Referer", None)
Timothy J. Baek's avatar
Timothy J. Baek committed
89

Timothy J. Baek's avatar
Timothy J. Baek committed
90
91
    session = aiohttp.ClientSession()
    response = None
92
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
        response = await session.request(
            request.method, target_url, data=body, headers=headers
95
96
        )

Timothy J. Baek's avatar
Timothy J. Baek committed
97
        print(response)
Timothy J. Baek's avatar
Timothy J. Baek committed
98
99
100
101
102
        if not response.ok:
            data = await response.json()
            print(data)
            response.raise_for_status()

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
103
        async def generate():
Timothy J. Baek's avatar
Timothy J. Baek committed
104
            async for line in response.content:
Timothy J. Baek's avatar
Timothy J. Baek committed
105
                print(line)
Timothy J. Baek's avatar
Timothy J. Baek committed
106
107
108
                yield line
            await session.close()

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
109
        return StreamingResponse(generate(), response.status)
110
111

    except Exception as e:
Timothy J. Baek's avatar
Timothy J. Baek committed
112
        print(e)
113
        error_detail = "Ollama WebUI: Server Connection Error"
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
114

Timothy J. Baek's avatar
Timothy J. Baek committed
115
        if response is not None:
116
            try:
Timothy J. Baek's avatar
Timothy J. Baek committed
117
                res = await response.json()
118
119
120
121
122
                if "error" in res:
                    error_detail = f"Ollama: {res['error']}"
            except:
                error_detail = f"Ollama: {e}"

Timothy J. Baek's avatar
Timothy J. Baek committed
123
124
125
126
127
        await session.close()
        raise HTTPException(
            status_code=response.status if response else 500,
            detail=error_detail,
        )