audio.py 2.11 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from typing import Dict

from .base import MediaHandler


class AudioHandler(MediaHandler):
    _instance = None

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super().__new__(cls)
        return cls._instance

    def get_media_signatures(self) -> Dict[bytes, str]:
        return {
            b"ID3": "mp3",
            b"\xff\xfb": "mp3",
            b"\xff\xf3": "mp3",
            b"\xff\xf2": "mp3",
            b"OggS": "ogg",
            b"fLaC": "flac",
        }

    def get_data_url_prefix(self) -> str:
        return "data:audio/"

    def get_data_url_pattern(self) -> str:
        return r"data:audio/(\w+);base64,(.+)"

    def get_default_extension(self) -> str:
        return "mp3"

    def is_base64(self, data: str) -> bool:
        if data.startswith(self.get_data_url_prefix()):
            return True

        try:
            import base64

            if len(data) % 4 == 0:
                base64.b64decode(data, validate=True)
                decoded = base64.b64decode(data[:100])
                for signature in self.get_media_signatures().keys():
                    if decoded.startswith(signature):
                        return True
                if decoded.startswith(b"RIFF") and b"WAVE" in decoded[:12]:
                    return True
                if decoded[:4] in [b"ftyp", b"\x00\x00\x00\x20", b"\x00\x00\x00\x18"]:
                    return True
        except Exception:
            return False

        return False

    def detect_extension(self, data: bytes) -> str:
        for signature, ext in self.get_media_signatures().items():
            if data.startswith(signature):
                return ext
        if data.startswith(b"RIFF") and b"WAVE" in data[:12]:
            return "wav"
        if data[:4] in [b"ftyp", b"\x00\x00\x00\x20", b"\x00\x00\x00\x18"]:
            return "m4a"
        return self.get_default_extension()


_handler = AudioHandler()


def is_base64_audio(data: str) -> bool:
    return _handler.is_base64(data)


def save_base64_audio(base64_data: str, output_dir: str) -> str:
    return _handler.save_base64(base64_data, output_dir)