Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
08ab2a16
Unverified
Commit
08ab2a16
authored
Jan 15, 2024
by
Liangsheng Yin
Committed by
GitHub
Jan 15, 2024
Browse files
Json Decode && Mutl-Turns (#4)
parent
f652494d
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
704 additions
and
33 deletions
+704
-33
3rdparty/flashinfer
3rdparty/flashinfer
+1
-1
benchmark/json_regex_decode/README.md
benchmark/json_regex_decode/README.md
+61
-0
benchmark/json_regex_decode/bench_other.py
benchmark/json_regex_decode/bench_other.py
+125
-0
benchmark/json_regex_decode/bench_sglang.py
benchmark/json_regex_decode/bench_sglang.py
+100
-0
benchmark/json_regex_decode/build_dataset.py
benchmark/json_regex_decode/build_dataset.py
+58
-0
benchmark/multi_turns/README.md
benchmark/multi_turns/README.md
+66
-0
benchmark/multi_turns/bench_other.py
benchmark/multi_turns/bench_other.py
+133
-0
benchmark/multi_turns/bench_sglang.py
benchmark/multi_turns/bench_sglang.py
+77
-0
benchmark/multi_turns/data_gen.py
benchmark/multi_turns/data_gen.py
+29
-0
python/sglang/api.py
python/sglang/api.py
+6
-0
python/sglang/backend/anthropic.py
python/sglang/backend/anthropic.py
+3
-3
python/sglang/backend/base_backend.py
python/sglang/backend/base_backend.py
+3
-3
python/sglang/backend/openai.py
python/sglang/backend/openai.py
+3
-3
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+3
-3
python/sglang/backend/tgi.py
python/sglang/backend/tgi.py
+3
-3
python/sglang/lang/compiler.py
python/sglang/lang/compiler.py
+3
-3
python/sglang/lang/interpreter.py
python/sglang/lang/interpreter.py
+3
-4
python/sglang/lang/ir.py
python/sglang/lang/ir.py
+13
-9
python/sglang/lang/tracer.py
python/sglang/lang/tracer.py
+2
-1
python/sglang/srt/backend_config.py
python/sglang/srt/backend_config.py
+12
-0
No files found.
flashinfer
@
88b9496e
Compare
00cf5f46
...
88b9496e
Subproject commit
00cf5f46fdbb4f1dbd9277fe3b842621c1d9e7dc
Subproject commit
88b9496e1a726ddb353eb42887cfc0ab32c99460
benchmark/json_regex_decode/README.md
0 → 100644
View file @
08ab2a16
## Run benchmark
### Build dataset
```
pip install wikipedia
python3 build_dataset.py
```
### Dependencies
```
llama_cpp_python 0.2.19
guidance 0.1.10
vllm 0.2.5
outlines 0.0.22
```
### Benchmark sglang
Run llama-7b
```
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Run mixtral-8x7b
(When there is a CUDA out-of-memory error, try to reduce the
`--mem-fraction-static`
)
```
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
```
Benchmark
```
python3 bench_sglang.py --num-questions 10
```
### Benchmark vllm
Run llama-7b
```
python3 -m outlines.serve.serve --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
Benchmark
```
python3 bench_other.py --backend vllm --num-questions 10
```
### Benchmark guidance
Run llama-7b and benchmark
```
python3 bench_other.py --backend guidance --num-questions 10 --parallel 1
```
\ No newline at end of file
benchmark/json_regex_decode/bench_other.py
0 → 100644
View file @
08ab2a16
import
argparse
import
json
import
time
from
concurrent.futures
import
ThreadPoolExecutor
from
functools
import
partial
from
sglang.test.test_utils
import
(
add_common_other_args_and_parse
,
call_generate_outlines
,
)
from
sglang.utils
import
dump_state_text
,
read_jsonl
from
sglang.lang.ir
import
REGEX_INT
,
REGEX_STRING
,
REGEX_FLOAT
from
tqdm
import
tqdm
REGEX_LIST
=
r
"\[("
+
REGEX_STRING
+
", )*"
+
REGEX_STRING
+
r
"\]"
# fmt: off
def
json_decode
(
document
,
generate
):
s
=
"Please extract the information of a city from the following wikipedia page.
\n
"
s
+=
"Page begin.
\n
"
+
document
+
"Page end.
\n
"
s
+=
"Here is the name, country, and symbol of the city in JSON format.
\n
"
s
+=
"{
\n
"
s
+=
' "name": '
s
+=
generate
(
s
,
max_tokens
=
8
,
regex
=
REGEX_STRING
+
","
)
+
"
\n
"
s
+=
' "country": '
s
+=
generate
(
s
,
max_tokens
=
8
,
regex
=
REGEX_STRING
+
","
)
+
"
\n
"
s
+=
' "latitude": '
s
+=
generate
(
s
,
max_tokens
=
8
,
regex
=
REGEX_FLOAT
+
","
)
+
"
\n
"
s
+=
' "population": '
s
+=
generate
(
s
,
max_tokens
=
8
,
regex
=
REGEX_INT
+
","
)
+
"
\n
"
s
+=
' "top 3 landmarks": '
s
+=
generate
(
s
,
max_tokens
=
24
,
regex
=
REGEX_LIST
)
+
"
\n
"
s
+=
"}
\n
"
return
s
# fmt: on
def
main
(
args
):
lines
=
read_jsonl
(
args
.
data_path
)
arguments
=
[]
for
i
in
range
(
len
(
lines
[:
args
.
num_questions
])):
arguments
.
append
(
{
"document"
:
lines
[
i
][
"document"
],
}
)
states
=
[
None
]
*
len
(
arguments
)
# Select backend
if
args
.
backend
==
"vllm"
:
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
/generate"
generate
=
partial
(
call_generate_outlines
,
url
=
url
,
temperature
=
0
)
elif
args
.
backend
==
"guidance"
:
from
guidance
import
gen
,
models
model
=
models
.
LlamaCpp
(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf"
,
n_gpu_layers
=-
1
,
n_ctx
=
4096
,
)
def
generate
(
prompt
,
max_tokens
,
stop
=
None
,
regex
=
None
):
out
=
(
model
+
prompt
+
gen
(
name
=
"answer"
,
max_tokens
=
max_tokens
,
temperature
=
0
,
stop
=
stop
,
regex
=
regex
,
)
)
return
out
[
"answer"
]
# warmup
for
_
in
range
(
3
):
generate
(
"Hello!"
*
10
,
max_tokens
=
64
,
stop
=
None
)
else
:
raise
ValueError
(
f
"Invalid backend:
{
args
.
backend
}
"
)
# Run requests
def
get_one_answer
(
i
):
states
[
i
]
=
json_decode
(
generate
=
generate
,
**
arguments
[
i
])
tic
=
time
.
time
()
if
args
.
parallel
==
1
:
for
i
in
tqdm
(
range
(
len
(
arguments
))):
get_one_answer
(
i
)
else
:
with
ThreadPoolExecutor
(
args
.
parallel
)
as
executor
:
rets
=
executor
.
map
(
get_one_answer
,
list
(
range
(
len
(
arguments
))))
for
_
in
rets
:
pass
latency
=
time
.
time
()
-
tic
# Compute accuracy
print
(
f
"Latency:
{
latency
:.
3
f
}
"
)
# Write results
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
value
=
{
"task"
:
"json_regex_decode"
,
"backend"
:
args
.
backend
,
"num_gpus"
:
1
,
"latency"
:
round
(
latency
,
3
),
"num_requests"
:
args
.
num_questions
,
"other"
:
{
"parallel"
:
args
.
parallel
,
},
}
fout
.
write
(
json
.
dumps
(
value
)
+
"
\n
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"questions.jsonl"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
20
)
args
=
add_common_other_args_and_parse
(
parser
)
main
(
args
)
benchmark/json_regex_decode/bench_sglang.py
0 → 100644
View file @
08ab2a16
import
argparse
import
json
import
time
import
sglang
as
sgl
from
sglang.lang.ir
import
REGEX_INT
,
REGEX_STRING
,
REGEX_FLOAT
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
select_sglang_backend
,
)
from
sglang.utils
import
dump_state_text
,
read_jsonl
REGEX_LIST
=
r
"\[("
+
REGEX_STRING
+
", )*"
+
REGEX_STRING
+
r
"\]"
# fmt: off
@
sgl
.
function
def
json_warm_up
(
s
):
s
+=
"The information about Hogwarts is in the following JSON format.
\n
"
with
s
.
var_scope
(
"json_output"
):
s
+=
"{
\n
"
s
+=
' "name": '
+
sgl
.
gen
(
"name"
,
max_tokens
=
8
,
regex
=
REGEX_STRING
+
","
)
+
"
\n
"
s
+=
' "country": '
+
sgl
.
gen
(
"country"
,
max_tokens
=
8
,
regex
=
REGEX_STRING
+
","
)
+
"
\n
"
s
+=
' "latitude": '
+
sgl
.
gen
(
"latitude"
,
max_tokens
=
8
,
regex
=
REGEX_FLOAT
+
","
)
+
"
\n
"
s
+=
' "population": '
+
sgl
.
gen
(
"population"
,
max_tokens
=
8
,
regex
=
REGEX_INT
+
","
)
+
"
\n
"
s
+=
' "top 3 landmarks": '
+
sgl
.
gen
(
"landmarks"
,
max_tokens
=
24
,
regex
=
REGEX_LIST
)
+
"
\n
"
s
+=
"}
\n
"
print
(
f
'The warmp up json result is:
\n
{
s
[
"json_output"
]
}
'
)
# fmt: on
# fmt: off
@
sgl
.
function
def
json_decode
(
s
,
document
):
s
+=
"Please extract the information of a city from the following wikipedia page.
\n
"
s
+=
"Page begin.
\n
"
+
document
+
"Page end.
\n
"
s
+=
"Here is the name, country, and symbol of the city in JSON format.
\n
"
with
s
.
var_scope
(
"json_output"
):
s
+=
"{
\n
"
s
+=
' "name": '
+
sgl
.
gen
(
"name"
,
max_tokens
=
8
,
regex
=
REGEX_STRING
+
","
)
+
"
\n
"
s
+=
' "country": '
+
sgl
.
gen
(
"country"
,
max_tokens
=
8
,
regex
=
REGEX_STRING
+
","
)
+
"
\n
"
s
+=
' "latitude": '
+
sgl
.
gen
(
"latitude"
,
max_tokens
=
8
,
regex
=
REGEX_FLOAT
+
","
)
+
"
\n
"
s
+=
' "population": '
+
sgl
.
gen
(
"population"
,
max_tokens
=
8
,
regex
=
REGEX_INT
+
","
)
+
"
\n
"
s
+=
' "top 3 landmarks": '
+
sgl
.
gen
(
"landmarks"
,
max_tokens
=
24
,
regex
=
REGEX_LIST
)
+
"
\n
"
s
+=
"}
\n
"
# fmt: on
def
main
(
args
):
lines
=
read_jsonl
(
args
.
data_path
)
arguments
=
[]
for
i
in
range
(
len
(
lines
[:
args
.
num_questions
])):
arguments
.
append
(
{
"document"
:
lines
[
i
][
"document"
],
}
)
# Select backend
backend
=
select_sglang_backend
(
args
)
sgl
.
set_default_backend
(
backend
)
# Warm up
json_warm_up
.
run
().
sync
()
# Run requests
tic
=
time
.
time
()
states
=
json_decode
.
run_batch
(
arguments
,
temperature
=
0
,
num_threads
=
args
.
parallel
)
for
state
in
states
:
state
.
sync
()
latency
=
time
.
time
()
-
tic
# Compute accuracy
print
(
f
"Latency:
{
latency
:.
3
f
}
"
)
# Write results
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
with
open
(
f
"tmp_
{
args
.
backend
}
_json_results.txt"
,
"w"
)
as
fout
:
for
state
in
states
:
fout
.
write
(
state
[
"json_output"
]
+
"
\n
"
)
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
value
=
{
"task"
:
"json_regex_decode"
,
"backend"
:
args
.
backend
,
"num_gpus"
:
1
,
"latency"
:
round
(
latency
,
3
),
"num_requests"
:
args
.
num_questions
,
"other"
:
{
"parallel"
:
args
.
parallel
,
},
}
fout
.
write
(
json
.
dumps
(
value
)
+
"
\n
"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"questions.jsonl"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
20
)
args
=
add_common_sglang_args_and_parse
(
parser
)
main
(
args
)
benchmark/json_regex_decode/build_dataset.py
0 → 100644
View file @
08ab2a16
import
json
import
transformers
import
wikipedia
model_path
=
"meta-llama/Llama-2-7b-chat-hf"
t
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_path
)
city_names
=
[
"los angles"
,
"london"
,
"tokyo"
,
"beijing"
,
"singapore"
,
"paris"
,
"dubai"
,
"sydney"
,
"moscow"
,
"rome"
,
"toronto"
,
"rio de janeiro"
,
"istanbul"
,
"berlin"
,
"auckland"
,
"buenos aires"
,
"mexico city"
,
"mumbai"
,
"seoul"
,
"bangkok"
,
"cairo"
,
"athens"
,
"jerusalem"
,
]
def
get_content
(
city_name
):
content
=
str
(
wikipedia
.
page
(
city_name
).
content
)
content
=
content
.
replace
(
"
\n\n
"
,
"
\n
"
)
tokens
=
t
.
encode
(
content
)
expected_tokens
=
3000
truncate_len
=
int
((
expected_tokens
/
len
(
tokens
))
*
len
(
content
))
truncate_content
=
content
[:
truncate_len
]
truncate_tokens
=
t
.
encode
(
truncate_content
)
# Count token
print
(
f
"city_name:
{
city_name
}
, #tokens:
{
len
(
tokens
)
}
, #truncate tokens:
{
len
(
truncate_tokens
)
}
"
)
return
truncate_content
if
__name__
==
"__main__"
:
with
open
(
"questions.jsonl"
,
"w"
)
as
fout
:
for
city_name
in
city_names
:
truncate_content
=
get_content
(
city_name
)
fout
.
write
(
json
.
dumps
({
"document"
:
truncate_content
})
+
"
\n
"
)
benchmark/multi_turns/README.md
0 → 100644
View file @
08ab2a16
### Benchmark sglang
Run llama-7b
```
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
```
Run mixtral-8x7b
(When there is a CUDA out-of-memory error, try to reduce the
`--mem-fraction-static`
)
```
python3 -m sglang.launch_server --model-path mistralai/Mixtral-8x7B-Instruct-v0.1 --port 30000 --tp-size 8
```
Benchmark(short output)
```
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf
```
Benchmark(long output)
```
python3 bench_sglang.py --tokenizer meta-llama/Llama-2-7b-chat-hf --long
```
### Benchmark vLLM
Run llama-7b
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model meta-llama/Llama-2-7b-chat-hf --disable-log-requests --port 21000
```
Run mixtral-8x7b
```
python3 -m vllm.entrypoints.api_server --tokenizer-mode auto --model mistralai/Mixtral-8x7B-Instruct-v0.1 --disable-log-requests --port 21000 --tensor-parallel-size 8
```
Benchmark(short output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm
```
Benchmark(long output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend vllm --long
```
### Benchmark guidance
Benchmark llama-7b(short output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1
```
Benchmark llama-7b(long output)
```
python3 bench_other.py --tokenizer meta-llama/Llama-2-7b-chat-hf --backend guidance --parallel 1 --long
```
\ No newline at end of file
benchmark/multi_turns/bench_other.py
0 → 100644
View file @
08ab2a16
import
json
import
time
from
argparse
import
ArgumentParser
from
concurrent.futures
import
ThreadPoolExecutor
import
requests
from
sglang.test.test_utils
import
add_common_other_args_and_parse
from
sglang.utils
import
dump_state_text
from
tqdm
import
tqdm
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
data_gen
import
gen_arguments
def
get_generate
(
args
):
# Select backend
if
args
.
backend
==
"vllm"
:
url
=
f
"
{
args
.
host
}
:
{
args
.
port
}
/generate"
def
generate
(
prompt
,
max_tokens
,
stop
=
None
,
temperature
=
0
,
url
=
url
,
n
=
1
):
data
=
{
"prompt"
:
prompt
,
"temperature"
:
temperature
,
"max_tokens"
:
max_tokens
,
"ignore_eos"
:
True
,
"stop"
:
stop
,
"stream"
:
False
,
"n"
:
n
,
}
res
=
requests
.
post
(
url
,
json
=
data
)
assert
res
.
status_code
==
200
return
res
.
json
()[
"text"
][
0
][
len
(
prompt
)
:]
elif
args
.
backend
==
"guidance"
:
from
guidance
import
gen
,
models
model
=
models
.
LlamaCpp
(
"/home/ubuntu/model_weights/Llama-2-7b-chat-hf/ggml-model-f16.gguf"
,
n_gpu_layers
=-
1
,
n_ctx
=
4096
,
)
def
generate
(
prompt
,
max_tokens
,
stop
=
None
):
out
=
(
model
+
prompt
+
gen
(
name
=
"answer"
,
max_tokens
=
max_tokens
,
temperature
=
0
,
stop
=
stop
)
)
return
out
[
"answer"
]
# warmup
for
_
in
range
(
3
):
generate
(
"Hello!"
*
10
,
max_tokens
=
64
,
stop
=
None
)
else
:
raise
ValueError
(
f
"Invalid backend:
{
args
.
backend
}
"
)
return
generate
def
multi_turns
(
generate
,
qas
):
s
=
""
for
qa
in
qas
:
s
+=
qa
[
"prompt"
]
s
+=
generate
(
s
,
max_tokens
=
qa
[
"new_tokens"
])
return
s
def
main
(
args
):
print
(
args
)
tokenizer
=
get_tokenizer
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
multi_qas
=
gen_arguments
(
args
,
tokenizer
)
states
=
[
None
]
*
args
.
num_qa
generate
=
get_generate
(
args
)
def
get_one_answer
(
i
):
states
[
i
]
=
multi_turns
(
generate
=
generate
,
**
multi_qas
[
i
])
tic
=
time
.
time
()
if
args
.
parallel
==
1
:
for
i
in
tqdm
(
range
(
len
(
multi_qas
))):
get_one_answer
(
i
)
else
:
with
ThreadPoolExecutor
(
args
.
parallel
)
as
executor
:
rets
=
executor
.
map
(
get_one_answer
,
list
(
range
(
len
(
multi_qas
))))
for
_
in
rets
:
pass
latency
=
time
.
time
()
-
tic
# Compute accuracy
print
(
f
"Latency:
{
latency
:.
3
f
}
"
)
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
value
=
{
"task"
:
"multi_turns"
,
"backend"
:
args
.
backend
,
"num_gpus"
:
1
,
"latency"
:
round
(
latency
,
3
),
"num_requests"
:
args
.
num_qa
,
"num_turns"
:
args
.
turns
,
"other"
:
{
"parallel"
:
args
.
parallel
,
"output_mode"
:
"long"
if
args
.
long
else
"short"
,
},
}
fout
.
write
(
json
.
dumps
(
value
)
+
"
\n
"
)
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--turns"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--num-qa"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--min-len-q"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--max-len-q"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--min-len-a"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--max-len-a"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--long"
,
action
=
"store_true"
)
args
=
add_common_other_args_and_parse
(
parser
)
if
args
.
long
:
args
.
min_len_a
=
256
args
.
max_len_a
=
512
args
.
num_qa
=
20
main
(
args
)
benchmark/multi_turns/bench_sglang.py
0 → 100644
View file @
08ab2a16
import
json
import
time
from
argparse
import
ArgumentParser
import
sglang
as
sgl
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
select_sglang_backend
,
)
from
sglang.utils
import
dump_state_text
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
from
data_gen
import
gen_arguments
@
sgl
.
function
def
multi_turns
(
s
,
qas
):
for
qa
in
qas
:
s
+=
qa
[
"prompt"
]
s
+=
sgl
.
gen
(
max_tokens
=
qa
[
"new_tokens"
],
ignore_eos
=
True
)
def
main
(
args
):
print
(
args
)
tokenizer
=
get_tokenizer
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
multi_qas
=
gen_arguments
(
args
,
tokenizer
)
backend
=
select_sglang_backend
(
args
)
tic
=
time
.
time
()
states
=
multi_turns
.
run_batch
(
multi_qas
,
temperature
=
0
,
backend
=
backend
,
num_threads
=
args
.
parallel
)
for
state
in
states
:
state
.
sync
()
latency
=
time
.
time
()
-
tic
print
(
f
"Latency:
{
latency
:.
3
f
}
"
)
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
value
=
{
"task"
:
"multi_turns"
,
"backend"
:
args
.
backend
,
"num_gpus"
:
1
,
"latency"
:
round
(
latency
,
3
),
"num_requests"
:
args
.
num_qa
,
"num_turns"
:
args
.
turns
,
"other"
:
{
"parallel"
:
args
.
parallel
,
"output_mode"
:
"long"
if
args
.
long
else
"short"
,
},
}
fout
.
write
(
json
.
dumps
(
value
)
+
"
\n
"
)
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
()
parser
.
add_argument
(
"--turns"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--num-qa"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--min-len-q"
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
"--max-len-q"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--min-len-a"
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
"--max-len-a"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
required
=
True
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--long"
,
action
=
"store_true"
)
args
=
add_common_sglang_args_and_parse
(
parser
)
if
args
.
long
:
args
.
min_len_a
=
256
args
.
max_len_a
=
512
args
.
num_qa
=
20
main
(
args
)
benchmark/multi_turns/data_gen.py
0 → 100644
View file @
08ab2a16
import
random
import
string
random
.
seed
(
42
)
def
gen_prompt
(
tokenizer
,
token_num
):
cha_set
=
string
.
ascii_letters
+
string
.
digits
ret
=
""
.
join
(
random
.
choices
(
cha_set
,
k
=
token_num
))
while
len
(
tokenizer
(
ret
).
input_ids
)
<
token_num
:
ret
+=
random
.
choice
(
cha_set
)
return
ret
def
gen_arguments
(
args
,
tokenizer
):
multi_qas
=
[{
"qas"
:
[]}
for
_
in
range
(
args
.
num_qa
)]
for
i
in
range
(
args
.
num_qa
):
qas
=
multi_qas
[
i
][
"qas"
]
for
_
in
range
(
args
.
turns
):
prompt_len
=
random
.
randint
(
args
.
min_len_q
,
args
.
max_len_q
)
new_tokens
=
random
.
randint
(
args
.
min_len_a
,
args
.
max_len_a
)
qas
.
append
(
{
"prompt"
:
gen_prompt
(
tokenizer
,
prompt_len
),
"new_tokens"
:
new_tokens
,
}
)
return
multi_qas
python/sglang/api.py
View file @
08ab2a16
...
...
@@ -37,6 +37,7 @@ def gen(
top_k
:
Optional
[
int
]
=
None
,
frequency_penalty
:
Optional
[
float
]
=
None
,
presence_penalty
:
Optional
[
float
]
=
None
,
ignore_eos
:
Optional
[
bool
]
=
None
,
dtype
:
Optional
[
type
]
=
None
,
choices
:
Optional
[
List
[
str
]]
=
None
,
regex
:
Optional
[
str
]
=
None
,
...
...
@@ -60,6 +61,7 @@ def gen(
top_k
,
frequency_penalty
,
presence_penalty
,
ignore_eos
,
dtype
,
regex
,
)
...
...
@@ -74,6 +76,7 @@ def gen_int(
top_k
:
Optional
[
int
]
=
None
,
frequency_penalty
:
Optional
[
float
]
=
None
,
presence_penalty
:
Optional
[
float
]
=
None
,
ignore_eos
:
Optional
[
bool
]
=
None
,
):
return
SglGen
(
name
,
...
...
@@ -84,6 +87,7 @@ def gen_int(
top_k
,
frequency_penalty
,
presence_penalty
,
ignore_eos
,
int
,
None
,
)
...
...
@@ -98,6 +102,7 @@ def gen_string(
top_k
:
Optional
[
int
]
=
None
,
frequency_penalty
:
Optional
[
float
]
=
None
,
presence_penalty
:
Optional
[
float
]
=
None
,
ignore_eos
:
Optional
[
bool
]
=
None
,
):
return
SglGen
(
name
,
...
...
@@ -108,6 +113,7 @@ def gen_string(
top_k
,
frequency_penalty
,
presence_penalty
,
ignore_eos
,
str
,
None
,
)
...
...
python/sglang/backend/anthropic.py
View file @
08ab2a16
...
...
@@ -4,7 +4,7 @@ import numpy as np
from
sglang.backend.base_backend
import
BaseBackend
from
sglang.lang.chat_template
import
get_chat_template
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.ir
import
SamplingParams
from
sglang.lang.ir
import
Sgl
SamplingParams
try
:
import
anthropic
...
...
@@ -28,7 +28,7 @@ class Anthropic(BaseBackend):
def
generate
(
self
,
s
:
StreamExecutor
,
sampling_params
:
SamplingParams
,
sampling_params
:
Sgl
SamplingParams
,
):
prompt
=
s
.
text_
ret
=
anthropic
.
Anthropic
().
completions
.
create
(
...
...
@@ -43,7 +43,7 @@ class Anthropic(BaseBackend):
def
generate_stream
(
self
,
s
:
StreamExecutor
,
sampling_params
:
SamplingParams
,
sampling_params
:
Sgl
SamplingParams
,
):
prompt
=
s
.
text_
generator
=
anthropic
.
Anthropic
().
completions
.
create
(
...
...
python/sglang/backend/base_backend.py
View file @
08ab2a16
...
...
@@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Union
from
sglang.lang.chat_template
import
get_chat_template
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.ir
import
SamplingParams
from
sglang.lang.ir
import
Sgl
SamplingParams
class
BaseBackend
:
...
...
@@ -48,14 +48,14 @@ class BaseBackend:
def
generate
(
self
,
s
:
StreamExecutor
,
sampling_params
:
SamplingParams
,
sampling_params
:
Sgl
SamplingParams
,
):
raise
NotImplementedError
()
def
generate_stream
(
self
,
s
:
StreamExecutor
,
sampling_params
:
SamplingParams
,
sampling_params
:
Sgl
SamplingParams
,
):
raise
NotImplementedError
()
...
...
python/sglang/backend/openai.py
View file @
08ab2a16
...
...
@@ -4,7 +4,7 @@ import numpy as np
from
sglang.backend.base_backend
import
BaseBackend
from
sglang.lang.chat_template
import
get_chat_template
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.ir
import
SamplingParams
from
sglang.lang.ir
import
Sgl
SamplingParams
try
:
import
openai
...
...
@@ -73,7 +73,7 @@ class OpenAI(BaseBackend):
def
generate
(
self
,
s
:
StreamExecutor
,
sampling_params
:
SamplingParams
,
sampling_params
:
Sgl
SamplingParams
,
):
if
sampling_params
.
dtype
is
None
:
if
self
.
is_chat_model
:
...
...
@@ -122,7 +122,7 @@ class OpenAI(BaseBackend):
def
generate_stream
(
self
,
s
:
StreamExecutor
,
sampling_params
:
SamplingParams
,
sampling_params
:
Sgl
SamplingParams
,
):
if
sampling_params
.
dtype
is
None
:
if
self
.
is_chat_model
:
...
...
python/sglang/backend/runtime_endpoint.py
View file @
08ab2a16
...
...
@@ -7,7 +7,7 @@ from sglang.backend.base_backend import BaseBackend
from
sglang.global_config
import
global_config
from
sglang.lang.chat_template
import
get_chat_template_by_model_path
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.ir
import
SamplingParams
,
SglArgument
from
sglang.lang.ir
import
Sgl
SamplingParams
,
SglArgument
from
sglang.utils
import
encode_image_base64
,
find_printable_text
,
http_request
...
...
@@ -55,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
def
generate
(
self
,
s
:
StreamExecutor
,
sampling_params
:
SamplingParams
,
sampling_params
:
Sgl
SamplingParams
,
):
if
sampling_params
.
dtype
is
None
:
data
=
{
...
...
@@ -87,7 +87,7 @@ class RuntimeEndpoint(BaseBackend):
def
generate_stream
(
self
,
s
:
StreamExecutor
,
sampling_params
:
SamplingParams
,
sampling_params
:
Sgl
SamplingParams
,
):
if
sampling_params
.
dtype
is
None
:
data
=
{
...
...
python/sglang/backend/tgi.py
View file @
08ab2a16
...
...
@@ -7,7 +7,7 @@ from typing import List, Optional, Union
from
sglang.backend.base_backend
import
BaseBackend
from
sglang.lang.chat_template
import
get_chat_template_by_model_path
from
sglang.lang.interpreter
import
StreamExecutor
from
sglang.lang.ir
import
SamplingParams
from
sglang.lang.ir
import
Sgl
SamplingParams
from
sglang.utils
import
http_request
...
...
@@ -138,7 +138,7 @@ class TGI(BaseBackend):
self
,
s
:
StreamExecutor
,
choices
:
List
[
str
],
sampling_params
:
SamplingParams
,
sampling_params
:
Sgl
SamplingParams
,
):
decision
=
self
.
retry_for_expected
(
s
.
text_
,
...
...
@@ -152,7 +152,7 @@ class TGI(BaseBackend):
s
:
StreamExecutor
,
max_tokens
:
int
,
stop
:
Union
[
str
,
List
[
str
]],
sampling_params
:
SamplingParams
,
sampling_params
:
Sgl
SamplingParams
,
dtype
:
Optional
[
str
]
=
None
,
):
if
dtype
is
None
:
...
...
python/sglang/lang/compiler.py
View file @
08ab2a16
...
...
@@ -6,7 +6,7 @@ from typing import List, Union
from
sglang.global_config
import
global_config
from
sglang.lang.interpreter
import
ProgramState
,
StreamExecutor
,
pin_program
from
sglang.lang.ir
import
(
SamplingParams
,
Sgl
SamplingParams
,
SglArgument
,
SglConstantText
,
SglExpr
,
...
...
@@ -140,7 +140,7 @@ class CompiledFunction:
kwargs
=
{
k
:
SglArgument
(
k
,
v
)
for
k
,
v
in
kwargs
.
items
()}
kwargs
.
update
(
self
.
function
.
bind_arguments
)
default_sampling_para
=
SamplingParams
(
default_sampling_para
=
Sgl
SamplingParams
(
max_new_tokens
=
max_new_tokens
,
stop
=
stop
,
temperature
=
temperature
,
...
...
@@ -173,7 +173,7 @@ class CompiledFunction:
backend
=
backend
or
global_config
.
default_backend
default_sampling_para
=
SamplingParams
(
default_sampling_para
=
Sgl
SamplingParams
(
max_new_tokens
=
max_new_tokens
,
stop
=
stop
,
temperature
=
temperature
,
...
...
python/sglang/lang/interpreter.py
View file @
08ab2a16
...
...
@@ -292,7 +292,7 @@ class StreamExecutor:
assert
isinstance
(
other
,
SglExpr
),
f
"
{
other
}
"
if
isinstance
(
other
,
(
SglConstantText
,
SglArgument
)
):
if
isinstance
(
other
,
SglConstantText
):
self
.
_execute_fill
(
other
.
value
)
elif
isinstance
(
other
,
SglGen
):
self
.
_execute_gen
(
other
)
...
...
@@ -332,8 +332,6 @@ class StreamExecutor:
def
_execute_image
(
self
,
expr
:
SglImage
):
path
=
expr
.
path
if
isinstance
(
path
,
SglArgument
):
path
=
path
.
value
base64_data
=
encode_image_base64
(
path
)
...
...
@@ -419,7 +417,7 @@ class StreamExecutor:
"role"
:
expr
.
role
,
"content"
:
[{
"type"
:
"text"
,
"text"
:
new_text
}],
}
for
(
image_path
,
image_base64_data
)
in
self
.
cur_images
:
for
image_path
,
image_base64_data
in
self
.
cur_images
:
last_msg
[
"content"
].
append
(
{
"type"
:
"image_url"
,
...
...
@@ -480,6 +478,7 @@ class StreamExecutor:
"top_k"
,
"frequency_penalty"
,
"presence_penalty"
,
"ignore_eos"
,
"dtype"
,
"regex"
,
]:
...
...
python/sglang/lang/ir.py
View file @
08ab2a16
...
...
@@ -13,7 +13,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
@
dataclasses
.
dataclass
class
SamplingParams
:
class
Sgl
SamplingParams
:
max_new_tokens
:
int
=
16
stop
:
Union
[
str
,
List
[
str
]]
=
()
temperature
:
float
=
1.0
...
...
@@ -21,13 +21,14 @@ class SamplingParams:
top_k
:
int
=
-
1
# -1 means disable
frequency_penalty
:
float
=
0.0
presence_penalty
:
float
=
0.0
ignore_eos
:
bool
=
False
# for constrained generation, not included in to_xxx_kwargs
dtype
:
Optional
[
str
]
=
None
regex
:
Optional
[
str
]
=
None
def
clone
(
self
):
return
SamplingParams
(
return
Sgl
SamplingParams
(
self
.
max_new_tokens
,
self
.
stop
,
self
.
temperature
,
...
...
@@ -67,6 +68,7 @@ class SamplingParams:
"top_k"
:
self
.
top_k
,
"frequency_penalty"
:
self
.
frequency_penalty
,
"presence_penalty"
:
self
.
presence_penalty
,
"ignore_eos"
:
self
.
ignore_eos
,
"regex"
:
self
.
regex
,
}
...
...
@@ -98,13 +100,14 @@ class SglFunction:
top_k
:
int
=
-
1
,
frequency_penalty
:
float
=
0.0
,
presence_penalty
:
float
=
0.0
,
ignore_eos
:
bool
=
False
,
stream
:
bool
=
False
,
backend
=
None
,
**
kwargs
,
):
from
sglang.lang.interpreter
import
run_program
default_sampling_para
=
SamplingParams
(
default_sampling_para
=
Sgl
SamplingParams
(
max_new_tokens
=
max_new_tokens
,
stop
=
stop
,
temperature
=
temperature
,
...
...
@@ -112,9 +115,9 @@ class SglFunction:
top_k
=
top_k
,
frequency_penalty
=
frequency_penalty
,
presence_penalty
=
presence_penalty
,
ignore_eos
=
ignore_eos
,
)
backend
=
backend
or
global_config
.
default_backend
kwargs
=
{
k
:
SglArgument
(
k
,
v
)
for
k
,
v
in
kwargs
.
items
()}
return
run_program
(
self
,
backend
,
args
,
kwargs
,
default_sampling_para
,
stream
)
def
run_batch
(
...
...
@@ -128,6 +131,7 @@ class SglFunction:
top_k
:
int
=
-
1
,
frequency_penalty
:
float
=
0.0
,
presence_penalty
:
float
=
0.0
,
ignore_eos
:
bool
=
False
,
backend
=
None
,
num_threads
:
Union
[
str
,
int
]
=
"auto"
,
progress_bar
:
bool
=
False
,
...
...
@@ -139,7 +143,7 @@ class SglFunction:
return
[]
assert
isinstance
(
batch_kwargs
[
0
],
dict
)
default_sampling_para
=
SamplingParams
(
default_sampling_para
=
Sgl
SamplingParams
(
max_new_tokens
=
max_new_tokens
,
stop
=
stop
,
temperature
=
temperature
,
...
...
@@ -147,11 +151,9 @@ class SglFunction:
top_k
=
top_k
,
frequency_penalty
=
frequency_penalty
,
presence_penalty
=
presence_penalty
,
ignore_eos
=
ignore_eos
,
)
backend
=
backend
or
global_config
.
default_backend
batch_kwargs
=
[
{
k
:
SglArgument
(
k
,
v
)
for
k
,
v
in
kwargs
.
items
()}
for
kwargs
in
batch_kwargs
]
return
run_program_batch
(
self
,
backend
,
...
...
@@ -321,12 +323,13 @@ class SglGen(SglExpr):
top_k
,
frequency_penalty
,
presence_penalty
,
ignore_eos
,
dtype
,
regex
,
):
super
().
__init__
()
self
.
name
=
name
self
.
sampling_params
=
SamplingParams
(
self
.
sampling_params
=
Sgl
SamplingParams
(
max_new_tokens
=
max_new_tokens
,
stop
=
stop
,
temperature
=
temperature
,
...
...
@@ -334,6 +337,7 @@ class SglGen(SglExpr):
top_k
=
top_k
,
frequency_penalty
=
frequency_penalty
,
presence_penalty
=
presence_penalty
,
ignore_eos
=
ignore_eos
,
dtype
=
dtype
,
regex
=
regex
,
)
...
...
python/sglang/lang/tracer.py
View file @
08ab2a16
...
...
@@ -40,7 +40,8 @@ def extract_prefix_by_tracing(program, backend):
try
:
with
TracingScope
(
tracer
):
tracer
.
ret_value
=
program
.
func
(
tracer
,
**
arguments
)
except
StopTracing
:
except
(
StopTracing
,
TypeError
):
# Some exceptions may not be catched
pass
# Run and cache prefix
...
...
python/sglang/srt/backend_config.py
0 → 100644
View file @
08ab2a16
"""
Backend configurations, may vary with different serving platforms.
"""
from
dataclasses
import
dataclass
@
dataclass
class
BackendConfig
:
extend_dependency_time
:
float
=
0.03
GLOBAL_BACKEND_CONFIG
=
BackendConfig
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment