Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9cf0a5ba
Unverified
Commit
9cf0a5ba
authored
Aug 10, 2024
by
gryffindor-rr
Committed by
GitHub
Aug 09, 2024
Browse files
Add skip_tokenizer_init args. (#959)
Co-authored-by:
lzhang
<
zhanglei@modelbest.cn
>
parent
b16e856f
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
218 additions
and
71 deletions
+218
-71
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+12
-2
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+13
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+7
-10
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+51
-23
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+29
-21
python/sglang/srt/sampling_params.py
python/sglang/srt/sampling_params.py
+9
-3
python/sglang/srt/server.py
python/sglang/srt/server.py
+12
-7
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+2
-0
test/srt/test_skip_tokenizer_srt.py
test/srt/test_skip_tokenizer_srt.py
+77
-0
No files found.
python/sglang/srt/constrained/fsm_cache.py
View file @
9cf0a5ba
...
@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
...
@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
class
FSMCache
(
BaseToolCache
):
class
FSMCache
(
BaseToolCache
):
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
enable
=
True
):
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
enable
=
True
,
skip_tokenizer_init
=
False
,
):
super
().
__init__
(
enable
=
enable
)
super
().
__init__
(
enable
=
enable
)
if
tokenizer_path
.
endswith
(
".json"
)
or
tokenizer_path
.
endswith
(
".model"
):
if
(
skip_tokenizer_init
or
tokenizer_path
.
endswith
(
".json"
)
or
tokenizer_path
.
endswith
(
".model"
)
):
# Do not support TiktokenTokenizer or SentencePieceTokenizer
# Do not support TiktokenTokenizer or SentencePieceTokenizer
return
return
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
9cf0a5ba
...
@@ -59,11 +59,14 @@ class DetokenizerManager:
...
@@ -59,11 +59,14 @@ class DetokenizerManager:
self
.
send_to_tokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_tokenizer
=
context
.
socket
(
zmq
.
PUSH
)
self
.
send_to_tokenizer
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
tokenizer_port
}
"
)
self
.
send_to_tokenizer
.
connect
(
f
"tcp://127.0.0.1:
{
port_args
.
tokenizer_port
}
"
)
self
.
tokenizer
=
get_tokenizer
(
if
server_args
.
skip_tokenizer_init
:
server_args
.
tokenizer_path
,
self
.
tokenizer
=
None
tokenizer_mode
=
server_args
.
tokenizer_mode
,
else
:
trust_remote_code
=
server_args
.
trust_remote_code
,
self
.
tokenizer
=
get_tokenizer
(
)
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
decode_status
=
{}
self
.
decode_status
=
{}
...
@@ -85,6 +88,11 @@ class DetokenizerManager:
...
@@ -85,6 +88,11 @@ class DetokenizerManager:
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
assert
isinstance
(
recv_obj
,
BatchTokenIDOut
)
bs
=
len
(
recv_obj
.
rids
)
bs
=
len
(
recv_obj
.
rids
)
if
self
.
tokenizer
is
None
:
# Send BatchTokenIDOut if no tokenizer init'ed.
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
continue
# Initialize decode status
# Initialize decode status
read_ids
,
surr_ids
=
[],
[]
read_ids
,
surr_ids
=
[],
[]
for
i
in
range
(
bs
):
for
i
in
range
(
bs
):
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
9cf0a5ba
...
@@ -195,6 +195,8 @@ class Req:
...
@@ -195,6 +195,8 @@ class Req:
return
all_ids
[
self
.
surr_offset
:],
self
.
read_offset
-
self
.
surr_offset
return
all_ids
[
self
.
surr_offset
:],
self
.
read_offset
-
self
.
surr_offset
def
get_next_inc_detokenization
(
self
):
def
get_next_inc_detokenization
(
self
):
if
self
.
tokenizer
is
None
:
return
False
,
""
read_ids
,
read_offset
=
self
.
init_incremental_detokenize
()
read_ids
,
read_offset
=
self
.
init_incremental_detokenize
()
surr_ids
=
read_ids
[:
read_offset
]
surr_ids
=
read_ids
[:
read_offset
]
...
@@ -225,16 +227,11 @@ class Req:
...
@@ -225,16 +227,11 @@ class Req:
return
return
last_token_id
=
self
.
output_ids
[
-
1
]
last_token_id
=
self
.
output_ids
[
-
1
]
if
(
if
self
.
tokenizer
is
None
:
last_token_id
==
self
.
tokenizer
.
eos_token_id
matched_eos
=
last_token_id
in
self
.
sampling_params
.
stop_token_ids
and
not
self
.
sampling_params
.
ignore_eos
else
:
):
matched_eos
=
last_token_id
==
self
.
tokenizer
.
eos_token_id
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
if
matched_eos
and
not
self
.
sampling_params
.
ignore_eos
:
matched
=
self
.
tokenizer
.
eos_token_id
)
return
if
last_token_id
in
self
.
sampling_params
.
stop_token_ids
:
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
last_token_id
)
self
.
finished_reason
=
FINISH_MATCHED_TOKEN
(
matched
=
last_token_id
)
return
return
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
9cf0a5ba
...
@@ -95,25 +95,28 @@ class TokenizerManager:
...
@@ -95,25 +95,28 @@ class TokenizerManager:
else
:
else
:
self
.
context_len
=
get_context_length
(
self
.
hf_config
)
self
.
context_len
=
get_context_length
(
self
.
hf_config
)
if
is_multimodal_model
(
self
.
model_path
):
if
server_args
.
skip_tokenizer_init
:
self
.
processor
=
get_processor
(
self
.
tokenizer
=
self
.
processor
=
None
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
tokenizer
=
self
.
processor
.
tokenizer
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
initargs
=
(
server_args
,),
)
else
:
else
:
self
.
tokenizer
=
get_tokenizer
(
if
is_multimodal_model
(
self
.
model_path
):
server_args
.
tokenizer_path
,
self
.
processor
=
get_processor
(
tokenizer_mode
=
server_args
.
tokenizer_mode
,
server_args
.
tokenizer_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
)
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
tokenizer
=
self
.
processor
.
tokenizer
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
initargs
=
(
server_args
,),
)
else
:
self
.
tokenizer
=
get_tokenizer
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
to_create_loop
=
True
self
.
to_create_loop
=
True
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
self
.
rid_to_state
:
Dict
[
str
,
ReqState
]
=
{}
...
@@ -171,6 +174,7 @@ class TokenizerManager:
...
@@ -171,6 +174,7 @@ class TokenizerManager:
rid
=
obj
.
rid
if
not_use_index
else
obj
.
rid
[
index
]
rid
=
obj
.
rid
if
not_use_index
else
obj
.
rid
[
index
]
input_text
=
obj
.
text
if
not_use_index
else
obj
.
text
[
index
]
input_text
=
obj
.
text
if
not_use_index
else
obj
.
text
[
index
]
if
obj
.
input_ids
is
None
:
if
obj
.
input_ids
is
None
:
assert
self
.
tokenizer
is
not
None
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
else
:
input_ids
=
obj
.
input_ids
if
not_use_index
else
obj
.
input_ids
[
index
]
input_ids
=
obj
.
input_ids
if
not_use_index
else
obj
.
input_ids
[
index
]
...
@@ -207,7 +211,20 @@ class TokenizerManager:
...
@@ -207,7 +211,20 @@ class TokenizerManager:
else
:
else
:
input_text
=
obj
.
text
input_text
=
obj
.
text
rid
=
obj
.
rid
[
0
]
rid
=
obj
.
rid
[
0
]
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
if
self
.
tokenizer
is
not
None
:
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
assert
obj
.
input_ids
is
not
None
input_ids
=
obj
.
input_ids
if
isinstance
(
obj
.
input_ids
,
list
)
and
isinstance
(
obj
.
input_ids
[
0
],
list
):
# when obj["input_ids"] is List[List[int]]
input_ids
=
obj
.
input_ids
[
index
]
rid
=
obj
.
rid
[
index
]
else
:
input_ids
=
obj
.
input_ids
rid
=
obj
.
rid
[
0
]
else
:
else
:
input_text
=
None
input_text
=
None
if
isinstance
(
obj
.
input_ids
,
list
)
and
isinstance
(
if
isinstance
(
obj
.
input_ids
,
list
)
and
isinstance
(
...
@@ -420,7 +437,7 @@ class TokenizerManager:
...
@@ -420,7 +437,7 @@ class TokenizerManager:
# Log requests
# Log requests
if
self
.
server_args
.
log_requests
and
state
.
finished
:
if
self
.
server_args
.
log_requests
and
state
.
finished
:
if
obj
.
text
is
None
:
if
obj
.
text
is
None
:
in_obj
=
{
"
text"
:
self
.
tokenizer
.
decode
(
obj
.
input_ids
)
}
in_obj
=
{
"
input_ids"
:
obj
.
input_ids
}
else
:
else
:
in_obj
=
{
"text"
:
obj
.
text
}
in_obj
=
{
"text"
:
obj
.
text
}
logger
.
info
(
f
"in=
{
in_obj
}
, out=
{
out
}
"
)
logger
.
info
(
f
"in=
{
in_obj
}
, out=
{
out
}
"
)
...
@@ -488,11 +505,12 @@ class TokenizerManager:
...
@@ -488,11 +505,12 @@ class TokenizerManager:
async
def
handle_loop
(
self
):
async
def
handle_loop
(
self
):
while
True
:
while
True
:
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
]
=
(
recv_obj
:
Union
[
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
]
=
(
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
)
)
assert
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
))
assert
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)
),
f
"Unexpected obj received:
{
type
(
recv_obj
)
}
"
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
for
i
,
rid
in
enumerate
(
recv_obj
.
rids
):
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
state
=
self
.
rid_to_state
.
get
(
rid
,
None
)
if
state
is
None
:
if
state
is
None
:
...
@@ -504,6 +522,15 @@ class TokenizerManager:
...
@@ -504,6 +522,15 @@ class TokenizerManager:
"text"
:
recv_obj
.
output_strs
[
i
],
"text"
:
recv_obj
.
output_strs
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
}
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
read_start
=
0
if
i
==
0
else
recv_obj
.
read_offsets
[
i
-
1
]
out_dict
=
{
"token_ids"
:
recv_obj
.
decode_ids
[
read_start
:
recv_obj
.
read_offsets
[
i
]
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
}
else
:
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
out_dict
=
{
out_dict
=
{
...
@@ -549,6 +576,7 @@ class TokenizerManager:
...
@@ -549,6 +576,7 @@ class TokenizerManager:
if
not
decode_to_text
:
if
not
decode_to_text
:
return
[(
logprob
,
token_id
,
None
)
for
logprob
,
token_id
in
token_logprobs
]
return
[(
logprob
,
token_id
,
None
)
for
logprob
,
token_id
in
token_logprobs
]
assert
self
.
tokenizer
is
not
None
token_ids
=
[
tid
for
_
,
tid
in
token_logprobs
]
token_ids
=
[
tid
for
_
,
tid
in
token_logprobs
]
token_texts
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
token_texts
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
return
[
return
[
...
...
python/sglang/srt/managers/tp_worker.py
View file @
9cf0a5ba
...
@@ -100,20 +100,22 @@ class ModelTpServer:
...
@@ -100,20 +100,22 @@ class ModelTpServer:
nccl_port
=
nccl_port
,
nccl_port
=
nccl_port
,
server_args
=
server_args
,
server_args
=
server_args
,
)
)
if
server_args
.
skip_tokenizer_init
:
if
is_multimodal_model
(
server_args
.
model_path
):
self
.
tokenizer
=
self
.
processor
=
None
self
.
processor
=
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
tokenizer
=
self
.
processor
.
tokenizer
else
:
else
:
self
.
tokenizer
=
get_tokenizer
(
if
is_multimodal_model
(
server_args
.
model_path
):
server_args
.
tokenizer_path
,
self
.
processor
=
get_processor
(
tokenizer_mode
=
server_args
.
tokenizer_mode
,
server_args
.
tokenizer_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
)
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
tokenizer
=
self
.
processor
.
tokenizer
else
:
self
.
tokenizer
=
get_tokenizer
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
)
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_total_num_tokens
=
self
.
model_runner
.
max_total_num_tokens
self
.
max_prefill_tokens
=
(
self
.
max_prefill_tokens
=
(
16384
16384
...
@@ -182,13 +184,15 @@ class ModelTpServer:
...
@@ -182,13 +184,15 @@ class ModelTpServer:
self
.
last_stats_tic
=
time
.
time
()
self
.
last_stats_tic
=
time
.
time
()
# Init the FSM cache for constrained generation
# Init the FSM cache for constrained generation
self
.
regex_fsm_cache
=
FSMCache
(
if
not
server_args
.
skip_tokenizer_init
:
server_args
.
tokenizer_path
,
self
.
regex_fsm_cache
=
FSMCache
(
{
server_args
.
tokenizer_path
,
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
{
"trust_remote_code"
:
server_args
.
trust_remote_code
,
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
},
"trust_remote_code"
:
server_args
.
trust_remote_code
,
)
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
)
self
.
jump_forward_cache
=
JumpForwardCache
()
self
.
jump_forward_cache
=
JumpForwardCache
()
# Init new token estimation
# Init new token estimation
...
@@ -466,7 +470,11 @@ class ModelTpServer:
...
@@ -466,7 +470,11 @@ class ModelTpServer:
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
else
:
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
if
self
.
tokenizer
is
None
:
for
i
,
req
in
enumerate
(
batch
.
reqs
):
next_token_ids
.
extend
(
req
.
sampling_params
.
stop_token_ids
)
else
:
next_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
*
len
(
batch
.
reqs
)
# Check finish conditions
# Check finish conditions
pt
=
0
pt
=
0
...
...
python/sglang/srt/sampling_params.py
View file @
9cf0a5ba
...
@@ -111,13 +111,19 @@ class SamplingParams:
...
@@ -111,13 +111,19 @@ class SamplingParams:
# Process stop strings
# Process stop strings
if
self
.
stop_strs
is
None
:
if
self
.
stop_strs
is
None
:
self
.
stop_strs
=
[]
self
.
stop_strs
=
[]
self
.
stop_str_max_len
=
0
if
self
.
stop_token_ids
is
None
:
self
.
stop_str_max_len
=
0
else
:
self
.
stop_str_max_len
=
1
else
:
else
:
if
isinstance
(
self
.
stop_strs
,
str
):
if
isinstance
(
self
.
stop_strs
,
str
):
self
.
stop_strs
=
[
self
.
stop_strs
]
self
.
stop_strs
=
[
self
.
stop_strs
]
stop_str_max_len
=
0
stop_str_max_len
=
0
for
stop_str
in
self
.
stop_strs
:
for
stop_str
in
self
.
stop_strs
:
stop_str_ids
=
tokenizer
.
encode
(
stop_str
,
add_special_tokens
=
False
)
if
tokenizer
is
not
None
:
stop_str_max_len
=
max
(
stop_str_max_len
,
len
(
stop_str_ids
))
stop_str_ids
=
tokenizer
.
encode
(
stop_str
,
add_special_tokens
=
False
)
stop_str_max_len
=
max
(
stop_str_max_len
,
len
(
stop_str_ids
))
else
:
stop_str_max_len
=
max
(
stop_str_max_len
,
len
(
stop_str
))
self
.
stop_str_max_len
=
stop_str_max_len
self
.
stop_str_max_len
=
stop_str_max_len
python/sglang/srt/server.py
View file @
9cf0a5ba
...
@@ -420,17 +420,22 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
...
@@ -420,17 +420,22 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
# Send a warmup request
# Send a warmup request
request_name
=
"/generate"
if
model_info
[
"is_generation"
]
else
"/encode"
request_name
=
"/generate"
if
model_info
[
"is_generation"
]
else
"/encode"
max_new_tokens
=
8
if
model_info
[
"is_generation"
]
else
1
max_new_tokens
=
8
if
model_info
[
"is_generation"
]
else
1
json_data
=
{
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
max_new_tokens
,
},
}
if
server_args
.
skip_tokenizer_init
:
json_data
[
"input_ids"
]
=
[
10
,
11
,
12
]
else
:
json_data
[
"text"
]
=
"The capital city of France is"
try
:
try
:
for
_
in
range
(
server_args
.
dp_size
):
for
_
in
range
(
server_args
.
dp_size
):
res
=
requests
.
post
(
res
=
requests
.
post
(
url
+
request_name
,
url
+
request_name
,
json
=
{
json
=
json_data
,
"text"
:
"The capital city of France is"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
max_new_tokens
,
},
},
headers
=
headers
,
headers
=
headers
,
timeout
=
600
,
timeout
=
600
,
)
)
...
...
python/sglang/srt/server_args.py
View file @
9cf0a5ba
...
@@ -27,6 +27,7 @@ class ServerArgs:
...
@@ -27,6 +27,7 @@ class ServerArgs:
model_path
:
str
model_path
:
str
tokenizer_path
:
Optional
[
str
]
=
None
tokenizer_path
:
Optional
[
str
]
=
None
tokenizer_mode
:
str
=
"auto"
tokenizer_mode
:
str
=
"auto"
skip_tokenizer_init
:
bool
=
False
load_format
:
str
=
"auto"
load_format
:
str
=
"auto"
dtype
:
str
=
"auto"
dtype
:
str
=
"auto"
trust_remote_code
:
bool
=
True
trust_remote_code
:
bool
=
True
...
@@ -151,6 +152,11 @@ class ServerArgs:
...
@@ -151,6 +152,11 @@ class ServerArgs:
"tokenizer if available, and 'slow' will "
"tokenizer if available, and 'slow' will "
"always use the slow tokenizer."
,
"always use the slow tokenizer."
,
)
)
parser
.
add_argument
(
"--skip-tokenizer-init"
,
action
=
"store_true"
,
help
=
"If set, skip init tokenizer and pass input_ids in generate request"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--load-format"
,
"--load-format"
,
type
=
str
,
type
=
str
,
...
...
python/sglang/srt/utils.py
View file @
9cf0a5ba
...
@@ -197,6 +197,8 @@ def allocate_init_ports(
...
@@ -197,6 +197,8 @@ def allocate_init_ports(
def
get_int_token_logit_bias
(
tokenizer
,
vocab_size
):
def
get_int_token_logit_bias
(
tokenizer
,
vocab_size
):
"""Get the logit bias for integer-only tokens."""
"""Get the logit bias for integer-only tokens."""
# a bug when model's vocab size > tokenizer.vocab_size
# a bug when model's vocab size > tokenizer.vocab_size
if
tokenizer
==
None
:
return
[
-
1e5
]
*
vocab_size
vocab_size
=
tokenizer
.
vocab_size
vocab_size
=
tokenizer
.
vocab_size
logit_bias
=
np
.
zeros
(
vocab_size
,
dtype
=
np
.
float32
)
logit_bias
=
np
.
zeros
(
vocab_size
,
dtype
=
np
.
float32
)
for
t_id
in
range
(
vocab_size
):
for
t_id
in
range
(
vocab_size
):
...
...
test/srt/test_skip_tokenizer_srt.py
0 → 100644
View file @
9cf0a5ba
import
json
import
os
import
sys
import
unittest
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
DEFAULT_MODEL_NAME_FOR_TEST
,
popen_launch_server
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
class
TestSRTEndpoint
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MODEL_NAME_FOR_TEST
cls
.
base_url
=
"http://127.0.0.1:8157"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
300
,
other_args
=
[
"--skip-tokenizer-init"
]
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
)
def
run_decode
(
self
,
return_logprob
=
False
,
top_logprobs_num
=
0
,
return_text
=
False
,
n
=
1
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
[
119689
,
50650
,
18291
,
30061
,
5316
,
26951
,
119690
,
],
# The capital of France is
"sampling_params"
:
{
"temperature"
:
0
if
n
==
1
else
0.5
,
"max_new_tokens"
:
32
,
"n"
:
n
,
"stop_token_ids"
:
[
119690
],
},
"stream"
:
False
,
"return_logprob"
:
return_logprob
,
"top_logprobs_num"
:
top_logprobs_num
,
"return_text_in_logprobs"
:
return_text
,
"logprob_start_len"
:
0
,
},
)
print
(
json
.
dumps
(
response
.
json
()))
print
(
"="
*
100
)
def
test_simple_decode
(
self
):
self
.
run_decode
()
def
test_parallel_sample
(
self
):
self
.
run_decode
(
n
=
3
)
def
test_logprob
(
self
):
for
top_logprobs_num
in
[
0
,
3
]:
for
return_text
in
[
False
,
False
]:
self
.
run_decode
(
return_logprob
=
True
,
top_logprobs_num
=
top_logprobs_num
,
return_text
=
return_text
,
)
if
__name__
==
"__main__"
:
unittest
.
main
(
warnings
=
"ignore"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment