Unverified Commit bb66cc4c authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix CI && python3.8 compatible (#920)

parent 975adb80
...@@ -57,4 +57,4 @@ jobs: ...@@ -57,4 +57,4 @@ jobs:
cd $HOME && python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 cd $HOME && python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512
echo "Stopping server..." echo "Stopping server..."
kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -v grep | awk '{print $2}') kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -- "--port 8413" | grep -v grep | awk '{print $2}')
...@@ -71,6 +71,7 @@ ...@@ -71,6 +71,7 @@
"source": [ "source": [
"import json\n", "import json\n",
"import os\n", "import os\n",
"from typing import List\n",
"\n", "\n",
"import chromadb\n", "import chromadb\n",
"\n", "\n",
...@@ -148,7 +149,7 @@ ...@@ -148,7 +149,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"@trace\n", "@trace\n",
"def retrieval(question: str) -> list[str]:\n", "def retrieval(question: str) -> List[str]:\n",
" return collection.query(\n", " return collection.query(\n",
" query_texts=[question],\n", " query_texts=[question],\n",
" n_results=1\n", " n_results=1\n",
...@@ -278,7 +279,7 @@ ...@@ -278,7 +279,7 @@
"\n", "\n",
"\n", "\n",
"@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n", "@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n",
"def retrieval(question: str) -> list[str]:\n", "def retrieval(question: str) -> List[str]:\n",
" return collection.query(\n", " return collection.query(\n",
" query_texts=[question],\n", " query_texts=[question],\n",
" n_results=1\n", " n_results=1\n",
......
...@@ -19,7 +19,7 @@ import functools ...@@ -19,7 +19,7 @@ import functools
import json import json
import os import os
import warnings import warnings
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import ( from transformers import (
...@@ -259,7 +259,7 @@ class TiktokenTokenizer: ...@@ -259,7 +259,7 @@ class TiktokenTokenizer:
Literal["all"], AbstractSet[str] Literal["all"], AbstractSet[str]
] = set(), # noqa: B006 ] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all", disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]: ) -> List[int]:
if isinstance(allowed_special, set): if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode( return tiktoken.Encoding.encode(
......
...@@ -7,7 +7,7 @@ import time ...@@ -7,7 +7,7 @@ import time
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from typing import Any from typing import Any, Dict, List, Tuple
import httpx import httpx
import jinja2 import jinja2
...@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = ( ...@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
) )
Message = dict[str, Any] # keys role, content Message = Dict[str, Any] # keys role, content
MessageList = list[Message] MessageList = List[Message]
class SamplerBase: class SamplerBase:
...@@ -45,9 +45,9 @@ class EvalResult: ...@@ -45,9 +45,9 @@ class EvalResult:
""" """
score: float | None # top-line metric score: float | None # top-line metric
metrics: dict[str, float] | None # other metrics metrics: Dict[str, float] | None # other metrics
htmls: list[str] # strings of valid HTML htmls: List[str] # strings of valid HTML
convos: list[MessageList] # sampled conversations convos: List[MessageList] # sampled conversations
@dataclass @dataclass
...@@ -57,7 +57,7 @@ class SingleEvalResult: ...@@ -57,7 +57,7 @@ class SingleEvalResult:
""" """
score: float | None score: float | None
metrics: dict[str, float] = field(default_factory=dict) metrics: Dict[str, float] = field(default_factory=dict)
html: str | None = None html: str | None = None
convo: MessageList | None = None # sampled conversation convo: MessageList | None = None # sampled conversation
...@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str): ...@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str):
def aggregate_results( def aggregate_results(
single_eval_results: list[SingleEvalResult], single_eval_results: List[SingleEvalResult],
default_stats: tuple[str] = ("mean", "std"), default_stats: Tuple[str] = ("mean", "std"),
name2stats: dict[str, tuple[str]] | None = None, name2stats: Dict[str, Tuple[str]] | None = None,
) -> EvalResult: ) -> EvalResult:
""" """
Aggregate results from multiple evaluations into a single EvalResult. Aggregate results from multiple evaluations into a single EvalResult.
...@@ -302,7 +302,7 @@ def aggregate_results( ...@@ -302,7 +302,7 @@ def aggregate_results(
) )
def map_with_progress(f: callable, xs: list[Any], num_threads: int): def map_with_progress(f: callable, xs: List[Any], num_threads: int):
""" """
Apply f to each element of xs, using a ThreadPool, and show progress. Apply f to each element of xs, using a ThreadPool, and show progress.
""" """
...@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str: ...@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str:
) )
def make_report_from_example_htmls(htmls: list[str]): def make_report_from_example_htmls(htmls: List[str]):
""" """
Create a standalone HTML report from a list of example htmls Create a standalone HTML report from a list of example htmls
""" """
......
...@@ -14,7 +14,7 @@ import re ...@@ -14,7 +14,7 @@ import re
from collections import Counter, defaultdict from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO from io import BytesIO
from typing import Any, Tuple from typing import Any, Dict, List, Tuple
import blobfile as bf import blobfile as bf
import tqdm import tqdm
...@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import ( ...@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import (
def evaluate_functional_correctness( def evaluate_functional_correctness(
sample: dict[str, str], sample: Dict[str, str],
completions: list[str], completions: List[str],
n_workers: int = 4, n_workers: int = 4,
timeout: float = 3.0, timeout: float = 3.0,
): ):
...@@ -70,7 +70,7 @@ class HumanEval(Eval): ...@@ -70,7 +70,7 @@ class HumanEval(Eval):
num_examples: int | None, num_examples: int | None,
num_threads: int, num_threads: int,
num_samples_per_task: int = 5, num_samples_per_task: int = 5,
ks_passes: list[int] = [1, 2, 5], ks_passes: List[int] = [1, 2, 5],
timeout: int = 120, timeout: int = 120,
): ):
self.seed = 0 self.seed = 0
...@@ -97,7 +97,7 @@ class HumanEval(Eval): ...@@ -97,7 +97,7 @@ class HumanEval(Eval):
] # remove signature ] # remove signature
return extracted_answer return extracted_answer
def fn(sample: dict[str, str]): def fn(sample: Dict[str, str]):
prompt_messages = [ prompt_messages = [
sampler._pack_message( sampler._pack_message(
role="user", content=instruction + sample["prompt"] role="user", content=instruction + sample["prompt"]
......
...@@ -8,7 +8,7 @@ import threading ...@@ -8,7 +8,7 @@ import threading
import time import time
import unittest import unittest
from functools import partial from functools import partial
from typing import Callable, Optional from typing import Callable, List, Optional
import numpy as np import numpy as np
import requests import requests
...@@ -457,7 +457,7 @@ def run_with_timeout( ...@@ -457,7 +457,7 @@ def run_with_timeout(
return ret_value[0] return ret_value[0]
def run_unittest_files(files: list[str], timeout_per_file: float): def run_unittest_files(files: List[str], timeout_per_file: float):
tic = time.time() tic = time.time()
success = True success = True
......
...@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase): ...@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000" cls.base_url = f"http://localhost:8157"
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
cls.base_url, cls.base_url,
......
...@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase): ...@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000" cls.base_url = f"http://localhost:8157"
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
@classmethod @classmethod
......
...@@ -12,7 +12,7 @@ class TestOpenAIServer(unittest.TestCase): ...@@ -12,7 +12,7 @@ class TestOpenAIServer(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000" cls.base_url = f"http://localhost:8157"
cls.api_key = "sk-123456" cls.api_key = "sk-123456"
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, api_key=cls.api_key cls.model, cls.base_url, timeout=300, api_key=cls.api_key
......
...@@ -12,11 +12,9 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -12,11 +12,9 @@ class TestSRTEndpoint(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
port = 30000
cls.model = MODEL_NAME_FOR_TEST cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:{port}" cls.base_url = f"http://localhost:{8157}"
cls.process = popen_launch_server(cls.model, port, timeout=300) cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
......
...@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase): ...@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000" cls.base_url = f"http://localhost:8157"
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"] cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment