Unverified Commit c77c1e05 authored by Chayenne's avatar Chayenne Committed by GitHub
Browse files

fix black in pre-commit (#1940)

parent dca87ec3
...@@ -448,7 +448,7 @@ class ServerArgs: ...@@ -448,7 +448,7 @@ class ServerArgs:
"--decode-log-interval", "--decode-log-interval",
type=int, type=int,
default=ServerArgs.decode_log_interval, default=ServerArgs.decode_log_interval,
help="The log interval of decode batch" help="The log interval of decode batch",
) )
# Data parallelism # Data parallelism
......
...@@ -742,7 +742,13 @@ def run_mmlu_test( ...@@ -742,7 +742,13 @@ def run_mmlu_test(
finally: finally:
pass pass
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size) run_and_check_memory_leak(
workload_func,
disable_radix_cache,
enable_mixed_chunk,
enable_overlap,
chunked_prefill_size,
)
def run_mulit_request_test( def run_mulit_request_test(
...@@ -775,4 +781,10 @@ def run_mulit_request_test( ...@@ -775,4 +781,10 @@ def run_mulit_request_test(
with ThreadPoolExecutor(2) as executor: with ThreadPoolExecutor(2) as executor:
list(executor.map(run_one, list(range(4)))) list(executor.map(run_one, list(range(4))))
run_and_check_memory_leak(workload_func, disable_radix_cache, enable_mixed_chunk, enable_overlap, chunked_prefill_size) run_and_check_memory_leak(
workload_func,
disable_radix_cache,
enable_mixed_chunk,
enable_overlap,
chunked_prefill_size,
)
...@@ -349,6 +349,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None: ...@@ -349,6 +349,7 @@ def wait_for_server(base_url: str, timeout: int = None) -> None:
def terminate_process(process): def terminate_process(process):
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
kill_child_process(process.pid, include_self=True) kill_child_process(process.pid, include_self=True)
......
...@@ -11,7 +11,7 @@ router = router.Router( ...@@ -11,7 +11,7 @@ router = router.Router(
"http://localhost:30000", "http://localhost:30000",
"http://localhost:30002", "http://localhost:30002",
], ],
policy="random" policy="random",
) )
# Start the router - this will block and run the server # Start the router - this will block and run the server
......
...@@ -104,15 +104,9 @@ if __name__ == "__main__": ...@@ -104,15 +104,9 @@ if __name__ == "__main__":
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4", default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
# default="meta-llama/Llama-2-7b-chat-hf", # default="meta-llama/Llama-2-7b-chat-hf",
) )
parser.add_argument( parser.add_argument("--max-new-tokens", type=int, default=16)
"--max-new-tokens",
type=int,
default=16)
parser.add_argument( parser.add_argument("--dtype", type=str, default="float16")
"--dtype",
type=str,
default="float16")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -56,7 +56,7 @@ ALL_OTHER_MODELS = [ ...@@ -56,7 +56,7 @@ ALL_OTHER_MODELS = [
ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True), ModelCase("HuggingFaceTB/SmolLM-135M-Instruct", skip_long_prompt=True),
ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True), ModelCase("allenai/OLMo-1B-0724-hf", decode_tolerance=8e-2, skip_long_prompt=True),
ModelCase("THUDM/glm-4-9b-chat"), ModelCase("THUDM/glm-4-9b-chat"),
ModelCase("openai-community/gpt2") ModelCase("openai-community/gpt2"),
] ]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
......
...@@ -3,6 +3,7 @@ python3 -m unittest test_openai_server.TestOpenAIServer.test_batch ...@@ -3,6 +3,7 @@ python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
""" """
import json import json
import time import time
import unittest import unittest
......
""" """
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample
""" """
import json import json
import unittest import unittest
......
...@@ -110,7 +110,6 @@ class TestSRTEngine(unittest.TestCase): ...@@ -110,7 +110,6 @@ class TestSRTEngine(unittest.TestCase):
def test_5_prompt_input_ids_consistency(self): def test_5_prompt_input_ids_consistency(self):
prompt = "The capital of UK is" prompt = "The capital of UK is"
model_path = DEFAULT_MODEL_NAME_FOR_TEST model_path = DEFAULT_MODEL_NAME_FOR_TEST
engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error")
sampling_params = {"temperature": 0, "max_new_tokens": 8} sampling_params = {"temperature": 0, "max_new_tokens": 8}
...@@ -118,7 +117,9 @@ class TestSRTEngine(unittest.TestCase): ...@@ -118,7 +117,9 @@ class TestSRTEngine(unittest.TestCase):
tokenizer = get_tokenizer(model_path) tokenizer = get_tokenizer(model_path)
token_ids = tokenizer.encode(prompt) token_ids = tokenizer.encode(prompt)
out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)["text"] out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[
"text"
]
engine.shutdown() engine.shutdown()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment