utils.py 3.89 KB
Newer Older
1
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
Timothy J. Baek's avatar
Timothy J. Baek committed
2
from fastapi import HTTPException, status, Depends, Request
3
from sqlalchemy.orm import Session
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
4

5
from apps.webui.internal.db import get_db
6
from apps.webui.models.users import Users
liu.vaayne's avatar
liu.vaayne committed
7

8
9
from pydantic import BaseModel
from typing import Union, Optional
10
from constants import ERROR_MESSAGES
11
12
13
14
from passlib.context import CryptContext
from datetime import datetime, timedelta
import requests
import jwt
liu.vaayne's avatar
liu.vaayne committed
15
import uuid
Timothy J. Baek's avatar
Timothy J. Baek committed
16
import logging
17
18
import config

Timothy J. Baek's avatar
Timothy J. Baek committed
19
20
21
logging.getLogger("passlib").setLevel(logging.ERROR)


22
SESSION_SECRET = config.WEBUI_SECRET_KEY
23
24
25
26
27
28
ALGORITHM = "HS256"

##############
# Auth Utils
##############

Timothy J. Baek's avatar
Timothy J. Baek committed
29
bearer_security = HTTPBearer(auto_error=False)
30
31
32
33
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


def verify_password(plain_password, hashed_password):
Timothy J. Baek's avatar
Timothy J. Baek committed
34
35
36
    return (
        pwd_context.verify(plain_password, hashed_password) if hashed_password else None
    )
37
38
39
40
41
42


def get_password_hash(password):
    return pwd_context.hash(password)


Timothy J. Baek's avatar
Timothy J. Baek committed
43
def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
44
45
46
47
48
49
    payload = data.copy()

    if expires_delta:
        expire = datetime.utcnow() + expires_delta
        payload.update({"exp": expire})

50
    encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
51
52
53
54
55
    return encoded_jwt


def decode_token(token: str) -> Optional[dict]:
    try:
56
        decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
57
58
59
60
61
62
        return decoded
    except Exception as e:
        return None


def extract_token_from_auth_header(auth_header: str):
Timothy J. Baek's avatar
Timothy J. Baek committed
63
    return auth_header[len("Bearer ") :]
64
65


liu.vaayne's avatar
liu.vaayne committed
66
67
68
69
70
def create_api_key():
    key = str(uuid.uuid4()).replace("-", "")
    return f"sk-{key}"


Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
73
def get_http_authorization_cred(auth_header: str):
    try:
        scheme, credentials = auth_header.split(" ")
Timothy J. Baek's avatar
Timothy J. Baek committed
74
        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
Timothy J. Baek's avatar
Timothy J. Baek committed
75
76
77
78
    except:
        raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)


Timothy J. Baek's avatar
Timothy J. Baek committed
79
def get_current_user(
Timothy J. Baek's avatar
Timothy J. Baek committed
80
    request: Request,
Timothy J. Baek's avatar
Timothy J. Baek committed
81
    auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
82
    db=Depends(get_db),
Timothy J. Baek's avatar
Timothy J. Baek committed
83
):
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
84
    token = None
Timothy J. Baek's avatar
Timothy J. Baek committed
85
86
87
88

    if auth_token is not None:
        token = auth_token.credentials

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
89
90
91
92
93
94
    if token is None and "token" in request.cookies:
        token = request.cookies.get("token")

    if token is None:
        raise HTTPException(status_code=403, detail="Not authenticated")

liu.vaayne's avatar
liu.vaayne committed
95
    # auth by api key
Timothy J. Baek's avatar
Timothy J. Baek committed
96
    if token.startswith("sk-"):
97
        return get_current_user_by_api_key(db, token)
Timothy J. Baek's avatar
Timothy J. Baek committed
98

liu.vaayne's avatar
liu.vaayne committed
99
    # auth by jwt token
Timothy J. Baek's avatar
Timothy J. Baek committed
100
    data = decode_token(token)
101
    if data != None and "id" in data:
102
        user = Users.get_user_by_id(db, data["id"])
103
104
105
106
        if user is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail=ERROR_MESSAGES.INVALID_TOKEN,
107
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
108
        else:
109
            Users.update_user_last_active_by_id(db, user.id)
110
        return user
111
112
113
114
115
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.UNAUTHORIZED,
        )
116

Timothy J. Baek's avatar
Timothy J. Baek committed
117

118
119
def get_current_user_by_api_key(db: Session, api_key: str):
    user = Users.get_user_by_api_key(db, api_key)
Timothy J. Baek's avatar
Timothy J. Baek committed
120

liu.vaayne's avatar
liu.vaayne committed
121
122
123
124
125
    if user is None:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.INVALID_TOKEN,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
126
    else:
127
        Users.update_user_last_active_by_id(db, user.id)
Timothy J. Baek's avatar
Timothy J. Baek committed
128

liu.vaayne's avatar
liu.vaayne committed
129
    return user
130

Timothy J. Baek's avatar
Timothy J. Baek committed
131

Timothy J. Baek's avatar
Timothy J. Baek committed
132
def get_verified_user(user=Depends(get_current_user)):
133
134
135
136
137
    if user.role not in {"user", "admin"}:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
138
    return user
139
140


Timothy J. Baek's avatar
Timothy J. Baek committed
141
def get_admin_user(user=Depends(get_current_user)):
142
143
144
145
146
    if user.role != "admin":
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
147
    return user