Unverified Commit d1a08632 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Add a test case for cached_tokens (#3145)

parent f8b28e46
...@@ -19,16 +19,16 @@ ...@@ -19,16 +19,16 @@
| [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) | | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) |
## News ## News
- [2024/12] 🔥 SGLang v0.4: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)). - [2025/01] 🔥 SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeekSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html))
- [2024/10] 🔥 The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)). - [2024/12] 🔥 v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)).
- [2024/09] SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). - [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)).
- [2024/07] Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)). - [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
<details> <details>
<summary>More</summary> <summary>More</summary>
- [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)).
- [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)). - [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)).
- [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)). - [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).
- [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)). - [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).
......
...@@ -331,6 +331,7 @@ class Req: ...@@ -331,6 +331,7 @@ class Req:
# The number of cached tokens, that were already cached in the KV cache # The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0 self.cached_tokens = 0
self.already_computed = 0
def extend_image_inputs(self, image_inputs): def extend_image_inputs(self, image_inputs):
if self.image_inputs is None: if self.image_inputs is None:
...@@ -750,13 +751,6 @@ class ScheduleBatch: ...@@ -750,13 +751,6 @@ class ScheduleBatch:
pt = 0 pt = 0
for i, req in enumerate(reqs): for i, req in enumerate(reqs):
already_computed = (
req.extend_logprob_start_len + 1 + req.cached_tokens
if req.extend_logprob_start_len > 0
else 0
)
req.cached_tokens += len(req.prefix_indices) - already_computed
req.req_pool_idx = req_pool_indices[i] req.req_pool_idx = req_pool_indices[i]
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
seq_lens.append(seq_len) seq_lens.append(seq_len)
...@@ -772,15 +766,20 @@ class ScheduleBatch: ...@@ -772,15 +766,20 @@ class ScheduleBatch:
# If req.input_embeds is already a list, append its content directly # If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
# Compute the relative logprob_start_len in an extend batch if req.return_logprob:
if req.logprob_start_len >= pre_len: # Compute the relative logprob_start_len in an extend batch
extend_logprob_start_len = min( if req.logprob_start_len >= pre_len:
req.logprob_start_len - pre_len, req.extend_input_len - 1 extend_logprob_start_len = min(
) req.logprob_start_len - pre_len, req.extend_input_len - 1
else: )
extend_logprob_start_len = req.extend_input_len - 1 else:
raise RuntimeError(
f"This should never happen. {req.logprob_start_len=}, {pre_len=}"
)
req.extend_logprob_start_len = extend_logprob_start_len
req.extend_logprob_start_len = extend_logprob_start_len req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len
req.is_retracted = False req.is_retracted = False
pre_lens.append(pre_len) pre_lens.append(pre_len)
......
...@@ -660,24 +660,23 @@ class Scheduler: ...@@ -660,24 +660,23 @@ class Scheduler:
self.waiting_queue.append(req) self.waiting_queue.append(req)
return return
# Copy more attributes
req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len == -1:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1
# Validate prompts length # Validate prompts length
error_msg = validate_input_length( error_msg = validate_input_length(
req, req,
self.max_req_input_len, self.max_req_input_len,
self.server_args.allow_auto_truncate, self.server_args.allow_auto_truncate,
) )
if error_msg: if error_msg:
self.waiting_queue.append(req) self.waiting_queue.append(req)
return return
# Copy more attributes
if recv_req.logprob_start_len == -1:
# By default, only return the logprobs for output tokens
req.logprob_start_len = len(req.origin_input_ids) - 1
else:
req.logprob_start_len = recv_req.logprob_start_len
req.sampling_params.max_new_tokens = min( req.sampling_params.max_new_tokens = min(
( (
req.sampling_params.max_new_tokens req.sampling_params.max_new_tokens
...@@ -725,12 +724,17 @@ class Scheduler: ...@@ -725,12 +724,17 @@ class Scheduler:
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
# Validate prompts length # Validate prompts length
validate_input_length( error_msg = validate_input_length(
req, req,
self.max_req_input_len, self.max_req_input_len,
self.server_args.allow_auto_truncate, self.server_args.allow_auto_truncate,
) )
if error_msg:
self.waiting_queue.append(req)
return
# Copy more attributes
req.logprob_start_len = len(req.origin_input_ids) - 1
self.waiting_queue.append(req) self.waiting_queue.append(req)
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked): def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
...@@ -1044,26 +1048,23 @@ class Scheduler: ...@@ -1044,26 +1048,23 @@ class Scheduler:
self.forward_ct += 1 self.forward_ct += 1
if self.is_generation: if self.is_generation:
if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0: if self.spec_algorithm.is_none():
if self.spec_algorithm.is_none(): model_worker_batch = batch.get_model_worker_batch()
model_worker_batch = batch.get_model_worker_batch() logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
logits_output, next_token_ids = ( model_worker_batch
self.tp_worker.forward_batch_generation(model_worker_batch) )
)
else:
(
logits_output,
next_token_ids,
model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
else: else:
assert False, "batch.extend_num_tokens == 0, this is unexpected!" (
logits_output,
next_token_ids,
model_worker_batch,
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size()
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
batch.output_ids = next_token_ids batch.output_ids = next_token_ids
ret = GenerationBatchResult( ret = GenerationBatchResult(
...@@ -1072,7 +1073,6 @@ class Scheduler: ...@@ -1072,7 +1073,6 @@ class Scheduler:
bid=model_worker_batch.bid, bid=model_worker_batch.bid,
) )
else: # embedding or reward model else: # embedding or reward model
assert batch.extend_num_tokens != 0
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch)
ret = EmbeddingBatchResult( ret = EmbeddingBatchResult(
......
...@@ -18,7 +18,6 @@ suites = { ...@@ -18,7 +18,6 @@ suites = {
"test_eagle_infer.py", "test_eagle_infer.py",
"test_embedding_openai_server.py", "test_embedding_openai_server.py",
"test_eval_accuracy_mini.py", "test_eval_accuracy_mini.py",
"test_get_weights_by_name.py",
"test_gguf.py", "test_gguf.py",
"test_input_embeddings.py", "test_input_embeddings.py",
"test_json_constrained.py", "test_json_constrained.py",
......
...@@ -236,12 +236,5 @@ class TestEBNFConstrained(unittest.TestCase): ...@@ -236,12 +236,5 @@ class TestEBNFConstrained(unittest.TestCase):
) )
class TestJumpForward(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, disable_overlap=True)
cls.check_jump_forward = True
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -5,6 +5,7 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_ ...@@ -5,6 +5,7 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
import json import json
import random import random
import time
import unittest import unittest
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Optional from typing import Optional
...@@ -317,12 +318,6 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -317,12 +318,6 @@ class TestSRTEndpoint(unittest.TestCase):
"""Test custom logit processor with a single request.""" """Test custom logit processor with a single request."""
self.run_custom_logit_processor(target_token_id=5) self.run_custom_logit_processor(target_token_id=5)
def test_custom_logit_processor_batch(self):
"""Test custom logit processor with a batch of requests."""
target_token_ids = list(range(32))
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_custom_logit_processor_batch_mixed(self): def test_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor.""" """Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16 target_token_ids = list(range(32)) + [None] * 16
...@@ -330,6 +325,31 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -330,6 +325,31 @@ class TestSRTEndpoint(unittest.TestCase):
with ThreadPoolExecutor(len(target_token_ids)) as executor: with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids)) list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_cache_tokens(self):
for _ in range(2):
time.sleep(1)
response = requests.post(self.base_url + "/flush_cache")
assert response.status_code == 200
def send_and_check_cached_tokens(input_ids):
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": list(input_ids),
"sampling_params": {
"max_new_tokens": 1,
},
},
)
response_json = response.json()
return response_json["meta_info"]["cached_tokens"]
self.assertEqual(send_and_check_cached_tokens(range(0, 100)), 0)
self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 100)
self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 9999)
self.assertEqual(send_and_check_cached_tokens(range(0, 1000)), 999)
self.assertEqual(send_and_check_cached_tokens(range(0, 11000)), 10000)
def test_get_server_info(self): def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info") response = requests.get(self.base_url + "/get_server_info")
response_json = response.json() response_json = response.json()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment