Unverified Commit 4fc1bf81 authored by Feng XiaoLong's avatar Feng XiaoLong Committed by GitHub
Browse files

[Bugfix] Migrate to REGEX Library to prevent catastrophic backtracking (#18454)


Signed-off-by: default avatarCrucifixion-Fxl <xmufxl@gmail.com>
Co-authored-by: default avatarCrucifixion-Fxl <xmufxl@gmail.com>
parent f2036734
...@@ -26,7 +26,7 @@ sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}" ...@@ -26,7 +26,7 @@ sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}"
# Remove HTML <details> section that includes <summary> text of "PR Checklist (Click to Expand)" # Remove HTML <details> section that includes <summary> text of "PR Checklist (Click to Expand)"
python3 - <<EOF python3 - <<EOF
import re import regex as re
with open("${NEW}", "r") as file: with open("${NEW}", "r") as file:
content = file.read() content = file.read()
......
...@@ -672,7 +672,7 @@ async def benchmark( ...@@ -672,7 +672,7 @@ async def benchmark(
def evaluate(ret, args): def evaluate(ret, args):
def _eval_correctness_json(expected, actual): def _eval_correctness_json(expected, actual):
# extract json string from string using regex # extract json string from string using regex
import re import regex as re
actual = actual.replace("\n", "").replace(" ", "").strip() actual = actual.replace("\n", "").replace(" ", "").strip()
try: try:
...@@ -687,7 +687,7 @@ def evaluate(ret, args): ...@@ -687,7 +687,7 @@ def evaluate(ret, args):
return actual in args.choice return actual in args.choice
def _eval_correctness_regex(expected, actual): def _eval_correctness_regex(expected, actual):
import re import regex as re
return re.match(args.regex, actual) is not None return re.match(args.regex, actual) is not None
......
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
import math import math
import pickle import pickle
import re
from collections import defaultdict from collections import defaultdict
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import regex as re
import seaborn as sns import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement from torch.utils.benchmark import Measurement as TMeasurement
......
...@@ -20,12 +20,12 @@ python prithvi_geospatial_mae.py ...@@ -20,12 +20,12 @@ python prithvi_geospatial_mae.py
import argparse import argparse
import datetime import datetime
import os import os
import re
from typing import Union from typing import Union
import albumentations import albumentations
import numpy as np import numpy as np
import rasterio import rasterio
import regex as re
import torch import torch
from einops import rearrange from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule from terratorch.datamodules import Sen1Floods11NonGeoDataModule
......
...@@ -8,6 +8,7 @@ requires = [ ...@@ -8,6 +8,7 @@ requires = [
"setuptools-scm>=8.0", "setuptools-scm>=8.0",
"torch == 2.7.0", "torch == 2.7.0",
"wheel", "wheel",
"regex",
"jinja2", "jinja2",
] ]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
......
...@@ -7,3 +7,4 @@ setuptools-scm>=8 ...@@ -7,3 +7,4 @@ setuptools-scm>=8
torch==2.7.0 torch==2.7.0
wheel wheel
jinja2>=3.1.6 jinja2>=3.1.6
regex
regex # Replace re for higher-performance regex matching
cachetools cachetools
psutil psutil
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
......
...@@ -5,12 +5,12 @@ import importlib.util ...@@ -5,12 +5,12 @@ import importlib.util
import json import json
import logging import logging
import os import os
import re
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from shutil import which from shutil import which
import regex as re
import torch import torch
from packaging.version import Version, parse from packaging.version import Version, parse
from setuptools import Extension, setup from setuptools import Extension, setup
...@@ -389,7 +389,6 @@ class repackage_wheel(build_ext): ...@@ -389,7 +389,6 @@ class repackage_wheel(build_ext):
# vllm_flash_attn python code: # vllm_flash_attn python code:
# Regex from # Regex from
# `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)` # `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)`
import re
compiled_regex = re.compile( compiled_regex = re.compile(
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py") r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
file_members += list( file_members += list(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
import re
import weakref import weakref
from enum import Enum from enum import Enum
import jsonschema import jsonschema
import pytest import pytest
import regex as re
from pydantic import BaseModel from pydantic import BaseModel
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# imports for guided decoding tests # imports for guided decoding tests
import json import json
import re
from typing import Optional from typing import Optional
import jsonschema import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import regex as re
import requests import requests
import torch import torch
from openai import BadRequestError, OpenAI from openai import BadRequestError, OpenAI
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# imports for guided decoding tests # imports for guided decoding tests
import json import json
import re
import shutil import shutil
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional from typing import Optional
...@@ -11,6 +9,7 @@ import jsonschema ...@@ -11,6 +9,7 @@ import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import regex as re
# downloading lora to test lora requests # downloading lora to test lora requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from openai import BadRequestError from openai import BadRequestError
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# imports for guided decoding tests # imports for guided decoding tests
import re
import openai import openai
import pytest import pytest
import regex as re
from ...utils import RemoteOpenAIServer from ...utils import RemoteOpenAIServer
...@@ -32,7 +31,7 @@ async def test_out_of_vocab_token_ids(): ...@@ -32,7 +31,7 @@ async def test_out_of_vocab_token_ids():
client = remote_server.get_async_client() client = remote_server.get_async_client()
with pytest.raises(openai.BadRequestError, with pytest.raises(openai.BadRequestError,
match=re.compile('.*out of vocabulary.*')): match=re.compile('.*out of vocabulary.*').pattern):
await client.completions.create(model=model_name, await client.completions.create(model=model_name,
prompt=[999999], prompt=[999999],
max_tokens=5, max_tokens=5,
...@@ -46,9 +45,10 @@ async def test_reject_multistep_with_guided_decoding(): ...@@ -46,9 +45,10 @@ async def test_reject_multistep_with_guided_decoding():
with RemoteOpenAIServer(model_name, server_args) as remote_server: with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client() client = remote_server.get_async_client()
with pytest.raises(openai.BadRequestError, with pytest.raises(
openai.BadRequestError,
match=re.compile( match=re.compile(
'.*Guided decoding .* multi-step decoding.*')): '.*Guided decoding .* multi-step decoding.*').pattern):
await client.completions.create( await client.completions.create(
model=model_name, model=model_name,
prompt="Hello", prompt="Hello",
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
import re
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional from typing import Optional
import librosa import librosa
import pytest import pytest
import regex as re
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
......
...@@ -3,11 +3,11 @@ ...@@ -3,11 +3,11 @@
for manipulating the input / output of HF & vLLM test runners, which are for manipulating the input / output of HF & vLLM test runners, which are
typically specific to a small subset of models. typically specific to a small subset of models.
""" """
import re
import types import types
from pathlib import PosixPath from pathlib import PosixPath
from typing import Optional, Union from typing import Optional, Union
import regex as re
import torch import torch
from PIL.Image import Image from PIL.Image import Image
from transformers import (AutoConfig, AutoTokenizer, BatchFeature, from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
import re
from copy import deepcopy from copy import deepcopy
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
import regex as re
from pydantic import TypeAdapter from pydantic import TypeAdapter
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
from __future__ import annotations from __future__ import annotations
import json import json
import re
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import jsonschema import jsonschema
import pytest import pytest
import regex as re
from pydantic import BaseModel from pydantic import BaseModel
from tests.reasoning.utils import run_reasoning_extraction from tests.reasoning.utils import run_reasoning_extraction
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re
from typing import Optional from typing import Optional
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import regex as re
from openai import BadRequestError from openai import BadRequestError
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
import regex as re
from vllm import CompletionOutput from vllm import CompletionOutput
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment