Unverified Commit 7d671e4a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Enable overlap by default (#2067)

parent 699384cb
......@@ -220,7 +220,8 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return reqs
def _extend(reqs, model_runner):
@torch.no_grad
def extend(reqs, model_runner):
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool,
......@@ -236,15 +237,8 @@ def _extend(reqs, model_runner):
return next_token_ids, logits_output.next_token_logits, batch
def extend(reqs, model_runner):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode = not model_runner.torch_tp_applied
with torch.inference_mode(use_inf_mode):
return _extend(reqs, model_runner)
def _decode(input_token_ids, batch, model_runner):
@torch.no_grad
def decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids
batch.prepare_for_decode()
model_worker_batch = batch.get_model_worker_batch()
......@@ -254,14 +248,6 @@ def _decode(input_token_ids, batch, model_runner):
return next_token_ids, logits_output.next_token_logits
def decode(input_token_ids, batch, model_runner):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode = not model_runner.torch_tp_applied
with torch.inference_mode(use_inf_mode):
return _decode(input_token_ids, batch, model_runner)
def correctness_test(
server_args,
port_args,
......
......@@ -87,9 +87,12 @@ class OutlinesGrammar(BaseGrammarObject):
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
tokens = torch.tensor(
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
).to(vocab_mask.device, non_blocking=True)
vocab_mask = vocab_mask[idx]
vocab_mask.fill_(1)
vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
......
......@@ -899,10 +899,7 @@ class ScheduleBatch:
self.input_ids = self.output_ids
self.output_ids = None
if self.sampling_info.penalizer_orchestrator:
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
self.input_ids
)
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(self.input_ids)
# Alloc mem
bs = len(self.reqs)
......
......@@ -30,7 +30,7 @@ import torch
import zmq
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
......@@ -102,7 +102,7 @@ class Scheduler:
self.disable_jump_forward = server_args.disable_jump_forward
self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = server_args.enable_overlap_schedule
self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics
......@@ -159,6 +159,23 @@ class Scheduler:
trust_remote_code=server_args.trust_remote_code,
)
# Check whether overlap can be enabled
if not self.is_generation:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.")
if (
server_args.attention_backend == "triton"
or server_args.enable_double_sparsity
or (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
)
):
self.enable_overlap = False
logger.info(
"Overlap scheduler is disabled if using triton attention backend."
)
# Launch a tensor parallel worker
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
......@@ -903,6 +920,7 @@ class Scheduler:
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
......@@ -958,6 +976,7 @@ class Scheduler:
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model
......@@ -1031,6 +1050,7 @@ class Scheduler:
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.cuda.current_stream().synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs)
......
......@@ -157,14 +157,19 @@ class TpModelWorkerClient:
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# A cuda stream sync here to avoid the cuda illegal memory access error.
_ = model_worker_batch.seq_lens[0].item()
torch.cuda.current_stream().synchronize()
# Push a new batch to the queue
model_worker_batch.sampling_info = dataclasses.replace(
model_worker_batch.sampling_info,
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info = model_worker_batch.sampling_info
sampling_info.update_penalties()
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
sampling_info,
sampling_info_done=threading.Event(),
scaling_penalties=sampling_info.scaling_penalties,
linear_penalties=sampling_info.linear_penalties,
)
self.cur_sampling_info = model_worker_batch.sampling_info
# Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
# Allocate output future objects
......
......@@ -116,7 +116,7 @@ class ModelRunner:
)
if self.is_multimodal:
logger.warning(
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = None
......@@ -636,13 +636,11 @@ class ModelRunner:
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
) -> torch.Tensor:
sampling_info = forward_batch.sampling_info
if sampling_info.sampling_info_done:
# Overlap mode: the function update_regex_vocab_mask was executed
# in process_batch_result of the last batch.
if sampling_info.grammars:
sampling_info.sampling_info_done.wait()
sampling_info.update_penalties()
else:
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
sampling_info.update_regex_vocab_mask()
......
......@@ -132,9 +132,6 @@ class SamplingBatchInfo:
return len(self.temperatures)
def update_penalties(self):
if not self.penalizer_orchestrator:
return
self.scaling_penalties = None
self.linear_penalties = None
......@@ -176,8 +173,7 @@ class SamplingBatchInfo:
grammar.fill_vocab_mask(self.vocab_mask, i)
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
if self.penalizer_orchestrator:
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
for item in [
"temperatures",
......@@ -216,8 +212,7 @@ class SamplingBatchInfo:
return None
def merge_batch(self, other: "SamplingBatchInfo"):
if self.penalizer_orchestrator:
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [
"temperatures",
......
......@@ -123,7 +123,7 @@ class ServerArgs:
disable_disk_cache: bool = False
disable_custom_all_reduce: bool = False
disable_mla: bool = False
enable_overlap_schedule: bool = False
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_torch_compile: bool = False
......@@ -172,9 +172,7 @@ class ServerArgs:
if gpu_mem < 25000:
self.chunked_prefill_size //= 4 # make it 2048
self.cuda_graph_max_bs = 4
logger.warning(
"Automatically adjust --chunked-prefill-size for small GPUs."
)
logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
if not is_flashinfer_available():
self.attention_backend = "triton"
......@@ -192,15 +190,22 @@ class ServerArgs:
self.chunked_prefill_size = self.chunked_prefill_size // 2
self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
self.enable_overlap_schedule = False
logger.warning(
self.disable_overlap_schedule = True
logger.info(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
"Data parallel size is adjusted to be the same as tensor parallel size."
"Data parallel size is adjusted to be the same as tensor parallel size. "
"Overlap schedule is disabled."
)
if self.enable_mixed_chunk:
logger.info(
"Overlap schedule is disabled because mixed-style chunked prefill is enabled."
)
self.disable_overlap_schedule = True
if self.enable_overlap_schedule:
if not self.disable_overlap_schedule:
self.disable_jump_forward = True
@staticmethod
......@@ -624,9 +629,9 @@ class ServerArgs:
help="Disable the NaN detection for better performance.",
)
parser.add_argument(
"--enable-overlap-schedule",
"--disable-overlap-schedule",
action="store_true",
help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
)
parser.add_argument(
"--enable-mixed-chunk",
......@@ -692,6 +697,11 @@ class ServerArgs:
)
# Deprecated arguments
parser.add_argument(
"--enable-overlap-schedule",
action=DeprecatedAction,
help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.",
)
parser.add_argument(
"--disable-flashinfer",
action=DeprecatedAction,
......
......@@ -670,7 +670,7 @@ def run_and_check_memory_leak(
workload_func,
disable_radix_cache,
enable_mixed_chunk,
enable_overlap,
disable_overlap,
chunked_prefill_size,
):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
......@@ -678,8 +678,8 @@ def run_and_check_memory_leak(
other_args += ["--disable-radix-cache"]
if enable_mixed_chunk:
other_args += ["--enable-mixed-chunk"]
if enable_overlap:
other_args += ["--enable-overlap-schedule"]
if disable_overlap:
other_args += ["--disable-overlap-schedule"]
model = DEFAULT_MODEL_NAME_FOR_TEST
port = random.randint(4000, 5000)
......@@ -731,7 +731,7 @@ def run_and_check_memory_leak(
def run_mmlu_test(
disable_radix_cache=False,
enable_mixed_chunk=False,
enable_overlap=False,
disable_overlap=False,
chunked_prefill_size=32,
):
def workload_func(base_url, model):
......@@ -754,7 +754,7 @@ def run_mmlu_test(
workload_func,
disable_radix_cache,
enable_mixed_chunk,
enable_overlap,
disable_overlap,
chunked_prefill_size,
)
......
......@@ -17,8 +17,8 @@ suites = {
"test_json_constrained.py",
"test_large_max_new_tokens.py",
"test_metrics.py",
"test_non_overlap_scheduler.py",
"test_openai_server.py",
"test_overlap_schedule.py",
"test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_retract_decode.py",
......
......@@ -97,8 +97,8 @@ class TestBenchServing(unittest.TestCase):
if is_in_ci():
self.assertLess(res["median_e2e_latency_ms"], 12000)
self.assertLess(res["median_ttft_ms"], 80)
self.assertLess(res["median_itl_ms"], 11)
self.assertLess(res["median_ttft_ms"], 86)
self.assertLess(res["median_itl_ms"], 10)
def test_moe_offline_throughput_default(self):
res = run_bench_serving(
......
......@@ -78,10 +78,11 @@ class TestJSONConstrained(unittest.TestCase):
self.assertIsInstance(js_obj["population"], int)
# Make sure jump forward is triggered
self.assertGreater(
ret["meta_info"]["completion_tokens"],
ret["meta_info"]["completion_tokens_wo_jump_forward"],
)
# NOTE: This is skipped because overlap scheduler does not support jump forward
# self.assertGreater(
# ret["meta_info"]["completion_tokens"],
# ret["meta_info"]["completion_tokens_wo_jump_forward"],
# )
def test_json_generate(self):
self.run_decode(json_schema=self.json_schema)
......
......@@ -59,7 +59,7 @@ class TestMoEEvalAccuracyLarge(unittest.TestCase):
)
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.41)
self.assertGreater(metrics["score"], 0.40)
def test_mgsm_en(self):
args = SimpleNamespace(
......
......@@ -12,22 +12,22 @@ from sglang.test.test_utils import run_mmlu_test
class TestOverlapSchedule(unittest.TestCase):
def test_no_radix_attention_chunked_prefill(self):
run_mmlu_test(
disable_radix_cache=True, chunked_prefill_size=32, enable_overlap=True
disable_radix_cache=True, chunked_prefill_size=32, disable_overlap=True
)
def test_no_radix_attention_no_chunked_prefill(self):
run_mmlu_test(
disable_radix_cache=True, chunked_prefill_size=-1, enable_overlap=True
disable_radix_cache=True, chunked_prefill_size=-1, disable_overlap=True
)
def test_radix_attention_chunked_prefill(self):
run_mmlu_test(
disable_radix_cache=False, chunked_prefill_size=32, enable_overlap=True
disable_radix_cache=False, chunked_prefill_size=32, disable_overlap=True
)
def test_radix_attention_no_chunked_prefill(self):
run_mmlu_test(
disable_radix_cache=False, chunked_prefill_size=-1, enable_overlap=True
disable_radix_cache=False, chunked_prefill_size=-1, disable_overlap=True
)
......
......@@ -107,7 +107,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
)
class TestRadixCacheOverlapLPM(TestRadixCacheFCFS):
class TestRadixCacheNonOverlapLPM(TestRadixCacheFCFS):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
......@@ -117,7 +117,7 @@ class TestRadixCacheOverlapLPM(TestRadixCacheFCFS):
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-overlap-schedule",
"--disable-overlap-schedule",
"--chunked-prefill-size",
"128",
"--max-total-tokens",
......
import time
import unittest
from types import SimpleNamespace
......@@ -56,14 +57,14 @@ class TestTorchCompile(unittest.TestCase):
return response.json()
def test_throughput(self):
import time
# Warmup
res = self.run_decode(16)
max_tokens = 256
tic = time.time()
res = self.run_decode(max_tokens)
tok = time.time()
print(res["text"])
print(f"{res=}")
throughput = max_tokens / (tok - tic)
print(f"Throughput: {throughput} tokens/s")
self.assertGreaterEqual(throughput, 152)
......
import time
import unittest
from types import SimpleNamespace
......@@ -56,10 +57,10 @@ class TestTorchCompile(unittest.TestCase):
return response.json()
def test_throughput(self):
import time
# Warmup
res = self.run_decode(16)
max_tokens = 256
tic = time.time()
res = self.run_decode(max_tokens)
tok = time.time()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment