Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
23950056
Unverified
Commit
23950056
authored
Jan 25, 2024
by
parasol-aser
Committed by
GitHub
Jan 25, 2024
Browse files
support speculative execution for openai API (#48)
Co-authored-by:
Ying Sheng
<
sqy1415@gmail.com
>
parent
93414c82
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
178 additions
and
12 deletions
+178
-12
examples/usage/openai_speculative.py
examples/usage/openai_speculative.py
+19
-0
python/sglang/api.py
python/sglang/api.py
+10
-2
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+70
-5
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+2
-1
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+3
-1
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+1
-1
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-0
test/lang/test_openai_spec.py
test/lang/test_openai_spec.py
+68
-0
test/srt/test_fast_forward.py
test/srt/test_fast_forward.py
+2
-1
test/srt/test_robust.py
test/srt/test_robust.py
+2
-1
No files found.
examples/usage/openai_speculative.py
0 → 100644
View file @
23950056
from
sglang
import
function
,
gen
,
set_default_backend
,
OpenAI
@
function
(
api_num_spec_tokens
=
512
)
def
gen_character_spec
(
s
):
s
+=
"Construct a character within the following format:
\n
"
s
+=
"Name: Steve Jobs.
\n
Birthday: February 24, 1955.
\n
Job: Apple CEO.
\n
"
s
+=
"
\n
Please generate new Name, Birthday and Job.
\n
"
s
+=
"Name:"
+
gen
(
"name"
,
stop
=
"
\n
"
)
+
"
\n
Birthday:"
+
gen
(
"birthday"
,
stop
=
"
\n
"
)
s
+=
"
\n
Job:"
+
gen
(
"job"
,
stop
=
"
\n
"
)
+
"
\n
"
set_default_backend
(
OpenAI
(
"gpt-3.5-turbo-instruct"
))
state
=
gen_character_spec
.
run
()
print
(
"name:"
,
state
[
"name"
])
print
(
"birthday:"
,
state
[
"birthday"
])
print
(
"job:"
,
state
[
"job"
])
python/sglang/api.py
View file @
23950056
...
@@ -20,8 +20,16 @@ from sglang.lang.ir import (
...
@@ -20,8 +20,16 @@ from sglang.lang.ir import (
)
)
def
function
(
func
:
Callable
):
def
function
(
return
SglFunction
(
func
)
func
:
Optional
[
Callable
]
=
None
,
api_num_spec_tokens
:
Optional
[
int
]
=
None
):
if
func
:
return
SglFunction
(
func
,
api_num_spec_tokens
=
api_num_spec_tokens
)
def
decorator
(
func
):
return
SglFunction
(
func
,
api_num_spec_tokens
=
api_num_spec_tokens
)
return
decorator
def
Runtime
(
*
args
,
**
kwargs
):
def
Runtime
(
*
args
,
**
kwargs
):
...
...
python/sglang/lang/interpreter.py
View file @
23950056
...
@@ -51,10 +51,14 @@ def run_program(
...
@@ -51,10 +51,14 @@ def run_program(
if
hasattr
(
backend
,
"endpoint"
):
if
hasattr
(
backend
,
"endpoint"
):
backend
=
backend
.
endpoint
backend
=
backend
.
endpoint
assert
backend
is
not
None
,
"Please specify a backend"
assert
backend
is
not
None
,
"Please specify a backend"
func_kwargs
.
update
(
program
.
bind_arguments
)
func_kwargs
.
update
(
program
.
bind_arguments
)
stream_executor
=
StreamExecutor
(
stream_executor
=
StreamExecutor
(
backend
,
func_kwargs
,
default_sampling_para
,
chat_template
=
None
,
stream
=
stream
backend
,
func_kwargs
,
default_sampling_para
,
chat_template
=
None
,
stream
=
stream
,
api_num_spec_tokens
=
program
.
api_num_spec_tokens
,
)
)
state
=
ProgramState
(
stream_executor
)
state
=
ProgramState
(
stream_executor
)
...
@@ -175,6 +179,7 @@ class StreamExecutor:
...
@@ -175,6 +179,7 @@ class StreamExecutor:
default_sampling_para
,
default_sampling_para
,
chat_template
,
chat_template
,
stream
,
stream
,
api_num_spec_tokens
=
None
,
use_thread
=
True
,
use_thread
=
True
,
):
):
self
.
sid
=
uuid
.
uuid4
().
hex
self
.
sid
=
uuid
.
uuid4
().
hex
...
@@ -182,6 +187,7 @@ class StreamExecutor:
...
@@ -182,6 +187,7 @@ class StreamExecutor:
self
.
arguments
:
Dict
[
str
,
Any
]
=
arguments
self
.
arguments
:
Dict
[
str
,
Any
]
=
arguments
self
.
default_sampling_para
=
default_sampling_para
self
.
default_sampling_para
=
default_sampling_para
self
.
stream
=
stream
self
.
stream
=
stream
self
.
api_num_spec_tokens
=
api_num_spec_tokens
self
.
variables
=
{}
# Dict[name: str -> value: str]
self
.
variables
=
{}
# Dict[name: str -> value: str]
self
.
variable_event
=
{}
# Dict[name: str -> event: threading.Event]
self
.
variable_event
=
{}
# Dict[name: str -> event: threading.Event]
...
@@ -191,6 +197,9 @@ class StreamExecutor:
...
@@ -191,6 +197,9 @@ class StreamExecutor:
# For completion
# For completion
self
.
text_
=
""
# The full text
self
.
text_
=
""
# The full text
# For speculative execution
self
.
speculated_text
=
""
# For chat
# For chat
self
.
messages_
=
[]
# The messages in the OpenAI API format
self
.
messages_
=
[]
# The messages in the OpenAI API format
self
.
chat_template
=
chat_template
or
self
.
backend
.
get_chat_template
()
self
.
chat_template
=
chat_template
or
self
.
backend
.
get_chat_template
()
...
@@ -341,6 +350,10 @@ class StreamExecutor:
...
@@ -341,6 +350,10 @@ class StreamExecutor:
def
_execute_fill
(
self
,
value
:
str
):
def
_execute_fill
(
self
,
value
:
str
):
value
=
str
(
value
)
value
=
str
(
value
)
if
self
.
speculated_text
.
startswith
(
value
):
self
.
speculated_text
=
self
.
speculated_text
[
len
(
value
)
:]
else
:
self
.
speculated_text
=
""
self
.
text_
+=
value
self
.
text_
+=
value
def
_execute_image
(
self
,
expr
:
SglImage
):
def
_execute_image
(
self
,
expr
:
SglImage
):
...
@@ -360,9 +373,61 @@ class StreamExecutor:
...
@@ -360,9 +373,61 @@ class StreamExecutor:
name
=
expr
.
name
name
=
expr
.
name
if
not
self
.
stream
:
if
not
self
.
stream
:
if
self
.
api_num_spec_tokens
is
not
None
:
stop
=
sampling_params
.
stop
max_new_tokens
=
sampling_params
.
max_new_tokens
meta_info
=
{}
def
regen
():
sampling_params
.
max_new_tokens
=
max
(
sampling_params
.
max_new_tokens
,
self
.
api_num_spec_tokens
)
sampling_params
.
stop
=
None
self
.
speculated_text
,
meta_info
=
self
.
backend
.
generate
(
self
,
sampling_params
=
sampling_params
)
def
find_stop
():
if
isinstance
(
stop
,
str
):
return
self
.
speculated_text
.
find
(
stop
),
len
(
stop
)
elif
isinstance
(
stop
,
(
tuple
,
list
)):
pos
=
-
1
stop_len
=
0
for
stop_str
in
stop
:
stop_pos
=
self
.
speculated_text
.
find
(
stop_str
)
if
stop_pos
!=
-
1
and
(
pos
==
-
1
or
stop_pos
<
pos
):
pos
=
stop_pos
stop_len
=
len
(
stop_str
)
return
pos
,
stop_len
else
:
raise
Exception
(
"Wrong type of stop in sampling parameters."
)
if
stop
is
None
:
if
len
(
self
.
speculated_text
)
<
max_new_tokens
:
regen
()
comp
=
self
.
speculated_text
[:
max_new_tokens
]
self
.
speculated_text
=
self
.
speculated_text
[
max_new_tokens
:]
elif
isinstance
(
stop
,
(
str
,
list
,
tuple
)):
if
self
.
speculated_text
==
""
:
regen
()
stop_pos
,
stop_len
=
find_stop
()
if
stop_pos
==
-
1
:
stop_pos
,
stop_len
=
(
min
(
sampling_params
.
max_new_tokens
,
len
(
self
.
speculated_text
),
),
0
,
)
comp
=
self
.
speculated_text
[:
stop_pos
]
self
.
speculated_text
=
self
.
speculated_text
[
stop_pos
:]
else
:
raise
ValueError
(
"Wrong type of stop in sampling parameters."
)
else
:
comp
,
meta_info
=
self
.
backend
.
generate
(
comp
,
meta_info
=
self
.
backend
.
generate
(
self
,
sampling_params
=
sampling_params
self
,
sampling_params
=
sampling_params
)
)
self
.
text_
+=
comp
self
.
text_
+=
comp
self
.
variables
[
name
]
=
comp
self
.
variables
[
name
]
=
comp
...
...
python/sglang/lang/ir.py
View file @
23950056
...
@@ -95,8 +95,9 @@ class SglSamplingParams:
...
@@ -95,8 +95,9 @@ class SglSamplingParams:
class
SglFunction
:
class
SglFunction
:
def
__init__
(
self
,
func
,
bind_arguments
=
None
):
def
__init__
(
self
,
func
,
api_num_spec_tokens
=
None
,
bind_arguments
=
None
):
self
.
func
=
func
self
.
func
=
func
self
.
api_num_spec_tokens
=
api_num_spec_tokens
self
.
bind_arguments
=
bind_arguments
or
{}
self
.
bind_arguments
=
bind_arguments
or
{}
self
.
pin_prefix_rid
=
None
self
.
pin_prefix_rid
=
None
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
23950056
...
@@ -60,7 +60,9 @@ class DetokenizerManager:
...
@@ -60,7 +60,9 @@ class DetokenizerManager:
if
first_token
.
startswith
(
"▁"
):
if
first_token
.
startswith
(
"▁"
):
output_strs
[
i
]
=
" "
+
output_strs
[
i
]
output_strs
[
i
]
=
" "
+
output_strs
[
i
]
output_strs
[
i
]
=
recv_obj
.
output_and_fast_forward_strs
[
i
]
+
output_strs
[
i
]
output_strs
[
i
]
=
(
recv_obj
.
output_and_fast_forward_strs
[
i
]
+
output_strs
[
i
]
)
self
.
send_to_tokenizer
.
send_pyobj
(
self
.
send_to_tokenizer
.
send_pyobj
(
BatchStrOut
(
BatchStrOut
(
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
23950056
...
@@ -12,6 +12,7 @@ import rpyc
...
@@ -12,6 +12,7 @@ import rpyc
import
torch
import
torch
from
rpyc.utils.classic
import
obtain
from
rpyc.utils.classic
import
obtain
from
rpyc.utils.server
import
ThreadedServer
from
rpyc.utils.server
import
ThreadedServer
from
sglang.srt.constrained.fast_forward
import
FastForwardCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.constrained.fsm_cache
import
FSMCache
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_processor
,
get_tokenizer
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
,
TokenizedGenerateReqInput
from
sglang.srt.managers.io_struct
import
BatchTokenIDOut
,
TokenizedGenerateReqInput
...
@@ -21,7 +22,6 @@ from sglang.srt.managers.router.radix_cache import RadixCache
...
@@ -21,7 +22,6 @@ from sglang.srt.managers.router.radix_cache import RadixCache
from
sglang.srt.managers.router.scheduler
import
Scheduler
from
sglang.srt.managers.router.scheduler
import
Scheduler
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.constrained.fast_forward
import
FastForwardCache
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_exception_traceback
,
get_exception_traceback
,
get_int_token_logit_bias
,
get_int_token_logit_bias
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
23950056
...
@@ -200,6 +200,7 @@ class TokenizerManager:
...
@@ -200,6 +200,7 @@ class TokenizerManager:
)
)
tokenized_obj
=
TokenizedGenerateReqInput
(
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
=
rid
,
rid
=
rid
,
input_text
=
obj
.
text
[
i
],
input_ids
=
input_ids
,
input_ids
=
input_ids
,
pixel_values
=
pixel_values
,
pixel_values
=
pixel_values
,
image_hash
=
image_hash
,
image_hash
=
image_hash
,
...
...
test/lang/test_openai_spec.py
0 → 100644
View file @
23950056
from
sglang
import
OpenAI
,
function
,
gen
,
set_default_backend
@
function
()
def
gen_character_default
(
s
):
s
+=
"Construct a character within the following format:
\n
"
s
+=
"Name: Steve Jobs.
\n
Birthday: February 24, 1955.
\n
Job: Apple CEO.
\n
Welcome.
\n
"
s
+=
"
\n
Please generate new Name, Birthday and Job.
\n
"
s
+=
"Name:"
+
gen
(
"name"
,
stop
=
"
\n
"
)
+
"
\n
Birthday:"
+
gen
(
"birthday"
,
stop
=
"
\n
"
)
s
+=
"
\n
Job:"
+
gen
(
"job"
,
stop
=
"
\n
"
)
+
"
\n
Welcome.
\n
"
@
function
(
api_num_spec_tokens
=
512
)
def
gen_character_spec
(
s
):
s
+=
"Construct a character within the following format:
\n
"
s
+=
"Name: Steve Jobs.
\n
Birthday: February 24, 1955.
\n
Job: Apple CEO.
\n
Welcome.
\n
"
s
+=
"
\n
Please generate new Name, Birthday and Job.
\n
"
s
+=
"Name:"
+
gen
(
"name"
,
stop
=
"
\n
"
)
+
"
\n
Birthday:"
+
gen
(
"birthday"
,
stop
=
"
\n
"
)
s
+=
"
\n
Job:"
+
gen
(
"job"
,
stop
=
"
\n
"
)
+
"
\n
Welcome.
\n
"
@
function
(
api_num_spec_tokens
=
512
)
def
gen_character_no_stop
(
s
):
s
+=
"Construct a character within the following format:
\n
"
s
+=
"Name: Steve Jobs.
\n
Birthday: February 24, 1955.
\n
Job: Apple CEO.
\n
Welcome.
\n
"
s
+=
"
\n
Please generate new Name, Birthday and Job.
\n
"
s
+=
"Name:"
+
gen
(
"name"
)
+
"
\n
Birthday:"
+
gen
(
"birthday"
)
s
+=
"
\n
Job:"
+
gen
(
"job"
)
+
"
\n
Welcome.
\n
"
@
function
(
api_num_spec_tokens
=
512
)
def
gen_character_multi_stop
(
s
):
s
+=
"Construct a character within the following format:
\n
"
s
+=
(
"Name: Steve Jobs.###Birthday: February 24, 1955.###Job: Apple CEO.
\n
Welcome.
\n
"
)
s
+=
"
\n
Please generate new Name, Birthday and Job.
\n
"
s
+=
"Name:"
+
gen
(
"name"
,
stop
=
[
"
\n
"
,
"###"
])
s
+=
"###Birthday:"
+
gen
(
"birthday"
,
stop
=
[
"
\n
"
,
"###"
])
s
+=
"###Job:"
+
gen
(
"job"
,
stop
=
[
"
\n
"
,
"###"
])
+
"
\n
Welcome.
\n
"
set_default_backend
(
OpenAI
(
"gpt-3.5-turbo-instruct"
))
state
=
gen_character_default
.
run
()
print
(
state
.
text
())
print
(
"="
*
60
)
state
=
gen_character_no_stop
.
run
()
print
(
"name###"
,
state
[
"name"
])
print
(
"birthday###:"
,
state
[
"birthday"
])
print
(
"job###"
,
state
[
"job"
])
print
(
"="
*
60
)
state
=
gen_character_multi_stop
.
run
()
print
(
state
.
text
())
print
(
"="
*
60
)
state
=
gen_character_spec
.
run
()
print
(
state
.
text
())
print
(
"name###"
,
state
[
"name"
])
print
(
"birthday###"
,
state
[
"birthday"
])
print
(
"job###"
,
state
[
"job"
])
test/srt/test_fast_forward.py
View file @
23950056
import
argparse
import
argparse
from
enum
import
Enum
from
enum
import
Enum
import
sglang
as
sgl
from
pydantic
import
BaseModel
,
constr
from
pydantic
import
BaseModel
,
constr
from
sglang.srt.constrained.json_schema
import
build_regex_from_object
from
sglang.srt.constrained.json_schema
import
build_regex_from_object
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
...
@@ -9,6 +8,8 @@ from sglang.test.test_utils import (
...
@@ -9,6 +8,8 @@ from sglang.test.test_utils import (
select_sglang_backend
,
select_sglang_backend
,
)
)
import
sglang
as
sgl
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
IP_REGEX
=
r
"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
ip_fast_forward
=
(
ip_fast_forward
=
(
...
...
test/srt/test_robust.py
View file @
23950056
...
@@ -2,13 +2,14 @@ import argparse
...
@@ -2,13 +2,14 @@ import argparse
import
random
import
random
import
string
import
string
import
sglang
as
sgl
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
add_common_sglang_args_and_parse
,
select_sglang_backend
,
select_sglang_backend
,
)
)
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
import
sglang
as
sgl
TOKENIZER
=
None
TOKENIZER
=
None
RANDOM_PREFILL_LEN
=
None
RANDOM_PREFILL_LEN
=
None
RANDOM_DECODE_LEN
=
None
RANDOM_DECODE_LEN
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment