Unverified Commit 3694f8f9 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Mixed style of chunked prefill (#1013)

parent 5a261bd0
...@@ -111,11 +111,14 @@ class PrefillAdder: ...@@ -111,11 +111,14 @@ class PrefillAdder:
rem_total_tokens: int, rem_total_tokens: int,
rem_input_tokens: int, rem_input_tokens: int,
rem_chunk_tokens: Optional[int], rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0,
): ):
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.rem_total_tokens = rem_total_tokens self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_input_tokens = rem_input_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens self.rem_chunk_tokens = rem_chunk_tokens
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens
self.can_run_list = [] self.can_run_list = []
self.new_inflight_req = None self.new_inflight_req = None
......
...@@ -329,6 +329,9 @@ class ScheduleBatch: ...@@ -329,6 +329,9 @@ class ScheduleBatch:
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
extend_num_tokens: int = None extend_num_tokens: int = None
# For mixed chunekd prefill
prefix_lens_cpu: List[int] = None
# For processing logprobs # For processing logprobs
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: List[int] = None top_logprobs_nums: List[int] = None
...@@ -462,9 +465,33 @@ class ScheduleBatch: ...@@ -462,9 +465,33 @@ class ScheduleBatch:
self.extend_num_tokens = extend_num_tokens self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
self.batch_sampling_params(vocab_size) self.batch_sampling_params(vocab_size)
def mix_with_running(self, running_batch: "ScheduleBatch"):
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
prefix_lens_cpu.extend(
[
len(r.origin_input_ids) + len(r.output_ids) - 1
for r in running_batch.reqs
]
)
for req in running_batch.reqs:
req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = 1
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
self.merge(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens = extend_num_tokens
self.prefix_lens_cpu = prefix_lens_cpu
def check_decode_mem(self): def check_decode_mem(self):
bs = self.batch_size() bs = self.batch_size()
if self.token_to_kv_pool.available_size() >= bs: if self.token_to_kv_pool.available_size() >= bs:
......
...@@ -174,6 +174,9 @@ class ModelTpServer: ...@@ -174,6 +174,9 @@ class ModelTpServer:
# Chunked prefill # Chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None self.current_inflight_req = None
self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)
# Init the FSM cache for constrained generation # Init the FSM cache for constrained generation
if not server_args.skip_tokenizer_init: if not server_args.skip_tokenizer_init:
...@@ -366,11 +369,14 @@ class ModelTpServer: ...@@ -366,11 +369,14 @@ class ModelTpServer:
# Get priority queue # Get priority queue
prefix_computed = self.scheduler.calc_priority(self.waiting_queue) prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
num_mixed_running = running_bs if self.is_mixed_chunk else 0
adder = PrefillAdder( adder = PrefillAdder(
self.tree_cache, self.tree_cache,
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
self.max_prefill_tokens, self.max_prefill_tokens,
self.chunked_prefill_size, self.chunked_prefill_size,
num_mixed_running,
) )
if self.running_batch is not None: if self.running_batch is not None:
...@@ -416,15 +422,27 @@ class ModelTpServer: ...@@ -416,15 +422,27 @@ class ModelTpServer:
) )
else: else:
tree_cache_hit_rate = 0.0 tree_cache_hit_rate = 0.0
logger.info(
f"[gpu={self.gpu_id}] Prefill batch. " if num_mixed_running > 0:
f"#new-seq: {len(can_run_list)}, " logger.info(
f"#new-token: {adder.log_input_tokens}, " f"[gpu={self.gpu_id}] Prefill batch"
f"#cached-token: {adder.log_hit_tokens}, " f"(mixed #running-req: {num_mixed_running}). "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"#new-seq: {len(can_run_list)}, "
f"#running-req: {running_bs}, " f"#new-token: {adder.log_input_tokens}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" f"#cached-token: {adder.log_hit_tokens}, "
) f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
)
else:
logger.info(
f"[gpu={self.gpu_id}] Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
)
# Return the new batch # Return the new batch
new_batch = ScheduleBatch.init_new( new_batch = ScheduleBatch.init_new(
...@@ -440,6 +458,13 @@ class ModelTpServer: ...@@ -440,6 +458,13 @@ class ModelTpServer:
# Build batch tensors # Build batch tensors
batch.prepare_for_extend(self.model_config.vocab_size) batch.prepare_for_extend(self.model_config.vocab_size)
decoding_reqs = []
if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode()
batch.mix_with_running(self.running_batch)
decoding_reqs = self.running_batch.reqs
self.running_batch = None
if self.model_runner.is_generation: if self.model_runner.is_generation:
# Forward and sample the next tokens # Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
...@@ -481,7 +506,8 @@ class ModelTpServer: ...@@ -481,7 +506,8 @@ class ModelTpServer:
if req.finished(): if req.finished():
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
else: elif req not in decoding_reqs:
# To reduce overhead, only cache prefill reqs
self.tree_cache.cache_unfinished_req(req) self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req: if req is self.current_inflight_req:
......
...@@ -88,11 +88,11 @@ class InputMetadata: ...@@ -88,11 +88,11 @@ class InputMetadata:
self.image_sizes = [r.image_size for r in reqs] self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [ self.image_offsets = [
( (
(r.image_offset - len(r.prefix_indices)) (r.image_offset - batch.prefix_lens_cpu[i])
if r.image_offset is not None if r.image_offset is not None
else 0 else 0
) )
for r in reqs for i, r in enumerate(reqs)
] ]
def compute_positions(self, batch: ScheduleBatch): def compute_positions(self, batch: ScheduleBatch):
...@@ -109,8 +109,8 @@ class InputMetadata: ...@@ -109,8 +109,8 @@ class InputMetadata:
self.positions = torch.tensor( self.positions = torch.tensor(
np.concatenate( np.concatenate(
[ [
np.arange(len(req.prefix_indices), len(req.fill_ids)) np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
for req in batch.reqs for i, req in enumerate(batch.reqs)
], ],
axis=0, axis=0,
), ),
...@@ -123,7 +123,7 @@ class InputMetadata: ...@@ -123,7 +123,7 @@ class InputMetadata:
np.concatenate( np.concatenate(
[ [
np.arange( np.arange(
len(req.prefix_indices) + position_ids_offsets_cpu[i], batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
len(req.fill_ids) + position_ids_offsets_cpu[i], len(req.fill_ids) + position_ids_offsets_cpu[i],
) )
for i, req in enumerate(batch.reqs) for i, req in enumerate(batch.reqs)
...@@ -141,12 +141,13 @@ class InputMetadata: ...@@ -141,12 +141,13 @@ class InputMetadata:
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
else: else:
extend_lens_cpu = [ extend_lens_cpu = [
len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs len(r.fill_ids) - batch.prefix_lens_cpu[i]
for i, r in enumerate(batch.reqs)
] ]
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.seq_lens) self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs) self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
@classmethod @classmethod
def from_schedule_batch( def from_schedule_batch(
...@@ -180,14 +181,8 @@ class InputMetadata: ...@@ -180,14 +181,8 @@ class InputMetadata:
if forward_mode != ForwardMode.DECODE: if forward_mode != ForwardMode.DECODE:
ret.init_multimuldal_info(batch) ret.init_multimuldal_info(batch)
prefix_lens = None
if forward_mode != ForwardMode.DECODE:
prefix_lens = torch.tensor(
[len(r.prefix_indices) for r in batch.reqs], device="cuda"
)
if model_runner.server_args.disable_flashinfer: if model_runner.server_args.disable_flashinfer:
ret.init_triton_args(batch, prefix_lens) ret.init_triton_args(batch)
flashinfer_use_ragged = False flashinfer_use_ragged = False
if not model_runner.server_args.disable_flashinfer: if not model_runner.server_args.disable_flashinfer:
...@@ -198,30 +193,35 @@ class InputMetadata: ...@@ -198,30 +193,35 @@ class InputMetadata:
): ):
flashinfer_use_ragged = True flashinfer_use_ragged = True
ret.init_flashinfer_handlers( ret.init_flashinfer_handlers(
model_runner, prefix_lens, flashinfer_use_ragged model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
) )
return ret return ret
def init_triton_args(self, batch: ScheduleBatch, prefix_lens): def init_triton_args(self, batch: ScheduleBatch):
"""Init auxiliary variables for triton attention backend.""" """Init auxiliary variables for triton attention backend."""
self.triton_max_seq_len = int(torch.max(self.seq_lens)) self.triton_max_seq_len = int(torch.max(self.seq_lens))
self.triton_prefix_lens = prefix_lens
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0) self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
if self.forward_mode == ForwardMode.DECODE: if self.forward_mode == ForwardMode.DECODE:
self.triton_max_extend_len = None self.triton_max_extend_len = None
else: else:
extend_seq_lens = self.seq_lens - prefix_lens self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
extend_seq_lens = self.seq_lens - self.triton_prefix_lens
self.triton_max_extend_len = int(torch.max(extend_seq_lens)) self.triton_max_extend_len = int(torch.max(extend_seq_lens))
def init_flashinfer_handlers( def init_flashinfer_handlers(
self, self,
model_runner, model_runner,
prefix_lens, prefix_lens_cpu,
flashinfer_use_ragged, flashinfer_use_ragged,
): ):
if self.forward_mode != ForwardMode.DECODE:
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
else:
prefix_lens = None
update_flashinfer_indices( update_flashinfer_indices(
self.forward_mode, self.forward_mode,
model_runner, model_runner,
......
...@@ -445,15 +445,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer): ...@@ -445,15 +445,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
print(f"Initialization failed. warmup error: {last_traceback}", flush=True) print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
sys.exit(1) sys.exit(1)
# Print warnings here
if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
logger.warning(
"You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
"This combination is an experimental feature and we noticed it can lead to "
"wrong generation results. If you want to use chunked prefill, it is recommended "
"not using `--disable-radix-cache`."
)
logger.info("The server is fired up and ready to roll!") logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None: if pipe_finish_writer is not None:
pipe_finish_writer.send("init ok") pipe_finish_writer.send("init ok")
......
...@@ -80,6 +80,7 @@ class ServerArgs: ...@@ -80,6 +80,7 @@ class ServerArgs:
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False disable_cuda_graph: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
enable_p2p_check: bool = False enable_p2p_check: bool = False
enable_mla: bool = False enable_mla: bool = False
...@@ -396,6 +397,11 @@ class ServerArgs: ...@@ -396,6 +397,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
) )
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
help="Enabling mixing prefill and decode in a chunked batch.",
)
parser.add_argument( parser.add_argument(
"--enable-torch-compile", "--enable-torch-compile",
action="store_true", action="store_true",
......
# Adapted from https://github.com/openai/simple-evals/ # Adapted from https://github.com/openai/simple-evals/
import base64
import os import os
import resource import resource
import time import time
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool from multiprocessing.pool import ThreadPool
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple
import httpx import httpx
import jinja2 import jinja2
...@@ -44,8 +43,8 @@ class EvalResult: ...@@ -44,8 +43,8 @@ class EvalResult:
Result of running an evaluation (usually consisting of many samples) Result of running an evaluation (usually consisting of many samples)
""" """
score: float | None # top-line metric score: Optional[float] # top-line metric
metrics: Dict[str, float] | None # other metrics metrics: Optional[Dict[str, float]] # other metrics
htmls: List[str] # strings of valid HTML htmls: List[str] # strings of valid HTML
convos: List[MessageList] # sampled conversations convos: List[MessageList] # sampled conversations
...@@ -56,10 +55,10 @@ class SingleEvalResult: ...@@ -56,10 +55,10 @@ class SingleEvalResult:
Result of evaluating a single sample Result of evaluating a single sample
""" """
score: float | None score: Optional[float]
metrics: Dict[str, float] = field(default_factory=dict) metrics: Dict[str, float] = field(default_factory=dict)
html: str | None = None html: Optional[str] = None
convo: MessageList | None = None # sampled conversation convo: Optional[MessageList] = None # sampled conversation
class Eval: class Eval:
...@@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase): ...@@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase):
def __init__( def __init__(
self, self,
base_url: str = None, base_url: str = None,
model: str | None = None, model: Optional[str] = None,
system_message: str | None = None, system_message: Optional[str] = None,
temperature: float = 0.0, temperature: float = 0.0,
max_tokens: int = 2048, max_tokens: int = 2048,
): ):
...@@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str): ...@@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str):
def aggregate_results( def aggregate_results(
single_eval_results: List[SingleEvalResult], single_eval_results: List[SingleEvalResult],
default_stats: Tuple[str] = ("mean", "std"), default_stats: Tuple[str] = ("mean", "std"),
name2stats: Dict[str, Tuple[str]] | None = None, name2stats: Optional[Dict[str, Tuple[str]]] = None,
) -> EvalResult: ) -> EvalResult:
""" """
Aggregate results from multiple evaluations into a single EvalResult. Aggregate results from multiple evaluations into a single EvalResult.
......
...@@ -8,6 +8,7 @@ https://arxiv.org/abs/2311.12022 ...@@ -8,6 +8,7 @@ https://arxiv.org/abs/2311.12022
import random import random
import re import re
from typing import Optional
import pandas import pandas
...@@ -28,7 +29,7 @@ class GPQAEval(Eval): ...@@ -28,7 +29,7 @@ class GPQAEval(Eval):
def __init__( def __init__(
self, self,
filename: str, filename: str,
num_examples: int | None, num_examples: Optional[int],
num_threads: int, num_threads: int,
n_repeats: int = 1, n_repeats: int = 1,
): ):
......
...@@ -9,7 +9,7 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/ ...@@ -9,7 +9,7 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
import random import random
import re import re
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List from typing import Dict, List, Optional
import tqdm import tqdm
...@@ -61,7 +61,7 @@ def evaluate_functional_correctness( ...@@ -61,7 +61,7 @@ def evaluate_functional_correctness(
class HumanEval(Eval): class HumanEval(Eval):
def __init__( def __init__(
self, self,
num_examples: int | None, num_examples: Optional[int],
num_threads: int, num_threads: int,
num_samples_per_task: int = 5, num_samples_per_task: int = 5,
ks_passes: List[int] = [1, 2, 5], ks_passes: List[int] = [1, 2, 5],
......
...@@ -8,6 +8,7 @@ https://arxiv.org/abs/2103.03874 ...@@ -8,6 +8,7 @@ https://arxiv.org/abs/2103.03874
import random import random
import re import re
from typing import Optional
import pandas import pandas
...@@ -36,7 +37,7 @@ class MathEval(Eval): ...@@ -36,7 +37,7 @@ class MathEval(Eval):
self, self,
filename: str, filename: str,
equality_checker: SamplerBase, equality_checker: SamplerBase,
num_examples: int | None, num_examples: Optional[int],
num_threads: int, num_threads: int,
): ):
df = pandas.read_csv(filename) df = pandas.read_csv(filename)
......
...@@ -8,6 +8,7 @@ https://arxiv.org/abs/2009.03300 ...@@ -8,6 +8,7 @@ https://arxiv.org/abs/2009.03300
import random import random
import re import re
from typing import Optional
import pandas import pandas
...@@ -84,7 +85,7 @@ subject2category = { ...@@ -84,7 +85,7 @@ subject2category = {
class MMLUEval(Eval): class MMLUEval(Eval):
def __init__(self, filename: str, num_examples: int | None, num_threads: int): def __init__(self, filename: str, num_examples: Optional[int], num_threads: int):
df = pandas.read_csv(filename) df = pandas.read_csv(filename)
examples = [row.to_dict() for _, row in df.iterrows()] examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples: if num_examples:
......
...@@ -11,11 +11,14 @@ from sglang.test.test_utils import ( ...@@ -11,11 +11,14 @@ from sglang.test.test_utils import (
class TestChunkedPrefill(unittest.TestCase): class TestChunkedPrefill(unittest.TestCase):
def run_mmlu(self, disable_radix_cache): def run_mmlu(self, disable_radix_cache, enable_mixed_chunk):
other_args = ["--chunked-prefill-size", "32"] other_args = ["--chunked-prefill-size", "32"]
if disable_radix_cache: if disable_radix_cache:
other_args += ["--disable-radix-cache"] other_args += ["--disable-radix-cache"]
if enable_mixed_chunk:
other_args += ["--enable-mixed-chunk"]
model = DEFAULT_MODEL_NAME_FOR_TEST model = DEFAULT_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_UNIT_TEST base_url = DEFAULT_URL_FOR_UNIT_TEST
process = popen_launch_server( process = popen_launch_server(
...@@ -40,10 +43,16 @@ class TestChunkedPrefill(unittest.TestCase): ...@@ -40,10 +43,16 @@ class TestChunkedPrefill(unittest.TestCase):
kill_child_process(process.pid) kill_child_process(process.pid)
def test_chunked_prefill(self): def test_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=False) self.run_mmlu(disable_radix_cache=False, enable_mixed_chunk=False)
def test_mixed_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=False, enable_mixed_chunk=True)
def test_chunked_prefill_without_radix_cache(self): def test_chunked_prefill_without_radix_cache(self):
self.run_mmlu(disable_radix_cache=True) self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=False)
def test_mixed_chunked_prefill_without_radix_cache(self):
self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=True)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -6,7 +6,6 @@ from sglang.test.run_eval import run_eval ...@@ -6,7 +6,6 @@ from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_ACCURACY_TEST, DEFAULT_URL_FOR_ACCURACY_TEST,
DEFAULT_URL_FOR_UNIT_TEST,
popen_launch_server, popen_launch_server,
) )
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_ACCURACY_TEST,
popen_launch_server,
)
class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_ACCURACY_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=300,
other_args=[
"--log-level-http",
"warning",
"--chunked-prefill-size",
"256",
"--enable-mixed-chunk",
],
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=3000,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.71, f"{metrics}"
def test_human_eval(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="humaneval",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.64, f"{metrics}"
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.84, f"{metrics}"
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment