Unverified Commit a220536f authored by Vincent Zhong's avatar Vincent Zhong Committed by GitHub
Browse files

[ perf ] Replace json-> orjson in hot path (#11221)


Signed-off-by: default avatarvincentzed <207368749+vincentzed@users.noreply.github.com>
parent 7b064f04
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
import enum
import json
import logging
from dataclasses import dataclass, field
from typing import List, Optional, Union
import orjson
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
......@@ -66,7 +67,7 @@ class LoadConfig:
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self.model_loader_extra_config = orjson.loads(model_loader_extra_config)
self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
......
......@@ -5,6 +5,8 @@ import logging
from abc import ABC, abstractmethod
from typing import Union
import orjson
logger = logging.getLogger(__name__)
try:
......@@ -148,7 +150,7 @@ class HarmonyContext(ConversationContext):
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1]
args = json.loads(last_msg.content[0].text)
args = orjson.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text
content = TextContent(text=result_str)
......
......@@ -7,6 +7,7 @@ import json
from collections.abc import Iterable
from typing import Literal, Optional, Union
import orjson
from openai.types.responses import (
ResponseOutputItem,
ResponseOutputMessage,
......@@ -228,7 +229,7 @@ def parse_output_message(message: Message):
if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message")
content = message.content[0]
browser_call = json.loads(content.text)
browser_call = orjson.loads(content.text)
# TODO: translate to url properly!
if recipient == "browser.search":
action = ActionSearch(
......
......@@ -555,7 +555,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
async def generate_from_file_request(file: UploadFile, request: Request):
"""Handle a generate request, this is purely to work with input_embeds."""
content = await file.read()
input_embeds = json.loads(content.decode("utf-8"))
input_embeds = orjson.loads(content.decode("utf-8"))
obj = GenerateReqInput(
input_embeds=input_embeds,
......
......@@ -6,6 +6,7 @@ import uuid
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, Union
import orjson
from fastapi import HTTPException, Request
from fastapi.responses import ORJSONResponse, StreamingResponse
......@@ -197,7 +198,7 @@ class OpenAIServingBase(ABC):
)
try:
raw_labels = (
json.loads(raw_request.headers.get(header))
orjson.loads(raw_request.headers.get(header))
if raw_request and raw_request.headers.get(header)
else None
)
......
......@@ -7,6 +7,7 @@ import time
import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union
import orjson
from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from jsonschema import Draft202012Validator, SchemaError
......@@ -285,7 +286,7 @@ class OpenAIServingChat(OpenAIServingBase):
if "arguments" in item["function"] and isinstance(
item["function"]["arguments"], str
):
item["function"]["arguments"] = json.loads(
item["function"]["arguments"] = orjson.loads(
item["function"]["arguments"]
)
......@@ -860,7 +861,7 @@ class OpenAIServingChat(OpenAIServingBase):
finish_reason["matched"] = None
try:
# For required tool choice, we expect a JSON array of tool calls
tool_call_data = json.loads(text)
tool_call_data = orjson.loads(text)
tool_calls = []
for i, tool in enumerate(tool_call_data):
# Create a ToolCallItem from the JSON data
......
......@@ -5,7 +5,6 @@ from __future__ import annotations
import asyncio
import copy
import json
import logging
import time
from contextlib import AsyncExitStack
......@@ -14,6 +13,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional,
import jinja2
import openai.types.responses as openai_responses_types
import orjson
from fastapi import Request
from fastapi.responses import ORJSONResponse
from openai.types.responses import (
......@@ -1061,7 +1061,7 @@ class OpenAIServingResponses(OpenAIServingChat):
):
function_name = previous_item.recipient[len("browser.") :]
action = None
parsed_args = json.loads(previous_item.content[0].text)
parsed_args = ororjson.loads(previous_item.content[0].text)
if function_name == "search":
action = openai_responses_types.response_function_web_search.ActionSearch(
type="search",
......
......@@ -3,6 +3,7 @@ import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List
import orjson
from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow
......@@ -96,7 +97,7 @@ class BaseFormatDetector(ABC):
Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further".
"""
action = json.loads(text)
action = orjson.loads(text)
return StreamingParseResult(calls=self.parse_base_json(action, tools))
def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
......
......@@ -3,6 +3,7 @@ from json import JSONDecodeError, JSONDecoder
from json.decoder import WHITESPACE
from typing import Any, List, Literal, Optional, Tuple, Union
import orjson
import partial_json_parser
from partial_json_parser.core.options import Allow
......@@ -51,7 +52,7 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
def _is_complete_json(input_str: str) -> bool:
try:
json.loads(input_str)
orjson.loads(input_str)
return True
except JSONDecodeError:
return False
......
......@@ -34,6 +34,7 @@ from http import HTTPStatus
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
import fastapi
import orjson
import torch
import uvloop
import zmq
......@@ -157,7 +158,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.log_requests = server_args.log_requests
self.log_requests_level = server_args.log_requests_level
self.preferred_sampling_params = (
json.loads(server_args.preferred_sampling_params)
orjson.loads(server_args.preferred_sampling_params)
if server_args.preferred_sampling_params
else None
)
......
......@@ -4,6 +4,7 @@ from functools import lru_cache
from typing import Any, Dict, List, Optional
import dill
import orjson
import torch
......@@ -12,7 +13,7 @@ def _cache_from_str(json_str: str):
"""Deserialize a json string to a Callable object.
This function is cached to avoid redundant deserialization.
"""
data = json.loads(json_str)
data = orjson.loads(json_str)
return dill.loads(bytes.fromhex(data["callable"]))
......
......@@ -22,6 +22,8 @@ import random
import tempfile
from typing import Dict, List, Literal, Optional, Union
import orjson
from sglang.srt.connector import ConnectorType
from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.lora.lora_registry import LoRARef
......@@ -3041,7 +3043,7 @@ class ServerArgs:
self.model_path,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
model_override_args=json.loads(self.json_model_override_args),
model_override_args=orjson.loads(self.json_model_override_args),
**kwargs,
)
return hf_config
......
......@@ -12,7 +12,6 @@
# limitations under the License.
# ==============================================================================
"""Common utilities."""
from __future__ import annotations
import argparse
......@@ -70,6 +69,7 @@ from typing import (
)
import numpy as np
import orjson
import psutil
import pybase64
import requests
......@@ -1112,7 +1112,7 @@ def configure_logger(server_args, prefix: str = ""):
f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
)
with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
custom_config = json.loads(file.read())
custom_config = orjson.loads(file.read())
logging.config.dictConfig(custom_config)
return
format = f"[%(asctime)s{prefix}] %(message)s"
......@@ -2525,9 +2525,9 @@ def log_info_on_rank0(logger, msg):
def load_json_config(data: str):
try:
return json.loads(data)
return orjson.loads(data)
except JSONDecodeError:
return json.loads(Path(data).read_text())
return orjson.loads(Path(data).read_text())
def dispose_tensor(x: torch.Tensor):
......@@ -3236,7 +3236,7 @@ def numa_bind_to_node(node: int):
def json_list_type(value):
try:
return json.loads(value)
return orjson.loads(value)
except json.JSONDecodeError:
raise argparse.ArgumentTypeError(
f"Invalid JSON list: {value}. Please provide a valid JSON list."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment