main.py 1.95 KB
Newer Older
1
import os
Timothy J. Baek's avatar
Timothy J. Baek committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from fastapi import (
    FastAPI,
    Request,
    Depends,
    HTTPException,
    status,
    UploadFile,
    File,
    Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel

from constants import ERROR_MESSAGES
from utils.utils import (
    decode_token,
    get_current_user,
    get_verified_user,
    get_admin_user,
)
from utils.misc import calculate_sha256

24
from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR, DEVICE_TYPE
25

26
if DEVICE_TYPE != "cuda":
27
28
    whisper_device_type = "cpu"

Timothy J. Baek's avatar
Timothy J. Baek committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.post("/transcribe")
def transcribe(
    file: UploadFile = File(...),
    user=Depends(get_current_user),
):
45
    print(file.content_type)
Timothy J. Baek's avatar
Timothy J. Baek committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

    if file.content_type not in ["audio/mpeg", "audio/wav"]:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
        )

    try:
        filename = file.filename
        file_path = f"{UPLOAD_DIR}/{filename}"
        contents = file.file.read()
        with open(file_path, "wb") as f:
            f.write(contents)
            f.close()

        model = WhisperModel(
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
62
            WHISPER_MODEL,
63
            device=whisper_device_type,
Timothy J. Baek's avatar
Timothy J. Baek committed
64
            compute_type="int8",
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
65
            download_root=WHISPER_MODEL_DIR,
Timothy J. Baek's avatar
Timothy J. Baek committed
66
67
68
        )

        segments, info = model.transcribe(file_path, beam_size=5)
69
        print(
Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
72
73
74
75
            "Detected language '%s' with probability %f"
            % (info.language, info.language_probability)
        )

        transcript = "".join([segment.text for segment in list(segments)])

76
        return {"text": transcript.strip()}
Timothy J. Baek's avatar
Timothy J. Baek committed
77
78

    except Exception as e:
79
        print(e)
Timothy J. Baek's avatar
Timothy J. Baek committed
80
81
82
83
84

        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERROR_MESSAGES.DEFAULT(e),
        )