Unverified Commit a37e1247 authored by Brayden Zhong's avatar Brayden Zhong Committed by GitHub
Browse files

[Multimodal][Perf] Use `pybase64` instead of `base64` (#7724)

parent 136c6e04
......@@ -38,6 +38,7 @@ runtime_common = [
"psutil",
"pydantic",
"pynvml",
"pybase64",
"python-multipart",
"pyzmq>=25.1.2",
"soundfile==0.13.1",
......
......@@ -814,9 +814,9 @@ def sample_mmmu_requests(
List of tuples (prompt, prompt_token_len, output_token_len).
"""
try:
import base64
import io
import pybase64
from datasets import load_dataset
except ImportError:
raise ImportError("Please install datasets: pip install datasets")
......@@ -867,7 +867,7 @@ def sample_mmmu_requests(
# Encode image to base64
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
img_str = pybase64.b64encode(buffered.getvalue()).decode("utf-8")
image_data = f"data:image/jpeg;base64,{img_str}"
else:
continue
......
import base64
import copy
import dataclasses
import multiprocessing
......@@ -7,6 +6,7 @@ import threading
import time
from typing import Any, Dict, List, Optional, Tuple, Union
import pybase64
import requests
import torch
import torch.distributed as dist
......
......@@ -28,12 +28,12 @@ LLaVA-Onevision : https://arxiv.org/pdf/2408.03326
"""
import ast
import base64
import math
import re
from io import BytesIO
import numpy as np
import pybase64
from PIL import Image
from sglang.srt.utils import flatten_nested_list
......@@ -252,7 +252,7 @@ def process_anyres_image(image, processor, grid_pinpoints):
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
return Image.open(BytesIO(pybase64.b64decode(image, validate=True)))
def expand2square(pil_img, background_color):
......
......@@ -15,7 +15,6 @@
from __future__ import annotations
import base64
import builtins
import ctypes
import dataclasses
......@@ -68,6 +67,7 @@ from typing import (
import numpy as np
import psutil
import pybase64
import requests
import torch
import torch.distributed
......@@ -616,7 +616,7 @@ def decode_video_base64(video_base64):
from PIL import Image
# Decode the base64 string
video_bytes = base64.b64decode(video_base64)
video_bytes = pybase64.b64decode(video_base64, validate=True)
# Placeholder for the start indices of each PNG image
img_starts = []
......@@ -702,7 +702,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
audio, original_sr = sf.read(BytesIO(audio_file))
elif audio_file.startswith("data:"):
audio_file = audio_file.split(",")[1]
audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
audio, original_sr = sf.read(
BytesIO(pybase64.b64decode(audio_file, validate=True))
)
elif audio_file.startswith("http://") or audio_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
response = requests.get(audio_file, stream=True, timeout=timeout)
......@@ -771,12 +773,12 @@ def load_image(
image = Image.open(image_file)
elif image_file.startswith("data:"):
image_file = image_file.split(",")[1]
image = Image.open(BytesIO(base64.b64decode(image_file)))
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
elif image_file.startswith("video:"):
image_file = image_file.replace("video:", "")
image, image_size = decode_video_base64(image_file)
elif isinstance(image_file, str):
image = Image.open(BytesIO(base64.b64decode(image_file)))
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
else:
raise ValueError(f"Invalid image: {image}")
......@@ -1866,7 +1868,7 @@ class MultiprocessingSerializer:
if output_str:
# Convert bytes to base64-encoded string
output = base64.b64encode(output).decode("utf-8")
output = pybase64.b64encode(output).decode("utf-8")
return output
......@@ -1883,7 +1885,7 @@ class MultiprocessingSerializer:
"""
if isinstance(data, str):
# Decode base64 string to bytes
data = base64.b64decode(data)
data = pybase64.b64decode(data, validate=True)
return ForkingPickler.loads(data)
......
"""Common utilities"""
import base64
import importlib
import json
import logging
......@@ -20,6 +19,7 @@ from json import dumps
from typing import Any, Callable, List, Optional, Tuple, Type, Union
import numpy as np
import pybase64
import requests
from IPython.display import HTML, display
from pydantic import BaseModel
......@@ -148,15 +148,15 @@ def encode_image_base64(image_path: Union[str, bytes]):
if isinstance(image_path, str):
with open(image_path, "rb") as image_file:
data = image_file.read()
return base64.b64encode(data).decode("utf-8")
return pybase64.b64encode(data).decode("utf-8")
elif isinstance(image_path, bytes):
return base64.b64encode(image_path).decode("utf-8")
return pybase64.b64encode(image_path).decode("utf-8")
else:
# image_path is PIL.WebPImagePlugin.WebPImageFile
image = image_path
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
return pybase64.b64encode(buffered.getvalue()).decode("utf-8")
def encode_frame(frame):
......@@ -223,7 +223,7 @@ def encode_video_base64(video_path: str, num_frames: int = 16):
video_bytes = b"".join(encoded_frames)
# Encode the concatenated bytes to base64
video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8")
video_base64 = "video:" + pybase64.b64encode(video_bytes).decode("utf-8")
return video_base64
......
import base64
import copy
import io
import json
import os
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment