Unverified Commit 4ede6770 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix retract for page size > 1 (#4914)

parent b26bc86b
...@@ -87,53 +87,11 @@ jobs: ...@@ -87,53 +87,11 @@ jobs:
run: | run: |
bash scripts/ci_install_dependency.sh bash scripts/ci_install_dependency.sh
- name: Test data parallelism (DP=2) - name: Run test
timeout-minutes: 10
run: |
cd test/srt
python3 test_data_parallelism.py
- name: Test data parallelism attention (DP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 test_dp_attention.py
- name: Test update weights from distributed
timeout-minutes: 10
run: |
cd test/srt
python3 test_update_weights_from_distributed.py
- name: Test VerlEngine
timeout-minutes: 10
run: |
cd test/srt
python3 test_verl_engine.py
- name: Test Patch Torch
timeout-minutes: 10
run: |
cd test/srt
python3 test_patch_torch.py
- name: Test expert parallelism (EP=2)
timeout-minutes: 10
run: |
cd test/srt
python3 test_moe_ep.py
- name: Test torch compile (TP=2)
timeout-minutes: 10 timeout-minutes: 10
run: | run: |
cd test/srt cd test/srt
python3 test_mla_tp.py python3 run_suite.py --suite per-commit-2-gpu
- name: Test lora tensor parallelism (TP=2)
timeout-minutes: 10
run: |
cd test/srt/models/lora
python3 test_lora_tp.py
performance-test-1-gpu-part-1: performance-test-1-gpu-part-1:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
......
...@@ -169,7 +169,9 @@ class BaseGrammarBackend(ABC): ...@@ -169,7 +169,9 @@ class BaseGrammarBackend(ABC):
self.cache.clear() self.cache.clear()
def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size): def create_grammar_backend(
server_args: ServerArgs, tokenizer, vocab_size: int
) -> Optional[BaseGrammarBackend]:
if server_args.grammar_backend == "outlines": if server_args.grammar_backend == "outlines":
from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
...@@ -188,6 +190,8 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size): ...@@ -188,6 +190,8 @@ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
tokenizer=tokenizer, tokenizer=tokenizer,
whitespace_pattern=server_args.constrained_json_whitespace_pattern, whitespace_pattern=server_args.constrained_json_whitespace_pattern,
) )
elif server_args.grammar_backend == "none":
return None
else: else:
raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}") raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
......
...@@ -599,6 +599,7 @@ class Req: ...@@ -599,6 +599,7 @@ class Req:
self.extend_logprob_start_len = 0 self.extend_logprob_start_len = 0
self.is_chunked = 0 self.is_chunked = 0
self.req_pool_idx = None self.req_pool_idx = None
self.already_computed = 0
def __repr__(self): def __repr__(self):
return ( return (
...@@ -960,8 +961,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -960,8 +961,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# If req.input_embeds is already a list, append its content directly # If req.input_embeds is already a list, append its content directly
input_embeds.extend(req.input_embeds) # Use extend to avoid nesting input_embeds.extend(req.input_embeds) # Use extend to avoid nesting
if req.is_retracted:
req.already_computed = 0
req.cached_tokens += pre_len - req.already_computed req.cached_tokens += pre_len - req.already_computed
req.already_computed = seq_len req.already_computed = seq_len
req.is_retracted = False req.is_retracted = False
...@@ -1189,7 +1188,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1189,7 +1188,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
else: else:
# TODO: apply more fine-grained retraction # TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices) last_uncached_pos = (
(len(req.prefix_indices) + server_args.page_size - 1)
// server_args.page_size
* server_args.page_size
)
token_indices = self.req_to_token_pool.req_to_token[ token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx] req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
] ]
......
...@@ -33,7 +33,7 @@ class SchedulerMetricsCollector: ...@@ -33,7 +33,7 @@ class SchedulerMetricsCollector:
def __init__(self, labels: Dict[str, str]) -> None: def __init__(self, labels: Dict[str, str]) -> None:
# We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR`
from prometheus_client import Gauge from prometheus_client import Gauge, Histogram
self.labels = labels self.labels = labels
self.last_log_time = time.time() self.last_log_time = time.time()
...@@ -139,10 +139,10 @@ class TokenizerMetricsCollector: ...@@ -139,10 +139,10 @@ class TokenizerMetricsCollector:
labelnames=labels.keys(), labelnames=labels.keys(),
buckets=[ buckets=[
0.1, 0.1,
0.3, 0.2,
0.5, 0.4,
0.7, 0.6,
0.9, 0.8,
1, 1,
2, 2,
4, 4,
...@@ -153,36 +153,9 @@ class TokenizerMetricsCollector: ...@@ -153,36 +153,9 @@ class TokenizerMetricsCollector:
40, 40,
60, 60,
80, 80,
120, 100,
160, 200,
], 400,
)
self.histogram_time_per_output_token = Histogram(
name="sglang:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labels.keys(),
buckets=[
0.002,
0.005,
0.010,
0.020,
0.030,
0.040,
0.050,
0.060,
0.070,
0.080,
0.090,
0.100,
0.150,
0.200,
0.300,
0.400,
0.600,
0.800,
1.000,
2.000,
], ],
) )
...@@ -202,17 +175,18 @@ class TokenizerMetricsCollector: ...@@ -202,17 +175,18 @@ class TokenizerMetricsCollector:
0.030, 0.030,
0.035, 0.035,
0.040, 0.040,
0.050, 0.060,
0.075, 0.080,
0.100, 0.100,
0.150,
0.200, 0.200,
0.300,
0.400, 0.400,
0.500, 0.600,
0.750, 0.800,
1.000, 1.000,
2.000, 2.000,
4.000,
6.000,
8.000,
], ],
) )
...@@ -224,23 +198,22 @@ class TokenizerMetricsCollector: ...@@ -224,23 +198,22 @@ class TokenizerMetricsCollector:
0.1, 0.1,
0.2, 0.2,
0.4, 0.4,
0.6,
0.8, 0.8,
1, 1,
2, 2,
5, 4,
6,
8,
10, 10,
20, 20,
40, 40,
60, 60,
80, 80,
100, 100,
150,
200, 200,
250, 400,
300, 800,
350,
500,
1000,
], ],
) )
...@@ -256,13 +229,10 @@ class TokenizerMetricsCollector: ...@@ -256,13 +229,10 @@ class TokenizerMetricsCollector:
): ):
self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens) if cached_tokens > 0:
self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
self.num_requests_total.labels(**self.labels).inc(1) self.num_requests_total.labels(**self.labels).inc(1)
self._log_histogram(self.histogram_e2e_request_latency, e2e_latency) self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
if generation_tokens >= 1:
self.histogram_time_per_output_token.labels(**self.labels).observe(
e2e_latency / generation_tokens
)
def observe_time_to_first_token(self, value: float): def observe_time_to_first_token(self, value: float):
self.histogram_time_to_first_token.labels(**self.labels).observe(value) self.histogram_time_to_first_token.labels(**self.labels).observe(value)
......
...@@ -128,7 +128,7 @@ class ServerArgs: ...@@ -128,7 +128,7 @@ class ServerArgs:
# Kernel backend # Kernel backend
attention_backend: Optional[str] = None attention_backend: Optional[str] = None
sampling_backend: Optional[str] = None sampling_backend: Optional[str] = None
grammar_backend: Optional[str] = "xgrammar" grammar_backend: Optional[str] = None
# Speculative decoding # Speculative decoding
speculative_algorithm: Optional[str] = None speculative_algorithm: Optional[str] = None
...@@ -193,6 +193,13 @@ class ServerArgs: ...@@ -193,6 +193,13 @@ class ServerArgs:
disaggregation_bootstrap_port: int = 8998 disaggregation_bootstrap_port: int = 8998
def __post_init__(self): def __post_init__(self):
# Expert parallelism
if self.enable_ep_moe:
self.ep_size = self.tp_size
logger.info(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
# Set missing default values # Set missing default values
if self.tokenizer_path is None: if self.tokenizer_path is None:
self.tokenizer_path = self.model_path self.tokenizer_path = self.model_path
...@@ -274,12 +281,9 @@ class ServerArgs: ...@@ -274,12 +281,9 @@ class ServerArgs:
) )
self.disable_cuda_graph = True self.disable_cuda_graph = True
# Expert parallelism # Choose grammar backend
if self.enable_ep_moe: if self.grammar_backend is None:
self.ep_size = self.tp_size self.grammar_backend = "xgrammar"
logger.info(
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
# Data parallelism attention # Data parallelism attention
if self.enable_dp_attention: if self.enable_dp_attention:
...@@ -813,7 +817,7 @@ class ServerArgs: ...@@ -813,7 +817,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--grammar-backend", "--grammar-backend",
type=str, type=str,
choices=["xgrammar", "outlines", "llguidance"], choices=["xgrammar", "outlines", "llguidance", "none"],
default=ServerArgs.grammar_backend, default=ServerArgs.grammar_backend,
help="Choose the backend for grammar-guided decoding.", help="Choose the backend for grammar-guided decoding.",
) )
......
...@@ -1012,9 +1012,6 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple): ...@@ -1012,9 +1012,6 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
class CustomTestCase(unittest.TestCase): class CustomTestCase(unittest.TestCase):
pass
"""
def _callTestMethod(self, method): def _callTestMethod(self, method):
max_retry = int( max_retry = int(
os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0") os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0")
...@@ -1023,4 +1020,3 @@ class CustomTestCase(unittest.TestCase): ...@@ -1023,4 +1020,3 @@ class CustomTestCase(unittest.TestCase):
lambda: super(CustomTestCase, self)._callTestMethod(method), lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=max_retry, max_retry=max_retry,
) )
"""
...@@ -33,6 +33,9 @@ CI_LORA_MODELS = [ ...@@ -33,6 +33,9 @@ CI_LORA_MODELS = [
], ],
max_loras_per_batch=1, max_loras_per_batch=1,
), ),
]
ALL_OTHER_LORA_MODELS = [
LoRAModelCase( LoRAModelCase(
base="meta-llama/Llama-3.1-8B-Instruct", base="meta-llama/Llama-3.1-8B-Instruct",
adaptors=[ adaptors=[
...@@ -43,9 +46,6 @@ CI_LORA_MODELS = [ ...@@ -43,9 +46,6 @@ CI_LORA_MODELS = [
], ],
max_loras_per_batch=1, max_loras_per_batch=1,
), ),
]
ALL_OTHER_LORA_MODELS = [
LoRAModelCase( LoRAModelCase(
base="meta-llama/Llama-2-7b-hf", base="meta-llama/Llama-2-7b-hf",
adaptors=[LoRAAdaptor(name="winddude/wizardLM-LlaMA-LoRA-7B")], adaptors=[LoRAAdaptor(name="winddude/wizardLM-LlaMA-LoRA-7B")],
......
...@@ -16,7 +16,7 @@ suites = { ...@@ -16,7 +16,7 @@ suites = {
TestFile("models/lora/test_lora.py", 76), TestFile("models/lora/test_lora.py", 76),
TestFile("models/lora/test_lora_backend.py", 420), TestFile("models/lora/test_lora_backend.py", 420),
TestFile("models/lora/test_multi_lora_backend.py", 144), TestFile("models/lora/test_multi_lora_backend.py", 144),
TestFile("models/test_embedding_models.py", 119), TestFile("models/test_embedding_models.py", 35),
TestFile("models/test_generation_models.py", 103), TestFile("models/test_generation_models.py", 103),
TestFile("models/test_grok_models.py", 60), TestFile("models/test_grok_models.py", 60),
TestFile("models/test_qwen_models.py", 82), TestFile("models/test_qwen_models.py", 82),
...@@ -38,7 +38,7 @@ suites = { ...@@ -38,7 +38,7 @@ suites = {
TestFile("test_metrics.py", 32), TestFile("test_metrics.py", 32),
TestFile("test_mla.py", 92), TestFile("test_mla.py", 92),
TestFile("test_mla_deepseek_v3.py", 221), TestFile("test_mla_deepseek_v3.py", 221),
TestFile("test_mla_int8_deepseek_v3.py", 421), TestFile("test_mla_int8_deepseek_v3.py", 522),
TestFile("test_mla_flashinfer.py", 395), TestFile("test_mla_flashinfer.py", 395),
TestFile("test_mla_fp8.py", 93), TestFile("test_mla_fp8.py", 93),
TestFile("test_no_chunked_prefill.py", 126), TestFile("test_no_chunked_prefill.py", 126),
...@@ -59,7 +59,7 @@ suites = { ...@@ -59,7 +59,7 @@ suites = {
TestFile("test_srt_endpoint.py", 94), TestFile("test_srt_endpoint.py", 94),
TestFile("test_torch_compile.py", 76), TestFile("test_torch_compile.py", 76),
TestFile("test_torch_compile_moe.py", 85), TestFile("test_torch_compile_moe.py", 85),
TestFile("test_torch_native_attention_backend.py", 149), TestFile("test_torch_native_attention_backend.py", 123),
TestFile("test_torchao.py", 70), TestFile("test_torchao.py", 70),
TestFile("test_triton_attention_kernels.py", 4), TestFile("test_triton_attention_kernels.py", 4),
TestFile("test_triton_attention_backend.py", 134), TestFile("test_triton_attention_backend.py", 134),
...@@ -76,6 +76,16 @@ suites = { ...@@ -76,6 +76,16 @@ suites = {
TestFile("test_hicache.py", 60), TestFile("test_hicache.py", 60),
TestFile("test_hicache_mla.py", 90), TestFile("test_hicache_mla.py", 90),
], ],
"per-commit-2-gpu": [
TestFile("test_data_parallelism.py", 90),
TestFile("test_dp_attention.py", 90),
TestFile("test_update_weights_from_distributed.py", 100),
TestFile("test_verl_engine.py", 100),
TestFile("test_patch_torch.py", 30),
TestFile("test_moe_ep.py", 220),
TestFile("test_mla_tp.py", 420),
TestFile("test_lora_tp.py", 300),
],
"nightly": [ "nightly": [
TestFile("test_nightly_gsm8k_eval.py"), TestFile("test_nightly_gsm8k_eval.py"),
], ],
......
...@@ -60,3 +60,7 @@ class TestDPAttentionDP2TP2(CustomTestCase): ...@@ -60,3 +60,7 @@ class TestDPAttentionDP2TP2(CustomTestCase):
metrics = run_eval(args) metrics = run_eval(args)
print(f"{metrics=}") print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8) self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__":
unittest.main()
...@@ -63,7 +63,6 @@ class TestEnableMetrics(CustomTestCase): ...@@ -63,7 +63,6 @@ class TestEnableMetrics(CustomTestCase):
"sglang:cached_tokens_total", "sglang:cached_tokens_total",
"sglang:num_requests_total", "sglang:num_requests_total",
"sglang:time_to_first_token_seconds", "sglang:time_to_first_token_seconds",
"sglang:time_per_output_token_seconds",
"sglang:inter_token_latency_seconds", "sglang:inter_token_latency_seconds",
"sglang:e2e_request_latency_seconds", "sglang:e2e_request_latency_seconds",
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment