Unverified Commit fcc2e37f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Split the __init__ of scheduler as smaller functions. Improve the eagle tests (#4128)

parent 0804dd11
...@@ -482,6 +482,7 @@ class BatchEmbeddingOut: ...@@ -482,6 +482,7 @@ class BatchEmbeddingOut:
embeddings: List[List[float]] embeddings: List[List[float]]
# Token counts # Token counts
prompt_tokens: List[int] prompt_tokens: List[int]
cached_tokens: List[int]
@dataclass @dataclass
......
...@@ -159,17 +159,6 @@ class Scheduler: ...@@ -159,17 +159,6 @@ class Scheduler:
) )
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.decode_mem_cache_buf_multiplier = (
(
self.server_args.speculative_num_draft_tokens
+ (
self.server_args.speculative_eagle_topk
* self.server_args.speculative_num_draft_tokens
)
)
if not self.spec_algorithm.is_none()
else 1
)
# Distributed rank info # Distributed rank info
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
...@@ -208,42 +197,12 @@ class Scheduler: ...@@ -208,42 +197,12 @@ class Scheduler:
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
# Init tokenizer # Init tokenizer
self.model_config = ModelConfig( self.init_tokenizer()
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
self.is_generation = self.model_config.is_generation
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
# Check whether overlap can be enabled # Check whether overlap can be enabled
if not self.is_generation: if not self.is_generation:
self.enable_overlap = False self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.") logger.info("Overlap scheduler is disabled for embedding models.")
if self.model_config.is_multimodal: if self.model_config.is_multimodal:
self.enable_overlap = False self.enable_overlap = False
logger.info("Overlap scheduler is disabled for multimodal models.") logger.info("Overlap scheduler is disabled for multimodal models.")
...@@ -307,32 +266,7 @@ class Scheduler: ...@@ -307,32 +266,7 @@ class Scheduler:
) )
# Init memory pool and cache # Init memory pool and cache
self.req_to_token_pool, self.token_to_kv_pool_allocator = ( self.init_memory_pool_and_cache()
self.tp_worker.get_memory_pool()
)
if (
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
if self.enable_hierarchical_cache:
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
disable=server_args.disable_radix_cache,
)
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
# Init running status # Init running status
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
...@@ -346,25 +280,13 @@ class Scheduler: ...@@ -346,25 +280,13 @@ class Scheduler:
self.forward_ct = 0 self.forward_ct = 0
self.forward_ct_decode = 0 self.forward_ct_decode = 0
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.last_decode_stats_tic = time.time() self.last_decode_stats_tic = time.time()
self.return_health_check_ct = 0 self.return_health_check_ct = 0
self.current_stream = torch.get_device_module(self.device).current_stream() self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu": if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU self.current_stream.synchronize = lambda: None # No-op for CPU
# For metrics only. # Init session info
# The largest prefill length of a single request
self._largest_prefill_len: int = 0
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
self.last_gen_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
# Session info
self.sessions: Dict[str, Session] = {} self.sessions: Dict[str, Session] = {}
# Init chunked prefill # Init chunked prefill
...@@ -385,11 +307,11 @@ class Scheduler: ...@@ -385,11 +307,11 @@ class Scheduler:
else: else:
self.grammar_backend = None self.grammar_backend = None
# Init new token estimation # Init schedule policy and new token estimation
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
assert ( assert (
server_args.schedule_conservativeness >= 0 server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness" ), "Invalid schedule_conservativeness"
self.init_new_token_ratio = min( self.init_new_token_ratio = min(
global_config.default_init_new_token_ratio global_config.default_init_new_token_ratio
* server_args.schedule_conservativeness, * server_args.schedule_conservativeness,
...@@ -428,14 +350,7 @@ class Scheduler: ...@@ -428,14 +350,7 @@ class Scheduler:
self.profiler_target_forward_ct: Optional[int] = None self.profiler_target_forward_ct: Optional[int] = None
# Init metrics stats # Init metrics stats
self.stats = SchedulerStats() self.init_metrics()
if self.enable_metrics:
self.metrics_collector = SchedulerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
)
# Init request dispatcher # Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher( self._request_dispatcher = TypeBasedDispatcher(
...@@ -458,39 +373,104 @@ class Scheduler: ...@@ -458,39 +373,104 @@ class Scheduler:
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(ProfileReq, self.profile), (ProfileReq, self.profile),
(GetInternalStateReq, self.get_internal_state), (GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
] ]
) )
def watchdog_thread(self): def init_tokenizer(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long.""" server_args = self.server_args
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time()
while True: self.model_config = ModelConfig(
current = time.time() server_args.model_path,
if self.cur_batch is not None: trust_remote_code=server_args.trust_remote_code,
if self.watchdog_last_forward_ct == self.forward_ct: revision=server_args.revision,
if current > self.watchdog_last_time + self.watchdog_timeout: context_length=server_args.context_length,
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})") model_override_args=server_args.json_model_override_args,
break is_embedding=server_args.is_embedding,
else: dtype=server_args.dtype,
self.watchdog_last_forward_ct = self.forward_ct quantization=server_args.quantization,
self.watchdog_last_time = current )
time.sleep(self.watchdog_timeout // 2) self.is_generation = self.model_config.is_generation
# Print batch size and memory pool info to check whether there are de-sync issues. if server_args.skip_tokenizer_init:
logger.error( self.tokenizer = self.processor = None
f"{self.cur_batch.batch_size()=}, " else:
f"{self.cur_batch.reqs=}, " if self.model_config.is_multimodal:
f"{self.token_to_kv_pool_allocator.available_size()=}, " self.processor = get_processor(
f"{self.tree_cache.evictable_size()=}, " server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
def init_memory_pool_and_cache(self):
server_args = self.server_args
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
self.tp_worker.get_memory_pool()
)
if (
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
if self.enable_hierarchical_cache:
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
disable=server_args.disable_radix_cache,
)
self.decode_mem_cache_buf_multiplier = (
1
if self.spec_algorithm.is_none()
else (
server_args.speculative_num_draft_tokens
+ (
server_args.speculative_eagle_topk
* server_args.speculative_num_steps
)
)
) )
# Wait for some time so that the parent process can print the error.
pyspy_dump_schedulers() def init_metrics(self):
print(file=sys.stderr, flush=True) # The largest prefill length of a single request
print(file=sys.stdout, flush=True) self._largest_prefill_len: int = 0
time.sleep(5) # The largest context length (prefill + generation) of a single request
self.parent_process.send_signal(signal.SIGQUIT) self._largest_prefill_decode_len: int = 0
self.last_gen_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.stats = SchedulerStats()
if self.enable_metrics:
engine_type = "unified"
self.metrics_collector = SchedulerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
"engine_type": engine_type,
},
)
@torch.no_grad() @torch.no_grad()
def event_loop_normal(self): def event_loop_normal(self):
...@@ -1176,6 +1156,7 @@ class Scheduler: ...@@ -1176,6 +1156,7 @@ class Scheduler:
): ):
self.stop_profile() self.stop_profile()
# Run forward
if self.is_generation: if self.is_generation:
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
...@@ -1196,6 +1177,7 @@ class Scheduler: ...@@ -1196,6 +1177,7 @@ class Scheduler:
self.spec_num_total_forward_ct += batch.batch_size() self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens self.num_generated_tokens += num_accepted_tokens
batch.output_ids = next_token_ids batch.output_ids = next_token_ids
# These 2 values are needed for processing the output, but the values can be # These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that # modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing. # we can use the correct values in output processing.
...@@ -1229,7 +1211,6 @@ class Scheduler: ...@@ -1229,7 +1211,6 @@ class Scheduler:
result: Union[GenerationBatchResult, EmbeddingBatchResult], result: Union[GenerationBatchResult, EmbeddingBatchResult],
): ):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
assert isinstance(result, GenerationBatchResult)
self.process_batch_result_decode(batch, result) self.process_batch_result_decode(batch, result)
if batch.is_empty(): if batch.is_empty():
self.running_batch = None self.running_batch = None
...@@ -1481,6 +1462,7 @@ class Scheduler: ...@@ -1481,6 +1462,7 @@ class Scheduler:
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize() self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs, batch.return_logprob) self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool_allocator.free_group_end() self.token_to_kv_pool_allocator.free_group_end()
...@@ -1584,7 +1566,9 @@ class Scheduler: ...@@ -1584,7 +1566,9 @@ class Scheduler:
req.temp_input_token_ids_logprobs_idx req.temp_input_token_ids_logprobs_idx
) )
for val, idx in zip( for val, idx in zip(
req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx req.temp_input_top_logprobs_val,
req.temp_input_top_logprobs_idx,
strict=True,
): ):
req.input_top_logprobs_val.extend(val) req.input_top_logprobs_val.extend(val)
req.input_top_logprobs_idx.extend(idx) req.input_top_logprobs_idx.extend(idx)
...@@ -1809,14 +1793,18 @@ class Scheduler: ...@@ -1809,14 +1793,18 @@ class Scheduler:
else: # embedding or reward model else: # embedding or reward model
embeddings = [] embeddings = []
prompt_tokens = [] prompt_tokens = []
cached_tokens = []
for req in reqs: for req in reqs:
if req.finished(): if req.finished():
rids.append(req.rid) rids.append(req.rid)
finished_reasons.append(req.finished_reason.to_json()) finished_reasons.append(req.finished_reason.to_json())
embeddings.append(req.embedding) embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids)) prompt_tokens.append(len(req.origin_input_ids))
cached_tokens.append(req.cached_tokens)
self.send_to_detokenizer.send_pyobj( self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens) BatchEmbeddingOut(
rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
)
) )
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
...@@ -1902,6 +1890,37 @@ class Scheduler: ...@@ -1902,6 +1890,37 @@ class Scheduler:
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs]) self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:] self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time()
while True:
current = time.time()
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if current > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = current
time.sleep(self.watchdog_timeout // 2)
# Print batch size and memory pool info to check whether there are de-sync issues.
logger.error(
f"{self.cur_batch.batch_size()=}, "
f"{self.cur_batch.reqs=}, "
f"{self.token_to_kv_pool_allocator.available_size()=}, "
f"{self.tree_cache.evictable_size()=}, "
)
# Wait for some time so that the parent process can print the error.
pyspy_dump_schedulers()
print(file=sys.stderr, flush=True)
print(file=sys.stdout, flush=True)
time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT)
def flush_cache_wrapped(self, recv_req: FlushCacheReq): def flush_cache_wrapped(self, recv_req: FlushCacheReq):
self.flush_cache() self.flush_cache()
...@@ -1913,7 +1932,6 @@ class Scheduler: ...@@ -1913,7 +1932,6 @@ class Scheduler:
self.cur_batch = None self.cur_batch = None
self.last_batch = None self.last_batch = None
self.tree_cache.reset() self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
if self.grammar_backend: if self.grammar_backend:
self.grammar_backend.reset() self.grammar_backend.reset()
self.req_to_token_pool.clear() self.req_to_token_pool.clear()
...@@ -2005,6 +2023,9 @@ class Scheduler: ...@@ -2005,6 +2023,9 @@ class Scheduler:
req.to_abort = True req.to_abort = True
break break
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk.""" """In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req) success, message = self.tp_worker.update_weights_from_disk(recv_req)
......
...@@ -1068,6 +1068,7 @@ class TokenizerManager: ...@@ -1068,6 +1068,7 @@ class TokenizerManager:
self.metrics_collector.observe_one_finished_request( self.metrics_collector.observe_one_finished_request(
recv_obj.prompt_tokens[i], recv_obj.prompt_tokens[i],
completion_tokens, completion_tokens,
recv_obj.cached_tokens[i],
state.finished_time - state.created_time, state.finished_time - state.created_time,
) )
......
...@@ -121,6 +121,12 @@ class TokenizerMetricsCollector: ...@@ -121,6 +121,12 @@ class TokenizerMetricsCollector:
labelnames=labels.keys(), labelnames=labels.keys(),
) )
self.cached_tokens_total = Counter(
name="sglang:cached_tokens_total",
documentation="Number of cached prompt tokens.",
labelnames=labels.keys(),
)
self.num_requests_total = Counter( self.num_requests_total = Counter(
name="sglang:num_requests_total", name="sglang:num_requests_total",
documentation="Number of requests processed.", documentation="Number of requests processed.",
...@@ -245,10 +251,12 @@ class TokenizerMetricsCollector: ...@@ -245,10 +251,12 @@ class TokenizerMetricsCollector:
self, self,
prompt_tokens: int, prompt_tokens: int,
generation_tokens: int, generation_tokens: int,
cached_tokens: int,
e2e_latency: float, e2e_latency: float,
): ):
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
self.num_requests_total.labels(**self.labels).inc(1) self.num_requests_total.labels(**self.labels).inc(1)
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency) self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
if generation_tokens >= 1: if generation_tokens >= 1:
......
import multiprocessing as mp import multiprocessing as mp
import os
import random import random
import threading import threading
import time import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
from typing import List, Optional
import requests import requests
import torch
import sglang as sgl import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
...@@ -19,7 +23,9 @@ from sglang.test.test_utils import ( ...@@ -19,7 +23,9 @@ from sglang.test.test_utils import (
popen_launch_server, popen_launch_server,
) )
acc_rate_tolerance = 0.15 torch_dtype = torch.float16
prefill_tolerance = 5e-2
decode_tolerance: float = 5e-2
class TestEAGLEEngine(unittest.TestCase): class TestEAGLEEngine(unittest.TestCase):
...@@ -28,51 +34,72 @@ class TestEAGLEEngine(unittest.TestCase): ...@@ -28,51 +34,72 @@ class TestEAGLEEngine(unittest.TestCase):
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, "speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE", "speculative_algorithm": "EAGLE",
"speculative_num_steps": 5, "speculative_num_steps": 5,
"speculative_eagle_topk": 8, "speculative_eagle_topk": 4,
"speculative_num_draft_tokens": 64, "speculative_num_draft_tokens": 8,
"mem_fraction_static": 0.7, "mem_fraction_static": 0.7,
"cuda_graph_max_bs": 32, "cuda_graph_max_bs": 5,
} }
NUM_CONFIGS = 3
def setUp(self): def setUp(self):
self.prompt = "Today is a sunny day and I like" self.prompt = "Today is a sunny day and I like"
self.sampling_params = {"temperature": 0, "max_new_tokens": 8} self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) ref_engine = sgl.Engine(
model_path=self.BASE_CONFIG["model_path"], cuda_graph_max_bs=1
)
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"] self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown() ref_engine.shutdown()
def test_correctness(self): def test_correctness(self):
configs = [ configs = [
# Basic config
self.BASE_CONFIG, self.BASE_CONFIG,
# Disable cuda graph
{**self.BASE_CONFIG, "disable_cuda_graph": True}, {**self.BASE_CONFIG, "disable_cuda_graph": True},
{**self.BASE_CONFIG, "chunked_prefill_size": 2}, # Chunked prefill
{**self.BASE_CONFIG, "chunked_prefill_size": 4},
] ]
for config in configs: for i, config in enumerate(configs[: self.NUM_CONFIGS]):
with self.subTest( with self.subTest(i=i):
cuda_graph=( print(f"{config=}")
"enabled" if len(config) == len(self.BASE_CONFIG) else "disabled" engine = sgl.Engine(**config, log_level="info", decode_log_interval=10)
),
chunked_prefill_size=(
config["chunked_prefill_size"]
if "chunked_prefill_size" in config
else "default"
),
):
engine = sgl.Engine(**config)
try: try:
self._test_basic_generation(engine) self._test_single_generation(engine)
self._test_eos_token(engine)
self._test_batch_generation(engine) self._test_batch_generation(engine)
self._test_eos_token(engine)
self._test_acc_length(engine)
finally: finally:
engine.shutdown() engine.shutdown()
print("=" * 100)
def _test_basic_generation(self, engine): def _test_single_generation(self, engine):
output = engine.generate(self.prompt, self.sampling_params)["text"] output = engine.generate(self.prompt, self.sampling_params)["text"]
print(f"{output=}, {self.ref_output=}") print(f"{output=}, {self.ref_output=}")
self.assertEqual(output, self.ref_output) self.assertEqual(output, self.ref_output)
def _test_batch_generation(self, engine):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
params = {"temperature": 0, "max_new_tokens": 50}
outputs = engine.generate(prompts, params)
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)
print(f"{engine.get_server_info()=}")
avg_spec_accept_length = engine.get_server_info()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.9)
def _test_eos_token(self, engine): def _test_eos_token(self, engine):
prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]" prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
params = { params = {
...@@ -88,32 +115,54 @@ class TestEAGLEEngine(unittest.TestCase): ...@@ -88,32 +115,54 @@ class TestEAGLEEngine(unittest.TestCase):
tokens = tokenizer.encode(output, truncation=False) tokens = tokenizer.encode(output, truncation=False)
self.assertNotIn(tokenizer.eos_token_id, tokens) self.assertNotIn(tokenizer.eos_token_id, tokens)
def _test_batch_generation(self, engine): def _test_acc_length(self, engine):
prompts = [ prompt = [
"Hello, my name is", "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
"The president of the United States is",
"The capital of France is",
"The future of AI is",
] ]
params = {"temperature": 0, "max_new_tokens": 30} sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(prompt, sampling_params)
outputs = engine.generate(prompts, params) output = output[0]
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}") if "spec_verify_ct" in output["meta_info"]:
print(f"Generated: {output['text']}") acc_length = (
print("-" * 40) output["meta_info"]["completion_tokens"]
/ output["meta_info"]["spec_verify_ct"]
)
else:
acc_length = 1.0
speed = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["e2e_latency"]
)
print(f"{acc_length=}")
self.assertGreater(acc_length, 3.6)
prompts = [ class TestEAGLEEngineTokenMap(unittest.TestCase):
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]" BASE_CONFIG = {
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]', "model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]", "speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwho are you?[/INST]", "speculative_algorithm": "EAGLE",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwhere are you from?[/INST]", "speculative_num_steps": 5,
] "speculative_eagle_topk": 4,
"speculative_num_draft_tokens": 8,
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 5,
}
NUM_CONFIGS = 1
class TestEAGLEServer(unittest.TestCase): class TestEAGLEServer(unittest.TestCase):
PROMPTS = [
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwho are you?[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwhere are you from?[/INST]",
]
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
...@@ -127,17 +176,17 @@ class TestEAGLEServer(unittest.TestCase): ...@@ -127,17 +176,17 @@ class TestEAGLEServer(unittest.TestCase):
"--speculative-draft-model-path", "--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps", "--speculative-num-steps",
"5", 5,
"--speculative-eagle-topk", "--speculative-eagle-topk",
"8", 8,
"--speculative-num-draft-tokens", "--speculative-num-draft-tokens",
"64", 64,
"--mem-fraction-static", "--mem-fraction-static",
"0.7", 0.7,
"--chunked-prefill-size", "--chunked-prefill-size",
"128", 128,
"--cuda-graph-max-bs", "--max-running-requests",
"32", 8,
], ],
) )
...@@ -147,7 +196,7 @@ class TestEAGLEServer(unittest.TestCase): ...@@ -147,7 +196,7 @@ class TestEAGLEServer(unittest.TestCase):
def send_request(self): def send_request(self):
time.sleep(random.uniform(0, 2)) time.sleep(random.uniform(0, 2))
for prompt in prompts: for prompt in self.PROMPTS:
url = self.base_url + "/generate" url = self.base_url + "/generate"
data = { data = {
"text": prompt, "text": prompt,
...@@ -160,7 +209,7 @@ class TestEAGLEServer(unittest.TestCase): ...@@ -160,7 +209,7 @@ class TestEAGLEServer(unittest.TestCase):
assert response.status_code == 200 assert response.status_code == 200
def send_requests_abort(self): def send_requests_abort(self):
for prompt in prompts: for prompt in self.PROMPTS:
try: try:
time.sleep(random.uniform(0, 2)) time.sleep(random.uniform(0, 2))
url = self.base_url + "/generate" url = self.base_url + "/generate"
...@@ -192,6 +241,8 @@ class TestEAGLEServer(unittest.TestCase): ...@@ -192,6 +241,8 @@ class TestEAGLEServer(unittest.TestCase):
p.join() p.join()
def test_gsm8k(self): def test_gsm8k(self):
server_info = requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace( args = SimpleNamespace(
num_shots=5, num_shots=5,
data_path=None, data_path=None,
...@@ -201,96 +252,25 @@ class TestEAGLEServer(unittest.TestCase): ...@@ -201,96 +252,25 @@ class TestEAGLEServer(unittest.TestCase):
host="http://127.0.0.1", host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]), port=int(self.base_url.split(":")[-1]),
) )
metrics = run_eval(args) metrics = run_eval(args)
print(f"{metrics=}") print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.20) self.assertGreater(metrics["accuracy"], 0.20)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.9)
def measure_acc_rate(engine): # Wait a little bit so that the memory check happens.
tic = time.time() time.sleep(4)
prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.<|separator|>\n\nAssistant:"
]
sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(prompt, sampling_params)
output = output[0]
latency = time.time() - tic
if "spec_verify_ct" in output["meta_info"]:
base_acc_length = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["spec_verify_ct"]
)
else:
base_acc_length = 0.0
base_speed = output["meta_info"]["completion_tokens"] / latency
return base_acc_length, base_speed
class TestEagleAcceptanceRate(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
ref_engine = sgl.Engine(
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
speculative_algorithm="EAGLE",
speculative_num_steps=5,
speculative_eagle_topk=8,
speculative_num_draft_tokens=64,
mem_fraction_static=0.7,
disable_radix_cache=True,
)
cls.base_acc_length, cls.base_speed = measure_acc_rate(ref_engine)
ref_engine.shutdown()
assert cls.base_acc_length > 4.45
def test_acc_rate(self):
base_acc_length, base_speed = self.base_acc_length, self.base_speed
chunk_engine = sgl.Engine(
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
speculative_algorithm="EAGLE",
speculative_num_steps=5,
speculative_eagle_topk=8,
speculative_num_draft_tokens=64,
mem_fraction_static=0.7,
chunked_prefill_size=2,
disable_radix_cache=True,
)
chunked_acc_length, chunked_base_speed = measure_acc_rate(chunk_engine)
chunk_engine.shutdown()
print(base_acc_length, base_speed)
print(chunked_acc_length, chunked_base_speed)
assert abs(base_acc_length - chunked_acc_length) < acc_rate_tolerance
def test_acc_rate_prefix_caching(self):
base_acc_length, base_speed = self.base_acc_length, self.base_speed
prefix_caching_engine = sgl.Engine(
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
speculative_algorithm="EAGLE",
speculative_num_steps=5,
speculative_eagle_topk=8,
speculative_num_draft_tokens=64,
mem_fraction_static=0.7,
chunked_prefill_size=4,
schedule_policy="lpm",
)
for _ in range(10):
acc_length, _ = measure_acc_rate(prefix_caching_engine)
print(f"{acc_length=}")
assert abs(base_acc_length - acc_length) < acc_rate_tolerance
# The second one should hit the prefix cache.
prefix_caching_engine.shutdown()
class TestEAGLERetract(unittest.TestCase): class TestEAGLERetract(TestEAGLEServer):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# These config helps find a leak.
os.environ["SGLANG_CI_SMALL_KV_SIZE"] = "4500"
cls.base_url = DEFAULT_URL_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server( cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
...@@ -302,41 +282,20 @@ class TestEAGLERetract(unittest.TestCase): ...@@ -302,41 +282,20 @@ class TestEAGLERetract(unittest.TestCase):
"--speculative-draft-model-path", "--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps", "--speculative-num-steps",
"5", 5,
"--speculative-eagle-topk", "--speculative-eagle-topk",
"8", 8,
"--speculative-num-draft-tokens", "--speculative-num-draft-tokens",
"64", 64,
"--mem-fraction-static", "--mem-fraction-static",
"0.7", 0.7,
"--chunked-prefill-size", "--chunked-prefill-size",
"128", 128,
"--max-running-requests", "--max-running-requests",
"64", 64,
], ],
) )
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.20)
# Wait a little bit so that the memory check happens.
time.sleep(5)
class TestEAGLEServerTriton(TestEAGLEServer): class TestEAGLEServerTriton(TestEAGLEServer):
@classmethod @classmethod
...@@ -352,73 +311,20 @@ class TestEAGLEServerTriton(TestEAGLEServer): ...@@ -352,73 +311,20 @@ class TestEAGLEServerTriton(TestEAGLEServer):
"--speculative-draft-model-path", "--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps", "--speculative-num-steps",
"5", 5,
"--speculative-eagle-topk", "--speculative-eagle-topk",
"4", 8,
"--speculative-num-draft-tokens", "--speculative-num-draft-tokens",
"8", 64,
"--mem-fraction-static", "--mem-fraction-static",
"0.7", 0.7,
"--attention-backend", "--attention-backend",
"triton", "triton",
"--cuda-graph-max-bs", "--max-running-requests",
"16", 8,
], ],
) )
class TestEAGLEEngineTokenMap(unittest.TestCase):
def setUp(self):
self.prompt = "Today is a sunny day and I like"
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
ref_engine = sgl.Engine(
model_path="meta-llama/Meta-Llama-3-8B-Instruct", cuda_graph_max_bs=2
)
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown()
def test_correctness(self):
config = {
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
"speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 4,
"speculative_num_draft_tokens": 8,
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 4,
"dtype": "bfloat16",
}
engine = sgl.Engine(**config)
try:
self._test_basic_generation(engine)
self._test_batch_generation(engine)
finally:
engine.shutdown()
def _test_basic_generation(self, engine):
output = engine.generate(self.prompt, self.sampling_params)["text"]
print(f"{output=}, {self.ref_output=}")
self.assertEqual(output, self.ref_output)
def _test_batch_generation(self, engine):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
params = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, params)
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -59,6 +59,7 @@ class TestEnableMetrics(unittest.TestCase): ...@@ -59,6 +59,7 @@ class TestEnableMetrics(unittest.TestCase):
"sglang:spec_accept_length", "sglang:spec_accept_length",
"sglang:prompt_tokens_total", "sglang:prompt_tokens_total",
"sglang:generation_tokens_total", "sglang:generation_tokens_total",
"sglang:cached_tokens_total",
"sglang:num_requests_total", "sglang:num_requests_total",
"sglang:time_to_first_token_seconds", "sglang:time_to_first_token_seconds",
"sglang:time_per_output_token_seconds", "sglang:time_per_output_token_seconds",
......
...@@ -94,7 +94,7 @@ class TestEpMoEFP8(unittest.TestCase): ...@@ -94,7 +94,7 @@ class TestEpMoEFP8(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.5 self.assertGreaterEqual(metrics["score"], 0.5)
def test_mgsm_en(self): def test_mgsm_en(self):
args = SimpleNamespace( args = SimpleNamespace(
...@@ -106,7 +106,7 @@ class TestEpMoEFP8(unittest.TestCase): ...@@ -106,7 +106,7 @@ class TestEpMoEFP8(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.8 self.assertGreaterEqual(metrics["score"], 0.8)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment