old_main.py 5.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from flask import Flask, request, Response, jsonify
from flask_cors import CORS

import requests
import json

from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH

app = Flask(__name__)
CORS(
    app
)  # Enable Cross-Origin Resource Sharing (CORS) to allow requests from different domains

# Define the target server URL
TARGET_SERVER_URL = OLLAMA_API_BASE_URL


@app.route("/url", methods=["GET"])
def get_ollama_api_url():
    headers = dict(request.headers)
    if "Authorization" in headers:
        _, credentials = headers["Authorization"].split()
        token_data = decode_token(credentials)
        if token_data is None or "email" not in token_data:
            return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401

        user = Users.get_user_by_email(token_data["email"])
        if user and user.role == "admin":
            return (
                jsonify({"OLLAMA_API_BASE_URL": TARGET_SERVER_URL}),
                200,
            )
        else:
            return (
                jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}),
                401,
            )
    else:
        return (
            jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}),
            401,
        )


@app.route("/url/update", methods=["POST"])
def update_ollama_api_url():
    headers = dict(request.headers)
    data = request.get_json(force=True)

    if "Authorization" in headers:
        _, credentials = headers["Authorization"].split()
        token_data = decode_token(credentials)
        if token_data is None or "email" not in token_data:
            return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401

        user = Users.get_user_by_email(token_data["email"])
        if user and user.role == "admin":
            TARGET_SERVER_URL = data["url"]
            return (
                jsonify({"OLLAMA_API_BASE_URL": TARGET_SERVER_URL}),
                200,
            )
        else:
            return (
                jsonify({"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}),
                401,
            )
    else:
        return (
            jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}),
            401,
        )


78
79
80
@app.route("/",
           defaults={"path": ""},
           methods=["GET", "POST", "PUT", "DELETE"])
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
@app.route("/<path:path>", methods=["GET", "POST", "PUT", "DELETE"])
def proxy(path):
    # Combine the base URL of the target server with the requested path
    target_url = f"{TARGET_SERVER_URL}/{path}"
    print(target_url)

    # Get data from the original request
    data = request.get_data()
    headers = dict(request.headers)

    # Basic RBAC support
    if WEBUI_AUTH:
        if "Authorization" in headers:
            _, credentials = headers["Authorization"].split()
            token_data = decode_token(credentials)
            if token_data is None or "email" not in token_data:
                return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401

            user = Users.get_user_by_email(token_data["email"])
            if user:
                # Only user and admin roles can access
                if user.role in ["user", "admin"]:
                    if path in ["pull", "delete", "push", "copy", "create"]:
                        # Only admin role can perform actions above
                        if user.role == "admin":
                            pass
                        else:
                            return (
109
110
111
112
                                jsonify({
                                    "detail":
                                    ERROR_MESSAGES.ACCESS_PROHIBITED
                                }),
113
114
115
116
117
                                401,
                            )
                    else:
                        pass
                else:
118
119
                    return jsonify(
                        {"detail": ERROR_MESSAGES.ACCESS_PROHIBITED}), 401
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
            else:
                return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
        else:
            return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
    else:
        pass

    r = None

    headers.pop("Host", None)
    headers.pop("Authorization", None)
    headers.pop("Origin", None)
    headers.pop("Referer", None)

    try:
        # Make a request to the target server
        r = requests.request(
            method=request.method,
            url=target_url,
            data=data,
            headers=headers,
            stream=True,  # Enable streaming for server-sent events
        )

        r.raise_for_status()

        # Proxy the target server's response to the client
        def generate():
            for chunk in r.iter_content(chunk_size=8192):
                yield chunk

        response = Response(generate(), status=r.status_code)

        # Copy headers from the target server's response to the client's response
        for key, value in r.headers.items():
            response.headers[key] = value

        return response
    except Exception as e:
        print(e)
        error_detail = "Ollama WebUI: Server Connection Error"
        if r != None:
            print(r.text)
            res = r.json()
            if "error" in res:
                error_detail = f"Ollama: {res['error']}"
            print(res)

        return (
169
170
171
172
            jsonify({
                "detail": error_detail,
                "message": str(e),
            }),
173
174
175
176
177
178
            400,
        )


if __name__ == "__main__":
    app.run(debug=True)