Unverified Commit 0e7409ad authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the overlap for xgrammar (#2377)

parent 3cde5eb6
...@@ -106,4 +106,4 @@ def import_new_model_classes(): ...@@ -106,4 +106,4 @@ def import_new_model_classes():
ModelRegistry.models.update(import_new_model_classes()) ModelRegistry.models.update(import_new_model_classes())
launch_server(server_args) launch_server(server_args)
``` ```
\ No newline at end of file
...@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject): ...@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject):
self.guide = guide self.guide = guide
self.jump_forward_map = jump_forward_map self.jump_forward_map = jump_forward_map
self.state = 0 self.state = 0
self.finished = False
def accept_token(self, token: int): def accept_token(self, token: int):
self.state = self.guide.get_next_state(self.state, token) self.state = self.guide.get_next_state(self.state, token)
...@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject): ...@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject):
) -> torch.Tensor: ) -> torch.Tensor:
return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device) return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
@staticmethod
def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
return vocab_mask
def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
tokens = torch.tensor( tokens = torch.tensor(
self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64 self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
......
...@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
self.matcher = matcher self.matcher = matcher
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.ctx = ctx self.ctx = ctx
self.finished = False
def accept_token(self, token: int): def accept_token(self, token: int):
assert self.matcher.accept_token(token) assert self.matcher.accept_token(token)
...@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject): ...@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
self.matcher.fill_next_token_bitmask(vocab_mask, idx) self.matcher.fill_next_token_bitmask(vocab_mask, idx)
@staticmethod @staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None: def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
if vocab_mask.device.type != logits.device.type: return vocab_mask.to(device, non_blocking=True)
# vocab_mask must then be on the same device as logits
# when applying the token bitmask, so we check and move if needed
vocab_mask = vocab_mask.to(logits.device)
@staticmethod
def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
apply_token_bitmask_inplace(logits, vocab_mask) apply_token_bitmask_inplace(logits, vocab_mask)
def copy(self): def copy(self):
......
...@@ -114,9 +114,6 @@ class Scheduler: ...@@ -114,9 +114,6 @@ class Scheduler:
self.skip_tokenizer_init = server_args.skip_tokenizer_init self.skip_tokenizer_init = server_args.skip_tokenizer_init
self.enable_metrics = server_args.enable_metrics self.enable_metrics = server_args.enable_metrics
# Session info
self.sessions = {}
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
...@@ -259,6 +256,10 @@ class Scheduler: ...@@ -259,6 +256,10 @@ class Scheduler:
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.last_decode_stats_tic = time.time() self.last_decode_stats_tic = time.time()
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
self.current_stream = torch.get_device_module(self.device).current_stream()
# Session info
self.sessions = {}
# Init chunked prefill # Init chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size self.chunked_prefill_size = server_args.chunked_prefill_size
...@@ -356,6 +357,7 @@ class Scheduler: ...@@ -356,6 +357,7 @@ class Scheduler:
) )
def watchdog_thread(self): def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one batch takes too long."""
self.watchdog_last_forward_ct = 0 self.watchdog_last_forward_ct = 0
self.watchdog_last_time = time.time() self.watchdog_last_time = time.time()
...@@ -433,61 +435,6 @@ class Scheduler: ...@@ -433,61 +435,6 @@ class Scheduler:
self.last_batch = batch self.last_batch = batch
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
else:
num_tokens = local_batch.extend_num_tokens
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_cpu_group,
)
if local_batch is None and global_num_tokens.max().item() > 0:
local_batch = self.get_idle_batch()
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()
# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
return local_batch
def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
self.enable_overlap,
)
idle_batch.prepare_for_idle()
return idle_batch
def recv_requests(self): def recv_requests(self):
if self.tp_rank == 0 or self.server_args.enable_dp_attention: if self.tp_rank == 0 or self.server_args.enable_dp_attention:
recv_reqs = [] recv_reqs = []
...@@ -993,7 +940,7 @@ class Scheduler: ...@@ -993,7 +940,7 @@ class Scheduler:
self.process_batch_result_prefill(batch, result) self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_dummy_first(): elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.get_device_module(self.device).current_stream().synchronize() self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
def process_batch_result_prefill(self, batch: ScheduleBatch, result): def process_batch_result_prefill(self, batch: ScheduleBatch, result):
...@@ -1049,13 +996,14 @@ class Scheduler: ...@@ -1049,13 +996,14 @@ class Scheduler:
if req.grammar is not None: if req.grammar is not None:
req.grammar.accept_token(next_token_id) req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
else: else:
# being chunked reqs' prefill is not finished # being chunked reqs' prefill is not finished
req.is_being_chunked -= 1 req.is_being_chunked -= 1
if batch.next_batch_sampling_info: if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.get_device_module(self.device).current_stream().synchronize() self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
else: # embedding or reward model else: # embedding or reward model
...@@ -1127,10 +1075,11 @@ class Scheduler: ...@@ -1127,10 +1075,11 @@ class Scheduler:
if req.grammar is not None: if req.grammar is not None:
req.grammar.accept_token(next_token_id) req.grammar.accept_token(next_token_id)
req.grammar.finished = req.finished()
if batch.next_batch_sampling_info: if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask() batch.next_batch_sampling_info.update_regex_vocab_mask()
torch.get_device_module(self.device).current_stream().synchronize() self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set() batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs) self.stream_output(batch.reqs)
...@@ -1328,6 +1277,61 @@ class Scheduler: ...@@ -1328,6 +1277,61 @@ class Scheduler:
) )
) )
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
else:
num_tokens = local_batch.extend_num_tokens
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_cpu_group,
)
if local_batch is None and global_num_tokens.max().item() > 0:
local_batch = self.get_idle_batch()
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()
# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
return local_batch
def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
self.enable_overlap,
)
idle_batch.prepare_for_idle()
return idle_batch
def move_ready_grammar_requests(self): def move_ready_grammar_requests(self):
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue.""" """Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs = 0 num_ready_reqs = 0
...@@ -1469,10 +1473,6 @@ def run_scheduler_process( ...@@ -1469,10 +1473,6 @@ def run_scheduler_process(
dp_rank: Optional[int], dp_rank: Optional[int],
pipe_writer, pipe_writer,
): ):
# set cpu affinity to this gpu process
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
if dp_rank is None and "SGLANG_DP_RANK" in os.environ: if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
dp_rank = int(os.environ["SGLANG_DP_RANK"]) dp_rank = int(os.environ["SGLANG_DP_RANK"])
...@@ -1482,6 +1482,10 @@ def run_scheduler_process( ...@@ -1482,6 +1482,10 @@ def run_scheduler_process(
else: else:
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
# set cpu affinity to this gpu process
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
suppress_other_loggers() suppress_other_loggers()
parent_process = psutil.Process().parent() parent_process = psutil.Process().parent()
......
...@@ -80,6 +80,7 @@ class TpModelWorkerClient: ...@@ -80,6 +80,7 @@ class TpModelWorkerClient:
) )
self.forward_thread.start() self.forward_thread.start()
self.parent_process = psutil.Process().parent() self.parent_process = psutil.Process().parent()
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
def get_worker_info(self): def get_worker_info(self):
return self.worker.get_worker_info() return self.worker.get_worker_info()
...@@ -191,7 +192,7 @@ class TpModelWorkerClient: ...@@ -191,7 +192,7 @@ class TpModelWorkerClient:
) )
# A cuda stream sync here to avoid the cuda illegal memory access error. # A cuda stream sync here to avoid the cuda illegal memory access error.
torch.get_device_module(self.device).current_stream().synchronize() self.scheduler_stream.synchronize()
# Push a new batch to the queue # Push a new batch to the queue
self.input_queue.put((model_worker_batch, self.future_token_ids_ct)) self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
......
...@@ -158,22 +158,23 @@ class SamplingBatchInfo: ...@@ -158,22 +158,23 @@ class SamplingBatchInfo:
return return
# find a grammar from the list # find a grammar from the list
grammar = next(grammar for grammar in self.grammars if grammar) first_grammar = next(grammar for grammar in self.grammars if grammar)
# maybe we can reuse the existing mask? # maybe we can reuse the existing mask?
self.vocab_mask = grammar.allocate_vocab_mask( self.vocab_mask = first_grammar.allocate_vocab_mask(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
batch_size=len(self.temperatures), batch_size=len(self.temperatures),
device=self.device, device=self.device,
) )
self.apply_mask = type(grammar).apply_vocab_mask # force to use static method self.apply_mask = first_grammar.apply_vocab_mask # force to use static method
# Apply the mask
for i, grammar in enumerate(self.grammars): for i, grammar in enumerate(self.grammars):
if grammar is not None: if grammar and not grammar.finished:
try: grammar.fill_vocab_mask(self.vocab_mask, i)
grammar.fill_vocab_mask(self.vocab_mask, i)
except RuntimeError: # Move the mask to the device if needed
continue self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices) self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
......
""" """
python3 -m unittest test_json_constrained.TestJSONConstrained.test_json_generate python3 -m unittest test_json_constrained.TestJSONConstrainedOutlinesBackend.test_json_generate
python3 -m unittest test_json_constrained.TestJSONConstrainedXGrammarBackend.test_json_generate
""" """
import json import json
...@@ -11,38 +12,50 @@ import requests ...@@ -11,38 +12,50 @@ import requests
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import ( from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
popen_launch_server, popen_launch_server,
) )
def setup_class(cls, backend: str, disable_overlap: bool):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
other_args = [
"--max-running-requests",
"10",
"--grammar-backend",
backend,
]
if disable_overlap:
other_args += ["--disable-overlap-schedule"]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
class TestJSONConstrainedOutlinesBackend(unittest.TestCase): class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST setup_class(cls, backend="outlines", disable_overlap=False)
cls.base_url = DEFAULT_URL_FOR_TEST cls.check_jump_forward = False
cls.json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string", "pattern": "^[\\w]+$"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=300,
other_args=[
"--max-running-requests",
"10",
"--grammar-backend",
"outlines",
],
)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
...@@ -83,11 +96,13 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase): ...@@ -83,11 +96,13 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
self.assertIsInstance(js_obj["population"], int) self.assertIsInstance(js_obj["population"], int)
# Make sure jump forward is triggered # Make sure jump forward is triggered
# NOTE: This is skipped because overlap scheduler does not support jump forward # NOTE: The overlap scheduler does not support jump forward so we only do this test
# self.assertGreater( # when --disable-overlap-schedule is set.
# ret["meta_info"]["completion_tokens"], if self.check_jump_forward:
# ret["meta_info"]["completion_tokens_wo_jump_forward"], self.assertGreater(
# ) ret["meta_info"]["completion_tokens"],
ret["meta_info"]["completion_tokens_wo_jump_forward"],
)
def test_json_generate(self): def test_json_generate(self):
self.run_decode(json_schema=self.json_schema) self.run_decode(json_schema=self.json_schema)
...@@ -126,32 +141,18 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase): ...@@ -126,32 +141,18 @@ class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
list(executor.map(self.run_decode, json_schemas)) list(executor.map(self.run_decode, json_schemas))
class TestJumpForwardOutlinesBackend(unittest.TestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="outlines", disable_overlap=True)
cls.check_jump_forward = True
class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend): class TestJSONConstrainedXGrammarBackend(TestJSONConstrainedOutlinesBackend):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST setup_class(cls, backend="xgrammar", disable_overlap=False)
cls.base_url = DEFAULT_URL_FOR_TEST cls.check_jump_forward = False
cls.json_schema = json.dumps(
{
"type": "object",
"properties": {
"name": {"type": "string"},
"population": {"type": "integer"},
},
"required": ["name", "population"],
}
)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=300,
other_args=[
"--max-running-requests",
"10",
"--grammar-backend",
"xgrammar",
],
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment