Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
62b3812b
Unverified
Commit
62b3812b
authored
Apr 09, 2024
by
Liangsheng Yin
Committed by
GitHub
Apr 09, 2024
Browse files
Time cost utils (#355)
parent
550a4f78
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
64 additions
and
45 deletions
+64
-45
python/sglang/backend/openai.py
python/sglang/backend/openai.py
+1
-2
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+1
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+5
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+14
-7
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+37
-29
test/srt/model/bench_llama_low_api.py
test/srt/model/bench_llama_low_api.py
+6
-6
No files found.
python/sglang/backend/openai.py
View file @
62b3812b
...
@@ -9,9 +9,8 @@ from sglang.lang.interpreter import StreamExecutor
...
@@ -9,9 +9,8 @@ from sglang.lang.interpreter import StreamExecutor
from
sglang.lang.ir
import
SglSamplingParams
from
sglang.lang.ir
import
SglSamplingParams
try
:
try
:
import
tiktoken
import
openai
import
openai
import
tiktoken
except
ImportError
as
e
:
except
ImportError
as
e
:
openai
=
tiktoken
=
e
openai
=
tiktoken
=
e
...
...
python/sglang/srt/constrained/fsm_cache.py
View file @
62b3812b
...
@@ -7,6 +7,7 @@ class FSMCache(BaseCache):
...
@@ -7,6 +7,7 @@ class FSMCache(BaseCache):
super
().
__init__
(
enable
=
enable
)
super
().
__init__
(
enable
=
enable
)
from
importlib.metadata
import
version
from
importlib.metadata
import
version
if
version
(
"outlines"
)
>=
"0.0.35"
:
if
version
(
"outlines"
)
>=
"0.0.35"
:
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
...
python/sglang/srt/server.py
View file @
62b3812b
...
@@ -53,7 +53,7 @@ from sglang.srt.managers.openai_protocol import (
...
@@ -53,7 +53,7 @@ from sglang.srt.managers.openai_protocol import (
from
sglang.srt.managers.router.manager
import
start_router_process
from
sglang.srt.managers.router.manager
import
start_router_process
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
handle_port_init
from
sglang.srt.utils
import
enable_show_time_cost
,
handle_port_init
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.responses
import
JSONResponse
from
starlette.responses
import
JSONResponse
...
@@ -503,6 +503,10 @@ def launch_server(server_args, pipe_finish_writer):
...
@@ -503,6 +503,10 @@ def launch_server(server_args, pipe_finish_writer):
global
tokenizer_manager
global
tokenizer_manager
global
chat_template_name
global
chat_template_name
# start show time thread
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
# disable disk cache if needed
# disable disk cache if needed
if
server_args
.
disable_disk_cache
:
if
server_args
.
disable_disk_cache
:
disable_cache
()
disable_cache
()
...
...
python/sglang/srt/server_args.py
View file @
62b3812b
...
@@ -26,13 +26,14 @@ class ServerArgs:
...
@@ -26,13 +26,14 @@ class ServerArgs:
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
log_stats_interval
:
int
=
10
log_stats_interval
:
int
=
10
log_level
:
str
=
"info"
log_level
:
str
=
"info"
api_key
:
str
=
""
show_time_cost
:
bool
=
False
# optional modes
# optional modes
disable_radix_cache
:
bool
=
False
disable_radix_cache
:
bool
=
False
enable_flashinfer
:
bool
=
False
enable_flashinfer
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
api_key
:
str
=
""
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer_path
is
None
:
if
self
.
tokenizer_path
is
None
:
...
@@ -181,6 +182,18 @@ class ServerArgs:
...
@@ -181,6 +182,18 @@ class ServerArgs:
default
=
ServerArgs
.
log_stats_interval
,
default
=
ServerArgs
.
log_stats_interval
,
help
=
"Log stats interval in second."
,
help
=
"Log stats interval in second."
,
)
)
parser
.
add_argument
(
"--api-key"
,
type
=
str
,
default
=
ServerArgs
.
api_key
,
help
=
"Set API Key"
,
)
parser
.
add_argument
(
"--show-time-cost"
,
action
=
"store_true"
,
help
=
"Show time cost of custom marks"
,
)
# optional modes
# optional modes
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-radix-cache"
,
"--disable-radix-cache"
,
...
@@ -202,12 +215,6 @@ class ServerArgs:
...
@@ -202,12 +215,6 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
)
)
parser
.
add_argument
(
"--api-key"
,
type
=
str
,
default
=
ServerArgs
.
api_key
,
help
=
"Set API Key"
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
python/sglang/srt/utils.py
View file @
62b3812b
...
@@ -11,48 +11,56 @@ from typing import List, Optional
...
@@ -11,48 +11,56 @@ from typing import List, Optional
import
numpy
as
np
import
numpy
as
np
import
requests
import
requests
import
torch
import
torch
import
torch.distributed
as
dist
is_show_cost_time
=
False
show_time_cost
=
False
time_infos
=
{}
def
mark_cost_time
(
func_name
):
def
enable_show_time_cost
():
def
inner_func
(
func
):
global
show_time_cost
def
time_func
(
*
args
,
**
kwargs
):
show_time_cost
=
True
if
dist
.
get_rank
()
in
[
0
,
1
]
and
is_show_cost_time
:
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
ans
=
func
(
*
args
,
**
kwargs
)
torch
.
cuda
.
synchronize
()
print
(
func_name
,
"cost time:"
,
(
time
.
time
()
-
start_time
)
*
1000
)
return
ans
else
:
torch
.
cuda
.
synchronize
()
ans
=
func
(
*
args
,
**
kwargs
)
torch
.
cuda
.
synchronize
()
return
ans
return
time_func
return
inner_func
class
TimeInfo
:
def
__init__
(
self
,
name
,
interval
=
0.1
,
color
=
0
,
indent
=
0
):
self
.
name
=
name
self
.
interval
=
interval
self
.
color
=
color
self
.
indent
=
indent
self
.
acc_time
=
0
self
.
last_acc_time
=
0
def
check
(
self
):
if
self
.
acc_time
-
self
.
last_acc_time
>
self
.
interval
:
self
.
last_acc_time
=
self
.
acc_time
return
True
return
False
time_mark
=
{}
def
pretty_print
(
self
):
print
(
f
"
\x1b
[
{
self
.
color
}
m"
,
end
=
""
)
print
(
"-"
*
self
.
indent
*
2
,
end
=
""
)
print
(
f
"
{
self
.
name
}
:
{
self
.
acc_time
:.
3
f
}
s
\x1b
[0m"
)
def
mark_start
(
key
):
def
mark_start
(
name
,
interval
=
0.1
,
color
=
0
,
indent
=
0
):
global
time_infos
,
show_time_cost
if
not
show_time_cost
:
return
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
global
time_mark
if
time_infos
.
get
(
name
,
None
)
is
None
:
time_
mark
[
key
]
=
t
ime
.
time
(
)
time_
infos
[
name
]
=
T
ime
Info
(
name
,
interval
,
color
,
indent
)
return
time_infos
[
name
].
acc_time
-=
time
.
time
()
def
mark_end
(
key
,
print_min_cost
=
0.0
):
def
mark_end
(
name
):
global
time_infos
,
show_time_cost
if
not
show_time_cost
:
return
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
global
time_mark
time_infos
[
name
].
acc_time
+=
time
.
time
()
cost_time
=
(
time
.
time
()
-
time_mark
[
key
])
*
1000
if
time_infos
[
name
].
check
():
if
cost_time
>
print_min_cost
:
time_infos
[
name
].
pretty_print
()
print
(
f
"cost
{
key
}
:"
,
cost_time
)
def
calculate_time
(
show
=
False
,
min_cost_ms
=
0.0
):
def
calculate_time
(
show
=
False
,
min_cost_ms
=
0.0
):
...
...
test/srt/model/bench_llama_low_api.py
View file @
62b3812b
...
@@ -66,9 +66,9 @@ class BenchBatch:
...
@@ -66,9 +66,9 @@ class BenchBatch:
p_idx
=
prefix_req_idx
[
i
//
fork_num
].
item
()
p_idx
=
prefix_req_idx
[
i
//
fork_num
].
item
()
n_idx
=
self
.
req_pool_indices
[
i
].
item
()
n_idx
=
self
.
req_pool_indices
[
i
].
item
()
req_to_token
[
n_idx
,
:
prefix_len
]
=
req_to_token
[
p_idx
,
:
prefix_len
]
req_to_token
[
n_idx
,
:
prefix_len
]
=
req_to_token
[
p_idx
,
:
prefix_len
]
req_to_token
[
req_to_token
[
n_idx
,
prefix_len
:
prefix_len
+
extend_len
]
=
(
n_idx
,
prefix_len
:
prefix_len
+
extend_len
self
.
out_cache_loc
[
i
*
extend_len
:
(
i
+
1
)
*
extend_len
]
]
=
self
.
out_cache_loc
[
i
*
extend_len
:
(
i
+
1
)
*
extend_len
]
)
def
update_decode
(
self
,
predict_ids
,
batch_size
):
def
update_decode
(
self
,
predict_ids
,
batch_size
):
assert
predict_ids
.
shape
[
0
]
==
batch_size
assert
predict_ids
.
shape
[
0
]
==
batch_size
...
@@ -81,9 +81,9 @@ class BenchBatch:
...
@@ -81,9 +81,9 @@ class BenchBatch:
self
.
out_cache_cont_start
,
self
.
out_cache_cont_start
,
self
.
out_cache_cont_end
,
self
.
out_cache_cont_end
,
)
=
self
.
token_to_kv_pool
.
alloc_contiguous
(
batch_size
)
)
=
self
.
token_to_kv_pool
.
alloc_contiguous
(
batch_size
)
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
]
=
(
self
.
req_pool_indices
,
self
.
seq_lens
self
.
out_cache_loc
]
=
self
.
out_cache_loc
)
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens
.
add_
(
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment