Unverified Commit 9cf0a5ba authored by gryffindor-rr's avatar gryffindor-rr Committed by GitHub
Browse files

Add skip_tokenizer_init args. (#959)


Co-authored-by: default avatarlzhang <zhanglei@modelbest.cn>
parent b16e856f
...@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache ...@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
class FSMCache(BaseToolCache): class FSMCache(BaseToolCache):
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True): def __init__(
self,
tokenizer_path,
tokenizer_args_dict,
enable=True,
skip_tokenizer_init=False,
):
super().__init__(enable=enable) super().__init__(enable=enable)
if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"): if (
skip_tokenizer_init
or tokenizer_path.endswith(".json")
or tokenizer_path.endswith(".model")
):
# Do not support TiktokenTokenizer or SentencePieceTokenizer # Do not support TiktokenTokenizer or SentencePieceTokenizer
return return
......
...@@ -59,11 +59,14 @@ class DetokenizerManager: ...@@ -59,11 +59,14 @@ class DetokenizerManager:
self.send_to_tokenizer = context.socket(zmq.PUSH) self.send_to_tokenizer = context.socket(zmq.PUSH)
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.tokenizer = get_tokenizer( if server_args.skip_tokenizer_init:
server_args.tokenizer_path, self.tokenizer = None
tokenizer_mode=server_args.tokenizer_mode, else:
trust_remote_code=server_args.trust_remote_code, self.tokenizer = get_tokenizer(
) server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.decode_status = {} self.decode_status = {}
...@@ -85,6 +88,11 @@ class DetokenizerManager: ...@@ -85,6 +88,11 @@ class DetokenizerManager:
assert isinstance(recv_obj, BatchTokenIDOut) assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids) bs = len(recv_obj.rids)
if self.tokenizer is None:
# Send BatchTokenIDOut if no tokenizer init'ed.
self.send_to_tokenizer.send_pyobj(recv_obj)
continue
# Initialize decode status # Initialize decode status
read_ids, surr_ids = [], [] read_ids, surr_ids = [], []
for i in range(bs): for i in range(bs):
......
...@@ -195,6 +195,8 @@ class Req: ...@@ -195,6 +195,8 @@ class Req:
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
def get_next_inc_detokenization(self): def get_next_inc_detokenization(self):
if self.tokenizer is None:
return False, ""
read_ids, read_offset = self.init_incremental_detokenize() read_ids, read_offset = self.init_incremental_detokenize()
surr_ids = read_ids[:read_offset] surr_ids = read_ids[:read_offset]
...@@ -225,16 +227,11 @@ class Req: ...@@ -225,16 +227,11 @@ class Req:
return return
last_token_id = self.output_ids[-1] last_token_id = self.output_ids[-1]
if ( if self.tokenizer is None:
last_token_id == self.tokenizer.eos_token_id matched_eos = last_token_id in self.sampling_params.stop_token_ids
and not self.sampling_params.ignore_eos else:
): matched_eos = last_token_id == self.tokenizer.eos_token_id
self.finished_reason = FINISH_MATCHED_TOKEN( if matched_eos and not self.sampling_params.ignore_eos:
matched=self.tokenizer.eos_token_id
)
return
if last_token_id in self.sampling_params.stop_token_ids:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return return
......
...@@ -95,25 +95,28 @@ class TokenizerManager: ...@@ -95,25 +95,28 @@ class TokenizerManager:
else: else:
self.context_len = get_context_length(self.hf_config) self.context_len = get_context_length(self.hf_config)
if is_multimodal_model(self.model_path): if server_args.skip_tokenizer_init:
self.processor = get_processor( self.tokenizer = self.processor = None
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
)
else: else:
self.tokenizer = get_tokenizer( if is_multimodal_model(self.model_path):
server_args.tokenizer_path, self.processor = get_processor(
tokenizer_mode=server_args.tokenizer_mode, server_args.tokenizer_path,
trust_remote_code=server_args.trust_remote_code, tokenizer_mode=server_args.tokenizer_mode,
) trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
)
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.to_create_loop = True self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {} self.rid_to_state: Dict[str, ReqState] = {}
...@@ -171,6 +174,7 @@ class TokenizerManager: ...@@ -171,6 +174,7 @@ class TokenizerManager:
rid = obj.rid if not_use_index else obj.rid[index] rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index] input_text = obj.text if not_use_index else obj.text[index]
if obj.input_ids is None: if obj.input_ids is None:
assert self.tokenizer is not None
input_ids = self.tokenizer.encode(input_text) input_ids = self.tokenizer.encode(input_text)
else: else:
input_ids = obj.input_ids if not_use_index else obj.input_ids[index] input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
...@@ -207,7 +211,20 @@ class TokenizerManager: ...@@ -207,7 +211,20 @@ class TokenizerManager:
else: else:
input_text = obj.text input_text = obj.text
rid = obj.rid[0] rid = obj.rid[0]
input_ids = self.tokenizer.encode(input_text) if self.tokenizer is not None:
input_ids = self.tokenizer.encode(input_text)
else:
assert obj.input_ids is not None
input_ids = obj.input_ids
if isinstance(obj.input_ids, list) and isinstance(
obj.input_ids[0], list
):
# when obj["input_ids"] is List[List[int]]
input_ids = obj.input_ids[index]
rid = obj.rid[index]
else:
input_ids = obj.input_ids
rid = obj.rid[0]
else: else:
input_text = None input_text = None
if isinstance(obj.input_ids, list) and isinstance( if isinstance(obj.input_ids, list) and isinstance(
...@@ -420,7 +437,7 @@ class TokenizerManager: ...@@ -420,7 +437,7 @@ class TokenizerManager:
# Log requests # Log requests
if self.server_args.log_requests and state.finished: if self.server_args.log_requests and state.finished:
if obj.text is None: if obj.text is None:
in_obj = {"text": self.tokenizer.decode(obj.input_ids)} in_obj = {"input_ids": obj.input_ids}
else: else:
in_obj = {"text": obj.text} in_obj = {"text": obj.text}
logger.info(f"in={in_obj}, out={out}") logger.info(f"in={in_obj}, out={out}")
...@@ -488,11 +505,12 @@ class TokenizerManager: ...@@ -488,11 +505,12 @@ class TokenizerManager:
async def handle_loop(self): async def handle_loop(self):
while True: while True:
recv_obj: Union[BatchStrOut, BatchEmbeddingOut] = ( recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = (
await self.recv_from_detokenizer.recv_pyobj() await self.recv_from_detokenizer.recv_pyobj()
) )
assert isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut)) assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
), f"Unexpected obj received: {type(recv_obj)}"
for i, rid in enumerate(recv_obj.rids): for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None) state = self.rid_to_state.get(rid, None)
if state is None: if state is None:
...@@ -504,6 +522,15 @@ class TokenizerManager: ...@@ -504,6 +522,15 @@ class TokenizerManager:
"text": recv_obj.output_strs[i], "text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i], "meta_info": recv_obj.meta_info[i],
} }
elif isinstance(recv_obj, BatchTokenIDOut):
read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
out_dict = {
"token_ids": recv_obj.decode_ids[
read_start : recv_obj.read_offsets[i]
],
"meta_info": recv_obj.meta_info[i],
}
else: else:
assert isinstance(recv_obj, BatchEmbeddingOut) assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = { out_dict = {
...@@ -549,6 +576,7 @@ class TokenizerManager: ...@@ -549,6 +576,7 @@ class TokenizerManager:
if not decode_to_text: if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs] return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
assert self.tokenizer is not None
token_ids = [tid for _, tid in token_logprobs] token_ids = [tid for _, tid in token_logprobs]
token_texts = self.tokenizer.batch_decode(token_ids) token_texts = self.tokenizer.batch_decode(token_ids)
return [ return [
......
...@@ -100,20 +100,22 @@ class ModelTpServer: ...@@ -100,20 +100,22 @@ class ModelTpServer:
nccl_port=nccl_port, nccl_port=nccl_port,
server_args=server_args, server_args=server_args,
) )
if server_args.skip_tokenizer_init:
if is_multimodal_model(server_args.model_path): self.tokenizer = self.processor = None
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
else: else:
self.tokenizer = get_tokenizer( if is_multimodal_model(server_args.model_path):
server_args.tokenizer_path, self.processor = get_processor(
tokenizer_mode=server_args.tokenizer_mode, server_args.tokenizer_path,
trust_remote_code=server_args.trust_remote_code, tokenizer_mode=server_args.tokenizer_mode,
) trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.max_total_num_tokens = self.model_runner.max_total_num_tokens self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = ( self.max_prefill_tokens = (
16384 16384
...@@ -182,13 +184,15 @@ class ModelTpServer: ...@@ -182,13 +184,15 @@ class ModelTpServer:
self.last_stats_tic = time.time() self.last_stats_tic = time.time()
# Init the FSM cache for constrained generation # Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache( if not server_args.skip_tokenizer_init:
server_args.tokenizer_path, self.regex_fsm_cache = FSMCache(
{ server_args.tokenizer_path,
"tokenizer_mode": server_args.tokenizer_mode, {
"trust_remote_code": server_args.trust_remote_code, "tokenizer_mode": server_args.tokenizer_mode,
}, "trust_remote_code": server_args.trust_remote_code,
) },
skip_tokenizer_init=server_args.skip_tokenizer_init,
)
self.jump_forward_cache = JumpForwardCache() self.jump_forward_cache = JumpForwardCache()
# Init new token estimation # Init new token estimation
...@@ -466,7 +470,11 @@ class ModelTpServer: ...@@ -466,7 +470,11 @@ class ModelTpServer:
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
else: else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) if self.tokenizer is None:
for i, req in enumerate(batch.reqs):
next_token_ids.extend(req.sampling_params.stop_token_ids)
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
# Check finish conditions # Check finish conditions
pt = 0 pt = 0
......
...@@ -111,13 +111,19 @@ class SamplingParams: ...@@ -111,13 +111,19 @@ class SamplingParams:
# Process stop strings # Process stop strings
if self.stop_strs is None: if self.stop_strs is None:
self.stop_strs = [] self.stop_strs = []
self.stop_str_max_len = 0 if self.stop_token_ids is None:
self.stop_str_max_len = 0
else:
self.stop_str_max_len = 1
else: else:
if isinstance(self.stop_strs, str): if isinstance(self.stop_strs, str):
self.stop_strs = [self.stop_strs] self.stop_strs = [self.stop_strs]
stop_str_max_len = 0 stop_str_max_len = 0
for stop_str in self.stop_strs: for stop_str in self.stop_strs:
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False) if tokenizer is not None:
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids)) stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
else:
stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len self.stop_str_max_len = stop_str_max_len
...@@ -420,17 +420,22 @@ def _wait_and_warmup(server_args, pipe_finish_writer): ...@@ -420,17 +420,22 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
# Send a warmup request # Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode" request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 1 max_new_tokens = 8 if model_info["is_generation"] else 1
json_data = {
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
}
if server_args.skip_tokenizer_init:
json_data["input_ids"] = [10, 11, 12]
else:
json_data["text"] = "The capital city of France is"
try: try:
for _ in range(server_args.dp_size): for _ in range(server_args.dp_size):
res = requests.post( res = requests.post(
url + request_name, url + request_name,
json={ json=json_data,
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
},
headers=headers, headers=headers,
timeout=600, timeout=600,
) )
......
...@@ -27,6 +27,7 @@ class ServerArgs: ...@@ -27,6 +27,7 @@ class ServerArgs:
model_path: str model_path: str
tokenizer_path: Optional[str] = None tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto" tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False
load_format: str = "auto" load_format: str = "auto"
dtype: str = "auto" dtype: str = "auto"
trust_remote_code: bool = True trust_remote_code: bool = True
...@@ -151,6 +152,11 @@ class ServerArgs: ...@@ -151,6 +152,11 @@ class ServerArgs:
"tokenizer if available, and 'slow' will " "tokenizer if available, and 'slow' will "
"always use the slow tokenizer.", "always use the slow tokenizer.",
) )
parser.add_argument(
"--skip-tokenizer-init",
action="store_true",
help="If set, skip init tokenizer and pass input_ids in generate request",
)
parser.add_argument( parser.add_argument(
"--load-format", "--load-format",
type=str, type=str,
......
...@@ -197,6 +197,8 @@ def allocate_init_ports( ...@@ -197,6 +197,8 @@ def allocate_init_ports(
def get_int_token_logit_bias(tokenizer, vocab_size): def get_int_token_logit_bias(tokenizer, vocab_size):
"""Get the logit bias for integer-only tokens.""" """Get the logit bias for integer-only tokens."""
# a bug when model's vocab size > tokenizer.vocab_size # a bug when model's vocab size > tokenizer.vocab_size
if tokenizer == None:
return [-1e5] * vocab_size
vocab_size = tokenizer.vocab_size vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32) logit_bias = np.zeros(vocab_size, dtype=np.float32)
for t_id in range(vocab_size): for t_id in range(vocab_size):
......
import json
import os
import sys
import unittest
import requests
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
class TestSRTEndpoint(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = "http://127.0.0.1:8157"
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, other_args=["--skip-tokenizer-init"]
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_decode(
self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
):
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": [
119689,
50650,
18291,
30061,
5316,
26951,
119690,
], # The capital of France is
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 32,
"n": n,
"stop_token_ids": [119690],
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
print(json.dumps(response.json()))
print("=" * 100)
def test_simple_decode(self):
self.run_decode()
def test_parallel_sample(self):
self.run_decode(n=3)
def test_logprob(self):
for top_logprobs_num in [0, 3]:
for return_text in [False, False]:
self.run_decode(
return_logprob=True,
top_logprobs_num=top_logprobs_num,
return_text=return_text,
)
if __name__ == "__main__":
unittest.main(warnings="ignore")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment