main.py 2.73 KB
Newer Older
1
import os
2
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from fastapi import (
    FastAPI,
    Request,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel

from constants import ERROR_MESSAGES
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
from utils.misc import calculate_sha256

Timothy J. Baek's avatar
Timothy J. Baek committed
25
26
27
28
29
30
from config import (
    SRC_LOG_LEVELS,
    CACHE_DIR,
    UPLOAD_DIR,
    WHISPER_MODEL,
    WHISPER_MODEL_DIR,
31
    WHISPER_MODEL_AUTO_UPDATE,
Jannik Streidl's avatar
Jannik Streidl committed
32
    DEVICE_TYPE,
Timothy J. Baek's avatar
Timothy J. Baek committed
33
)
34
35
36

log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
Timothy J. Baek's avatar
Timothy J. Baek committed
37
38
39
40
41
42
43
44
45
46

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

Jannik Streidl's avatar
Jannik Streidl committed
47
48
49
50
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
log.info(f"whisper_device_type: {whisper_device_type}")

Timothy J. Baek's avatar
Timothy J. Baek committed
51
52
53
54
55
56

@app.post("/transcribe")
def transcribe(
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
57
    log.info(f"file.content_type: {file.content_type}")
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

    if file.content_type not in ["audio/mpeg", "audio/wav"]:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
        )

    try:
        filename = file.filename
        file_path = f"{UPLOAD_DIR}/{filename}"
        contents = file.file.read()
        with open(file_path, "wb") as f:
            f.write(contents)
            f.close()

73
74
75
76
77
78
79
80
81
82
        whisper_kwargs = {
            "model_size_or_path": WHISPER_MODEL,
            "device": whisper_device_type,
            "compute_type": "int8",
            "download_root": WHISPER_MODEL_DIR,
            "local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
        }

        log.debug(f"whisper_kwargs: {whisper_kwargs}")

Self Denial's avatar
Self Denial committed
83
        try:
84
85
            model = WhisperModel(**whisper_kwargs)
        except:
86
            log.warning(
Self Denial's avatar
Self Denial committed
87
88
                "WhisperModel initialization failed, attempting download with local_files_only=False"
            )
89
90
            whisper_kwargs["local_files_only"] = False
            model = WhisperModel(**whisper_kwargs)
Timothy J. Baek's avatar
Timothy J. Baek committed
91
92

        segments, info = model.transcribe(file_path, beam_size=5)
93
        log.info(
Timothy J. Baek's avatar
Timothy J. Baek committed
94
95
96
97
98
99
            "Detected language '%s' with probability %f"
            % (info.language, info.language_probability)
        )

        transcript = "".join([segment.text for segment in list(segments)])

100
        return {"text": transcript.strip()}
Timothy J. Baek's avatar
Timothy J. Baek committed
101
102

    except Exception as e:
103
        log.exception(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
104
105
106
107
108

        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )