Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e4d68afc
"examples/vscode:/vscode.git/clone" did not exist on "e47cc1fc1a89a5375c322d296cd122fe71ab859f"
Unverified
Commit
e4d68afc
authored
Sep 09, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 09, 2024
Browse files
[Minor] Many cleanup (#1357)
parent
c9b75917
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
376 additions
and
254 deletions
+376
-254
benchmark/gsm8k/README.md
benchmark/gsm8k/README.md
+0
-5
benchmark/gsm8k/bench_other.py
benchmark/gsm8k/bench_other.py
+18
-12
benchmark/gsm8k/bench_sglang.py
benchmark/gsm8k/bench_sglang.py
+26
-13
benchmark/gsm8k/download_data.sh
benchmark/gsm8k/download_data.sh
+0
-2
benchmark/hellaswag/README.md
benchmark/hellaswag/README.md
+0
-5
benchmark/hellaswag/bench_other.py
benchmark/hellaswag/bench_other.py
+13
-10
benchmark/hellaswag/bench_sglang.py
benchmark/hellaswag/bench_sglang.py
+14
-10
examples/frontend_language/usage/llava_video/srt_example_llava_v.py
...rontend_language/usage/llava_video/srt_example_llava_v.py
+2
-1
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+33
-38
python/sglang/launch_server.py
python/sglang/launch_server.py
+1
-2
python/sglang/launch_server_llavavid.py
python/sglang/launch_server_llavavid.py
+3
-1
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+29
-38
python/sglang/srt/managers/controller_multi.py
python/sglang/srt/managers/controller_multi.py
+1
-5
python/sglang/srt/managers/controller_single.py
python/sglang/srt/managers/controller_single.py
+0
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+2
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+80
-77
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+3
-6
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+18
-22
python/sglang/test/few_shot_gsm8k.py
python/sglang/test/few_shot_gsm8k.py
+132
-0
No files found.
benchmark/gsm8k/README.md
View file @
e4d68afc
## Download data
```
bash download_data.sh
```
## Run benchmark
### Benchmark sglang
...
...
benchmark/gsm8k/bench_other.py
View file @
e4d68afc
...
...
@@ -10,7 +10,7 @@ import numpy as np
from
tqdm
import
tqdm
from
sglang.test.test_utils
import
add_common_other_args_and_parse
,
get_call_generate
from
sglang.utils
import
dump_state_text
,
read_jsonl
from
sglang.utils
import
download_and_cache_file
,
dump_state_text
,
read_jsonl
INVALID
=
-
9999999
...
...
@@ -41,24 +41,28 @@ def get_answer_value(answer_str):
def
main
(
args
):
lines
=
read_jsonl
(
args
.
data_path
)
# Select backend
call_generate
=
get_call_generate
(
args
)
# Read data
url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename
=
download_and_cache_file
(
url
)
lines
=
list
(
read_jsonl
(
filename
))
# Construct prompts
k
=
args
.
num_shot
few_shot_examples
=
get_few_shot_examples
(
lines
,
k
)
num_questions
=
args
.
num_questions
num_shots
=
args
.
num_shots
few_shot_examples
=
get_few_shot_examples
(
lines
,
num_shots
)
questions
=
[]
labels
=
[]
for
i
in
range
(
len
(
lines
[:
args
.
num_questions
])):
for
i
in
range
(
len
(
lines
[:
num_questions
])):
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
labels
.
append
(
get_answer_value
(
lines
[
i
][
"answer"
]))
assert
all
(
l
!=
INVALID
for
l
in
labels
)
states
=
[
None
]
*
len
(
labels
)
# Select backend
call_generate
=
get_call_generate
(
args
)
# Run requests
if
args
.
backend
!=
"lmql"
:
# Use thread pool
...
...
@@ -113,11 +117,13 @@ def main(args):
# Compute accuracy
acc
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
invalid
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
print
(
f
"Latency:
{
latency
:.
3
f
}
"
)
p
rint
(
f
"Invalid:
{
invalid
:.
3
f
}
"
)
# P
rint
results
print
(
f
"Accuracy:
{
acc
:.
3
f
}
"
)
print
(
f
"Invalid:
{
invalid
:.
3
f
}
"
)
print
(
f
"Latency:
{
latency
:.
3
f
}
s"
)
#
Write
results
#
Dump
results
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
...
...
@@ -138,7 +144,7 @@ def main(args):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-shot"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-shot
s
"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"test.jsonl"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
args
=
add_common_other_args_and_parse
(
parser
)
...
...
benchmark/gsm8k/bench_sglang.py
View file @
e4d68afc
...
...
@@ -6,11 +6,12 @@ import time
import
numpy
as
np
from
sglang.api
import
set_default_backend
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
select_sglang_backend
,
)
from
sglang.utils
import
dump_state_text
,
read_jsonl
from
sglang.utils
import
download_and_cache_file
,
dump_state_text
,
read_jsonl
INVALID
=
-
9999999
...
...
@@ -41,15 +42,22 @@ def get_answer_value(answer_str):
def
main
(
args
):
lines
=
read_jsonl
(
args
.
data_path
)
# Select backend
set_default_backend
(
select_sglang_backend
(
args
))
# Read data
url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename
=
download_and_cache_file
(
url
)
lines
=
list
(
read_jsonl
(
filename
))
# Construct prompts
k
=
args
.
num_shot
few_shot_examples
=
get_few_shot_examples
(
lines
,
k
)
num_questions
=
args
.
num_questions
num_shots
=
args
.
num_shots
few_shot_examples
=
get_few_shot_examples
(
lines
,
num_shots
)
questions
=
[]
labels
=
[]
for
i
in
range
(
len
(
lines
[:
args
.
num_questions
])):
for
i
in
range
(
len
(
lines
[:
num_questions
])):
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
labels
.
append
(
get_answer_value
(
lines
[
i
][
"answer"
]))
assert
all
(
l
!=
INVALID
for
l
in
labels
)
...
...
@@ -72,15 +80,11 @@ def main(args):
########## SGL Program End ##########
#####################################
# Select backend
backend
=
select_sglang_backend
(
args
)
# Run requests
tic
=
time
.
time
()
states
=
few_shot_gsm8k
.
run_batch
(
arguments
,
temperature
=
0
,
backend
=
backend
,
num_threads
=
args
.
parallel
,
progress_bar
=
True
,
)
...
...
@@ -96,11 +100,20 @@ def main(args):
# Compute accuracy
acc
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
invalid
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
print
(
f
"Latency:
{
latency
:.
3
f
}
"
)
print
(
f
"Invalid:
{
invalid
:.
3
f
}
"
)
# Compute speed
num_output_tokens
=
sum
(
s
.
get_meta_info
(
"answer"
)[
"completion_tokens"
]
for
s
in
states
)
output_throughput
=
num_output_tokens
/
latency
# Print results
print
(
f
"Accuracy:
{
acc
:.
3
f
}
"
)
print
(
f
"Invalid:
{
invalid
:.
3
f
}
"
)
print
(
f
"Latency:
{
latency
:.
3
f
}
s"
)
print
(
f
"Output throughput:
{
output_throughput
:.
3
f
}
token/s"
)
#
Write
results
#
Dump
results
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
...
...
@@ -121,7 +134,7 @@ def main(args):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-shot"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-shot
s
"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"test.jsonl"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
args
=
add_common_sglang_args_and_parse
(
parser
)
...
...
benchmark/gsm8k/download_data.sh
deleted
100755 → 0
View file @
c9b75917
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
\ No newline at end of file
benchmark/hellaswag/README.md
View file @
e4d68afc
## Download data
```
wget https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl
```
## Run benchmark
### Benchmark sglang
...
...
benchmark/hellaswag/bench_other.py
View file @
e4d68afc
...
...
@@ -8,7 +8,7 @@ import numpy as np
from
tqdm
import
tqdm
from
sglang.test.test_utils
import
add_common_other_args_and_parse
,
get_call_select
from
sglang.utils
import
read_jsonl
from
sglang.utils
import
download_and_cache_file
,
read_jsonl
def
get_one_example
(
lines
,
i
,
include_answer
):
...
...
@@ -26,25 +26,29 @@ def get_few_shot_examples(lines, k):
def
main
(
args
):
lines
=
read_jsonl
(
args
.
data_path
)
# Select backend
call_select
=
get_call_select
(
args
)
# Read data
url
=
"https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename
=
download_and_cache_file
(
url
)
lines
=
list
(
read_jsonl
(
filename
))
# Construct prompts
k
=
args
.
num_shot
few_shot_examples
=
get_few_shot_examples
(
lines
,
k
)
num_questions
=
args
.
num_questions
num_shots
=
args
.
num_shots
few_shot_examples
=
get_few_shot_examples
(
lines
,
num_shots
)
questions
=
[]
choices
=
[]
labels
=
[]
for
i
in
range
(
len
(
lines
[:
args
.
num_questions
])):
for
i
in
range
(
len
(
lines
[:
num_questions
])):
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
choices
.
append
(
lines
[
i
][
"endings"
])
labels
.
append
(
lines
[
i
][
"label"
])
preds
=
[
None
]
*
len
(
labels
)
# Select backend
call_select
=
get_call_select
(
args
)
# Run requests
if
args
.
backend
!=
"lmql"
:
# Use thread pool
...
...
@@ -65,7 +69,6 @@ def main(args):
total
=
len
(
questions
),
)
)
else
:
# Use asyncio
async
def
batched_call
(
batch_size
):
...
...
@@ -108,7 +111,7 @@ def main(args):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-shot"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--num-shot
s
"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"hellaswag_val.jsonl"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
args
=
add_common_other_args_and_parse
(
parser
)
...
...
benchmark/hellaswag/bench_sglang.py
View file @
e4d68afc
...
...
@@ -4,11 +4,12 @@ import time
import
numpy
as
np
from
sglang.api
import
set_default_backend
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
select_sglang_backend
,
)
from
sglang.utils
import
read_jsonl
from
sglang.utils
import
download_and_cache_file
,
read_jsonl
def
get_one_example
(
lines
,
i
,
include_answer
):
...
...
@@ -26,16 +27,23 @@ def get_few_shot_examples(lines, k):
def
main
(
args
):
lines
=
read_jsonl
(
args
.
data_path
)
# Select backend
set_default_backend
(
select_sglang_backend
(
args
))
# Read data
url
=
"https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename
=
download_and_cache_file
(
url
)
lines
=
list
(
read_jsonl
(
filename
))
# Construct prompts
k
=
args
.
num_shot
few_shot_examples
=
get_few_shot_examples
(
lines
,
k
)
num_questions
=
args
.
num_questions
num_shots
=
args
.
num_shots
few_shot_examples
=
get_few_shot_examples
(
lines
,
num_shots
)
questions
=
[]
choices
=
[]
labels
=
[]
for
i
in
range
(
len
(
lines
[:
args
.
num_questions
])):
for
i
in
range
(
len
(
lines
[:
num_questions
])):
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
choices
.
append
(
lines
[
i
][
"endings"
])
labels
.
append
(
lines
[
i
][
"label"
])
...
...
@@ -56,15 +64,11 @@ def main(args):
########## SGL Program End ##########
#####################################
# Select backend
backend
=
select_sglang_backend
(
args
)
# Run requests
tic
=
time
.
time
()
rets
=
few_shot_hellaswag
.
run_batch
(
arguments
,
temperature
=
0
,
backend
=
backend
,
num_threads
=
args
.
parallel
,
progress_bar
=
True
,
)
...
...
@@ -95,7 +99,7 @@ def main(args):
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-shot"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--num-shot
s
"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"hellaswag_val.jsonl"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
args
=
add_common_sglang_args_and_parse
(
parser
)
...
...
examples/frontend_language/usage/llava_video/srt_example_llava_v.py
View file @
e4d68afc
...
...
@@ -7,6 +7,7 @@ python3 srt_example_llava_v.py
import
argparse
import
csv
import
json
import
os
import
time
...
...
@@ -223,7 +224,7 @@ if __name__ == "__main__":
tokenizer_path
=
tokenizer_path
,
port
=
cur_port
,
additional_ports
=
[
cur_port
+
1
,
cur_port
+
2
,
cur_port
+
3
,
cur_port
+
4
],
model_override_args
=
model_override_args
,
json_
model_override_args
=
json
.
dumps
(
model_override_args
)
,
tp_size
=
1
,
)
sgl
.
set_default_backend
(
runtime
)
...
...
python/sglang/bench_serving.py
View file @
e4d68afc
...
...
@@ -298,34 +298,41 @@ class BenchmarkMetrics:
median_e2e_latency_ms
:
float
default_sharegpt_path
=
"
ShareGPT_V3_unfiltered_cleaned_split.json"
SHAREGPT_URL
=
"https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/
ShareGPT_V3_unfiltered_cleaned_split.json"
def
download_sharegpt_dataset
(
path
):
url
=
"https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
def
download_and_cache_file
(
url
:
str
,
filename
:
Optional
[
str
]
=
None
):
"""Read and cache a file from a url."""
if
filename
is
None
:
filename
=
os
.
path
.
join
(
"/tmp"
,
url
.
split
(
"/"
)[
-
1
])
print
(
f
"Downloading dataset from
{
url
}
"
)
try
:
response
=
requests
.
get
(
url
,
stream
=
True
)
response
.
raise_for_status
()
# Check if the cache file already exists
if
os
.
path
.
exists
(
filename
):
return
filename
print
(
f
"Downloading from
{
url
}
to
{
filename
}
"
)
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
block_size
=
8192
# Stream the response to show the progress bar
response
=
requests
.
get
(
url
,
stream
=
True
)
response
.
raise_for_status
()
# Check for request errors
with
open
(
path
,
"wb"
)
as
f
,
tqdm
(
desc
=
"Downloading"
,
total
=
total_size
,
unit
=
"iB"
,
unit_scale
=
True
,
unit_divisor
=
1024
,
)
as
progress_bar
:
for
data
in
response
.
iter_content
(
block_size
):
size
=
f
.
write
(
data
)
progress_bar
.
update
(
size
)
# Total size of the file in bytes
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
chunk_size
=
1024
# Download in chunks of 1KB
print
(
f
"Dataset downloaded and saved to
{
path
}
"
)
except
requests
.
RequestException
as
e
:
raise
Exception
(
f
"Failed to download dataset:
{
e
}
"
)
# Use tqdm to display the progress bar
with
open
(
filename
,
"wb"
)
as
f
,
tqdm
(
desc
=
filename
,
total
=
total_size
,
unit
=
"B"
,
unit_scale
=
True
,
unit_divisor
=
1024
,
)
as
bar
:
for
chunk
in
response
.
iter_content
(
chunk_size
=
chunk_size
):
f
.
write
(
chunk
)
bar
.
update
(
len
(
chunk
))
return
filename
def
sample_sharegpt_requests
(
...
...
@@ -338,13 +345,8 @@ def sample_sharegpt_requests(
raise
ValueError
(
"output_len too small"
)
# Download sharegpt if necessary
if
not
os
.
path
.
isfile
(
dataset_path
)
and
not
os
.
path
.
isfile
(
default_sharegpt_path
):
download_sharegpt_dataset
(
default_sharegpt_path
)
dataset_path
=
default_sharegpt_path
else
:
dataset_path
=
(
dataset_path
if
os
.
path
.
isfile
(
dataset_path
)
else
default_sharegpt_path
)
if
not
os
.
path
.
isfile
(
dataset_path
):
dataset_path
=
download_and_cache_file
(
SHAREGPT_URL
)
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
...
...
@@ -412,15 +414,8 @@ def sample_random_requests(
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
if
not
os
.
path
.
isfile
(
dataset_path
)
and
not
os
.
path
.
isfile
(
default_sharegpt_path
):
download_sharegpt_dataset
(
default_sharegpt_path
)
dataset_path
=
default_sharegpt_path
else
:
dataset_path
=
(
dataset_path
if
os
.
path
.
isfile
(
dataset_path
)
else
default_sharegpt_path
)
if
not
os
.
path
.
isfile
(
dataset_path
):
dataset_path
=
download_and_cache_file
(
SHAREGPT_URL
)
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
...
...
python/sglang/launch_server.py
View file @
e4d68afc
...
...
@@ -9,10 +9,9 @@ from sglang.srt.utils import kill_child_process
if
__name__
==
"__main__"
:
server_args
=
prepare_server_args
(
sys
.
argv
[
1
:])
model_override_args
=
server_args
.
json_model_override_args
try
:
launch_server
(
server_args
,
model_override_args
=
model_override_args
)
launch_server
(
server_args
)
except
Exception
as
e
:
raise
e
finally
:
...
...
python/sglang/launch_server_llavavid.py
View file @
e4d68afc
"""Launch the inference server for Llava-video model."""
import
json
import
sys
from
sglang.srt.server
import
launch_server
,
prepare_server_args
...
...
@@ -19,5 +20,6 @@ if __name__ == "__main__":
model_override_args
[
"model_max_length"
]
=
4096
*
2
if
"34b"
in
server_args
.
model_path
.
lower
():
model_override_args
[
"image_token_index"
]
=
64002
server_args
.
json_model_override_args
=
json
.
dumps
(
model_override_args
)
launch_server
(
server_args
,
model_override_args
,
None
)
launch_server
(
server_args
)
python/sglang/srt/constrained/fsm_cache.py
View file @
e4d68afc
...
...
@@ -16,6 +16,7 @@ limitations under the License.
"""Cache for the compressed finite state machine."""
from
outlines.fsm.json_schema
import
build_regex_from_schema
from
transformers
import
AutoTokenizer
from
sglang.srt.constrained
import
RegexGuide
,
TransformerTokenizer
from
sglang.srt.constrained.base_tool_cache
import
BaseToolCache
...
...
@@ -28,12 +29,9 @@ class FSMCache(BaseToolCache):
tokenizer_args_dict
,
enable
=
True
,
skip_tokenizer_init
=
False
,
json_schema_mode
=
False
,
):
super
().
__init__
(
enable
=
enable
)
self
.
json_schema_mode
=
json_schema_mode
if
(
skip_tokenizer_init
or
tokenizer_path
.
endswith
(
".json"
)
...
...
@@ -42,44 +40,37 @@ class FSMCache(BaseToolCache):
# Do not support TiktokenTokenizer or SentencePieceTokenizer
return
from
importlib.metadata
import
version
tokenizer_args_dict
.
setdefault
(
"padding_side"
,
"left"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_path
,
**
tokenizer_args_dict
)
try
:
self
.
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer
)
except
AttributeError
:
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
origin_pad_token_id
=
tokenizer
.
pad_token_id
if
version
(
"outlines"
)
>=
"0.0.35"
:
from
transformers
import
AutoTokenizer
def
fset
(
self
,
value
)
:
self
.
_value
=
value
tokenizer_args_dict
.
setdefault
(
"padding_side"
,
"left"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_path
,
**
tokenizer_args_dict
type
(
tokenizer
).
pad_token_id
=
property
(
fget
=
type
(
tokenizer
).
pad_token_id
.
fget
,
fset
=
fset
)
try
:
self
.
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer
)
except
AttributeError
:
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
origin_pad_token_id
=
tokenizer
.
pad_token_id
def
fset
(
self
,
value
):
self
.
_value
=
value
type
(
tokenizer
).
pad_token_id
=
property
(
fget
=
type
(
tokenizer
).
pad_token_id
.
fget
,
fset
=
fset
)
self
.
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer
)
self
.
outlines_tokenizer
.
tokenizer
.
pad_token_id
=
origin_pad_token_id
self
.
outlines_tokenizer
.
pad_token_id
=
origin_pad_token_id
self
.
outlines_tokenizer
.
pad_token
=
(
self
.
outlines_tokenizer
.
tokenizer
.
pad_token
)
self
.
outlines_tokenizer
.
vocabulary
=
(
self
.
outlines_tokenizer
.
tokenizer
.
get_vocab
()
)
else
:
self
.
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer_path
,
**
tokenizer_args_dict
self
.
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer
)
self
.
outlines_tokenizer
.
tokenizer
.
pad_token_id
=
origin_pad_token_id
self
.
outlines_tokenizer
.
pad_token_id
=
origin_pad_token_id
self
.
outlines_tokenizer
.
pad_token
=
(
self
.
outlines_tokenizer
.
tokenizer
.
pad_token
)
self
.
outlines_tokenizer
.
vocabulary
=
(
self
.
outlines_tokenizer
.
tokenizer
.
get_vocab
()
)
def
init_value
(
self
,
value
):
if
self
.
json_schema_mode
:
regex
=
build_regex_from_schema
(
value
,
whitespace_pattern
=
r
"[\n\t ]*"
)
return
RegexGuide
(
regex
,
self
.
outlines_tokenizer
),
regex
def
init_value
(
self
,
key
):
key_type
,
key_string
=
key
if
key_type
==
"json"
:
regex
=
build_regex_from_schema
(
key_string
,
whitespace_pattern
=
r
"[\n\t ]*"
)
elif
key_type
==
"regex"
:
regex
=
key_string
else
:
return
RegexGuide
(
value
,
self
.
outlines_tokenizer
)
raise
ValueError
(
f
"Invalid key_type:
{
key_type
}
"
)
return
RegexGuide
(
regex
,
self
.
outlines_tokenizer
),
regex
python/sglang/srt/managers/controller_multi.py
View file @
e4d68afc
...
...
@@ -71,12 +71,10 @@ class ControllerMulti:
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_override_args
,
):
# Parse args
self
.
server_args
=
server_args
self
.
port_args
=
port_args
self
.
model_override_args
=
model_override_args
self
.
load_balance_method
=
LoadBalanceMethod
.
from_str
(
server_args
.
load_balance_method
)
...
...
@@ -114,7 +112,6 @@ class ControllerMulti:
self
.
server_args
,
self
.
port_args
,
pipe_controller_writer
,
self
.
model_override_args
,
True
,
gpu_ids
,
dp_worker_id
,
...
...
@@ -189,14 +186,13 @@ def start_controller_process(
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
pipe_writer
,
model_override_args
:
dict
,
):
"""Start a controller process."""
configure_logger
(
server_args
)
try
:
controller
=
ControllerMulti
(
server_args
,
port_args
,
model_override_args
)
controller
=
ControllerMulti
(
server_args
,
port_args
)
except
Exception
:
pipe_writer
.
send
(
get_exception_traceback
())
raise
...
...
python/sglang/srt/managers/controller_single.py
View file @
e4d68afc
...
...
@@ -40,7 +40,6 @@ class ControllerSingle:
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_override_args
:
dict
,
gpu_ids
:
List
[
int
],
is_data_parallel_worker
:
bool
,
dp_worker_id
:
int
,
...
...
@@ -76,7 +75,6 @@ class ControllerSingle:
tp_rank_range
,
server_args
,
port_args
.
nccl_ports
[
dp_worker_id
],
model_override_args
,
)
# Launch tp rank 0
...
...
@@ -85,7 +83,6 @@ class ControllerSingle:
0
,
server_args
,
port_args
.
nccl_ports
[
dp_worker_id
],
model_override_args
,
)
self
.
tp_cpu_group
=
self
.
tp_server
.
model_runner
.
tp_group
.
cpu_group
...
...
@@ -126,7 +123,6 @@ def start_controller_process(
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
pipe_writer
:
multiprocessing
.
connection
.
Connection
,
model_override_args
:
dict
,
is_data_parallel_worker
:
bool
=
False
,
gpu_ids
:
List
[
int
]
=
None
,
dp_worker_id
:
int
=
None
,
...
...
@@ -149,7 +145,6 @@ def start_controller_process(
controller
=
ControllerSingle
(
server_args
,
port_args
,
model_override_args
,
gpu_ids
,
is_data_parallel_worker
,
dp_worker_id
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
e4d68afc
...
...
@@ -18,6 +18,7 @@ limitations under the License.
import
asyncio
import
concurrent.futures
import
dataclasses
import
json
import
logging
import
multiprocessing
as
mp
import
os
...
...
@@ -77,7 +78,6 @@ class TokenizerManager:
self
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
model_override_args
:
dict
=
None
,
):
self
.
server_args
=
server_args
...
...
@@ -95,7 +95,7 @@ class TokenizerManager:
self
.
hf_config
=
get_config
(
self
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
model_override_args
=
model_override_args
,
model_override_args
=
json
.
loads
(
server_args
.
json_
model_override_args
)
,
)
self
.
is_generation
=
is_generation_model
(
self
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
...
...
python/sglang/srt/managers/tp_worker.py
View file @
e4d68afc
...
...
@@ -15,13 +15,14 @@ limitations under the License.
"""A tensor parallel worker."""
import
json
import
logging
import
multiprocessing
import
os
import
pickle
import
time
import
warnings
from
typing
import
Any
,
List
,
Optional
,
Union
from
typing
import
Any
,
List
,
Optional
import
torch
import
torch.distributed
...
...
@@ -66,6 +67,7 @@ from sglang.utils import get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
# Crash on warning if we are running CI tests
crash_on_warning
=
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
...
...
@@ -76,11 +78,10 @@ class ModelTpServer:
tp_rank
:
int
,
server_args
:
ServerArgs
,
nccl_port
:
int
,
model_override_args
:
dict
,
):
suppress_other_loggers
()
#
Copy
arguments
#
Parse
arguments
self
.
gpu_id
=
gpu_id
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
server_args
.
tp_size
...
...
@@ -93,9 +94,8 @@ class ModelTpServer:
server_args
.
model_path
,
server_args
.
trust_remote_code
,
context_length
=
server_args
.
context_length
,
model_override_args
=
model_override_args
,
model_override_args
=
json
.
loads
(
server_args
.
json_
model_override_args
)
,
)
self
.
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
...
...
@@ -136,7 +136,7 @@ class ModelTpServer:
self
.
max_total_num_tokens
-
1
,
)
# Sync random seed
# Sync random seed
across TP workers
server_args
.
random_seed
=
broadcast_recv_input
(
[
server_args
.
random_seed
],
self
.
tp_rank
,
...
...
@@ -144,7 +144,7 @@ class ModelTpServer:
)[
0
]
set_random_seed
(
server_args
.
random_seed
)
# Print info
# Print
debug
info
logger
.
info
(
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
...
...
@@ -181,7 +181,7 @@ class ModelTpServer:
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
#
C
hunked prefill
#
Init c
hunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
self
.
is_mixed_chunk
=
(
...
...
@@ -197,16 +197,6 @@ class ModelTpServer:
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
json_schema_mode
=
False
,
)
self
.
json_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
json_schema_mode
=
True
,
)
self
.
jump_forward_cache
=
JumpForwardCache
()
...
...
@@ -227,11 +217,12 @@ class ModelTpServer:
try
:
# Recv requests
for
recv_req
in
recv_reqs
:
if
isinstance
(
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
)
):
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
self
.
handle_generate_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
TokenizedEmbeddingReqInput
):
self
.
handle_embedding_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
...
...
@@ -331,57 +322,56 @@ class ModelTpServer:
def
handle_generate_request
(
self
,
recv_req
:
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
],
recv_req
:
TokenizedGenerateReqInput
,
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
tokenizer
=
self
.
tokenizer
req
.
sampling_params
=
recv_req
.
sampling_params
if
self
.
model_runner
.
is_generation
:
req
.
pixel_values
=
recv_req
.
pixel_values
if
req
.
pixel_values
is
not
None
:
# Use image hash as fake token_ids, which is then used
# for prefix matching
image_hash
=
hash
(
tuple
(
recv_req
.
image_hashes
))
req
.
pad_value
=
[
(
image_hash
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
16
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
32
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
64
)
%
self
.
model_config
.
vocab_size
,
]
req
.
image_sizes
=
recv_req
.
image_sizes
(
req
.
origin_input_ids
,
req
.
image_offsets
,
)
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
,
req
.
image_sizes
,
)
# Only when pixel values is not None we have modalities
req
.
modalities
=
recv_req
.
modalites
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
# Init regex fsm fron json
req
.
pixel_values
=
recv_req
.
pixel_values
if
req
.
pixel_values
is
not
None
:
# Use image hash as fake token_ids, which is then used
# for prefix matching
image_hash
=
hash
(
tuple
(
recv_req
.
image_hashes
))
req
.
pad_value
=
[
(
image_hash
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
16
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
32
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
64
)
%
self
.
model_config
.
vocab_size
,
]
req
.
image_sizes
=
recv_req
.
image_sizes
(
req
.
origin_input_ids
,
req
.
image_offsets
,
)
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
,
req
.
image_sizes
,
)
# Only when pixel values is not None we have modalities
req
.
modalities
=
recv_req
.
modalites
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
# Init regex FSM
if
(
req
.
sampling_params
.
json_schema
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
):
if
req
.
sampling_params
.
json_schema
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
json
_fsm_cache
.
query
(
req
.
sampling_params
.
json_schema
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex
_fsm_cache
.
query
(
(
"json"
,
req
.
sampling_params
.
json_schema
)
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
computed_regex_string
)
# Init regex fsm
elif
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
query
(
req
.
sampling_params
.
regex
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
req
.
sampling_params
.
regex
)
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex_fsm_cache
.
query
(
(
"regex"
,
req
.
sampling_params
.
regex
)
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
computed_regex_string
)
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
...
...
@@ -390,16 +380,32 @@ class ModelTpServer:
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
req
.
sampling_params
.
max_new_tokens
=
min
(
(
req
.
sampling_params
.
max_new_tokens
if
req
.
sampling_params
.
max_new_tokens
is
not
None
else
1
<<
30
),
self
.
max_req_input_len
-
1
-
len
(
req
.
origin_input_ids
),
)
if
self
.
model_runner
.
is_generation
:
req
.
sampling_params
.
max_new_tokens
=
min
(
(
req
.
sampling_params
.
max_new_tokens
if
req
.
sampling_params
.
max_new_tokens
is
not
None
else
1
<<
30
),
self
.
max_req_input_len
-
1
-
len
(
req
.
origin_input_ids
),
self
.
waiting_queue
.
append
(
req
)
def
handle_embedding_request
(
self
,
recv_req
:
TokenizedEmbeddingReqInput
,
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
tokenizer
=
self
.
tokenizer
req
.
sampling_params
=
recv_req
.
sampling_params
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warn
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
self
.
waiting_queue
.
append
(
req
)
...
...
@@ -892,7 +898,6 @@ def run_tp_server(
tp_rank
:
int
,
server_args
:
ServerArgs
,
nccl_port
:
int
,
model_override_args
:
dict
,
):
"""Run a tensor parallel model server."""
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
...
...
@@ -903,7 +908,6 @@ def run_tp_server(
tp_rank
,
server_args
,
nccl_port
,
model_override_args
,
)
tp_cpu_group
=
model_server
.
model_runner
.
tp_group
.
cpu_group
...
...
@@ -920,14 +924,13 @@ def launch_tp_servers(
tp_rank_range
:
List
[
int
],
server_args
:
ServerArgs
,
nccl_port
:
int
,
model_override_args
:
dict
,
):
"""Launch multiple tensor parallel servers."""
procs
=
[]
for
i
in
tp_rank_range
:
proc
=
multiprocessing
.
Process
(
target
=
run_tp_server
,
args
=
(
gpu_ids
[
i
],
i
,
server_args
,
nccl_port
,
model_override_args
),
args
=
(
gpu_ids
[
i
],
i
,
server_args
,
nccl_port
),
)
proc
.
start
()
procs
.
append
(
proc
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e4d68afc
...
...
@@ -18,6 +18,7 @@ limitations under the License.
import
gc
import
importlib
import
importlib.resources
import
json
import
logging
import
pkgutil
from
functools
import
lru_cache
...
...
python/sglang/srt/server.py
View file @
e4d68afc
...
...
@@ -272,7 +272,6 @@ async def retrieve_file_content(file_id: str):
def
launch_server
(
server_args
:
ServerArgs
,
model_override_args
:
Optional
[
dict
]
=
None
,
pipe_finish_writer
:
Optional
[
mp
.
connection
.
Connection
]
=
None
,
):
"""Launch an HTTP server."""
...
...
@@ -317,7 +316,6 @@ def launch_server(
tp_rank_range
,
server_args
,
ports
[
3
],
model_override_args
,
)
try
:
...
...
@@ -328,7 +326,7 @@ def launch_server(
return
# Launch processes
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_override_args
)
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
)
if
server_args
.
chat_template
:
load_chat_template_for_openai_api
(
tokenizer_manager
,
server_args
.
chat_template
)
pipe_controller_reader
,
pipe_controller_writer
=
mp
.
Pipe
(
duplex
=
False
)
...
...
@@ -341,7 +339,7 @@ def launch_server(
proc_controller
=
mp
.
Process
(
target
=
start_controller_process
,
args
=
(
server_args
,
port_args
,
pipe_controller_writer
,
model_override_args
),
args
=
(
server_args
,
port_args
,
pipe_controller_writer
),
)
proc_controller
.
start
()
...
...
@@ -501,7 +499,6 @@ class Runtime:
def
__init__
(
self
,
log_level
:
str
=
"error"
,
model_override_args
:
Optional
[
dict
]
=
None
,
*
args
,
**
kwargs
,
):
...
...
@@ -525,7 +522,7 @@ class Runtime:
proc
=
mp
.
Process
(
target
=
launch_server
,
args
=
(
self
.
server_args
,
model_override_args
,
pipe_writer
),
args
=
(
self
.
server_args
,
pipe_writer
),
)
proc
.
start
()
pipe_writer
.
close
()
...
...
python/sglang/srt/server_args.py
View file @
e4d68afc
...
...
@@ -76,6 +76,14 @@ class ServerArgs:
dp_size
:
int
=
1
load_balance_method
:
str
=
"round_robin"
# Distributed args
nccl_init_addr
:
Optional
[
str
]
=
None
nnodes
:
int
=
1
node_rank
:
Optional
[
int
]
=
None
# Model override args in JSON
json_model_override_args
:
str
=
"{}"
# Optimization/debug options
disable_flashinfer
:
bool
=
False
disable_flashinfer_sampling
:
bool
=
False
...
...
@@ -91,14 +99,6 @@ class ServerArgs:
enable_mla
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
# Distributed args
nccl_init_addr
:
Optional
[
str
]
=
None
nnodes
:
int
=
1
node_rank
:
Optional
[
int
]
=
None
# Model override args in JSON
json_model_override_args
:
Optional
[
dict
]
=
None
def
__post_init__
(
self
):
if
self
.
tokenizer_path
is
None
:
self
.
tokenizer_path
=
self
.
model_path
...
...
@@ -385,6 +385,14 @@ class ServerArgs:
)
parser
.
add_argument
(
"--node-rank"
,
type
=
int
,
help
=
"The node rank."
)
# Model override args
parser
.
add_argument
(
"--json-model-override-args"
,
type
=
str
,
help
=
"A dictionary in JSON string format used to override default model configurations."
,
default
=
ServerArgs
.
json_model_override_args
,
)
# Optimization/debug options
parser
.
add_argument
(
"--disable-flashinfer"
,
...
...
@@ -459,22 +467,10 @@ class ServerArgs:
help
=
"Turn on memory efficient weight loading with quantization (quantize per layer during loading)."
,
)
# Model override args
parser
.
add_argument
(
"--json-model-override-args"
,
type
=
str
,
help
=
"A dictionary in JSON string format used to override default model configurations."
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
args
.
tp_size
=
args
.
tensor_parallel_size
args
.
dp_size
=
args
.
data_parallel_size
args
.
json_model_override_args
=
(
json
.
loads
(
args
.
json_model_override_args
)
if
args
.
json_model_override_args
else
None
)
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
return
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
...
...
@@ -498,7 +494,7 @@ class ServerArgs:
self
.
disable_flashinfer
=
False
def
prepare_server_args
(
arg
s
:
argparse
.
Namespace
)
->
ServerArgs
:
def
prepare_server_args
(
arg
v
:
List
[
str
]
)
->
ServerArgs
:
"""
Prepare the server arguments from the command line arguments.
...
...
@@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
"""
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
raw_args
=
parser
.
parse_args
(
arg
s
)
raw_args
=
parser
.
parse_args
(
arg
v
)
server_args
=
ServerArgs
.
from_cli_args
(
raw_args
)
return
server_args
...
...
python/sglang/test/few_shot_gsm8k.py
0 → 100644
View file @
e4d68afc
"""
Run few-shot GSM-8K evaluation.
Usage:
python3 -m sglang.test.few_shot_gsm8k --num-questions 200
"""
import
argparse
import
ast
import
re
import
time
import
numpy
as
np
from
sglang.api
import
set_default_backend
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.utils
import
download_and_cache_file
,
dump_state_text
,
read_jsonl
INVALID
=
-
9999999
def
get_one_example
(
lines
,
i
,
include_answer
):
ret
=
"Question: "
+
lines
[
i
][
"question"
]
+
"
\n
Answer:"
if
include_answer
:
ret
+=
" "
+
lines
[
i
][
"answer"
]
return
ret
def
get_few_shot_examples
(
lines
,
k
):
ret
=
""
for
i
in
range
(
k
):
ret
+=
get_one_example
(
lines
,
i
,
True
)
+
"
\n\n
"
return
ret
def
get_answer_value
(
answer_str
):
answer_str
=
answer_str
.
replace
(
","
,
""
)
numbers
=
re
.
findall
(
r
"\d+"
,
answer_str
)
if
len
(
numbers
)
<
1
:
return
INVALID
try
:
return
ast
.
literal_eval
(
numbers
[
-
1
])
except
SyntaxError
:
return
INVALID
def
main
(
args
):
# Select backend
set_default_backend
(
RuntimeEndpoint
(
f
"
{
args
.
host
}
:
{
args
.
port
}
"
))
# Read data
url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename
=
download_and_cache_file
(
url
)
lines
=
list
(
read_jsonl
(
filename
))
# Construct prompts
num_questions
=
args
.
num_questions
num_shots
=
args
.
num_shots
few_shot_examples
=
get_few_shot_examples
(
lines
,
num_shots
)
questions
=
[]
labels
=
[]
for
i
in
range
(
len
(
lines
[:
num_questions
])):
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
labels
.
append
(
get_answer_value
(
lines
[
i
][
"answer"
]))
assert
all
(
l
!=
INVALID
for
l
in
labels
)
arguments
=
[{
"question"
:
q
}
for
q
in
questions
]
#####################################
######### SGL Program Begin #########
#####################################
import
sglang
as
sgl
@
sgl
.
function
def
few_shot_gsm8k
(
s
,
question
):
s
+=
few_shot_examples
+
question
s
+=
sgl
.
gen
(
"answer"
,
max_tokens
=
512
,
stop
=
[
"Question"
,
"Assistant:"
,
"<|separator|>"
]
)
#####################################
########## SGL Program End ##########
#####################################
# Run requests
tic
=
time
.
time
()
states
=
few_shot_gsm8k
.
run_batch
(
arguments
,
temperature
=
0
,
num_threads
=
args
.
parallel
,
progress_bar
=
True
,
)
latency
=
time
.
time
()
-
tic
preds
=
[]
for
i
in
range
(
len
(
states
)):
preds
.
append
(
get_answer_value
(
states
[
i
][
"answer"
]))
# print(f"{preds=}")
# print(f"{labels=}")
# Compute accuracy
acc
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
invalid
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
# Compute speed
num_output_tokens
=
sum
(
s
.
get_meta_info
(
"answer"
)[
"completion_tokens"
]
for
s
in
states
)
output_throughput
=
num_output_tokens
/
latency
# Print results
print
(
f
"Accuracy:
{
acc
:.
3
f
}
"
)
print
(
f
"Invalid:
{
invalid
:.
3
f
}
"
)
print
(
f
"Latency:
{
latency
:.
3
f
}
s"
)
print
(
f
"Output throughput:
{
output_throughput
:.
3
f
}
token/s"
)
# Dump results
dump_state_text
(
"tmp_output_gsm8k.txt"
,
states
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-shots"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"test.jsonl"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
"--parallel"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
args
=
parser
.
parse_args
()
main
(
args
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment