Unverified Commit bb66cc4c authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix CI && python3.8 compatible (#920)

parent 975adb80
......@@ -57,4 +57,4 @@ jobs:
cd $HOME && python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512
echo "Stopping server..."
kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -v grep | awk '{print $2}')
kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -- "--port 8413" | grep -v grep | awk '{print $2}')
......@@ -71,6 +71,7 @@
"source": [
"import json\n",
"import os\n",
"from typing import List\n",
"\n",
"import chromadb\n",
"\n",
......@@ -148,7 +149,7 @@
"outputs": [],
"source": [
"@trace\n",
"def retrieval(question: str) -> list[str]:\n",
"def retrieval(question: str) -> List[str]:\n",
" return collection.query(\n",
" query_texts=[question],\n",
" n_results=1\n",
......@@ -278,7 +279,7 @@
"\n",
"\n",
"@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n",
"def retrieval(question: str) -> list[str]:\n",
"def retrieval(question: str) -> List[str]:\n",
" return collection.query(\n",
" query_texts=[question],\n",
" n_results=1\n",
......
......@@ -19,7 +19,7 @@ import functools
import json
import os
import warnings
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
from huggingface_hub import snapshot_download
from transformers import (
......@@ -259,7 +259,7 @@ class TiktokenTokenizer:
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]:
) -> List[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
......
......@@ -7,7 +7,7 @@ import time
from collections import defaultdict
from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool
from typing import Any
from typing import Any, Dict, List, Tuple
import httpx
import jinja2
......@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
)
Message = dict[str, Any] # keys role, content
MessageList = list[Message]
Message = Dict[str, Any] # keys role, content
MessageList = List[Message]
class SamplerBase:
......@@ -45,9 +45,9 @@ class EvalResult:
"""
score: float | None # top-line metric
metrics: dict[str, float] | None # other metrics
htmls: list[str] # strings of valid HTML
convos: list[MessageList] # sampled conversations
metrics: Dict[str, float] | None # other metrics
htmls: List[str] # strings of valid HTML
convos: List[MessageList] # sampled conversations
@dataclass
......@@ -57,7 +57,7 @@ class SingleEvalResult:
"""
score: float | None
metrics: dict[str, float] = field(default_factory=dict)
metrics: Dict[str, float] = field(default_factory=dict)
html: str | None = None
convo: MessageList | None = None # sampled conversation
......@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str):
def aggregate_results(
single_eval_results: list[SingleEvalResult],
default_stats: tuple[str] = ("mean", "std"),
name2stats: dict[str, tuple[str]] | None = None,
single_eval_results: List[SingleEvalResult],
default_stats: Tuple[str] = ("mean", "std"),
name2stats: Dict[str, Tuple[str]] | None = None,
) -> EvalResult:
"""
Aggregate results from multiple evaluations into a single EvalResult.
......@@ -302,7 +302,7 @@ def aggregate_results(
)
def map_with_progress(f: callable, xs: list[Any], num_threads: int):
def map_with_progress(f: callable, xs: List[Any], num_threads: int):
"""
Apply f to each element of xs, using a ThreadPool, and show progress.
"""
......@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str:
)
def make_report_from_example_htmls(htmls: list[str]):
def make_report_from_example_htmls(htmls: List[str]):
"""
Create a standalone HTML report from a list of example htmls
"""
......
......@@ -14,7 +14,7 @@ import re
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
from typing import Any, Tuple
from typing import Any, Dict, List, Tuple
import blobfile as bf
import tqdm
......@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import (
def evaluate_functional_correctness(
sample: dict[str, str],
completions: list[str],
sample: Dict[str, str],
completions: List[str],
n_workers: int = 4,
timeout: float = 3.0,
):
......@@ -70,7 +70,7 @@ class HumanEval(Eval):
num_examples: int | None,
num_threads: int,
num_samples_per_task: int = 5,
ks_passes: list[int] = [1, 2, 5],
ks_passes: List[int] = [1, 2, 5],
timeout: int = 120,
):
self.seed = 0
......@@ -97,7 +97,7 @@ class HumanEval(Eval):
] # remove signature
return extracted_answer
def fn(sample: dict[str, str]):
def fn(sample: Dict[str, str]):
prompt_messages = [
sampler._pack_message(
role="user", content=instruction + sample["prompt"]
......
......@@ -8,7 +8,7 @@ import threading
import time
import unittest
from functools import partial
from typing import Callable, Optional
from typing import Callable, List, Optional
import numpy as np
import requests
......@@ -457,7 +457,7 @@ def run_with_timeout(
return ret_value[0]
def run_unittest_files(files: list[str], timeout_per_file: float):
def run_unittest_files(files: List[str], timeout_per_file: float):
tic = time.time()
success = True
......
......@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000"
cls.base_url = f"http://localhost:8157"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
......
......@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000"
cls.base_url = f"http://localhost:8157"
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
@classmethod
......
......@@ -12,7 +12,7 @@ class TestOpenAIServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000"
cls.base_url = f"http://localhost:8157"
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
......
......@@ -12,11 +12,9 @@ class TestSRTEndpoint(unittest.TestCase):
@classmethod
def setUpClass(cls):
port = 30000
cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:{port}"
cls.process = popen_launch_server(cls.model, port, timeout=300)
cls.base_url = f"http://localhost:{8157}"
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
@classmethod
def tearDownClass(cls):
......
......@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000"
cls.base_url = f"http://localhost:8157"
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"]
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment