Unverified Commit 3694f8f9 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Mixed style of chunked prefill (#1013)

parent 5a261bd0
......@@ -111,11 +111,14 @@ class PrefillAdder:
rem_total_tokens: int,
rem_input_tokens: int,
rem_chunk_tokens: Optional[int],
mixed_with_decode_tokens: int = 0,
):
self.tree_cache = tree_cache
self.rem_total_tokens = rem_total_tokens
self.rem_input_tokens = rem_input_tokens
self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
self.rem_chunk_tokens = rem_chunk_tokens
if self.rem_chunk_tokens is not None:
self.rem_chunk_tokens -= mixed_with_decode_tokens
self.can_run_list = []
self.new_inflight_req = None
......
......@@ -329,6 +329,9 @@ class ScheduleBatch:
out_cache_loc: torch.Tensor = None
extend_num_tokens: int = None
# For mixed chunekd prefill
prefix_lens_cpu: List[int] = None
# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: List[int] = None
......@@ -462,9 +465,33 @@ class ScheduleBatch:
self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
self.batch_sampling_params(vocab_size)
def mix_with_running(self, running_batch: "ScheduleBatch"):
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
prefix_lens_cpu.extend(
[
len(r.origin_input_ids) + len(r.output_ids) - 1
for r in running_batch.reqs
]
)
for req in running_batch.reqs:
req.fill_ids = req.origin_input_ids + req.output_ids
req.extend_input_len = 1
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
self.merge(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens = extend_num_tokens
self.prefix_lens_cpu = prefix_lens_cpu
def check_decode_mem(self):
bs = self.batch_size()
if self.token_to_kv_pool.available_size() >= bs:
......
......@@ -174,6 +174,9 @@ class ModelTpServer:
# Chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
self.is_mixed_chunk = (
self.chunked_prefill_size is not None and server_args.enable_mixed_chunk
)
# Init the FSM cache for constrained generation
if not server_args.skip_tokenizer_init:
......@@ -366,11 +369,14 @@ class ModelTpServer:
# Get priority queue
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
num_mixed_running = running_bs if self.is_mixed_chunk else 0
adder = PrefillAdder(
self.tree_cache,
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
self.max_prefill_tokens,
self.chunked_prefill_size,
num_mixed_running,
)
if self.running_batch is not None:
......@@ -416,15 +422,27 @@ class ModelTpServer:
)
else:
tree_cache_hit_rate = 0.0
logger.info(
f"[gpu={self.gpu_id}] Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
)
if num_mixed_running > 0:
logger.info(
f"[gpu={self.gpu_id}] Prefill batch"
f"(mixed #running-req: {num_mixed_running}). "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
)
else:
logger.info(
f"[gpu={self.gpu_id}] Prefill batch. "
f"#new-seq: {len(can_run_list)}, "
f"#new-token: {adder.log_input_tokens}, "
f"#cached-token: {adder.log_hit_tokens}, "
f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
f"#running-req: {running_bs}, "
f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
)
# Return the new batch
new_batch = ScheduleBatch.init_new(
......@@ -440,6 +458,13 @@ class ModelTpServer:
# Build batch tensors
batch.prepare_for_extend(self.model_config.vocab_size)
decoding_reqs = []
if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.prepare_for_decode()
batch.mix_with_running(self.running_batch)
decoding_reqs = self.running_batch.reqs
self.running_batch = None
if self.model_runner.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
......@@ -481,7 +506,8 @@ class ModelTpServer:
if req.finished():
self.tree_cache.cache_finished_req(req)
else:
elif req not in decoding_reqs:
# To reduce overhead, only cache prefill reqs
self.tree_cache.cache_unfinished_req(req)
if req is self.current_inflight_req:
......
......@@ -88,11 +88,11 @@ class InputMetadata:
self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [
(
(r.image_offset - len(r.prefix_indices))
(r.image_offset - batch.prefix_lens_cpu[i])
if r.image_offset is not None
else 0
)
for r in reqs
for i, r in enumerate(reqs)
]
def compute_positions(self, batch: ScheduleBatch):
......@@ -109,8 +109,8 @@ class InputMetadata:
self.positions = torch.tensor(
np.concatenate(
[
np.arange(len(req.prefix_indices), len(req.fill_ids))
for req in batch.reqs
np.arange(batch.prefix_lens_cpu[i], len(req.fill_ids))
for i, req in enumerate(batch.reqs)
],
axis=0,
),
......@@ -123,7 +123,7 @@ class InputMetadata:
np.concatenate(
[
np.arange(
len(req.prefix_indices) + position_ids_offsets_cpu[i],
batch.prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
len(req.fill_ids) + position_ids_offsets_cpu[i],
)
for i, req in enumerate(batch.reqs)
......@@ -141,12 +141,13 @@ class InputMetadata:
self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
else:
extend_lens_cpu = [
len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
len(r.fill_ids) - batch.prefix_lens_cpu[i]
for i, r in enumerate(batch.reqs)
]
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
self.extend_start_loc = torch.zeros_like(self.seq_lens)
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
@classmethod
def from_schedule_batch(
......@@ -180,14 +181,8 @@ class InputMetadata:
if forward_mode != ForwardMode.DECODE:
ret.init_multimuldal_info(batch)
prefix_lens = None
if forward_mode != ForwardMode.DECODE:
prefix_lens = torch.tensor(
[len(r.prefix_indices) for r in batch.reqs], device="cuda"
)
if model_runner.server_args.disable_flashinfer:
ret.init_triton_args(batch, prefix_lens)
ret.init_triton_args(batch)
flashinfer_use_ragged = False
if not model_runner.server_args.disable_flashinfer:
......@@ -198,30 +193,35 @@ class InputMetadata:
):
flashinfer_use_ragged = True
ret.init_flashinfer_handlers(
model_runner, prefix_lens, flashinfer_use_ragged
model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged
)
return ret
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
def init_triton_args(self, batch: ScheduleBatch):
"""Init auxiliary variables for triton attention backend."""
self.triton_max_seq_len = int(torch.max(self.seq_lens))
self.triton_prefix_lens = prefix_lens
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
if self.forward_mode == ForwardMode.DECODE:
self.triton_max_extend_len = None
else:
extend_seq_lens = self.seq_lens - prefix_lens
self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
extend_seq_lens = self.seq_lens - self.triton_prefix_lens
self.triton_max_extend_len = int(torch.max(extend_seq_lens))
def init_flashinfer_handlers(
self,
model_runner,
prefix_lens,
prefix_lens_cpu,
flashinfer_use_ragged,
):
if self.forward_mode != ForwardMode.DECODE:
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
else:
prefix_lens = None
update_flashinfer_indices(
self.forward_mode,
model_runner,
......
......@@ -445,15 +445,6 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
sys.exit(1)
# Print warnings here
if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
logger.warning(
"You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
"This combination is an experimental feature and we noticed it can lead to "
"wrong generation results. If you want to use chunked prefill, it is recommended "
"not using `--disable-radix-cache`."
)
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("init ok")
......
......@@ -80,6 +80,7 @@ class ServerArgs:
disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_disk_cache: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
enable_p2p_check: bool = False
enable_mla: bool = False
......@@ -396,6 +397,11 @@ class ServerArgs:
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
help="Enabling mixing prefill and decode in a chunked batch.",
)
parser.add_argument(
"--enable-torch-compile",
action="store_true",
......
# Adapted from https://github.com/openai/simple-evals/
import base64
import os
import resource
import time
from collections import defaultdict
from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import httpx
import jinja2
......@@ -44,8 +43,8 @@ class EvalResult:
Result of running an evaluation (usually consisting of many samples)
"""
score: float | None # top-line metric
metrics: Dict[str, float] | None # other metrics
score: Optional[float] # top-line metric
metrics: Optional[Dict[str, float]] # other metrics
htmls: List[str] # strings of valid HTML
convos: List[MessageList] # sampled conversations
......@@ -56,10 +55,10 @@ class SingleEvalResult:
Result of evaluating a single sample
"""
score: float | None
score: Optional[float]
metrics: Dict[str, float] = field(default_factory=dict)
html: str | None = None
convo: MessageList | None = None # sampled conversation
html: Optional[str] = None
convo: Optional[MessageList] = None # sampled conversation
class Eval:
......@@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase):
def __init__(
self,
base_url: str = None,
model: str | None = None,
system_message: str | None = None,
model: Optional[str] = None,
system_message: Optional[str] = None,
temperature: float = 0.0,
max_tokens: int = 2048,
):
......@@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str):
def aggregate_results(
single_eval_results: List[SingleEvalResult],
default_stats: Tuple[str] = ("mean", "std"),
name2stats: Dict[str, Tuple[str]] | None = None,
name2stats: Optional[Dict[str, Tuple[str]]] = None,
) -> EvalResult:
"""
Aggregate results from multiple evaluations into a single EvalResult.
......
......@@ -8,6 +8,7 @@ https://arxiv.org/abs/2311.12022
import random
import re
from typing import Optional
import pandas
......@@ -28,7 +29,7 @@ class GPQAEval(Eval):
def __init__(
self,
filename: str,
num_examples: int | None,
num_examples: Optional[int],
num_threads: int,
n_repeats: int = 1,
):
......
......@@ -9,7 +9,7 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
import random
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List
from typing import Dict, List, Optional
import tqdm
......@@ -61,7 +61,7 @@ def evaluate_functional_correctness(
class HumanEval(Eval):
def __init__(
self,
num_examples: int | None,
num_examples: Optional[int],
num_threads: int,
num_samples_per_task: int = 5,
ks_passes: List[int] = [1, 2, 5],
......
......@@ -8,6 +8,7 @@ https://arxiv.org/abs/2103.03874
import random
import re
from typing import Optional
import pandas
......@@ -36,7 +37,7 @@ class MathEval(Eval):
self,
filename: str,
equality_checker: SamplerBase,
num_examples: int | None,
num_examples: Optional[int],
num_threads: int,
):
df = pandas.read_csv(filename)
......
......@@ -8,6 +8,7 @@ https://arxiv.org/abs/2009.03300
import random
import re
from typing import Optional
import pandas
......@@ -84,7 +85,7 @@ subject2category = {
class MMLUEval(Eval):
def __init__(self, filename: str, num_examples: int | None, num_threads: int):
def __init__(self, filename: str, num_examples: Optional[int], num_threads: int):
df = pandas.read_csv(filename)
examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples:
......
......@@ -11,11 +11,14 @@ from sglang.test.test_utils import (
class TestChunkedPrefill(unittest.TestCase):
def run_mmlu(self, disable_radix_cache):
def run_mmlu(self, disable_radix_cache, enable_mixed_chunk):
other_args = ["--chunked-prefill-size", "32"]
if disable_radix_cache:
other_args += ["--disable-radix-cache"]
if enable_mixed_chunk:
other_args += ["--enable-mixed-chunk"]
model = DEFAULT_MODEL_NAME_FOR_TEST
base_url = DEFAULT_URL_FOR_UNIT_TEST
process = popen_launch_server(
......@@ -40,10 +43,16 @@ class TestChunkedPrefill(unittest.TestCase):
kill_child_process(process.pid)
def test_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=False)
self.run_mmlu(disable_radix_cache=False, enable_mixed_chunk=False)
def test_mixed_chunked_prefill(self):
self.run_mmlu(disable_radix_cache=False, enable_mixed_chunk=True)
def test_chunked_prefill_without_radix_cache(self):
self.run_mmlu(disable_radix_cache=True)
self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=False)
def test_mixed_chunked_prefill_without_radix_cache(self):
self.run_mmlu(disable_radix_cache=True, enable_mixed_chunk=True)
if __name__ == "__main__":
......
......@@ -6,7 +6,6 @@ from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_ACCURACY_TEST,
DEFAULT_URL_FOR_UNIT_TEST,
popen_launch_server,
)
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_ACCURACY_TEST,
popen_launch_server,
)
class TestEvalAccuracyLargeChunkedPrefill(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_ACCURACY_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=300,
other_args=[
"--log-level-http",
"warning",
"--chunked-prefill-size",
"256",
"--enable-mixed-chunk",
],
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=3000,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.71, f"{metrics}"
def test_human_eval(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="humaneval",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.64, f"{metrics}"
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.84, f"{metrics}"
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment