Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e4d68afc
Unverified
Commit
e4d68afc
authored
Sep 09, 2024
by
Lianmin Zheng
Committed by
GitHub
Sep 09, 2024
Browse files
[Minor] Many cleanup (#1357)
parent
c9b75917
Changes
24
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
376 additions
and
254 deletions
+376
-254
benchmark/gsm8k/README.md
benchmark/gsm8k/README.md
+0
-5
benchmark/gsm8k/bench_other.py
benchmark/gsm8k/bench_other.py
+18
-12
benchmark/gsm8k/bench_sglang.py
benchmark/gsm8k/bench_sglang.py
+26
-13
benchmark/gsm8k/download_data.sh
benchmark/gsm8k/download_data.sh
+0
-2
benchmark/hellaswag/README.md
benchmark/hellaswag/README.md
+0
-5
benchmark/hellaswag/bench_other.py
benchmark/hellaswag/bench_other.py
+13
-10
benchmark/hellaswag/bench_sglang.py
benchmark/hellaswag/bench_sglang.py
+14
-10
examples/frontend_language/usage/llava_video/srt_example_llava_v.py
...rontend_language/usage/llava_video/srt_example_llava_v.py
+2
-1
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+33
-38
python/sglang/launch_server.py
python/sglang/launch_server.py
+1
-2
python/sglang/launch_server_llavavid.py
python/sglang/launch_server_llavavid.py
+3
-1
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+29
-38
python/sglang/srt/managers/controller_multi.py
python/sglang/srt/managers/controller_multi.py
+1
-5
python/sglang/srt/managers/controller_single.py
python/sglang/srt/managers/controller_single.py
+0
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+2
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+80
-77
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+3
-6
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+18
-22
python/sglang/test/few_shot_gsm8k.py
python/sglang/test/few_shot_gsm8k.py
+132
-0
No files found.
benchmark/gsm8k/README.md
View file @
e4d68afc
## Download data
```
bash download_data.sh
```
## Run benchmark
## Run benchmark
### Benchmark sglang
### Benchmark sglang
...
...
benchmark/gsm8k/bench_other.py
View file @
e4d68afc
...
@@ -10,7 +10,7 @@ import numpy as np
...
@@ -10,7 +10,7 @@ import numpy as np
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
sglang.test.test_utils
import
add_common_other_args_and_parse
,
get_call_generate
from
sglang.test.test_utils
import
add_common_other_args_and_parse
,
get_call_generate
from
sglang.utils
import
dump_state_text
,
read_jsonl
from
sglang.utils
import
download_and_cache_file
,
dump_state_text
,
read_jsonl
INVALID
=
-
9999999
INVALID
=
-
9999999
...
@@ -41,24 +41,28 @@ def get_answer_value(answer_str):
...
@@ -41,24 +41,28 @@ def get_answer_value(answer_str):
def
main
(
args
):
def
main
(
args
):
lines
=
read_jsonl
(
args
.
data_path
)
# Select backend
call_generate
=
get_call_generate
(
args
)
# Read data
url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename
=
download_and_cache_file
(
url
)
lines
=
list
(
read_jsonl
(
filename
))
# Construct prompts
# Construct prompts
k
=
args
.
num_shot
num_questions
=
args
.
num_questions
few_shot_examples
=
get_few_shot_examples
(
lines
,
k
)
num_shots
=
args
.
num_shots
few_shot_examples
=
get_few_shot_examples
(
lines
,
num_shots
)
questions
=
[]
questions
=
[]
labels
=
[]
labels
=
[]
for
i
in
range
(
len
(
lines
[:
args
.
num_questions
])):
for
i
in
range
(
len
(
lines
[:
num_questions
])):
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
labels
.
append
(
get_answer_value
(
lines
[
i
][
"answer"
]))
labels
.
append
(
get_answer_value
(
lines
[
i
][
"answer"
]))
assert
all
(
l
!=
INVALID
for
l
in
labels
)
assert
all
(
l
!=
INVALID
for
l
in
labels
)
states
=
[
None
]
*
len
(
labels
)
states
=
[
None
]
*
len
(
labels
)
# Select backend
call_generate
=
get_call_generate
(
args
)
# Run requests
# Run requests
if
args
.
backend
!=
"lmql"
:
if
args
.
backend
!=
"lmql"
:
# Use thread pool
# Use thread pool
...
@@ -113,11 +117,13 @@ def main(args):
...
@@ -113,11 +117,13 @@ def main(args):
# Compute accuracy
# Compute accuracy
acc
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
acc
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
invalid
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
invalid
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
print
(
f
"Latency:
{
latency
:.
3
f
}
"
)
p
rint
(
f
"Invalid:
{
invalid
:.
3
f
}
"
)
# P
rint
results
print
(
f
"Accuracy:
{
acc
:.
3
f
}
"
)
print
(
f
"Accuracy:
{
acc
:.
3
f
}
"
)
print
(
f
"Invalid:
{
invalid
:.
3
f
}
"
)
print
(
f
"Latency:
{
latency
:.
3
f
}
s"
)
#
Write
results
#
Dump
results
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
...
@@ -138,7 +144,7 @@ def main(args):
...
@@ -138,7 +144,7 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-shot"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-shot
s
"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"test.jsonl"
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"test.jsonl"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
args
=
add_common_other_args_and_parse
(
parser
)
args
=
add_common_other_args_and_parse
(
parser
)
...
...
benchmark/gsm8k/bench_sglang.py
View file @
e4d68afc
...
@@ -6,11 +6,12 @@ import time
...
@@ -6,11 +6,12 @@ import time
import
numpy
as
np
import
numpy
as
np
from
sglang.api
import
set_default_backend
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
add_common_sglang_args_and_parse
,
select_sglang_backend
,
select_sglang_backend
,
)
)
from
sglang.utils
import
dump_state_text
,
read_jsonl
from
sglang.utils
import
download_and_cache_file
,
dump_state_text
,
read_jsonl
INVALID
=
-
9999999
INVALID
=
-
9999999
...
@@ -41,15 +42,22 @@ def get_answer_value(answer_str):
...
@@ -41,15 +42,22 @@ def get_answer_value(answer_str):
def
main
(
args
):
def
main
(
args
):
lines
=
read_jsonl
(
args
.
data_path
)
# Select backend
set_default_backend
(
select_sglang_backend
(
args
))
# Read data
url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename
=
download_and_cache_file
(
url
)
lines
=
list
(
read_jsonl
(
filename
))
# Construct prompts
# Construct prompts
k
=
args
.
num_shot
num_questions
=
args
.
num_questions
few_shot_examples
=
get_few_shot_examples
(
lines
,
k
)
num_shots
=
args
.
num_shots
few_shot_examples
=
get_few_shot_examples
(
lines
,
num_shots
)
questions
=
[]
questions
=
[]
labels
=
[]
labels
=
[]
for
i
in
range
(
len
(
lines
[:
args
.
num_questions
])):
for
i
in
range
(
len
(
lines
[:
num_questions
])):
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
labels
.
append
(
get_answer_value
(
lines
[
i
][
"answer"
]))
labels
.
append
(
get_answer_value
(
lines
[
i
][
"answer"
]))
assert
all
(
l
!=
INVALID
for
l
in
labels
)
assert
all
(
l
!=
INVALID
for
l
in
labels
)
...
@@ -72,15 +80,11 @@ def main(args):
...
@@ -72,15 +80,11 @@ def main(args):
########## SGL Program End ##########
########## SGL Program End ##########
#####################################
#####################################
# Select backend
backend
=
select_sglang_backend
(
args
)
# Run requests
# Run requests
tic
=
time
.
time
()
tic
=
time
.
time
()
states
=
few_shot_gsm8k
.
run_batch
(
states
=
few_shot_gsm8k
.
run_batch
(
arguments
,
arguments
,
temperature
=
0
,
temperature
=
0
,
backend
=
backend
,
num_threads
=
args
.
parallel
,
num_threads
=
args
.
parallel
,
progress_bar
=
True
,
progress_bar
=
True
,
)
)
...
@@ -96,11 +100,20 @@ def main(args):
...
@@ -96,11 +100,20 @@ def main(args):
# Compute accuracy
# Compute accuracy
acc
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
acc
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
invalid
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
invalid
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
print
(
f
"Latency:
{
latency
:.
3
f
}
"
)
print
(
f
"Invalid:
{
invalid
:.
3
f
}
"
)
# Compute speed
num_output_tokens
=
sum
(
s
.
get_meta_info
(
"answer"
)[
"completion_tokens"
]
for
s
in
states
)
output_throughput
=
num_output_tokens
/
latency
# Print results
print
(
f
"Accuracy:
{
acc
:.
3
f
}
"
)
print
(
f
"Accuracy:
{
acc
:.
3
f
}
"
)
print
(
f
"Invalid:
{
invalid
:.
3
f
}
"
)
print
(
f
"Latency:
{
latency
:.
3
f
}
s"
)
print
(
f
"Output throughput:
{
output_throughput
:.
3
f
}
token/s"
)
#
Write
results
#
Dump
results
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
dump_state_text
(
f
"tmp_output_
{
args
.
backend
}
.txt"
,
states
)
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
with
open
(
args
.
result_file
,
"a"
)
as
fout
:
...
@@ -121,7 +134,7 @@ def main(args):
...
@@ -121,7 +134,7 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-shot"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-shot
s
"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"test.jsonl"
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"test.jsonl"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
args
=
add_common_sglang_args_and_parse
(
parser
)
args
=
add_common_sglang_args_and_parse
(
parser
)
...
...
benchmark/gsm8k/download_data.sh
deleted
100755 → 0
View file @
c9b75917
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl
\ No newline at end of file
benchmark/hellaswag/README.md
View file @
e4d68afc
## Download data
```
wget https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl
```
## Run benchmark
## Run benchmark
### Benchmark sglang
### Benchmark sglang
...
...
benchmark/hellaswag/bench_other.py
View file @
e4d68afc
...
@@ -8,7 +8,7 @@ import numpy as np
...
@@ -8,7 +8,7 @@ import numpy as np
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
sglang.test.test_utils
import
add_common_other_args_and_parse
,
get_call_select
from
sglang.test.test_utils
import
add_common_other_args_and_parse
,
get_call_select
from
sglang.utils
import
read_jsonl
from
sglang.utils
import
download_and_cache_file
,
read_jsonl
def
get_one_example
(
lines
,
i
,
include_answer
):
def
get_one_example
(
lines
,
i
,
include_answer
):
...
@@ -26,25 +26,29 @@ def get_few_shot_examples(lines, k):
...
@@ -26,25 +26,29 @@ def get_few_shot_examples(lines, k):
def
main
(
args
):
def
main
(
args
):
lines
=
read_jsonl
(
args
.
data_path
)
# Select backend
call_select
=
get_call_select
(
args
)
# Read data
url
=
"https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename
=
download_and_cache_file
(
url
)
lines
=
list
(
read_jsonl
(
filename
))
# Construct prompts
# Construct prompts
k
=
args
.
num_shot
num_questions
=
args
.
num_questions
few_shot_examples
=
get_few_shot_examples
(
lines
,
k
)
num_shots
=
args
.
num_shots
few_shot_examples
=
get_few_shot_examples
(
lines
,
num_shots
)
questions
=
[]
questions
=
[]
choices
=
[]
choices
=
[]
labels
=
[]
labels
=
[]
for
i
in
range
(
len
(
lines
[:
args
.
num_questions
])):
for
i
in
range
(
len
(
lines
[:
num_questions
])):
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
choices
.
append
(
lines
[
i
][
"endings"
])
choices
.
append
(
lines
[
i
][
"endings"
])
labels
.
append
(
lines
[
i
][
"label"
])
labels
.
append
(
lines
[
i
][
"label"
])
preds
=
[
None
]
*
len
(
labels
)
preds
=
[
None
]
*
len
(
labels
)
# Select backend
call_select
=
get_call_select
(
args
)
# Run requests
# Run requests
if
args
.
backend
!=
"lmql"
:
if
args
.
backend
!=
"lmql"
:
# Use thread pool
# Use thread pool
...
@@ -65,7 +69,6 @@ def main(args):
...
@@ -65,7 +69,6 @@ def main(args):
total
=
len
(
questions
),
total
=
len
(
questions
),
)
)
)
)
else
:
else
:
# Use asyncio
# Use asyncio
async
def
batched_call
(
batch_size
):
async
def
batched_call
(
batch_size
):
...
@@ -108,7 +111,7 @@ def main(args):
...
@@ -108,7 +111,7 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-shot"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--num-shot
s
"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"hellaswag_val.jsonl"
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"hellaswag_val.jsonl"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
args
=
add_common_other_args_and_parse
(
parser
)
args
=
add_common_other_args_and_parse
(
parser
)
...
...
benchmark/hellaswag/bench_sglang.py
View file @
e4d68afc
...
@@ -4,11 +4,12 @@ import time
...
@@ -4,11 +4,12 @@ import time
import
numpy
as
np
import
numpy
as
np
from
sglang.api
import
set_default_backend
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
add_common_sglang_args_and_parse
,
add_common_sglang_args_and_parse
,
select_sglang_backend
,
select_sglang_backend
,
)
)
from
sglang.utils
import
read_jsonl
from
sglang.utils
import
download_and_cache_file
,
read_jsonl
def
get_one_example
(
lines
,
i
,
include_answer
):
def
get_one_example
(
lines
,
i
,
include_answer
):
...
@@ -26,16 +27,23 @@ def get_few_shot_examples(lines, k):
...
@@ -26,16 +27,23 @@ def get_few_shot_examples(lines, k):
def
main
(
args
):
def
main
(
args
):
lines
=
read_jsonl
(
args
.
data_path
)
# Select backend
set_default_backend
(
select_sglang_backend
(
args
))
# Read data
url
=
"https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl"
filename
=
download_and_cache_file
(
url
)
lines
=
list
(
read_jsonl
(
filename
))
# Construct prompts
# Construct prompts
k
=
args
.
num_shot
num_questions
=
args
.
num_questions
few_shot_examples
=
get_few_shot_examples
(
lines
,
k
)
num_shots
=
args
.
num_shots
few_shot_examples
=
get_few_shot_examples
(
lines
,
num_shots
)
questions
=
[]
questions
=
[]
choices
=
[]
choices
=
[]
labels
=
[]
labels
=
[]
for
i
in
range
(
len
(
lines
[:
args
.
num_questions
])):
for
i
in
range
(
len
(
lines
[:
num_questions
])):
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
choices
.
append
(
lines
[
i
][
"endings"
])
choices
.
append
(
lines
[
i
][
"endings"
])
labels
.
append
(
lines
[
i
][
"label"
])
labels
.
append
(
lines
[
i
][
"label"
])
...
@@ -56,15 +64,11 @@ def main(args):
...
@@ -56,15 +64,11 @@ def main(args):
########## SGL Program End ##########
########## SGL Program End ##########
#####################################
#####################################
# Select backend
backend
=
select_sglang_backend
(
args
)
# Run requests
# Run requests
tic
=
time
.
time
()
tic
=
time
.
time
()
rets
=
few_shot_hellaswag
.
run_batch
(
rets
=
few_shot_hellaswag
.
run_batch
(
arguments
,
arguments
,
temperature
=
0
,
temperature
=
0
,
backend
=
backend
,
num_threads
=
args
.
parallel
,
num_threads
=
args
.
parallel
,
progress_bar
=
True
,
progress_bar
=
True
,
)
)
...
@@ -95,7 +99,7 @@ def main(args):
...
@@ -95,7 +99,7 @@ def main(args):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-shot"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--num-shot
s
"
,
type
=
int
,
default
=
20
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"hellaswag_val.jsonl"
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"hellaswag_val.jsonl"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
args
=
add_common_sglang_args_and_parse
(
parser
)
args
=
add_common_sglang_args_and_parse
(
parser
)
...
...
examples/frontend_language/usage/llava_video/srt_example_llava_v.py
View file @
e4d68afc
...
@@ -7,6 +7,7 @@ python3 srt_example_llava_v.py
...
@@ -7,6 +7,7 @@ python3 srt_example_llava_v.py
import
argparse
import
argparse
import
csv
import
csv
import
json
import
os
import
os
import
time
import
time
...
@@ -223,7 +224,7 @@ if __name__ == "__main__":
...
@@ -223,7 +224,7 @@ if __name__ == "__main__":
tokenizer_path
=
tokenizer_path
,
tokenizer_path
=
tokenizer_path
,
port
=
cur_port
,
port
=
cur_port
,
additional_ports
=
[
cur_port
+
1
,
cur_port
+
2
,
cur_port
+
3
,
cur_port
+
4
],
additional_ports
=
[
cur_port
+
1
,
cur_port
+
2
,
cur_port
+
3
,
cur_port
+
4
],
model_override_args
=
model_override_args
,
json_
model_override_args
=
json
.
dumps
(
model_override_args
)
,
tp_size
=
1
,
tp_size
=
1
,
)
)
sgl
.
set_default_backend
(
runtime
)
sgl
.
set_default_backend
(
runtime
)
...
...
python/sglang/bench_serving.py
View file @
e4d68afc
...
@@ -298,34 +298,41 @@ class BenchmarkMetrics:
...
@@ -298,34 +298,41 @@ class BenchmarkMetrics:
median_e2e_latency_ms
:
float
median_e2e_latency_ms
:
float
default_sharegpt_path
=
"
ShareGPT_V3_unfiltered_cleaned_split.json"
SHAREGPT_URL
=
"https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/
ShareGPT_V3_unfiltered_cleaned_split.json"
def
download_sharegpt_dataset
(
path
):
def
download_and_cache_file
(
url
:
str
,
filename
:
Optional
[
str
]
=
None
):
url
=
"https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
"""Read and cache a file from a url."""
if
filename
is
None
:
filename
=
os
.
path
.
join
(
"/tmp"
,
url
.
split
(
"/"
)[
-
1
])
print
(
f
"Downloading dataset from
{
url
}
"
)
# Check if the cache file already exists
try
:
if
os
.
path
.
exists
(
filename
):
return
filename
print
(
f
"Downloading from
{
url
}
to
{
filename
}
"
)
# Stream the response to show the progress bar
response
=
requests
.
get
(
url
,
stream
=
True
)
response
=
requests
.
get
(
url
,
stream
=
True
)
response
.
raise_for_status
()
response
.
raise_for_status
()
# Check for request errors
# Total size of the file in bytes
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
block_size
=
8192
chunk_size
=
1024
# Download in chunks of 1KB
with
open
(
path
,
"wb"
)
as
f
,
tqdm
(
# Use tqdm to display the progress bar
desc
=
"Downloading"
,
with
open
(
filename
,
"wb"
)
as
f
,
tqdm
(
desc
=
filename
,
total
=
total_size
,
total
=
total_size
,
unit
=
"
i
B"
,
unit
=
"B"
,
unit_scale
=
True
,
unit_scale
=
True
,
unit_divisor
=
1024
,
unit_divisor
=
1024
,
)
as
progress_
bar
:
)
as
bar
:
for
data
in
response
.
iter_content
(
bloc
k_size
):
for
chunk
in
response
.
iter_content
(
chunk_size
=
chun
k_size
):
size
=
f
.
write
(
data
)
f
.
write
(
chunk
)
progress_
bar
.
update
(
size
)
bar
.
update
(
len
(
chunk
)
)
print
(
f
"Dataset downloaded and saved to
{
path
}
"
)
return
filename
except
requests
.
RequestException
as
e
:
raise
Exception
(
f
"Failed to download dataset:
{
e
}
"
)
def
sample_sharegpt_requests
(
def
sample_sharegpt_requests
(
...
@@ -338,13 +345,8 @@ def sample_sharegpt_requests(
...
@@ -338,13 +345,8 @@ def sample_sharegpt_requests(
raise
ValueError
(
"output_len too small"
)
raise
ValueError
(
"output_len too small"
)
# Download sharegpt if necessary
# Download sharegpt if necessary
if
not
os
.
path
.
isfile
(
dataset_path
)
and
not
os
.
path
.
isfile
(
default_sharegpt_path
):
if
not
os
.
path
.
isfile
(
dataset_path
):
download_sharegpt_dataset
(
default_sharegpt_path
)
dataset_path
=
download_and_cache_file
(
SHAREGPT_URL
)
dataset_path
=
default_sharegpt_path
else
:
dataset_path
=
(
dataset_path
if
os
.
path
.
isfile
(
dataset_path
)
else
default_sharegpt_path
)
# Load the dataset.
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
with
open
(
dataset_path
)
as
f
:
...
@@ -412,15 +414,8 @@ def sample_random_requests(
...
@@ -412,15 +414,8 @@ def sample_random_requests(
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
# Download sharegpt if necessary
if
not
os
.
path
.
isfile
(
dataset_path
)
and
not
os
.
path
.
isfile
(
if
not
os
.
path
.
isfile
(
dataset_path
):
default_sharegpt_path
dataset_path
=
download_and_cache_file
(
SHAREGPT_URL
)
):
download_sharegpt_dataset
(
default_sharegpt_path
)
dataset_path
=
default_sharegpt_path
else
:
dataset_path
=
(
dataset_path
if
os
.
path
.
isfile
(
dataset_path
)
else
default_sharegpt_path
)
# Load the dataset.
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
with
open
(
dataset_path
)
as
f
:
...
...
python/sglang/launch_server.py
View file @
e4d68afc
...
@@ -9,10 +9,9 @@ from sglang.srt.utils import kill_child_process
...
@@ -9,10 +9,9 @@ from sglang.srt.utils import kill_child_process
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
server_args
=
prepare_server_args
(
sys
.
argv
[
1
:])
server_args
=
prepare_server_args
(
sys
.
argv
[
1
:])
model_override_args
=
server_args
.
json_model_override_args
try
:
try
:
launch_server
(
server_args
,
model_override_args
=
model_override_args
)
launch_server
(
server_args
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
e
raise
e
finally
:
finally
:
...
...
python/sglang/launch_server_llavavid.py
View file @
e4d68afc
"""Launch the inference server for Llava-video model."""
"""Launch the inference server for Llava-video model."""
import
json
import
sys
import
sys
from
sglang.srt.server
import
launch_server
,
prepare_server_args
from
sglang.srt.server
import
launch_server
,
prepare_server_args
...
@@ -19,5 +20,6 @@ if __name__ == "__main__":
...
@@ -19,5 +20,6 @@ if __name__ == "__main__":
model_override_args
[
"model_max_length"
]
=
4096
*
2
model_override_args
[
"model_max_length"
]
=
4096
*
2
if
"34b"
in
server_args
.
model_path
.
lower
():
if
"34b"
in
server_args
.
model_path
.
lower
():
model_override_args
[
"image_token_index"
]
=
64002
model_override_args
[
"image_token_index"
]
=
64002
server_args
.
json_model_override_args
=
json
.
dumps
(
model_override_args
)
launch_server
(
server_args
,
model_override_args
,
None
)
launch_server
(
server_args
)
python/sglang/srt/constrained/fsm_cache.py
View file @
e4d68afc
...
@@ -16,6 +16,7 @@ limitations under the License.
...
@@ -16,6 +16,7 @@ limitations under the License.
"""Cache for the compressed finite state machine."""
"""Cache for the compressed finite state machine."""
from
outlines.fsm.json_schema
import
build_regex_from_schema
from
outlines.fsm.json_schema
import
build_regex_from_schema
from
transformers
import
AutoTokenizer
from
sglang.srt.constrained
import
RegexGuide
,
TransformerTokenizer
from
sglang.srt.constrained
import
RegexGuide
,
TransformerTokenizer
from
sglang.srt.constrained.base_tool_cache
import
BaseToolCache
from
sglang.srt.constrained.base_tool_cache
import
BaseToolCache
...
@@ -28,12 +29,9 @@ class FSMCache(BaseToolCache):
...
@@ -28,12 +29,9 @@ class FSMCache(BaseToolCache):
tokenizer_args_dict
,
tokenizer_args_dict
,
enable
=
True
,
enable
=
True
,
skip_tokenizer_init
=
False
,
skip_tokenizer_init
=
False
,
json_schema_mode
=
False
,
):
):
super
().
__init__
(
enable
=
enable
)
super
().
__init__
(
enable
=
enable
)
self
.
json_schema_mode
=
json_schema_mode
if
(
if
(
skip_tokenizer_init
skip_tokenizer_init
or
tokenizer_path
.
endswith
(
".json"
)
or
tokenizer_path
.
endswith
(
".json"
)
...
@@ -42,15 +40,8 @@ class FSMCache(BaseToolCache):
...
@@ -42,15 +40,8 @@ class FSMCache(BaseToolCache):
# Do not support TiktokenTokenizer or SentencePieceTokenizer
# Do not support TiktokenTokenizer or SentencePieceTokenizer
return
return
from
importlib.metadata
import
version
if
version
(
"outlines"
)
>=
"0.0.35"
:
from
transformers
import
AutoTokenizer
tokenizer_args_dict
.
setdefault
(
"padding_side"
,
"left"
)
tokenizer_args_dict
.
setdefault
(
"padding_side"
,
"left"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_path
,
**
tokenizer_args_dict
)
tokenizer_path
,
**
tokenizer_args_dict
)
try
:
try
:
self
.
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer
)
self
.
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer
)
except
AttributeError
:
except
AttributeError
:
...
@@ -72,14 +63,14 @@ class FSMCache(BaseToolCache):
...
@@ -72,14 +63,14 @@ class FSMCache(BaseToolCache):
self
.
outlines_tokenizer
.
vocabulary
=
(
self
.
outlines_tokenizer
.
vocabulary
=
(
self
.
outlines_tokenizer
.
tokenizer
.
get_vocab
()
self
.
outlines_tokenizer
.
tokenizer
.
get_vocab
()
)
)
def
init_value
(
self
,
key
):
key_type
,
key_string
=
key
if
key_type
==
"json"
:
regex
=
build_regex_from_schema
(
key_string
,
whitespace_pattern
=
r
"[\n\t ]*"
)
elif
key_type
==
"regex"
:
regex
=
key_string
else
:
else
:
self
.
outlines_tokenizer
=
TransformerTokenizer
(
raise
ValueError
(
f
"Invalid key_type:
{
key_type
}
"
)
tokenizer_path
,
**
tokenizer_args_dict
)
def
init_value
(
self
,
value
):
if
self
.
json_schema_mode
:
regex
=
build_regex_from_schema
(
value
,
whitespace_pattern
=
r
"[\n\t ]*"
)
return
RegexGuide
(
regex
,
self
.
outlines_tokenizer
),
regex
return
RegexGuide
(
regex
,
self
.
outlines_tokenizer
),
regex
else
:
return
RegexGuide
(
value
,
self
.
outlines_tokenizer
)
python/sglang/srt/managers/controller_multi.py
View file @
e4d68afc
...
@@ -71,12 +71,10 @@ class ControllerMulti:
...
@@ -71,12 +71,10 @@ class ControllerMulti:
self
,
self
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
port_args
:
PortArgs
,
model_override_args
,
):
):
# Parse args
# Parse args
self
.
server_args
=
server_args
self
.
server_args
=
server_args
self
.
port_args
=
port_args
self
.
port_args
=
port_args
self
.
model_override_args
=
model_override_args
self
.
load_balance_method
=
LoadBalanceMethod
.
from_str
(
self
.
load_balance_method
=
LoadBalanceMethod
.
from_str
(
server_args
.
load_balance_method
server_args
.
load_balance_method
)
)
...
@@ -114,7 +112,6 @@ class ControllerMulti:
...
@@ -114,7 +112,6 @@ class ControllerMulti:
self
.
server_args
,
self
.
server_args
,
self
.
port_args
,
self
.
port_args
,
pipe_controller_writer
,
pipe_controller_writer
,
self
.
model_override_args
,
True
,
True
,
gpu_ids
,
gpu_ids
,
dp_worker_id
,
dp_worker_id
,
...
@@ -189,14 +186,13 @@ def start_controller_process(
...
@@ -189,14 +186,13 @@ def start_controller_process(
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
port_args
:
PortArgs
,
pipe_writer
,
pipe_writer
,
model_override_args
:
dict
,
):
):
"""Start a controller process."""
"""Start a controller process."""
configure_logger
(
server_args
)
configure_logger
(
server_args
)
try
:
try
:
controller
=
ControllerMulti
(
server_args
,
port_args
,
model_override_args
)
controller
=
ControllerMulti
(
server_args
,
port_args
)
except
Exception
:
except
Exception
:
pipe_writer
.
send
(
get_exception_traceback
())
pipe_writer
.
send
(
get_exception_traceback
())
raise
raise
...
...
python/sglang/srt/managers/controller_single.py
View file @
e4d68afc
...
@@ -40,7 +40,6 @@ class ControllerSingle:
...
@@ -40,7 +40,6 @@ class ControllerSingle:
self
,
self
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
port_args
:
PortArgs
,
model_override_args
:
dict
,
gpu_ids
:
List
[
int
],
gpu_ids
:
List
[
int
],
is_data_parallel_worker
:
bool
,
is_data_parallel_worker
:
bool
,
dp_worker_id
:
int
,
dp_worker_id
:
int
,
...
@@ -76,7 +75,6 @@ class ControllerSingle:
...
@@ -76,7 +75,6 @@ class ControllerSingle:
tp_rank_range
,
tp_rank_range
,
server_args
,
server_args
,
port_args
.
nccl_ports
[
dp_worker_id
],
port_args
.
nccl_ports
[
dp_worker_id
],
model_override_args
,
)
)
# Launch tp rank 0
# Launch tp rank 0
...
@@ -85,7 +83,6 @@ class ControllerSingle:
...
@@ -85,7 +83,6 @@ class ControllerSingle:
0
,
0
,
server_args
,
server_args
,
port_args
.
nccl_ports
[
dp_worker_id
],
port_args
.
nccl_ports
[
dp_worker_id
],
model_override_args
,
)
)
self
.
tp_cpu_group
=
self
.
tp_server
.
model_runner
.
tp_group
.
cpu_group
self
.
tp_cpu_group
=
self
.
tp_server
.
model_runner
.
tp_group
.
cpu_group
...
@@ -126,7 +123,6 @@ def start_controller_process(
...
@@ -126,7 +123,6 @@ def start_controller_process(
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
port_args
:
PortArgs
,
pipe_writer
:
multiprocessing
.
connection
.
Connection
,
pipe_writer
:
multiprocessing
.
connection
.
Connection
,
model_override_args
:
dict
,
is_data_parallel_worker
:
bool
=
False
,
is_data_parallel_worker
:
bool
=
False
,
gpu_ids
:
List
[
int
]
=
None
,
gpu_ids
:
List
[
int
]
=
None
,
dp_worker_id
:
int
=
None
,
dp_worker_id
:
int
=
None
,
...
@@ -149,7 +145,6 @@ def start_controller_process(
...
@@ -149,7 +145,6 @@ def start_controller_process(
controller
=
ControllerSingle
(
controller
=
ControllerSingle
(
server_args
,
server_args
,
port_args
,
port_args
,
model_override_args
,
gpu_ids
,
gpu_ids
,
is_data_parallel_worker
,
is_data_parallel_worker
,
dp_worker_id
,
dp_worker_id
,
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
e4d68afc
...
@@ -18,6 +18,7 @@ limitations under the License.
...
@@ -18,6 +18,7 @@ limitations under the License.
import
asyncio
import
asyncio
import
concurrent.futures
import
concurrent.futures
import
dataclasses
import
dataclasses
import
json
import
logging
import
logging
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
os
import
os
...
@@ -77,7 +78,6 @@ class TokenizerManager:
...
@@ -77,7 +78,6 @@ class TokenizerManager:
self
,
self
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
port_args
:
PortArgs
,
model_override_args
:
dict
=
None
,
):
):
self
.
server_args
=
server_args
self
.
server_args
=
server_args
...
@@ -95,7 +95,7 @@ class TokenizerManager:
...
@@ -95,7 +95,7 @@ class TokenizerManager:
self
.
hf_config
=
get_config
(
self
.
hf_config
=
get_config
(
self
.
model_path
,
self
.
model_path
,
trust_remote_code
=
server_args
.
trust_remote_code
,
trust_remote_code
=
server_args
.
trust_remote_code
,
model_override_args
=
model_override_args
,
model_override_args
=
json
.
loads
(
server_args
.
json_
model_override_args
)
,
)
)
self
.
is_generation
=
is_generation_model
(
self
.
is_generation
=
is_generation_model
(
self
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
self
.
hf_config
.
architectures
,
self
.
server_args
.
is_embedding
...
...
python/sglang/srt/managers/tp_worker.py
View file @
e4d68afc
...
@@ -15,13 +15,14 @@ limitations under the License.
...
@@ -15,13 +15,14 @@ limitations under the License.
"""A tensor parallel worker."""
"""A tensor parallel worker."""
import
json
import
logging
import
logging
import
multiprocessing
import
multiprocessing
import
os
import
os
import
pickle
import
pickle
import
time
import
time
import
warnings
import
warnings
from
typing
import
Any
,
List
,
Optional
,
Union
from
typing
import
Any
,
List
,
Optional
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -66,6 +67,7 @@ from sglang.utils import get_exception_traceback
...
@@ -66,6 +67,7 @@ from sglang.utils import get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Crash on warning if we are running CI tests
crash_on_warning
=
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
crash_on_warning
=
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
)
==
"true"
...
@@ -76,11 +78,10 @@ class ModelTpServer:
...
@@ -76,11 +78,10 @@ class ModelTpServer:
tp_rank
:
int
,
tp_rank
:
int
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
nccl_port
:
int
,
nccl_port
:
int
,
model_override_args
:
dict
,
):
):
suppress_other_loggers
()
suppress_other_loggers
()
#
Copy
arguments
#
Parse
arguments
self
.
gpu_id
=
gpu_id
self
.
gpu_id
=
gpu_id
self
.
tp_rank
=
tp_rank
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
server_args
.
tp_size
self
.
tp_size
=
server_args
.
tp_size
...
@@ -93,9 +94,8 @@ class ModelTpServer:
...
@@ -93,9 +94,8 @@ class ModelTpServer:
server_args
.
model_path
,
server_args
.
model_path
,
server_args
.
trust_remote_code
,
server_args
.
trust_remote_code
,
context_length
=
server_args
.
context_length
,
context_length
=
server_args
.
context_length
,
model_override_args
=
model_override_args
,
model_override_args
=
json
.
loads
(
server_args
.
json_
model_override_args
)
,
)
)
self
.
model_runner
=
ModelRunner
(
self
.
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
model_config
=
self
.
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
...
@@ -136,7 +136,7 @@ class ModelTpServer:
...
@@ -136,7 +136,7 @@ class ModelTpServer:
self
.
max_total_num_tokens
-
1
,
self
.
max_total_num_tokens
-
1
,
)
)
# Sync random seed
# Sync random seed
across TP workers
server_args
.
random_seed
=
broadcast_recv_input
(
server_args
.
random_seed
=
broadcast_recv_input
(
[
server_args
.
random_seed
],
[
server_args
.
random_seed
],
self
.
tp_rank
,
self
.
tp_rank
,
...
@@ -144,7 +144,7 @@ class ModelTpServer:
...
@@ -144,7 +144,7 @@ class ModelTpServer:
)[
0
]
)[
0
]
set_random_seed
(
server_args
.
random_seed
)
set_random_seed
(
server_args
.
random_seed
)
# Print info
# Print
debug
info
logger
.
info
(
logger
.
info
(
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"max_total_num_tokens=
{
self
.
max_total_num_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
f
"max_prefill_tokens=
{
self
.
max_prefill_tokens
}
, "
...
@@ -181,7 +181,7 @@ class ModelTpServer:
...
@@ -181,7 +181,7 @@ class ModelTpServer:
self
.
num_generated_tokens
=
0
self
.
num_generated_tokens
=
0
self
.
last_stats_tic
=
time
.
time
()
self
.
last_stats_tic
=
time
.
time
()
#
C
hunked prefill
#
Init c
hunked prefill
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
chunked_prefill_size
=
server_args
.
chunked_prefill_size
self
.
current_inflight_req
=
None
self
.
current_inflight_req
=
None
self
.
is_mixed_chunk
=
(
self
.
is_mixed_chunk
=
(
...
@@ -197,16 +197,6 @@ class ModelTpServer:
...
@@ -197,16 +197,6 @@ class ModelTpServer:
"trust_remote_code"
:
server_args
.
trust_remote_code
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
json_schema_mode
=
False
,
)
self
.
json_fsm_cache
=
FSMCache
(
server_args
.
tokenizer_path
,
{
"tokenizer_mode"
:
server_args
.
tokenizer_mode
,
"trust_remote_code"
:
server_args
.
trust_remote_code
,
},
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
,
json_schema_mode
=
True
,
)
)
self
.
jump_forward_cache
=
JumpForwardCache
()
self
.
jump_forward_cache
=
JumpForwardCache
()
...
@@ -227,11 +217,12 @@ class ModelTpServer:
...
@@ -227,11 +217,12 @@ class ModelTpServer:
try
:
try
:
# Recv requests
# Recv requests
for
recv_req
in
recv_reqs
:
for
recv_req
in
recv_reqs
:
if
isinstance
(
if
isinstance
(
recv_req
,
TokenizedGenerateReqInput
):
recv_req
,
(
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
)
):
self
.
handle_generate_request
(
recv_req
)
self
.
handle_generate_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
TokenizedEmbeddingReqInput
):
self
.
handle_embedding_request
(
recv_req
)
self
.
do_not_get_new_batch
=
False
elif
isinstance
(
recv_req
,
FlushCacheReq
):
elif
isinstance
(
recv_req
,
FlushCacheReq
):
self
.
flush_cache
()
self
.
flush_cache
()
elif
isinstance
(
recv_req
,
AbortReq
):
elif
isinstance
(
recv_req
,
AbortReq
):
...
@@ -331,12 +322,11 @@ class ModelTpServer:
...
@@ -331,12 +322,11 @@ class ModelTpServer:
def
handle_generate_request
(
def
handle_generate_request
(
self
,
self
,
recv_req
:
Union
[
TokenizedGenerateReqInput
,
TokenizedEmbeddingReqInput
],
recv_req
:
TokenizedGenerateReqInput
,
):
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
tokenizer
=
self
.
tokenizer
req
.
tokenizer
=
self
.
tokenizer
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
sampling_params
=
recv_req
.
sampling_params
if
self
.
model_runner
.
is_generation
:
req
.
pixel_values
=
recv_req
.
pixel_values
req
.
pixel_values
=
recv_req
.
pixel_values
if
req
.
pixel_values
is
not
None
:
if
req
.
pixel_values
is
not
None
:
# Use image hash as fake token_ids, which is then used
# Use image hash as fake token_ids, which is then used
...
@@ -365,22 +355,22 @@ class ModelTpServer:
...
@@ -365,22 +355,22 @@ class ModelTpServer:
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
req
.
stream
=
recv_req
.
stream
# Init regex fsm fron json
# Init regex FSM
if
(
req
.
sampling_params
.
json_schema
is
not
None
or
req
.
sampling_params
.
regex
is
not
None
):
if
req
.
sampling_params
.
json_schema
is
not
None
:
if
req
.
sampling_params
.
json_schema
is
not
None
:
req
.
regex_fsm
,
computed_regex_string
=
self
.
json
_fsm_cache
.
query
(
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex
_fsm_cache
.
query
(
req
.
sampling_params
.
json_schema
(
"json"
,
req
.
sampling_params
.
json_schema
)
)
)
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
computed_regex_string
)
# Init regex fsm
elif
req
.
sampling_params
.
regex
is
not
None
:
elif
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
query
(
req
.
sampling_params
.
regex
)
req
.
regex_fsm
,
computed_regex_string
=
self
.
regex_fsm_cache
.
query
(
(
"regex"
,
req
.
sampling_params
.
regex
)
)
if
not
self
.
disable_regex_jump_forward
:
if
not
self
.
disable_regex_jump_forward
:
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
req
.
jump_forward_map
=
self
.
jump_forward_cache
.
query
(
req
.
sampling_params
.
regex
computed_regex_string
)
)
# Truncate prompts that are too long
# Truncate prompts that are too long
...
@@ -390,8 +380,6 @@ class ModelTpServer:
...
@@ -390,8 +380,6 @@ class ModelTpServer:
"the max context length. Truncated!!!"
"the max context length. Truncated!!!"
)
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
if
self
.
model_runner
.
is_generation
:
req
.
sampling_params
.
max_new_tokens
=
min
(
req
.
sampling_params
.
max_new_tokens
=
min
(
(
(
req
.
sampling_params
.
max_new_tokens
req
.
sampling_params
.
max_new_tokens
...
@@ -403,6 +391,24 @@ class ModelTpServer:
...
@@ -403,6 +391,24 @@ class ModelTpServer:
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
def
handle_embedding_request
(
self
,
recv_req
:
TokenizedEmbeddingReqInput
,
):
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
tokenizer
=
self
.
tokenizer
req
.
sampling_params
=
recv_req
.
sampling_params
# Truncate prompts that are too long
if
len
(
req
.
origin_input_ids
)
>=
self
.
max_req_input_len
:
logger
.
warn
(
"Request length is longer than the KV cache pool size or "
"the max context length. Truncated!!!"
)
req
.
origin_input_ids
=
req
.
origin_input_ids
[:
self
.
max_req_input_len
]
self
.
waiting_queue
.
append
(
req
)
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
def
get_new_prefill_batch
(
self
)
->
Optional
[
ScheduleBatch
]:
running_bs
=
(
running_bs
=
(
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
len
(
self
.
running_batch
.
reqs
)
if
self
.
running_batch
is
not
None
else
0
...
@@ -892,7 +898,6 @@ def run_tp_server(
...
@@ -892,7 +898,6 @@ def run_tp_server(
tp_rank
:
int
,
tp_rank
:
int
,
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
nccl_port
:
int
,
nccl_port
:
int
,
model_override_args
:
dict
,
):
):
"""Run a tensor parallel model server."""
"""Run a tensor parallel model server."""
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
configure_logger
(
server_args
,
prefix
=
f
" TP
{
tp_rank
}
"
)
...
@@ -903,7 +908,6 @@ def run_tp_server(
...
@@ -903,7 +908,6 @@ def run_tp_server(
tp_rank
,
tp_rank
,
server_args
,
server_args
,
nccl_port
,
nccl_port
,
model_override_args
,
)
)
tp_cpu_group
=
model_server
.
model_runner
.
tp_group
.
cpu_group
tp_cpu_group
=
model_server
.
model_runner
.
tp_group
.
cpu_group
...
@@ -920,14 +924,13 @@ def launch_tp_servers(
...
@@ -920,14 +924,13 @@ def launch_tp_servers(
tp_rank_range
:
List
[
int
],
tp_rank_range
:
List
[
int
],
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
nccl_port
:
int
,
nccl_port
:
int
,
model_override_args
:
dict
,
):
):
"""Launch multiple tensor parallel servers."""
"""Launch multiple tensor parallel servers."""
procs
=
[]
procs
=
[]
for
i
in
tp_rank_range
:
for
i
in
tp_rank_range
:
proc
=
multiprocessing
.
Process
(
proc
=
multiprocessing
.
Process
(
target
=
run_tp_server
,
target
=
run_tp_server
,
args
=
(
gpu_ids
[
i
],
i
,
server_args
,
nccl_port
,
model_override_args
),
args
=
(
gpu_ids
[
i
],
i
,
server_args
,
nccl_port
),
)
)
proc
.
start
()
proc
.
start
()
procs
.
append
(
proc
)
procs
.
append
(
proc
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e4d68afc
...
@@ -18,6 +18,7 @@ limitations under the License.
...
@@ -18,6 +18,7 @@ limitations under the License.
import
gc
import
gc
import
importlib
import
importlib
import
importlib.resources
import
importlib.resources
import
json
import
logging
import
logging
import
pkgutil
import
pkgutil
from
functools
import
lru_cache
from
functools
import
lru_cache
...
...
python/sglang/srt/server.py
View file @
e4d68afc
...
@@ -272,7 +272,6 @@ async def retrieve_file_content(file_id: str):
...
@@ -272,7 +272,6 @@ async def retrieve_file_content(file_id: str):
def
launch_server
(
def
launch_server
(
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
model_override_args
:
Optional
[
dict
]
=
None
,
pipe_finish_writer
:
Optional
[
mp
.
connection
.
Connection
]
=
None
,
pipe_finish_writer
:
Optional
[
mp
.
connection
.
Connection
]
=
None
,
):
):
"""Launch an HTTP server."""
"""Launch an HTTP server."""
...
@@ -317,7 +316,6 @@ def launch_server(
...
@@ -317,7 +316,6 @@ def launch_server(
tp_rank_range
,
tp_rank_range
,
server_args
,
server_args
,
ports
[
3
],
ports
[
3
],
model_override_args
,
)
)
try
:
try
:
...
@@ -328,7 +326,7 @@ def launch_server(
...
@@ -328,7 +326,7 @@ def launch_server(
return
return
# Launch processes
# Launch processes
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
,
model_override_args
)
tokenizer_manager
=
TokenizerManager
(
server_args
,
port_args
)
if
server_args
.
chat_template
:
if
server_args
.
chat_template
:
load_chat_template_for_openai_api
(
tokenizer_manager
,
server_args
.
chat_template
)
load_chat_template_for_openai_api
(
tokenizer_manager
,
server_args
.
chat_template
)
pipe_controller_reader
,
pipe_controller_writer
=
mp
.
Pipe
(
duplex
=
False
)
pipe_controller_reader
,
pipe_controller_writer
=
mp
.
Pipe
(
duplex
=
False
)
...
@@ -341,7 +339,7 @@ def launch_server(
...
@@ -341,7 +339,7 @@ def launch_server(
proc_controller
=
mp
.
Process
(
proc_controller
=
mp
.
Process
(
target
=
start_controller_process
,
target
=
start_controller_process
,
args
=
(
server_args
,
port_args
,
pipe_controller_writer
,
model_override_args
),
args
=
(
server_args
,
port_args
,
pipe_controller_writer
),
)
)
proc_controller
.
start
()
proc_controller
.
start
()
...
@@ -501,7 +499,6 @@ class Runtime:
...
@@ -501,7 +499,6 @@ class Runtime:
def
__init__
(
def
__init__
(
self
,
self
,
log_level
:
str
=
"error"
,
log_level
:
str
=
"error"
,
model_override_args
:
Optional
[
dict
]
=
None
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -525,7 +522,7 @@ class Runtime:
...
@@ -525,7 +522,7 @@ class Runtime:
proc
=
mp
.
Process
(
proc
=
mp
.
Process
(
target
=
launch_server
,
target
=
launch_server
,
args
=
(
self
.
server_args
,
model_override_args
,
pipe_writer
),
args
=
(
self
.
server_args
,
pipe_writer
),
)
)
proc
.
start
()
proc
.
start
()
pipe_writer
.
close
()
pipe_writer
.
close
()
...
...
python/sglang/srt/server_args.py
View file @
e4d68afc
...
@@ -76,6 +76,14 @@ class ServerArgs:
...
@@ -76,6 +76,14 @@ class ServerArgs:
dp_size
:
int
=
1
dp_size
:
int
=
1
load_balance_method
:
str
=
"round_robin"
load_balance_method
:
str
=
"round_robin"
# Distributed args
nccl_init_addr
:
Optional
[
str
]
=
None
nnodes
:
int
=
1
node_rank
:
Optional
[
int
]
=
None
# Model override args in JSON
json_model_override_args
:
str
=
"{}"
# Optimization/debug options
# Optimization/debug options
disable_flashinfer
:
bool
=
False
disable_flashinfer
:
bool
=
False
disable_flashinfer_sampling
:
bool
=
False
disable_flashinfer_sampling
:
bool
=
False
...
@@ -91,14 +99,6 @@ class ServerArgs:
...
@@ -91,14 +99,6 @@ class ServerArgs:
enable_mla
:
bool
=
False
enable_mla
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
# Distributed args
nccl_init_addr
:
Optional
[
str
]
=
None
nnodes
:
int
=
1
node_rank
:
Optional
[
int
]
=
None
# Model override args in JSON
json_model_override_args
:
Optional
[
dict
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer_path
is
None
:
if
self
.
tokenizer_path
is
None
:
self
.
tokenizer_path
=
self
.
model_path
self
.
tokenizer_path
=
self
.
model_path
...
@@ -385,6 +385,14 @@ class ServerArgs:
...
@@ -385,6 +385,14 @@ class ServerArgs:
)
)
parser
.
add_argument
(
"--node-rank"
,
type
=
int
,
help
=
"The node rank."
)
parser
.
add_argument
(
"--node-rank"
,
type
=
int
,
help
=
"The node rank."
)
# Model override args
parser
.
add_argument
(
"--json-model-override-args"
,
type
=
str
,
help
=
"A dictionary in JSON string format used to override default model configurations."
,
default
=
ServerArgs
.
json_model_override_args
,
)
# Optimization/debug options
# Optimization/debug options
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-flashinfer"
,
"--disable-flashinfer"
,
...
@@ -459,22 +467,10 @@ class ServerArgs:
...
@@ -459,22 +467,10 @@ class ServerArgs:
help
=
"Turn on memory efficient weight loading with quantization (quantize per layer during loading)."
,
help
=
"Turn on memory efficient weight loading with quantization (quantize per layer during loading)."
,
)
)
# Model override args
parser
.
add_argument
(
"--json-model-override-args"
,
type
=
str
,
help
=
"A dictionary in JSON string format used to override default model configurations."
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
args
.
tp_size
=
args
.
tensor_parallel_size
args
.
tp_size
=
args
.
tensor_parallel_size
args
.
dp_size
=
args
.
data_parallel_size
args
.
dp_size
=
args
.
data_parallel_size
args
.
json_model_override_args
=
(
json
.
loads
(
args
.
json_model_override_args
)
if
args
.
json_model_override_args
else
None
)
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
attrs
=
[
attr
.
name
for
attr
in
dataclasses
.
fields
(
cls
)]
return
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
return
cls
(
**
{
attr
:
getattr
(
args
,
attr
)
for
attr
in
attrs
})
...
@@ -498,7 +494,7 @@ class ServerArgs:
...
@@ -498,7 +494,7 @@ class ServerArgs:
self
.
disable_flashinfer
=
False
self
.
disable_flashinfer
=
False
def
prepare_server_args
(
arg
s
:
argparse
.
Namespace
)
->
ServerArgs
:
def
prepare_server_args
(
arg
v
:
List
[
str
]
)
->
ServerArgs
:
"""
"""
Prepare the server arguments from the command line arguments.
Prepare the server arguments from the command line arguments.
...
@@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
...
@@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs:
"""
"""
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
ServerArgs
.
add_cli_args
(
parser
)
raw_args
=
parser
.
parse_args
(
arg
s
)
raw_args
=
parser
.
parse_args
(
arg
v
)
server_args
=
ServerArgs
.
from_cli_args
(
raw_args
)
server_args
=
ServerArgs
.
from_cli_args
(
raw_args
)
return
server_args
return
server_args
...
...
python/sglang/test/few_shot_gsm8k.py
0 → 100644
View file @
e4d68afc
"""
Run few-shot GSM-8K evaluation.
Usage:
python3 -m sglang.test.few_shot_gsm8k --num-questions 200
"""
import
argparse
import
ast
import
re
import
time
import
numpy
as
np
from
sglang.api
import
set_default_backend
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.utils
import
download_and_cache_file
,
dump_state_text
,
read_jsonl
INVALID
=
-
9999999
def
get_one_example
(
lines
,
i
,
include_answer
):
ret
=
"Question: "
+
lines
[
i
][
"question"
]
+
"
\n
Answer:"
if
include_answer
:
ret
+=
" "
+
lines
[
i
][
"answer"
]
return
ret
def
get_few_shot_examples
(
lines
,
k
):
ret
=
""
for
i
in
range
(
k
):
ret
+=
get_one_example
(
lines
,
i
,
True
)
+
"
\n\n
"
return
ret
def
get_answer_value
(
answer_str
):
answer_str
=
answer_str
.
replace
(
","
,
""
)
numbers
=
re
.
findall
(
r
"\d+"
,
answer_str
)
if
len
(
numbers
)
<
1
:
return
INVALID
try
:
return
ast
.
literal_eval
(
numbers
[
-
1
])
except
SyntaxError
:
return
INVALID
def
main
(
args
):
# Select backend
set_default_backend
(
RuntimeEndpoint
(
f
"
{
args
.
host
}
:
{
args
.
port
}
"
))
# Read data
url
=
"https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
filename
=
download_and_cache_file
(
url
)
lines
=
list
(
read_jsonl
(
filename
))
# Construct prompts
num_questions
=
args
.
num_questions
num_shots
=
args
.
num_shots
few_shot_examples
=
get_few_shot_examples
(
lines
,
num_shots
)
questions
=
[]
labels
=
[]
for
i
in
range
(
len
(
lines
[:
num_questions
])):
questions
.
append
(
get_one_example
(
lines
,
i
,
False
))
labels
.
append
(
get_answer_value
(
lines
[
i
][
"answer"
]))
assert
all
(
l
!=
INVALID
for
l
in
labels
)
arguments
=
[{
"question"
:
q
}
for
q
in
questions
]
#####################################
######### SGL Program Begin #########
#####################################
import
sglang
as
sgl
@
sgl
.
function
def
few_shot_gsm8k
(
s
,
question
):
s
+=
few_shot_examples
+
question
s
+=
sgl
.
gen
(
"answer"
,
max_tokens
=
512
,
stop
=
[
"Question"
,
"Assistant:"
,
"<|separator|>"
]
)
#####################################
########## SGL Program End ##########
#####################################
# Run requests
tic
=
time
.
time
()
states
=
few_shot_gsm8k
.
run_batch
(
arguments
,
temperature
=
0
,
num_threads
=
args
.
parallel
,
progress_bar
=
True
,
)
latency
=
time
.
time
()
-
tic
preds
=
[]
for
i
in
range
(
len
(
states
)):
preds
.
append
(
get_answer_value
(
states
[
i
][
"answer"
]))
# print(f"{preds=}")
# print(f"{labels=}")
# Compute accuracy
acc
=
np
.
mean
(
np
.
array
(
preds
)
==
np
.
array
(
labels
))
invalid
=
np
.
mean
(
np
.
array
(
preds
)
==
INVALID
)
# Compute speed
num_output_tokens
=
sum
(
s
.
get_meta_info
(
"answer"
)[
"completion_tokens"
]
for
s
in
states
)
output_throughput
=
num_output_tokens
/
latency
# Print results
print
(
f
"Accuracy:
{
acc
:.
3
f
}
"
)
print
(
f
"Invalid:
{
invalid
:.
3
f
}
"
)
print
(
f
"Latency:
{
latency
:.
3
f
}
s"
)
print
(
f
"Output throughput:
{
output_throughput
:.
3
f
}
token/s"
)
# Dump results
dump_state_text
(
"tmp_output_gsm8k.txt"
,
states
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--num-shots"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--data-path"
,
type
=
str
,
default
=
"test.jsonl"
)
parser
.
add_argument
(
"--num-questions"
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
"--parallel"
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
30000
)
args
=
parser
.
parse_args
()
main
(
args
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment