Unverified Commit fcc2e37f authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Split the __init__ of scheduler as smaller functions. Improve the eagle tests (#4128)

parent 0804dd11
......@@ -482,6 +482,7 @@ class BatchEmbeddingOut:
embeddings: List[List[float]]
# Token counts
prompt_tokens: List[int]
cached_tokens: List[int]
@dataclass
......
......@@ -159,17 +159,6 @@ class Scheduler:
)
self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.decode_mem_cache_buf_multiplier = (
(
self.server_args.speculative_num_draft_tokens
+ (
self.server_args.speculative_eagle_topk
* self.server_args.speculative_num_draft_tokens
)
)
if not self.spec_algorithm.is_none()
else 1
)
# Distributed rank info
self.dp_size = server_args.dp_size
......@@ -208,42 +197,12 @@ class Scheduler:
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
# Init tokenizer
self.model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
self.is_generation = self.model_config.is_generation
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
self.init_tokenizer()
# Check whether overlap can be enabled
if not self.is_generation:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.")
if self.model_config.is_multimodal:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for multimodal models.")
......@@ -307,32 +266,7 @@ class Scheduler:
)
# Init memory pool and cache
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
self.tp_worker.get_memory_pool()
)
if (
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
if self.enable_hierarchical_cache:
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
disable=server_args.disable_radix_cache,
)
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
self.init_memory_pool_and_cache()
# Init running status
self.waiting_queue: List[Req] = []
......@@ -346,25 +280,13 @@ class Scheduler:
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.last_decode_stats_tic = time.time()
self.return_health_check_ct = 0
self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
self.current_stream.synchronize = lambda: None # No-op for CPU
# For metrics only.
# The largest prefill length of a single request
self._largest_prefill_len: int = 0
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
self.last_gen_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
# Session info
# Init session info
self.sessions: Dict[str, Session] = {}
# Init chunked prefill
......@@ -385,11 +307,11 @@ class Scheduler:
else:
self.grammar_backend = None
# Init new token estimation
# Init schedule policy and new token estimation
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.init_new_token_ratio = min(
global_config.default_init_new_token_ratio
* server_args.schedule_conservativeness,
......@@ -428,14 +350,7 @@ class Scheduler:
self.profiler_target_forward_ct: Optional[int] = None
# Init metrics stats
self.stats = SchedulerStats()
if self.enable_metrics:
self.metrics_collector = SchedulerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
},
)
self.init_metrics()
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
......@@ -458,39 +373,104 @@ class Scheduler:
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(ProfileReq, self.profile),
(GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
]
)
def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time()
def init_tokenizer(self):
server_args = self.server_args
while True:
current = time.time()
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if current > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = current
time.sleep(self.watchdog_timeout // 2)
self.model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
self.is_generation = self.model_config.is_generation
# Print batch size and memory pool info to check whether there are de-sync issues.
logger.error(
f"{self.cur_batch.batch_size()=}, "
f"{self.cur_batch.reqs=}, "
f"{self.token_to_kv_pool_allocator.available_size()=}, "
f"{self.tree_cache.evictable_size()=}, "
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
if self.model_config.is_multimodal:
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
)
def init_memory_pool_and_cache(self):
server_args = self.server_args
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
self.tp_worker.get_memory_pool()
)
if (
server_args.chunked_prefill_size is not None
and server_args.disable_radix_cache
):
self.tree_cache = ChunkCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
if self.enable_hierarchical_cache:
self.tree_cache = HiRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
)
else:
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
disable=server_args.disable_radix_cache,
)
self.decode_mem_cache_buf_multiplier = (
1
if self.spec_algorithm.is_none()
else (
server_args.speculative_num_draft_tokens
+ (
server_args.speculative_eagle_topk
* server_args.speculative_num_steps
)
)
)
# Wait for some time so that the parent process can print the error.
pyspy_dump_schedulers()
print(file=sys.stderr, flush=True)
print(file=sys.stdout, flush=True)
time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT)
def init_metrics(self):
# The largest prefill length of a single request
self._largest_prefill_len: int = 0
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
self.last_gen_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
self.cum_spec_accept_length = 0
self.cum_spec_accept_count = 0
self.stats = SchedulerStats()
if self.enable_metrics:
engine_type = "unified"
self.metrics_collector = SchedulerMetricsCollector(
labels={
"model_name": self.server_args.served_model_name,
"engine_type": engine_type,
},
)
@torch.no_grad()
def event_loop_normal(self):
......@@ -1176,6 +1156,7 @@ class Scheduler:
):
self.stop_profile()
# Run forward
if self.is_generation:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
......@@ -1196,6 +1177,7 @@ class Scheduler:
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
batch.output_ids = next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
# we can use the correct values in output processing.
......@@ -1229,7 +1211,6 @@ class Scheduler:
result: Union[GenerationBatchResult, EmbeddingBatchResult],
):
if batch.forward_mode.is_decode():
assert isinstance(result, GenerationBatchResult)
self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
......@@ -1481,6 +1462,7 @@ class Scheduler:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool_allocator.free_group_end()
......@@ -1584,7 +1566,9 @@ class Scheduler:
req.temp_input_token_ids_logprobs_idx
)
for val, idx in zip(
req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx
req.temp_input_top_logprobs_val,
req.temp_input_top_logprobs_idx,
strict=True,
):
req.input_top_logprobs_val.extend(val)
req.input_top_logprobs_idx.extend(idx)
......@@ -1809,14 +1793,18 @@ class Scheduler:
else: # embedding or reward model
embeddings = []
prompt_tokens = []
cached_tokens = []
for req in reqs:
if req.finished():
rids.append(req.rid)
finished_reasons.append(req.finished_reason.to_json())
embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids))
cached_tokens.append(req.cached_tokens)
self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
BatchEmbeddingOut(
rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
)
)
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
......@@ -1902,6 +1890,37 @@ class Scheduler:
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time()
while True:
current = time.time()
if self.cur_batch is not None:
if self.watchdog_last_forward_ct == self.forward_ct:
if current > self.watchdog_last_time + self.watchdog_timeout:
logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
break
else:
self.watchdog_last_forward_ct = self.forward_ct
self.watchdog_last_time = current
time.sleep(self.watchdog_timeout // 2)
# Print batch size and memory pool info to check whether there are de-sync issues.
logger.error(
f"{self.cur_batch.batch_size()=}, "
f"{self.cur_batch.reqs=}, "
f"{self.token_to_kv_pool_allocator.available_size()=}, "
f"{self.tree_cache.evictable_size()=}, "
)
# Wait for some time so that the parent process can print the error.
pyspy_dump_schedulers()
print(file=sys.stderr, flush=True)
print(file=sys.stdout, flush=True)
time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT)
def flush_cache_wrapped(self, recv_req: FlushCacheReq):
self.flush_cache()
......@@ -1913,7 +1932,6 @@ class Scheduler:
self.cur_batch = None
self.last_batch = None
self.tree_cache.reset()
self.tree_cache_metrics = {"total": 0, "hit": 0}
if self.grammar_backend:
self.grammar_backend.reset()
self.req_to_token_pool.clear()
......@@ -2005,6 +2023,9 @@ class Scheduler:
req.to_abort = True
break
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
......
......@@ -1068,6 +1068,7 @@ class TokenizerManager:
self.metrics_collector.observe_one_finished_request(
recv_obj.prompt_tokens[i],
completion_tokens,
recv_obj.cached_tokens[i],
state.finished_time - state.created_time,
)
......
......@@ -121,6 +121,12 @@ class TokenizerMetricsCollector:
labelnames=labels.keys(),
)
self.cached_tokens_total = Counter(
name="sglang:cached_tokens_total",
documentation="Number of cached prompt tokens.",
labelnames=labels.keys(),
)
self.num_requests_total = Counter(
name="sglang:num_requests_total",
documentation="Number of requests processed.",
......@@ -245,10 +251,12 @@ class TokenizerMetricsCollector:
self,
prompt_tokens: int,
generation_tokens: int,
cached_tokens: int,
e2e_latency: float,
):
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
self.num_requests_total.labels(**self.labels).inc(1)
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
if generation_tokens >= 1:
......
import multiprocessing as mp
import os
import random
import threading
import time
import unittest
from types import SimpleNamespace
from typing import List, Optional
import requests
import torch
import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.runners import DEFAULT_PROMPTS, SRTRunner
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
......@@ -19,7 +23,9 @@ from sglang.test.test_utils import (
popen_launch_server,
)
acc_rate_tolerance = 0.15
torch_dtype = torch.float16
prefill_tolerance = 5e-2
decode_tolerance: float = 5e-2
class TestEAGLEEngine(unittest.TestCase):
......@@ -28,51 +34,72 @@ class TestEAGLEEngine(unittest.TestCase):
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 8,
"speculative_num_draft_tokens": 64,
"speculative_eagle_topk": 4,
"speculative_num_draft_tokens": 8,
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 32,
"cuda_graph_max_bs": 5,
}
NUM_CONFIGS = 3
def setUp(self):
self.prompt = "Today is a sunny day and I like"
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
ref_engine = sgl.Engine(
model_path=self.BASE_CONFIG["model_path"], cuda_graph_max_bs=1
)
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown()
def test_correctness(self):
configs = [
# Basic config
self.BASE_CONFIG,
# Disable cuda graph
{**self.BASE_CONFIG, "disable_cuda_graph": True},
{**self.BASE_CONFIG, "chunked_prefill_size": 2},
# Chunked prefill
{**self.BASE_CONFIG, "chunked_prefill_size": 4},
]
for config in configs:
with self.subTest(
cuda_graph=(
"enabled" if len(config) == len(self.BASE_CONFIG) else "disabled"
),
chunked_prefill_size=(
config["chunked_prefill_size"]
if "chunked_prefill_size" in config
else "default"
),
):
engine = sgl.Engine(**config)
for i, config in enumerate(configs[: self.NUM_CONFIGS]):
with self.subTest(i=i):
print(f"{config=}")
engine = sgl.Engine(**config, log_level="info", decode_log_interval=10)
try:
self._test_basic_generation(engine)
self._test_eos_token(engine)
self._test_single_generation(engine)
self._test_batch_generation(engine)
self._test_eos_token(engine)
self._test_acc_length(engine)
finally:
engine.shutdown()
print("=" * 100)
def _test_basic_generation(self, engine):
def _test_single_generation(self, engine):
output = engine.generate(self.prompt, self.sampling_params)["text"]
print(f"{output=}, {self.ref_output=}")
self.assertEqual(output, self.ref_output)
def _test_batch_generation(self, engine):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
params = {"temperature": 0, "max_new_tokens": 50}
outputs = engine.generate(prompts, params)
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)
print(f"{engine.get_server_info()=}")
avg_spec_accept_length = engine.get_server_info()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.9)
def _test_eos_token(self, engine):
prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
params = {
......@@ -88,32 +115,54 @@ class TestEAGLEEngine(unittest.TestCase):
tokens = tokenizer.encode(output, truncation=False)
self.assertNotIn(tokenizer.eos_token_id, tokens)
def _test_batch_generation(self, engine):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
def _test_acc_length(self, engine):
prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
]
params = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, params)
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)
sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(prompt, sampling_params)
output = output[0]
if "spec_verify_ct" in output["meta_info"]:
acc_length = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["spec_verify_ct"]
)
else:
acc_length = 1.0
speed = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["e2e_latency"]
)
print(f"{acc_length=}")
self.assertGreater(acc_length, 3.6)
prompts = [
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwho are you?[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwhere are you from?[/INST]",
]
class TestEAGLEEngineTokenMap(unittest.TestCase):
BASE_CONFIG = {
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
"speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 4,
"speculative_num_draft_tokens": 8,
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 5,
}
NUM_CONFIGS = 1
class TestEAGLEServer(unittest.TestCase):
PROMPTS = [
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like[/INST]"
'[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwho are you?[/INST]",
"[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nwhere are you from?[/INST]",
]
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
......@@ -127,17 +176,17 @@ class TestEAGLEServer(unittest.TestCase):
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"5",
5,
"--speculative-eagle-topk",
"8",
8,
"--speculative-num-draft-tokens",
"64",
64,
"--mem-fraction-static",
"0.7",
0.7,
"--chunked-prefill-size",
"128",
"--cuda-graph-max-bs",
"32",
128,
"--max-running-requests",
8,
],
)
......@@ -147,7 +196,7 @@ class TestEAGLEServer(unittest.TestCase):
def send_request(self):
time.sleep(random.uniform(0, 2))
for prompt in prompts:
for prompt in self.PROMPTS:
url = self.base_url + "/generate"
data = {
"text": prompt,
......@@ -160,7 +209,7 @@ class TestEAGLEServer(unittest.TestCase):
assert response.status_code == 200
def send_requests_abort(self):
for prompt in prompts:
for prompt in self.PROMPTS:
try:
time.sleep(random.uniform(0, 2))
url = self.base_url + "/generate"
......@@ -192,6 +241,8 @@ class TestEAGLEServer(unittest.TestCase):
p.join()
def test_gsm8k(self):
server_info = requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=None,
......@@ -201,96 +252,25 @@ class TestEAGLEServer(unittest.TestCase):
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.20)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.9)
def measure_acc_rate(engine):
tic = time.time()
prompt = [
"Human: Give me a fully functional FastAPI server. Show the python code.<|separator|>\n\nAssistant:"
]
sampling_params = {"temperature": 0, "max_new_tokens": 512}
output = engine.generate(prompt, sampling_params)
output = output[0]
latency = time.time() - tic
if "spec_verify_ct" in output["meta_info"]:
base_acc_length = (
output["meta_info"]["completion_tokens"]
/ output["meta_info"]["spec_verify_ct"]
)
else:
base_acc_length = 0.0
base_speed = output["meta_info"]["completion_tokens"] / latency
return base_acc_length, base_speed
class TestEagleAcceptanceRate(unittest.TestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
ref_engine = sgl.Engine(
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
speculative_algorithm="EAGLE",
speculative_num_steps=5,
speculative_eagle_topk=8,
speculative_num_draft_tokens=64,
mem_fraction_static=0.7,
disable_radix_cache=True,
)
cls.base_acc_length, cls.base_speed = measure_acc_rate(ref_engine)
ref_engine.shutdown()
assert cls.base_acc_length > 4.45
def test_acc_rate(self):
base_acc_length, base_speed = self.base_acc_length, self.base_speed
chunk_engine = sgl.Engine(
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
speculative_algorithm="EAGLE",
speculative_num_steps=5,
speculative_eagle_topk=8,
speculative_num_draft_tokens=64,
mem_fraction_static=0.7,
chunked_prefill_size=2,
disable_radix_cache=True,
)
chunked_acc_length, chunked_base_speed = measure_acc_rate(chunk_engine)
chunk_engine.shutdown()
print(base_acc_length, base_speed)
print(chunked_acc_length, chunked_base_speed)
assert abs(base_acc_length - chunked_acc_length) < acc_rate_tolerance
def test_acc_rate_prefix_caching(self):
base_acc_length, base_speed = self.base_acc_length, self.base_speed
prefix_caching_engine = sgl.Engine(
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
speculative_algorithm="EAGLE",
speculative_num_steps=5,
speculative_eagle_topk=8,
speculative_num_draft_tokens=64,
mem_fraction_static=0.7,
chunked_prefill_size=4,
schedule_policy="lpm",
)
for _ in range(10):
acc_length, _ = measure_acc_rate(prefix_caching_engine)
print(f"{acc_length=}")
assert abs(base_acc_length - acc_length) < acc_rate_tolerance
# The second one should hit the prefix cache.
prefix_caching_engine.shutdown()
# Wait a little bit so that the memory check happens.
time.sleep(4)
class TestEAGLERetract(unittest.TestCase):
class TestEAGLERetract(TestEAGLEServer):
@classmethod
def setUpClass(cls):
# These config helps find a leak.
os.environ["SGLANG_CI_SMALL_KV_SIZE"] = "4500"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
......@@ -302,41 +282,20 @@ class TestEAGLERetract(unittest.TestCase):
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"5",
5,
"--speculative-eagle-topk",
"8",
8,
"--speculative-num-draft-tokens",
"64",
64,
"--mem-fraction-static",
"0.7",
0.7,
"--chunked-prefill-size",
"128",
128,
"--max-running-requests",
"64",
64,
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.20)
# Wait a little bit so that the memory check happens.
time.sleep(5)
class TestEAGLEServerTriton(TestEAGLEServer):
@classmethod
......@@ -352,73 +311,20 @@ class TestEAGLEServerTriton(TestEAGLEServer):
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"5",
5,
"--speculative-eagle-topk",
"4",
8,
"--speculative-num-draft-tokens",
"8",
64,
"--mem-fraction-static",
"0.7",
0.7,
"--attention-backend",
"triton",
"--cuda-graph-max-bs",
"16",
"--max-running-requests",
8,
],
)
class TestEAGLEEngineTokenMap(unittest.TestCase):
def setUp(self):
self.prompt = "Today is a sunny day and I like"
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
ref_engine = sgl.Engine(
model_path="meta-llama/Meta-Llama-3-8B-Instruct", cuda_graph_max_bs=2
)
self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)["text"]
ref_engine.shutdown()
def test_correctness(self):
config = {
"model_path": "meta-llama/Meta-Llama-3-8B-Instruct",
"speculative_draft_model_path": "lmsys/sglang-EAGLE-LLaMA3-Instruct-8B",
"speculative_algorithm": "EAGLE",
"speculative_num_steps": 5,
"speculative_eagle_topk": 4,
"speculative_num_draft_tokens": 8,
"speculative_token_map": "thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt",
"mem_fraction_static": 0.7,
"cuda_graph_max_bs": 4,
"dtype": "bfloat16",
}
engine = sgl.Engine(**config)
try:
self._test_basic_generation(engine)
self._test_batch_generation(engine)
finally:
engine.shutdown()
def _test_basic_generation(self, engine):
output = engine.generate(self.prompt, self.sampling_params)["text"]
print(f"{output=}, {self.ref_output=}")
self.assertEqual(output, self.ref_output)
def _test_batch_generation(self, engine):
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
params = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, params)
for prompt, output in zip(prompts, outputs):
print(f"Prompt: {prompt}")
print(f"Generated: {output['text']}")
print("-" * 40)
if __name__ == "__main__":
unittest.main()
......@@ -59,6 +59,7 @@ class TestEnableMetrics(unittest.TestCase):
"sglang:spec_accept_length",
"sglang:prompt_tokens_total",
"sglang:generation_tokens_total",
"sglang:cached_tokens_total",
"sglang:num_requests_total",
"sglang:time_to_first_token_seconds",
"sglang:time_per_output_token_seconds",
......
......@@ -94,7 +94,7 @@ class TestEpMoEFP8(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.5
self.assertGreaterEqual(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
......@@ -106,7 +106,7 @@ class TestEpMoEFP8(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.8
self.assertGreaterEqual(metrics["score"], 0.8)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment