utils.py 3.73 KB
Newer Older
1
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
Timothy J. Baek's avatar
Timothy J. Baek committed
2
from fastapi import HTTPException, status, Depends, Request
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
3

4
from apps.webui.models.users import Users
liu.vaayne's avatar
liu.vaayne committed
5

6
7
from pydantic import BaseModel
from typing import Union, Optional
8
from constants import ERROR_MESSAGES
9
10
11
12
from passlib.context import CryptContext
from datetime import datetime, timedelta
import requests
import jwt
liu.vaayne's avatar
liu.vaayne committed
13
import uuid
Timothy J. Baek's avatar
Timothy J. Baek committed
14
import logging
15
16
import config

Timothy J. Baek's avatar
Timothy J. Baek committed
17
18
19
logging.getLogger("passlib").setLevel(logging.ERROR)


20
SESSION_SECRET = config.WEBUI_SECRET_KEY
21
22
23
24
25
26
ALGORITHM = "HS256"

##############
# Auth Utils
##############

Timothy J. Baek's avatar
Timothy J. Baek committed
27
28
bearer_security = HTTPBearer(auto_error=False)

29
30
31
32
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


def verify_password(plain_password, hashed_password):
Timothy J. Baek's avatar
Timothy J. Baek committed
33
34
35
    return (
        pwd_context.verify(plain_password, hashed_password) if hashed_password else None
    )
36
37
38
39
40
41


def get_password_hash(password):
    return pwd_context.hash(password)


Timothy J. Baek's avatar
Timothy J. Baek committed
42
def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
43
44
45
46
47
48
    payload = data.copy()

    if expires_delta:
        expire = datetime.utcnow() + expires_delta
        payload.update({"exp": expire})

49
    encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
50
51
52
53
54
    return encoded_jwt


def decode_token(token: str) -> Optional[dict]:
    try:
55
        decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
56
57
58
59
60
61
        return decoded
    except Exception as e:
        return None


def extract_token_from_auth_header(auth_header: str):
Timothy J. Baek's avatar
Timothy J. Baek committed
62
    return auth_header[len("Bearer ") :]
63
64


liu.vaayne's avatar
liu.vaayne committed
65
66
67
68
69
def create_api_key():
    key = str(uuid.uuid4()).replace("-", "")
    return f"sk-{key}"


Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
72
def get_http_authorization_cred(auth_header: str):
    try:
        scheme, credentials = auth_header.split(" ")
Timothy J. Baek's avatar
Timothy J. Baek committed
73
        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
Timothy J. Baek's avatar
Timothy J. Baek committed
74
75
76
77
    except:
        raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)


Timothy J. Baek's avatar
Timothy J. Baek committed
78
def get_current_user(
Timothy J. Baek's avatar
Timothy J. Baek committed
79
    request: Request,
Timothy J. Baek's avatar
Timothy J. Baek committed
80
81
    auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
):
Timothy J. Baek's avatar
Timothy J. Baek committed
82
83
84
85
86
87
88
89
90
    # get token from cookie
    token = request.cookies.get("token")

    if auth_token is None and token is None:
        raise HTTPException(status_code=403, detail="Not authenticated")

    if auth_token is not None:
        token = auth_token.credentials

liu.vaayne's avatar
liu.vaayne committed
91
    # auth by api key
Timothy J. Baek's avatar
Timothy J. Baek committed
92
93
94
    if token.startswith("sk-"):
        return get_current_user_by_api_key(token)

liu.vaayne's avatar
liu.vaayne committed
95
    # auth by jwt token
Timothy J. Baek's avatar
Timothy J. Baek committed
96
    data = decode_token(token)
97
98
    if data != None and "id" in data:
        user = Users.get_user_by_id(data["id"])
99
100
101
102
        if user is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail=ERROR_MESSAGES.INVALID_TOKEN,
103
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
        else:
            Users.update_user_last_active_by_id(user.id)
106
        return user
107
108
109
110
111
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.UNAUTHORIZED,
        )
112

Timothy J. Baek's avatar
Timothy J. Baek committed
113

liu.vaayne's avatar
liu.vaayne committed
114
def get_current_user_by_api_key(api_key: str):
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
115
    user = Users.get_user_by_api_key(api_key)
Timothy J. Baek's avatar
Timothy J. Baek committed
116

liu.vaayne's avatar
liu.vaayne committed
117
118
119
120
121
    if user is None:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.INVALID_TOKEN,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
122
123
124
    else:
        Users.update_user_last_active_by_id(user.id)

liu.vaayne's avatar
liu.vaayne committed
125
    return user
126

Timothy J. Baek's avatar
Timothy J. Baek committed
127

Timothy J. Baek's avatar
Timothy J. Baek committed
128
def get_verified_user(user=Depends(get_current_user)):
129
130
131
132
133
    if user.role not in {"user", "admin"}:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
134
    return user
135
136


Timothy J. Baek's avatar
Timothy J. Baek committed
137
def get_admin_user(user=Depends(get_current_user)):
138
139
140
141
142
    if user.role != "admin":
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
143
    return user