Unverified Commit a220536f authored by Vincent Zhong's avatar Vincent Zhong Committed by GitHub
Browse files

[ perf ] Replace json-> orjson in hot path (#11221)


Signed-off-by: default avatarvincentzed <207368749+vincentzed@users.noreply.github.com>
parent 7b064f04
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
import enum import enum
import json
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Union from typing import List, Optional, Union
import orjson
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -66,7 +67,7 @@ class LoadConfig: ...@@ -66,7 +67,7 @@ class LoadConfig:
def __post_init__(self): def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {} model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str): if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config) self.model_loader_extra_config = orjson.loads(model_loader_extra_config)
self._verify_load_format() self._verify_load_format()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
......
...@@ -5,6 +5,8 @@ import logging ...@@ -5,6 +5,8 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union from typing import Union
import orjson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:
...@@ -148,7 +150,7 @@ class HarmonyContext(ConversationContext): ...@@ -148,7 +150,7 @@ class HarmonyContext(ConversationContext):
if isinstance(tool_session, Tool): if isinstance(tool_session, Tool):
return await tool_session.get_result(self) return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1] tool_name = last_msg.recipient.split(".")[1]
args = json.loads(last_msg.content[0].text) args = orjson.loads(last_msg.content[0].text)
result = await tool_session.call_tool(tool_name, args) result = await tool_session.call_tool(tool_name, args)
result_str = result.content[0].text result_str = result.content[0].text
content = TextContent(text=result_str) content = TextContent(text=result_str)
......
...@@ -7,6 +7,7 @@ import json ...@@ -7,6 +7,7 @@ import json
from collections.abc import Iterable from collections.abc import Iterable
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
import orjson
from openai.types.responses import ( from openai.types.responses import (
ResponseOutputItem, ResponseOutputItem,
ResponseOutputMessage, ResponseOutputMessage,
...@@ -228,7 +229,7 @@ def parse_output_message(message: Message): ...@@ -228,7 +229,7 @@ def parse_output_message(message: Message):
if len(message.content) != 1: if len(message.content) != 1:
raise ValueError("Invalid number of contents in browser message") raise ValueError("Invalid number of contents in browser message")
content = message.content[0] content = message.content[0]
browser_call = json.loads(content.text) browser_call = orjson.loads(content.text)
# TODO: translate to url properly! # TODO: translate to url properly!
if recipient == "browser.search": if recipient == "browser.search":
action = ActionSearch( action = ActionSearch(
......
...@@ -555,7 +555,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -555,7 +555,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
async def generate_from_file_request(file: UploadFile, request: Request): async def generate_from_file_request(file: UploadFile, request: Request):
"""Handle a generate request, this is purely to work with input_embeds.""" """Handle a generate request, this is purely to work with input_embeds."""
content = await file.read() content = await file.read()
input_embeds = json.loads(content.decode("utf-8")) input_embeds = orjson.loads(content.decode("utf-8"))
obj = GenerateReqInput( obj = GenerateReqInput(
input_embeds=input_embeds, input_embeds=input_embeds,
......
...@@ -6,6 +6,7 @@ import uuid ...@@ -6,6 +6,7 @@ import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import orjson
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
...@@ -197,7 +198,7 @@ class OpenAIServingBase(ABC): ...@@ -197,7 +198,7 @@ class OpenAIServingBase(ABC):
) )
try: try:
raw_labels = ( raw_labels = (
json.loads(raw_request.headers.get(header)) orjson.loads(raw_request.headers.get(header))
if raw_request and raw_request.headers.get(header) if raw_request and raw_request.headers.get(header)
else None else None
) )
......
...@@ -7,6 +7,7 @@ import time ...@@ -7,6 +7,7 @@ import time
import uuid import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union
import orjson
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.responses import ORJSONResponse, StreamingResponse
from jsonschema import Draft202012Validator, SchemaError from jsonschema import Draft202012Validator, SchemaError
...@@ -285,7 +286,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -285,7 +286,7 @@ class OpenAIServingChat(OpenAIServingBase):
if "arguments" in item["function"] and isinstance( if "arguments" in item["function"] and isinstance(
item["function"]["arguments"], str item["function"]["arguments"], str
): ):
item["function"]["arguments"] = json.loads( item["function"]["arguments"] = orjson.loads(
item["function"]["arguments"] item["function"]["arguments"]
) )
...@@ -860,7 +861,7 @@ class OpenAIServingChat(OpenAIServingBase): ...@@ -860,7 +861,7 @@ class OpenAIServingChat(OpenAIServingBase):
finish_reason["matched"] = None finish_reason["matched"] = None
try: try:
# For required tool choice, we expect a JSON array of tool calls # For required tool choice, we expect a JSON array of tool calls
tool_call_data = json.loads(text) tool_call_data = orjson.loads(text)
tool_calls = [] tool_calls = []
for i, tool in enumerate(tool_call_data): for i, tool in enumerate(tool_call_data):
# Create a ToolCallItem from the JSON data # Create a ToolCallItem from the JSON data
......
...@@ -5,7 +5,6 @@ from __future__ import annotations ...@@ -5,7 +5,6 @@ from __future__ import annotations
import asyncio import asyncio
import copy import copy
import json
import logging import logging
import time import time
from contextlib import AsyncExitStack from contextlib import AsyncExitStack
...@@ -14,6 +13,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional, ...@@ -14,6 +13,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Optional,
import jinja2 import jinja2
import openai.types.responses as openai_responses_types import openai.types.responses as openai_responses_types
import orjson
from fastapi import Request from fastapi import Request
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from openai.types.responses import ( from openai.types.responses import (
...@@ -1061,7 +1061,7 @@ class OpenAIServingResponses(OpenAIServingChat): ...@@ -1061,7 +1061,7 @@ class OpenAIServingResponses(OpenAIServingChat):
): ):
function_name = previous_item.recipient[len("browser.") :] function_name = previous_item.recipient[len("browser.") :]
action = None action = None
parsed_args = json.loads(previous_item.content[0].text) parsed_args = ororjson.loads(previous_item.content[0].text)
if function_name == "search": if function_name == "search":
action = openai_responses_types.response_function_web_search.ActionSearch( action = openai_responses_types.response_function_web_search.ActionSearch(
type="search", type="search",
......
...@@ -3,6 +3,7 @@ import logging ...@@ -3,6 +3,7 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List
import orjson
from partial_json_parser.core.exceptions import MalformedJSON from partial_json_parser.core.exceptions import MalformedJSON
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
...@@ -96,7 +97,7 @@ class BaseFormatDetector(ABC): ...@@ -96,7 +97,7 @@ class BaseFormatDetector(ABC):
Parses the text in one go. Returns success=True if the format matches, otherwise False. Parses the text in one go. Returns success=True if the format matches, otherwise False.
Note that leftover_text here represents "content that this parser will not consume further". Note that leftover_text here represents "content that this parser will not consume further".
""" """
action = json.loads(text) action = orjson.loads(text)
return StreamingParseResult(calls=self.parse_base_json(action, tools)) return StreamingParseResult(calls=self.parse_base_json(action, tools))
def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int: def _ends_with_partial_token(self, buffer: str, bot_token: str) -> int:
......
...@@ -3,6 +3,7 @@ from json import JSONDecodeError, JSONDecoder ...@@ -3,6 +3,7 @@ from json import JSONDecodeError, JSONDecoder
from json.decoder import WHITESPACE from json.decoder import WHITESPACE
from typing import Any, List, Literal, Optional, Tuple, Union from typing import Any, List, Literal, Optional, Tuple, Union
import orjson
import partial_json_parser import partial_json_parser
from partial_json_parser.core.options import Allow from partial_json_parser.core.options import Allow
...@@ -51,7 +52,7 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]: ...@@ -51,7 +52,7 @@ def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
def _is_complete_json(input_str: str) -> bool: def _is_complete_json(input_str: str) -> bool:
try: try:
json.loads(input_str) orjson.loads(input_str)
return True return True
except JSONDecodeError: except JSONDecodeError:
return False return False
......
...@@ -34,6 +34,7 @@ from http import HTTPStatus ...@@ -34,6 +34,7 @@ from http import HTTPStatus
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
import fastapi import fastapi
import orjson
import torch import torch
import uvloop import uvloop
import zmq import zmq
...@@ -157,7 +158,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -157,7 +158,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.log_requests = server_args.log_requests self.log_requests = server_args.log_requests
self.log_requests_level = server_args.log_requests_level self.log_requests_level = server_args.log_requests_level
self.preferred_sampling_params = ( self.preferred_sampling_params = (
json.loads(server_args.preferred_sampling_params) orjson.loads(server_args.preferred_sampling_params)
if server_args.preferred_sampling_params if server_args.preferred_sampling_params
else None else None
) )
......
...@@ -4,6 +4,7 @@ from functools import lru_cache ...@@ -4,6 +4,7 @@ from functools import lru_cache
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import dill import dill
import orjson
import torch import torch
...@@ -12,7 +13,7 @@ def _cache_from_str(json_str: str): ...@@ -12,7 +13,7 @@ def _cache_from_str(json_str: str):
"""Deserialize a json string to a Callable object. """Deserialize a json string to a Callable object.
This function is cached to avoid redundant deserialization. This function is cached to avoid redundant deserialization.
""" """
data = json.loads(json_str) data = orjson.loads(json_str)
return dill.loads(bytes.fromhex(data["callable"])) return dill.loads(bytes.fromhex(data["callable"]))
......
...@@ -22,6 +22,8 @@ import random ...@@ -22,6 +22,8 @@ import random
import tempfile import tempfile
from typing import Dict, List, Literal, Optional, Union from typing import Dict, List, Literal, Optional, Union
import orjson
from sglang.srt.connector import ConnectorType from sglang.srt.connector import ConnectorType
from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.function_call_parser import FunctionCallParser
from sglang.srt.lora.lora_registry import LoRARef from sglang.srt.lora.lora_registry import LoRARef
...@@ -3041,7 +3043,7 @@ class ServerArgs: ...@@ -3041,7 +3043,7 @@ class ServerArgs:
self.model_path, self.model_path,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
revision=self.revision, revision=self.revision,
model_override_args=json.loads(self.json_model_override_args), model_override_args=orjson.loads(self.json_model_override_args),
**kwargs, **kwargs,
) )
return hf_config return hf_config
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Common utilities.""" """Common utilities."""
from __future__ import annotations from __future__ import annotations
import argparse import argparse
...@@ -70,6 +69,7 @@ from typing import ( ...@@ -70,6 +69,7 @@ from typing import (
) )
import numpy as np import numpy as np
import orjson
import psutil import psutil
import pybase64 import pybase64
import requests import requests
...@@ -1112,7 +1112,7 @@ def configure_logger(server_args, prefix: str = ""): ...@@ -1112,7 +1112,7 @@ def configure_logger(server_args, prefix: str = ""):
f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!" f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
) )
with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file: with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
custom_config = json.loads(file.read()) custom_config = orjson.loads(file.read())
logging.config.dictConfig(custom_config) logging.config.dictConfig(custom_config)
return return
format = f"[%(asctime)s{prefix}] %(message)s" format = f"[%(asctime)s{prefix}] %(message)s"
...@@ -2525,9 +2525,9 @@ def log_info_on_rank0(logger, msg): ...@@ -2525,9 +2525,9 @@ def log_info_on_rank0(logger, msg):
def load_json_config(data: str): def load_json_config(data: str):
try: try:
return json.loads(data) return orjson.loads(data)
except JSONDecodeError: except JSONDecodeError:
return json.loads(Path(data).read_text()) return orjson.loads(Path(data).read_text())
def dispose_tensor(x: torch.Tensor): def dispose_tensor(x: torch.Tensor):
...@@ -3236,7 +3236,7 @@ def numa_bind_to_node(node: int): ...@@ -3236,7 +3236,7 @@ def numa_bind_to_node(node: int):
def json_list_type(value): def json_list_type(value):
try: try:
return json.loads(value) return orjson.loads(value)
except json.JSONDecodeError: except json.JSONDecodeError:
raise argparse.ArgumentTypeError( raise argparse.ArgumentTypeError(
f"Invalid JSON list: {value}. Please provide a valid JSON list." f"Invalid JSON list: {value}. Please provide a valid JSON list."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment