"vscode:/vscode.git/clone" did not exist on "e3055164f7c5ee8be5e8b27c38fe1ec8450c06b8"
Unverified Commit 9cf0a5ba authored by gryffindor-rr's avatar gryffindor-rr Committed by GitHub
Browse files

Add skip_tokenizer_init args. (#959)


Co-authored-by: default avatarlzhang <zhanglei@modelbest.cn>
parent b16e856f
......@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
class FSMCache(BaseToolCache):
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
def __init__(
self,
tokenizer_path,
tokenizer_args_dict,
enable=True,
skip_tokenizer_init=False,
):
super().__init__(enable=enable)
if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
if (
skip_tokenizer_init
or tokenizer_path.endswith(".json")
or tokenizer_path.endswith(".model")
):
# Do not support TiktokenTokenizer or SentencePieceTokenizer
return
......
......@@ -59,11 +59,14 @@ class DetokenizerManager:
self.send_to_tokenizer = context.socket(zmq.PUSH)
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
if server_args.skip_tokenizer_init:
self.tokenizer = None
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.decode_status = {}
......@@ -85,6 +88,11 @@ class DetokenizerManager:
assert isinstance(recv_obj, BatchTokenIDOut)
bs = len(recv_obj.rids)
if self.tokenizer is None:
# Send BatchTokenIDOut if no tokenizer init'ed.
self.send_to_tokenizer.send_pyobj(recv_obj)
continue
# Initialize decode status
read_ids, surr_ids = [], []
for i in range(bs):
......
......@@ -195,6 +195,8 @@ class Req:
return all_ids[self.surr_offset :], self.read_offset - self.surr_offset
def get_next_inc_detokenization(self):
if self.tokenizer is None:
return False, ""
read_ids, read_offset = self.init_incremental_detokenize()
surr_ids = read_ids[:read_offset]
......@@ -225,16 +227,11 @@ class Req:
return
last_token_id = self.output_ids[-1]
if (
last_token_id == self.tokenizer.eos_token_id
and not self.sampling_params.ignore_eos
):
self.finished_reason = FINISH_MATCHED_TOKEN(
matched=self.tokenizer.eos_token_id
)
return
if last_token_id in self.sampling_params.stop_token_ids:
if self.tokenizer is None:
matched_eos = last_token_id in self.sampling_params.stop_token_ids
else:
matched_eos = last_token_id == self.tokenizer.eos_token_id
if matched_eos and not self.sampling_params.ignore_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return
......
......@@ -95,25 +95,28 @@ class TokenizerManager:
else:
self.context_len = get_context_length(self.hf_config)
if is_multimodal_model(self.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
)
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
if is_multimodal_model(self.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
initargs=(server_args,),
)
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.to_create_loop = True
self.rid_to_state: Dict[str, ReqState] = {}
......@@ -171,6 +174,7 @@ class TokenizerManager:
rid = obj.rid if not_use_index else obj.rid[index]
input_text = obj.text if not_use_index else obj.text[index]
if obj.input_ids is None:
assert self.tokenizer is not None
input_ids = self.tokenizer.encode(input_text)
else:
input_ids = obj.input_ids if not_use_index else obj.input_ids[index]
......@@ -207,7 +211,20 @@ class TokenizerManager:
else:
input_text = obj.text
rid = obj.rid[0]
input_ids = self.tokenizer.encode(input_text)
if self.tokenizer is not None:
input_ids = self.tokenizer.encode(input_text)
else:
assert obj.input_ids is not None
input_ids = obj.input_ids
if isinstance(obj.input_ids, list) and isinstance(
obj.input_ids[0], list
):
# when obj["input_ids"] is List[List[int]]
input_ids = obj.input_ids[index]
rid = obj.rid[index]
else:
input_ids = obj.input_ids
rid = obj.rid[0]
else:
input_text = None
if isinstance(obj.input_ids, list) and isinstance(
......@@ -420,7 +437,7 @@ class TokenizerManager:
# Log requests
if self.server_args.log_requests and state.finished:
if obj.text is None:
in_obj = {"text": self.tokenizer.decode(obj.input_ids)}
in_obj = {"input_ids": obj.input_ids}
else:
in_obj = {"text": obj.text}
logger.info(f"in={in_obj}, out={out}")
......@@ -488,11 +505,12 @@ class TokenizerManager:
async def handle_loop(self):
while True:
recv_obj: Union[BatchStrOut, BatchEmbeddingOut] = (
recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] = (
await self.recv_from_detokenizer.recv_pyobj()
)
assert isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut))
assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
), f"Unexpected obj received: {type(recv_obj)}"
for i, rid in enumerate(recv_obj.rids):
state = self.rid_to_state.get(rid, None)
if state is None:
......@@ -504,6 +522,15 @@ class TokenizerManager:
"text": recv_obj.output_strs[i],
"meta_info": recv_obj.meta_info[i],
}
elif isinstance(recv_obj, BatchTokenIDOut):
read_start = 0 if i == 0 else recv_obj.read_offsets[i - 1]
out_dict = {
"token_ids": recv_obj.decode_ids[
read_start : recv_obj.read_offsets[i]
],
"meta_info": recv_obj.meta_info[i],
}
else:
assert isinstance(recv_obj, BatchEmbeddingOut)
out_dict = {
......@@ -549,6 +576,7 @@ class TokenizerManager:
if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
assert self.tokenizer is not None
token_ids = [tid for _, tid in token_logprobs]
token_texts = self.tokenizer.batch_decode(token_ids)
return [
......
......@@ -100,20 +100,22 @@ class ModelTpServer:
nccl_port=nccl_port,
server_args=server_args,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
if server_args.skip_tokenizer_init:
self.tokenizer = self.processor = None
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
if is_multimodal_model(server_args.model_path):
self.processor = get_processor(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.tokenizer = self.processor.tokenizer
else:
self.tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
trust_remote_code=server_args.trust_remote_code,
)
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = (
16384
......@@ -182,13 +184,15 @@ class ModelTpServer:
self.last_stats_tic = time.time()
# Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
)
if not server_args.skip_tokenizer_init:
self.regex_fsm_cache = FSMCache(
server_args.tokenizer_path,
{
"tokenizer_mode": server_args.tokenizer_mode,
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
)
self.jump_forward_cache = JumpForwardCache()
# Init new token estimation
......@@ -466,7 +470,11 @@ class ModelTpServer:
next_token_ids = next_token_ids.tolist()
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
if self.tokenizer is None:
for i, req in enumerate(batch.reqs):
next_token_ids.extend(req.sampling_params.stop_token_ids)
else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
# Check finish conditions
pt = 0
......
......@@ -111,13 +111,19 @@ class SamplingParams:
# Process stop strings
if self.stop_strs is None:
self.stop_strs = []
self.stop_str_max_len = 0
if self.stop_token_ids is None:
self.stop_str_max_len = 0
else:
self.stop_str_max_len = 1
else:
if isinstance(self.stop_strs, str):
self.stop_strs = [self.stop_strs]
stop_str_max_len = 0
for stop_str in self.stop_strs:
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
if tokenizer is not None:
stop_str_ids = tokenizer.encode(stop_str, add_special_tokens=False)
stop_str_max_len = max(stop_str_max_len, len(stop_str_ids))
else:
stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len
......@@ -420,17 +420,22 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 8 if model_info["is_generation"] else 1
json_data = {
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
}
if server_args.skip_tokenizer_init:
json_data["input_ids"] = [10, 11, 12]
else:
json_data["text"] = "The capital city of France is"
try:
for _ in range(server_args.dp_size):
res = requests.post(
url + request_name,
json={
"text": "The capital city of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
},
},
json=json_data,
headers=headers,
timeout=600,
)
......
......@@ -27,6 +27,7 @@ class ServerArgs:
model_path: str
tokenizer_path: Optional[str] = None
tokenizer_mode: str = "auto"
skip_tokenizer_init: bool = False
load_format: str = "auto"
dtype: str = "auto"
trust_remote_code: bool = True
......@@ -151,6 +152,11 @@ class ServerArgs:
"tokenizer if available, and 'slow' will "
"always use the slow tokenizer.",
)
parser.add_argument(
"--skip-tokenizer-init",
action="store_true",
help="If set, skip init tokenizer and pass input_ids in generate request",
)
parser.add_argument(
"--load-format",
type=str,
......
......@@ -197,6 +197,8 @@ def allocate_init_ports(
def get_int_token_logit_bias(tokenizer, vocab_size):
"""Get the logit bias for integer-only tokens."""
# a bug when model's vocab size > tokenizer.vocab_size
if tokenizer == None:
return [-1e5] * vocab_size
vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32)
for t_id in range(vocab_size):
......
import json
import os
import sys
import unittest
import requests
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
class TestSRTEndpoint(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = "http://127.0.0.1:8157"
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, other_args=["--skip-tokenizer-init"]
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_decode(
self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
):
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": [
119689,
50650,
18291,
30061,
5316,
26951,
119690,
], # The capital of France is
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 32,
"n": n,
"stop_token_ids": [119690],
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
print(json.dumps(response.json()))
print("=" * 100)
def test_simple_decode(self):
self.run_decode()
def test_parallel_sample(self):
self.run_decode(n=3)
def test_logprob(self):
for top_logprobs_num in [0, 3]:
for return_text in [False, False]:
self.run_decode(
return_logprob=True,
top_logprobs_num=top_logprobs_num,
return_text=return_text,
)
if __name__ == "__main__":
unittest.main(warnings="ignore")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment