Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2e6e62e1
Unverified
Commit
2e6e62e1
authored
Jun 26, 2024
by
Lianmin Zheng
Committed by
GitHub
Jun 26, 2024
Browse files
Increase the number of thread limitation for tp worker managers. (#567)
parent
a385ee27
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
148 additions
and
84 deletions
+148
-84
benchmark/latency_throughput/bench_throughput.py
benchmark/latency_throughput/bench_throughput.py
+9
-4
benchmark/latency_throughput/test_latency.py
benchmark/latency_throughput/test_latency.py
+3
-2
benchmark/mmlu/bench_sglang.py
benchmark/mmlu/bench_sglang.py
+47
-55
playground/load_tokenizer.py
playground/load_tokenizer.py
+10
-5
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+2
-1
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+43
-3
python/sglang/srt/layers/fused_moe.py
python/sglang/srt/layers/fused_moe.py
+30
-11
python/sglang/srt/managers/controller/manager_single.py
python/sglang/srt/managers/controller/manager_single.py
+3
-2
python/sglang/srt/managers/controller/tp_worker.py
python/sglang/srt/managers/controller/tp_worker.py
+1
-1
No files found.
benchmark/latency_throughput/bench_throughput.py
View file @
2e6e62e1
...
...
@@ -250,9 +250,14 @@ def main(args: argparse.Namespace):
np
.
random
.
seed
(
args
.
seed
)
api_url
=
f
"http://
{
args
.
host
}
:
{
args
.
port
}
/generate"
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
tokenizer
.
endswith
(
".json"
)
or
args
.
tokenizer
.
endswith
(
".model"
):
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
tokenizer
=
get_tokenizer
(
args
.
tokenizer
)
else
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
dataset
:
input_requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
)
...
...
@@ -272,7 +277,7 @@ def main(args: argparse.Namespace):
for
i
in
range
(
args
.
num_prompts
):
prompt
=
tokenizer
.
decode
(
[
(
offsets
[
i
]
+
i
+
j
)
%
tokenizer
.
vocab_size
(
offsets
[
i
]
+
i
+
j
)
%
(
tokenizer
.
vocab_size
-
129
)
+
128
for
j
in
range
(
input_lens
[
i
])
]
)
...
...
benchmark/latency_throughput/test_latency.py
View file @
2e6e62e1
...
...
@@ -8,6 +8,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
default
=
"srt"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--max-tokens"
,
type
=
int
,
default
=
256
)
args
=
parser
.
parse_args
()
...
...
@@ -33,7 +34,7 @@ if __name__ == "__main__":
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"text"
:
[
prompt
]
*
args
.
batch_size
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
max_new_tokens
,
...
...
@@ -90,5 +91,5 @@ if __name__ == "__main__":
ret
=
response
.
json
()
print
(
ret
)
speed
=
max_new_tokens
/
latency
speed
=
args
.
batch_size
*
max_new_tokens
/
latency
print
(
f
"latency:
{
latency
:.
2
f
}
s, speed:
{
speed
:.
2
f
}
token/s"
)
benchmark/mmlu/bench_sglang.py
View file @
2e6e62e1
...
...
@@ -48,24 +48,45 @@ def gen_prompt(train_df, subject, k=-1):
return
prompt
def
evaluate
(
args
,
subject
,
dev_df
,
test_df
):
prompts
=
[]
def
main
(
args
):
subjects
=
sorted
(
[
f
.
split
(
"_test.csv"
)[
0
]
for
f
in
os
.
listdir
(
os
.
path
.
join
(
args
.
data_dir
,
"test"
))
if
"_test.csv"
in
f
]
)
# Build prompts
arguments
=
[]
labels
=
[]
num_questions
=
[]
for
subject
in
subjects
[:
args
.
nsub
]:
dev_df
=
pd
.
read_csv
(
os
.
path
.
join
(
args
.
data_dir
,
"dev"
,
subject
+
"_dev.csv"
),
header
=
None
)[:
args
.
ntrain
]
test_df
=
pd
.
read_csv
(
os
.
path
.
join
(
args
.
data_dir
,
"test"
,
subject
+
"_test.csv"
),
header
=
None
)
num_questions
.
append
(
test_df
.
shape
[
0
])
k
=
args
.
ntrain
few_shot_examples
=
gen_prompt
(
dev_df
,
subject
,
k
)
while
len
(
tokenizer
.
encode
(
few_shot_examples
))
>
1536
:
k
-=
1
k
=
args
.
ntrain
few_shot_examples
=
gen_prompt
(
dev_df
,
subject
,
k
)
while
len
(
tokenizer
.
encode
(
few_shot_examples
))
>
1536
:
k
-=
1
few_shot_examples
=
gen_prompt
(
dev_df
,
subject
,
k
)
for
i
in
range
(
test_df
.
shape
[
0
]):
prompt_end
=
format_example
(
test_df
,
i
,
include_answer
=
False
)
prompts
.
append
(
prompt_end
)
for
i
in
range
(
test_df
.
shape
[
0
]):
prompt_end
=
format_example
(
test_df
,
i
,
include_answer
=
False
)
label
=
test_df
.
iloc
[
i
,
test_df
.
shape
[
1
]
-
1
]
labels
.
append
(
label
)
arguments
.
append
({
"examples"
:
few_shot_examples
,
"question"
:
prompt_end
,
})
arguments
=
[{
"question"
:
p
}
for
p
in
prompts
]
label
=
test_df
.
iloc
[
i
,
test_df
.
shape
[
1
]
-
1
]
labels
.
append
(
label
)
#####################################
######### SGL Program Begin #########
...
...
@@ -93,62 +114,33 @@ def evaluate(args, subject, dev_df, test_df):
# Select backend
backend
=
select_sglang_backend
(
args
)
# Run
tic
=
time
.
time
()
states
=
few_shot_mmlu
.
bind
(
examples
=
few_shot_examples
).
run_batch
(
states
=
few_shot_mmlu
.
run_batch
(
arguments
,
temperature
=
0
,
max_new_tokens
=
1
,
backend
=
backend
,
num_threads
=
args
.
parallel
,
progress_bar
=
True
,
)
preds
=
[
s
[
"answer"
].
strip
()[
0
]
if
len
(
s
[
"answer"
].
strip
())
>
0
else
""
for
s
in
states
]
latency
=
time
.
time
()
-
tic
# Compute accuracy
cors
=
[
pred
==
label
for
pred
,
label
in
zip
(
preds
,
labels
)]
acc
=
np
.
mean
(
cors
)
cors
=
np
.
array
(
cors
)
print
(
"Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}"
.
format
(
acc
,
latency
,
len
(
prompts
),
subject
)
)
return
cors
,
acc
,
latency
def
main
(
args
):
subjects
=
sorted
(
[
f
.
split
(
"_test.csv"
)[
0
]
for
f
in
os
.
listdir
(
os
.
path
.
join
(
args
.
data_dir
,
"test"
))
if
"_test.csv"
in
f
]
)
all_cors
=
[]
all_latencies
=
[]
num_requests
=
0
for
subject
in
tqdm
(
subjects
[:
args
.
nsub
]):
dev_df
=
pd
.
read_csv
(
os
.
path
.
join
(
args
.
data_dir
,
"dev"
,
subject
+
"_dev.csv"
),
header
=
None
)[:
args
.
ntrain
]
test_df
=
pd
.
read_csv
(
os
.
path
.
join
(
args
.
data_dir
,
"test"
,
subject
+
"_test.csv"
),
header
=
None
)
cors
,
acc
,
latency
=
evaluate
(
args
,
subject
,
dev_df
,
test_df
)
all_cors
.
append
(
cors
)
all_latencies
.
append
(
latency
)
num_requests
+=
len
(
test_df
)
total_latency
=
np
.
sum
(
all_latencies
)
print
(
"Total latency: {:.3f}"
.
format
(
total_latency
))
pt
=
0
for
subject
,
num_qs
in
zip
(
subjects
[:
args
.
nsub
],
num_questions
):
print
(
f
"subject:
{
subject
}
, #q:
{
num_qs
}
, acc:
{
np
.
mean
(
cors
[
pt
:
pt
+
num_qs
]):.
3
f
}
"
)
pt
+=
num_qs
assert
pt
==
len
(
cors
)
weighted_acc
=
np
.
mean
(
cors
)
weighted_acc
=
np
.
mean
(
np
.
concatenate
(
all_cors
))
# Print results
print
(
"Total latency: {:.3f}"
.
format
(
latency
))
print
(
"Average accuracy: {:.3f}"
.
format
(
weighted_acc
))
# Write results
...
...
@@ -157,9 +149,9 @@ def main(args):
"task"
:
"mmlu"
,
"backend"
:
args
.
backend
,
"num_gpus"
:
1
,
"latency"
:
round
(
total_
latency
,
3
),
"latency"
:
round
(
latency
,
3
),
"accuracy"
:
round
(
weighted_acc
,
3
),
"num_requests"
:
num_reques
ts
,
"num_requests"
:
len
(
argumen
ts
)
,
"other"
:
{
"nsub"
:
args
.
nsub
,
"parallel"
:
args
.
parallel
,
...
...
playground/load_tokenizer.py
View file @
2e6e62e1
import
transforme
rs
import
argpa
rs
e
import
code
#name = "meta-llama/Llama-2-7b-chat-hf"
name
=
"meta-llama/Meta-Llama-3-8B-Instruct"
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
t
=
transformers
.
AutoTokenizer
.
from_pretrained
(
name
)
code
.
interact
(
local
=
locals
())
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--name"
,
type
=
str
,
default
=
"meta-llama/Meta-Llama-3-8B-Instruct"
)
args
=
parser
.
parse_args
()
t
=
get_tokenizer
(
args
.
name
)
code
.
interact
(
local
=
locals
())
\ No newline at end of file
python/sglang/srt/constrained/fsm_cache.py
View file @
2e6e62e1
...
...
@@ -8,7 +8,8 @@ class FSMCache(BaseCache):
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
enable
=
True
):
super
().
__init__
(
enable
=
enable
)
if
tokenizer_path
.
endswith
(
".json"
):
if
tokenizer_path
.
endswith
(
".json"
)
or
tokenizer_path
.
endswith
(
".model"
):
# Do not support TiktokenTokenizer or SentencePieceTokenizer
return
from
importlib.metadata
import
version
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
2e6e62e1
...
...
@@ -88,6 +88,9 @@ def get_tokenizer(
if
tokenizer_name
.
endswith
(
".json"
):
return
TiktokenTokenizer
(
tokenizer_name
)
if
tokenizer_name
.
endswith
(
".model"
):
return
SentencePieceTokenizer
(
tokenizer_name
)
"""Gets a tokenizer for the given model name via Huggingface."""
if
is_multimodal_model
(
tokenizer_name
):
processor
=
get_processor
(
...
...
@@ -179,6 +182,7 @@ def get_processor(
class
TiktokenTokenizer
:
def
__init__
(
self
,
tokenizer_path
):
import
tiktoken
from
jinja2
import
Template
PAT_STR_B
=
r
"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
...
...
@@ -216,6 +220,7 @@ class TiktokenTokenizer:
tokenizer
=
tiktoken
.
Encoding
(
**
kwargs
)
tokenizer
.
_default_allowed_special
=
default_allowed_special
or
set
()
tokenizer
.
_default_allowed_special
|=
{
"<|separator|>"
}
def
encode_patched
(
self
,
...
...
@@ -241,6 +246,9 @@ class TiktokenTokenizer:
self
.
tokenizer
=
tokenizer
self
.
eos_token_id
=
tokenizer
.
_special_tokens
[
"<|eos|>"
]
self
.
vocab_size
=
tokenizer
.
n_vocab
self
.
chat_template
=
Template
(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>
\n\n
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
)
def
encode
(
self
,
x
,
add_special_tokens
=
False
):
return
self
.
tokenizer
.
encode
(
x
)
...
...
@@ -255,7 +263,39 @@ class TiktokenTokenizer:
batch
=
[[
x
]
for
x
in
batch
]
return
self
.
tokenizer
.
decode_batch
(
batch
)
def
convert_ids_to_tokens
(
self
,
index
):
return
self
.
tokenizer
.
decode_single_token_bytes
(
index
).
decode
(
"utf-8"
,
errors
=
"ignore"
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
):
ret
=
self
.
chat_template
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
return
self
.
encode
(
ret
)
if
tokenize
else
ret
class
SentencePieceTokenizer
:
def
__init__
(
self
,
tokenizer_path
):
import
sentencepiece
as
spm
from
jinja2
import
Template
tokenizer
=
spm
.
SentencePieceProcessor
(
model_file
=
tokenizer_path
)
# Convert to HF interface
self
.
tokenizer
=
tokenizer
self
.
eos_token_id
=
tokenizer
.
eos_id
()
self
.
vocab_size
=
tokenizer
.
vocab_size
()
self
.
chat_template
=
Template
(
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>
\n\n
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
)
def
encode
(
self
,
x
,
add_special_tokens
=
False
):
return
self
.
tokenizer
.
encode
(
x
)
def
decode
(
self
,
x
):
return
self
.
tokenizer
.
decode
(
x
)
def
batch_decode
(
self
,
batch
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
):
if
isinstance
(
batch
[
0
],
int
):
batch
=
[[
x
]
for
x
in
batch
]
return
self
.
tokenizer
.
decode
(
batch
)
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
):
ret
=
self
.
chat_template
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
return
self
.
encode
(
ret
)
if
tokenize
else
ret
\ No newline at end of file
python/sglang/srt/layers/fused_moe.py
View file @
2e6e62e1
...
...
@@ -317,19 +317,38 @@ def get_default_config(
topk
:
int
,
dtype
:
Optional
[
str
],
)
->
Dict
[
str
,
int
]:
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
}
if
M
<=
E
:
if
dtype
==
"float8"
:
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
64
,
'GROUP_SIZE_M'
:
1
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
256
,
'BLOCK_SIZE_K'
:
128
,
'GROUP_SIZE_M'
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
}
if
M
<=
E
:
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
128
,
'GROUP_SIZE_M'
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
}
else
:
config
=
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
,
'GROUP_SIZE_M'
:
8
}
if
M
<=
E
:
config
=
{
'BLOCK_SIZE_M'
:
16
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
64
,
'GROUP_SIZE_M'
:
1
}
return
config
...
...
python/sglang/srt/managers/controller/manager_single.py
View file @
2e6e62e1
...
...
@@ -2,7 +2,7 @@
import
asyncio
import
logging
import
time
from
concurrent.futures
import
ThreadPoolExecutor
import
uvloop
import
zmq
...
...
@@ -91,6 +91,7 @@ def start_controller_process(
pipe_writer
.
send
(
"init ok"
)
loop
=
asyncio
.
new_event_loop
()
loop
.
set_default_executor
(
ThreadPoolExecutor
(
max_workers
=
256
))
asyncio
.
set_event_loop
(
loop
)
loop
.
create_task
(
controller
.
loop_for_recv_requests
())
try
:
...
...
@@ -98,4 +99,4 @@ def start_controller_process(
except
Exception
:
logger
.
error
(
"Exception in ControllerSingle:
\n
"
+
get_exception_traceback
())
finally
:
kill_parent_process
()
kill_parent_process
()
\ No newline at end of file
python/sglang/srt/managers/controller/tp_worker.py
View file @
2e6e62e1
...
...
@@ -100,7 +100,7 @@ class ModelTpServer:
self
.
max_prefill_tokens
=
(
max
(
self
.
model_config
.
context_len
,
min
(
self
.
max_total_num_tokens
//
6
,
65536
),
min
(
self
.
max_total_num_tokens
//
6
,
32768
),
)
if
server_args
.
max_prefill_tokens
is
None
else
server_args
.
max_prefill_tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment