Unverified Commit c77c1e05 authored by Chayenne's avatar Chayenne Committed by GitHub
Browse files

fix black in pre-commit (#1940)

parent dca87ec3
......@@ -448,7 +448,7 @@ class ServerArgs:
"--decode-log-interval",
type=int,
default=ServerArgs.decode_log_interval,
help="The log interval of decode batch"
help="The log interval of decode batch",
)
# Data parallelism
......
......@@ -742,7 +742,13 @@ def run_mmlu_test(
finally:
pass
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
run_and_check_memory_leak(
workload_func,
disable_radix_cache,
enable_mixed_chunk,
enable_overlap,
chunked_prefill_size,
)
def run_mulit_request_test(
......@@ -775,4 +781,10 @@ def run_mulit_request_test(
with ThreadPoolExecutor(2) as executor:
list(executor.map(run_one, list(range(4))))
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size)
run_and_check_memory_leak(
workload_func,
disable_radix_cache,
enable_mixed_chunk,
enable_overlap,
chunked_prefill_size,
)
......@@ -349,6 +349,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
def terminate_process(process):
from sglang.srt.utils import kill_child_process
kill_child_process(process.pid, include_self=True)
......
......@@ -11,7 +11,7 @@ router = router.Router(
"http://localhost:30000",
"http://localhost:30002",
],
policy="random"
policy="random",
)
# Start the router - this will block and run the server
......
......@@ -104,15 +104,9 @@ if __name__ == "__main__":
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
# default="meta-llama/Llama-2-7b-chat-hf",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=16)
parser.add_argument("--max-new-tokens", type=int, default=16)
parser.add_argument(
"--dtype",
type=str,
default="float16")
parser.add_argument("--dtype", type=str, default="float16")
args = parser.parse_args()
......
......@@ -56,7 +56,7 @@ ALL_OTHER_MODELS = [
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
ModelCase("THUDM/glm-4-9b-chat"),
ModelCase("openai-community/gpt2")
ModelCase("openai-community/gpt2"),
]
TORCH_DTYPES = [torch.float16]
......
......@@ -3,6 +3,7 @@ python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
"""
import json
import time
import unittest
......
"""
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample
"""
import json
import unittest
......
......@@ -110,7 +110,6 @@ class TestSRTEngine(unittest.TestCase):
def test_5_prompt_input_ids_consistency(self):
prompt = "The capital of UK is"
model_path = DEFAULT_MODEL_NAME_FOR_TEST
engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error")
sampling_params = {"temperature": 0, "max_new_tokens": 8}
......@@ -118,7 +117,9 @@ class TestSRTEngine(unittest.TestCase):
tokenizer = get_tokenizer(model_path)
token_ids = tokenizer.encode(prompt)
out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)["text"]
out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[
"text"
]
engine.shutdown()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment