Unverified Commit 5652c565 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Update CI threshold & Improve code style (#2159)

parent e3938b2f
...@@ -50,7 +50,7 @@ jobs: ...@@ -50,7 +50,7 @@ jobs:
timeout-minutes: 25 timeout-minutes: 25
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 0 --range-end 5 python3 run_suite.py --suite minimal --range-begin 0 --range-end 6
unit-test-backend-part-2: unit-test-backend-part-2:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -67,7 +67,7 @@ jobs: ...@@ -67,7 +67,7 @@ jobs:
timeout-minutes: 25 timeout-minutes: 25
run: | run: |
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 5 --range-end 14 python3 run_suite.py --suite minimal --range-begin 6 --range-end 14
unit-test-backend-part-3: unit-test-backend-part-3:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -103,6 +103,31 @@ jobs: ...@@ -103,6 +103,31 @@ jobs:
cd test/srt cd test/srt
python3 run_suite.py --suite minimal --range-begin 21 python3 run_suite.py --suite minimal --range-begin 21
unit-test-backend-2-gpu-part-1:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 2-gpu-runner
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Install dependencies
run: |
bash scripts/ci_install_dependency.sh
- name: Evaluate data parallelism accuracy (DP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 test_data_parallelism.py
- name: Evaluate MLA accuracy (TP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 test_mla.py
python3 test_mla_fp8.py
python3 test_dp_attention.py
performance-test-1-gpu-part-1: performance-test-1-gpu-part-1:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 1-gpu-runner runs-on: 1-gpu-runner
...@@ -178,23 +203,23 @@ jobs: ...@@ -178,23 +203,23 @@ jobs:
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
- name: Benchmark offline throughput (TP=2) - name: Benchmark single latency (TP=2)
timeout-minutes: 10 timeout-minutes: 10
run: | run: |
cd test/srt cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_default python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default
- name: Benchmark offline throughput (w/o RadixAttention) (TP=2) - name: Benchmark offline throughput (TP=2)
timeout-minutes: 10 timeout-minutes: 10
run: | run: |
cd test/srt cd test/srt
python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_default
- name: Benchmark single latency (TP=2) - name: Benchmark offline throughput (w/o RadixAttention) (TP=2)
timeout-minutes: 10 timeout-minutes: 10
run: | run: |
cd test/srt cd test/srt
python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache
accuracy-test-1-gpu: accuracy-test-1-gpu:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
...@@ -238,23 +263,10 @@ jobs: ...@@ -238,23 +263,10 @@ jobs:
cd test/srt cd test/srt
python3 test_moe_eval_accuracy_large.py python3 test_moe_eval_accuracy_large.py
- name: Evaluate MLA accuracy (TP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 test_mla.py
python3 test_mla_fp8.py
python3 test_dp_attention.py
- name: Evaluate data parallelism accuracy (DP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 test_data_parallelism.py
finish: finish:
needs: [ needs: [
unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, unit-test-backend-part-3, unit-test-backend-part-4, unit-test-frontend, unit-test-backend-part-1, unit-test-backend-part-2, unit-test-backend-part-3, unit-test-backend-part-4,
unit-test-backend-2-gpu-part-1,
performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu, performance-test-1-gpu-part-1, performance-test-1-gpu-part-2, performance-test-2-gpu,
accuracy-test-1-gpu, accuracy-test-2-gpu accuracy-test-1-gpu, accuracy-test-2-gpu
] ]
......
...@@ -212,6 +212,7 @@ def extend(reqs, model_runner): ...@@ -212,6 +212,7 @@ def extend(reqs, model_runner):
token_to_kv_pool=model_runner.token_to_kv_pool, token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None, tree_cache=None,
model_config=model_runner.model_config, model_config=model_runner.model_config,
enable_overlap=False,
) )
batch.prepare_for_extend() batch.prepare_for_extend()
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
......
"""
Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
"""
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
......
...@@ -437,9 +437,12 @@ class ScheduleBatch: ...@@ -437,9 +437,12 @@ class ScheduleBatch:
token_to_kv_pool: BaseTokenToKVPool = None token_to_kv_pool: BaseTokenToKVPool = None
tree_cache: BasePrefixCache = None tree_cache: BasePrefixCache = None
# For utility # Batch configs
model_config: ModelConfig = None model_config: ModelConfig = None
forward_mode: ForwardMode = None forward_mode: ForwardMode = None
enable_overlap: bool = False
# Sampling info
sampling_info: SamplingBatchInfo = None sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None next_batch_sampling_info: SamplingBatchInfo = None
...@@ -488,10 +491,11 @@ class ScheduleBatch: ...@@ -488,10 +491,11 @@ class ScheduleBatch:
def init_new( def init_new(
cls, cls,
reqs: List[Req], reqs: List[Req],
req_to_token_pool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool, token_to_kv_pool: ReqToTokenPool,
tree_cache, tree_cache: BasePrefixCache,
model_config, model_config: ModelConfig,
enable_overlap: bool,
): ):
return cls( return cls(
reqs=reqs, reqs=reqs,
...@@ -499,6 +503,7 @@ class ScheduleBatch: ...@@ -499,6 +503,7 @@ class ScheduleBatch:
token_to_kv_pool=token_to_kv_pool, token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache, tree_cache=tree_cache,
model_config=model_config, model_config=model_config,
enable_overlap=enable_overlap,
return_logprob=any(req.return_logprob for req in reqs), return_logprob=any(req.return_logprob for req in reqs),
has_stream=any(req.stream for req in reqs), has_stream=any(req.stream for req in reqs),
has_grammar=any(req.grammar for req in reqs), has_grammar=any(req.grammar for req in reqs),
...@@ -612,7 +617,7 @@ class ScheduleBatch: ...@@ -612,7 +617,7 @@ class ScheduleBatch:
assert len(self.out_cache_loc) == self.extend_num_tokens assert len(self.out_cache_loc) == self.extend_num_tokens
def prepare_for_extend(self, enable_overlap_schedule: bool = False): def prepare_for_extend(self):
self.forward_mode = ForwardMode.EXTEND self.forward_mode = ForwardMode.EXTEND
bs = len(self.reqs) bs = len(self.reqs)
...@@ -706,7 +711,7 @@ class ScheduleBatch: ...@@ -706,7 +711,7 @@ class ScheduleBatch:
self.sampling_info = SamplingBatchInfo.from_schedule_batch( self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self, self,
self.model_config.vocab_size, self.model_config.vocab_size,
enable_overlap_schedule=enable_overlap_schedule, enable_overlap_schedule=self.enable_overlap,
) )
def mix_with_running(self, running_batch: "ScheduleBatch"): def mix_with_running(self, running_batch: "ScheduleBatch"):
...@@ -897,7 +902,7 @@ class ScheduleBatch: ...@@ -897,7 +902,7 @@ class ScheduleBatch:
self.seq_lens_sum = 0 self.seq_lens_sum = 0
self.extend_num_tokens = 0 self.extend_num_tokens = 0
def prepare_for_decode(self, enable_overlap: bool = False): def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE self.forward_mode = ForwardMode.DECODE
self.input_ids = self.output_ids self.input_ids = self.output_ids
...@@ -914,7 +919,7 @@ class ScheduleBatch: ...@@ -914,7 +919,7 @@ class ScheduleBatch:
else: else:
locs = self.seq_lens locs = self.seq_lens
if enable_overlap: if self.enable_overlap:
# Do not use in-place operations in the overlap mode # Do not use in-place operations in the overlap mode
self.req_to_token_pool.write( self.req_to_token_pool.write(
(self.req_pool_indices, locs), self.out_cache_loc (self.req_pool_indices, locs), self.out_cache_loc
......
...@@ -466,6 +466,7 @@ class Scheduler: ...@@ -466,6 +466,7 @@ class Scheduler:
self.token_to_kv_pool, self.token_to_kv_pool,
self.tree_cache, self.tree_cache,
self.model_config, self.model_config,
self.enable_overlap,
) )
idle_batch.prepare_for_idle() idle_batch.prepare_for_idle()
return idle_batch return idle_batch
...@@ -842,14 +843,15 @@ class Scheduler: ...@@ -842,14 +843,15 @@ class Scheduler:
self.token_to_kv_pool, self.token_to_kv_pool,
self.tree_cache, self.tree_cache,
self.model_config, self.model_config,
self.enable_overlap,
) )
new_batch.prepare_for_extend(self.enable_overlap) new_batch.prepare_for_extend()
# Mixed-style chunked prefill # Mixed-style chunked prefill
if self.is_mixed_chunk and self.running_batch is not None: if self.is_mixed_chunk and self.running_batch is not None:
self.running_batch.filter_batch() self.running_batch.filter_batch()
if not self.running_batch.is_empty(): if not self.running_batch.is_empty():
self.running_batch.prepare_for_decode(self.enable_overlap) self.running_batch.prepare_for_decode()
new_batch.mix_with_running(self.running_batch) new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None self.running_batch = None
...@@ -900,7 +902,7 @@ class Scheduler: ...@@ -900,7 +902,7 @@ class Scheduler:
self.batch_is_full = False self.batch_is_full = False
# Update batch tensors # Update batch tensors
batch.prepare_for_decode(self.enable_overlap) batch.prepare_for_decode()
return batch return batch
def run_batch(self, batch: ScheduleBatch): def run_batch(self, batch: ScheduleBatch):
...@@ -1055,6 +1057,7 @@ class Scheduler: ...@@ -1055,6 +1057,7 @@ class Scheduler:
continue continue
if self.enable_overlap and req.finished(): if self.enable_overlap and req.finished():
# Free the one delayed token
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue continue
......
...@@ -23,7 +23,7 @@ import torch ...@@ -23,7 +23,7 @@ import torch
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.fused_moe.patch import fused_moe_forward_native from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
from sglang.srt.layers.logits_processor import ( from sglang.srt.layers.logits_processor import (
LogitsMetadata, LogitsMetadata,
LogitsProcessor, LogitsProcessor,
......
...@@ -20,7 +20,7 @@ class TestBenchServing(unittest.TestCase): ...@@ -20,7 +20,7 @@ class TestBenchServing(unittest.TestCase):
) )
if is_in_ci(): if is_in_ci():
self.assertGreater(res["output_throughput"], 2850) self.assertGreater(res["output_throughput"], 3350)
def test_offline_throughput_non_stream_small_batch_size(self): def test_offline_throughput_non_stream_small_batch_size(self):
res = run_bench_serving( res = run_bench_serving(
...@@ -47,7 +47,7 @@ class TestBenchServing(unittest.TestCase): ...@@ -47,7 +47,7 @@ class TestBenchServing(unittest.TestCase):
) )
if is_in_ci(): if is_in_ci():
self.assertGreater(res["output_throughput"], 2900) self.assertGreater(res["output_throughput"], 3350)
def test_offline_throughput_without_chunked_prefill(self): def test_offline_throughput_without_chunked_prefill(self):
res = run_bench_serving( res = run_bench_serving(
...@@ -74,7 +74,7 @@ class TestBenchServing(unittest.TestCase): ...@@ -74,7 +74,7 @@ class TestBenchServing(unittest.TestCase):
) )
if is_in_ci(): if is_in_ci():
self.assertGreater(res["output_throughput"], 2950) self.assertGreater(res["output_throughput"], 3450)
def test_offline_throughput_default_fp8(self): def test_offline_throughput_default_fp8(self):
res = run_bench_serving( res = run_bench_serving(
...@@ -85,7 +85,7 @@ class TestBenchServing(unittest.TestCase): ...@@ -85,7 +85,7 @@ class TestBenchServing(unittest.TestCase):
) )
if is_in_ci(): if is_in_ci():
self.assertGreater(res["output_throughput"], 3200) self.assertGreater(res["output_throughput"], 3850)
def test_online_latency_default(self): def test_online_latency_default(self):
res = run_bench_serving( res = run_bench_serving(
...@@ -109,7 +109,7 @@ class TestBenchServing(unittest.TestCase): ...@@ -109,7 +109,7 @@ class TestBenchServing(unittest.TestCase):
) )
if is_in_ci(): if is_in_ci():
self.assertGreater(res["output_throughput"], 1900) self.assertGreater(res["output_throughput"], 2150)
def test_moe_offline_throughput_without_radix_cache(self): def test_moe_offline_throughput_without_radix_cache(self):
res = run_bench_serving( res = run_bench_serving(
...@@ -120,7 +120,7 @@ class TestBenchServing(unittest.TestCase): ...@@ -120,7 +120,7 @@ class TestBenchServing(unittest.TestCase):
) )
if is_in_ci(): if is_in_ci():
self.assertGreater(res["output_throughput"], 1950) self.assertGreater(res["output_throughput"], 2150)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -6,6 +6,7 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_ ...@@ -6,6 +6,7 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
import json import json
import unittest import unittest
import numpy as np
import requests import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
...@@ -132,6 +133,7 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -132,6 +133,7 @@ class TestSRTEndpoint(unittest.TestCase):
) )
def test_logprob_with_chunked_prefill(self): def test_logprob_with_chunked_prefill(self):
"""Test a long prompt that requests output logprobs will not hit OOM."""
new_tokens = 4 new_tokens = 4
prompts = "I have a very good idea on this. " * 8000 prompts = "I have a very good idea on this. " * 8000
...@@ -154,6 +156,63 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -154,6 +156,63 @@ class TestSRTEndpoint(unittest.TestCase):
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
def test_logprob_match(self):
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
def run_generate(
prompt, return_logprob=False, max_new_tokens=512, logprob_start_len=-1
):
if isinstance(prompt, str):
prompt_kwargs = {"text": prompt}
else:
prompt_kwargs = {"input_ids": prompt}
response = requests.post(
self.base_url + "/generate",
json={
**prompt_kwargs,
"sampling_params": {
"temperature": 1.0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
"return_logprob": return_logprob,
"return_text_in_logprobs": True,
"logprob_start_len": logprob_start_len,
},
)
return response.json()
prompt = "I have a very good idea on how to"
gen = run_generate(prompt, return_logprob=True, logprob_start_len=0)
output_logprobs = np.array(
[x[0] for x in gen["meta_info"]["output_token_logprobs"]]
)
num_prompts_tokens = gen["meta_info"]["prompt_tokens"]
input_tokens = [x[1] for x in gen["meta_info"]["input_token_logprobs"]]
output_tokens = [x[1] for x in gen["meta_info"]["output_token_logprobs"]]
new_prompt = input_tokens + output_tokens
score = run_generate(
new_prompt, return_logprob=True, logprob_start_len=0, max_new_tokens=0
)
output_logprobs_score = np.array(
[
x[0]
for x in score["meta_info"]["input_token_logprobs"][num_prompts_tokens:]
]
)
print(f"{output_logprobs[-10:]=}")
print(f"{output_logprobs_score[-10:]=}")
diff = np.abs(output_logprobs - output_logprobs_score)
max_diff = np.max(diff)
self.assertLess(max_diff, 0.2)
def test_get_server_info(self): def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info") response = requests.get(self.base_url + "/get_server_info")
response_json = response.json() response_json = response.json()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment