Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
62b3812b
"vscode:/vscode.git/clone" did not exist on "fac152d22b3e39cd8e9b8c94cf7eb55f175ddbd5"
Unverified
Commit
62b3812b
authored
Apr 09, 2024
by
Liangsheng Yin
Committed by
GitHub
Apr 09, 2024
Browse files
Time cost utils (#355)
parent
550a4f78
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
64 additions
and
45 deletions
+64
-45
python/sglang/backend/openai.py
python/sglang/backend/openai.py
+1
-2
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+1
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+5
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+14
-7
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+37
-29
test/srt/model/bench_llama_low_api.py
test/srt/model/bench_llama_low_api.py
+6
-6
No files found.
python/sglang/backend/openai.py
View file @
62b3812b
...
@@ -9,9 +9,8 @@ from sglang.lang.interpreter import StreamExecutor
...
@@ -9,9 +9,8 @@ from sglang.lang.interpreter import StreamExecutor
from
sglang.lang.ir
import
SglSamplingParams
from
sglang.lang.ir
import
SglSamplingParams
try
:
try
:
import
tiktoken
import
openai
import
openai
import
tiktoken
except
ImportError
as
e
:
except
ImportError
as
e
:
openai
=
tiktoken
=
e
openai
=
tiktoken
=
e
...
...
python/sglang/srt/constrained/fsm_cache.py
View file @
62b3812b
...
@@ -7,6 +7,7 @@ class FSMCache(BaseCache):
...
@@ -7,6 +7,7 @@ class FSMCache(BaseCache):
super
().
__init__
(
enable
=
enable
)
super
().
__init__
(
enable
=
enable
)
from
importlib.metadata
import
version
from
importlib.metadata
import
version
if
version
(
"outlines"
)
>=
"0.0.35"
:
if
version
(
"outlines"
)
>=
"0.0.35"
:
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
...
python/sglang/srt/server.py
View file @
62b3812b
...
@@ -53,7 +53,7 @@ from sglang.srt.managers.openai_protocol import (
...
@@ -53,7 +53,7 @@ from sglang.srt.managers.openai_protocol import (
from
sglang.srt.managers.router.manager
import
start_router_process
from
sglang.srt.managers.router.manager
import
start_router_process
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.managers.tokenizer_manager
import
TokenizerManager
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
handle_port_init
from
sglang.srt.utils
import
enable_show_time_cost
,
handle_port_init
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.middleware.base
import
BaseHTTPMiddleware
from
starlette.responses
import
JSONResponse
from
starlette.responses
import
JSONResponse
...
@@ -503,6 +503,10 @@ def launch_server(server_args, pipe_finish_writer):
...
@@ -503,6 +503,10 @@ def launch_server(server_args, pipe_finish_writer):
global
tokenizer_manager
global
tokenizer_manager
global
chat_template_name
global
chat_template_name
# start show time thread
if
server_args
.
show_time_cost
:
enable_show_time_cost
()
# disable disk cache if needed
# disable disk cache if needed
if
server_args
.
disable_disk_cache
:
if
server_args
.
disable_disk_cache
:
disable_cache
()
disable_cache
()
...
...
python/sglang/srt/server_args.py
View file @
62b3812b
...
@@ -26,13 +26,14 @@ class ServerArgs:
...
@@ -26,13 +26,14 @@ class ServerArgs:
disable_log_stats
:
bool
=
False
disable_log_stats
:
bool
=
False
log_stats_interval
:
int
=
10
log_stats_interval
:
int
=
10
log_level
:
str
=
"info"
log_level
:
str
=
"info"
api_key
:
str
=
""
show_time_cost
:
bool
=
False
# optional modes
# optional modes
disable_radix_cache
:
bool
=
False
disable_radix_cache
:
bool
=
False
enable_flashinfer
:
bool
=
False
enable_flashinfer
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
api_key
:
str
=
""
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer_path
is
None
:
if
self
.
tokenizer_path
is
None
:
...
@@ -181,6 +182,18 @@ class ServerArgs:
...
@@ -181,6 +182,18 @@ class ServerArgs:
default
=
ServerArgs
.
log_stats_interval
,
default
=
ServerArgs
.
log_stats_interval
,
help
=
"Log stats interval in second."
,
help
=
"Log stats interval in second."
,
)
)
parser
.
add_argument
(
"--api-key"
,
type
=
str
,
default
=
ServerArgs
.
api_key
,
help
=
"Set API Key"
,
)
parser
.
add_argument
(
"--show-time-cost"
,
action
=
"store_true"
,
help
=
"Show time cost of custom marks"
,
)
# optional modes
# optional modes
parser
.
add_argument
(
parser
.
add_argument
(
"--disable-radix-cache"
,
"--disable-radix-cache"
,
...
@@ -202,12 +215,6 @@ class ServerArgs:
...
@@ -202,12 +215,6 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
)
)
parser
.
add_argument
(
"--api-key"
,
type
=
str
,
default
=
ServerArgs
.
api_key
,
help
=
"Set API Key"
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
python/sglang/srt/utils.py
View file @
62b3812b
...
@@ -11,48 +11,56 @@ from typing import List, Optional
...
@@ -11,48 +11,56 @@ from typing import List, Optional
import
numpy
as
np
import
numpy
as
np
import
requests
import
requests
import
torch
import
torch
import
torch.distributed
as
dist
is_show_cost_time
=
False
show_time_cost
=
False
time_infos
=
{}
def
mark_cost_time
(
func_name
):
def
enable_show_time_cost
():
def
inner_func
(
func
):
global
show_time_cost
def
time_func
(
*
args
,
**
kwargs
):
show_time_cost
=
True
if
dist
.
get_rank
()
in
[
0
,
1
]
and
is_show_cost_time
:
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
ans
=
func
(
*
args
,
**
kwargs
)
torch
.
cuda
.
synchronize
()
print
(
func_name
,
"cost time:"
,
(
time
.
time
()
-
start_time
)
*
1000
)
return
ans
else
:
torch
.
cuda
.
synchronize
()
ans
=
func
(
*
args
,
**
kwargs
)
torch
.
cuda
.
synchronize
()
return
ans
return
time_func
return
inner_func
class
TimeInfo
:
def
__init__
(
self
,
name
,
interval
=
0.1
,
color
=
0
,
indent
=
0
):
self
.
name
=
name
self
.
interval
=
interval
self
.
color
=
color
self
.
indent
=
indent
self
.
acc_time
=
0
self
.
last_acc_time
=
0
time_mark
=
{}
def
check
(
self
):
if
self
.
acc_time
-
self
.
last_acc_time
>
self
.
interval
:
self
.
last_acc_time
=
self
.
acc_time
return
True
return
False
def
pretty_print
(
self
):
print
(
f
"
\x1b
[
{
self
.
color
}
m"
,
end
=
""
)
print
(
"-"
*
self
.
indent
*
2
,
end
=
""
)
print
(
f
"
{
self
.
name
}
:
{
self
.
acc_time
:.
3
f
}
s
\x1b
[0m"
)
def
mark_start
(
key
):
torch
.
cuda
.
synchronize
()
def
mark_start
(
name
,
interval
=
0.1
,
color
=
0
,
indent
=
0
):
global
time_
mark
global
time_
infos
,
show_time_cost
time_mark
[
key
]
=
time
.
time
()
if
not
show_time_cost
:
return
return
torch
.
cuda
.
synchronize
()
if
time_infos
.
get
(
name
,
None
)
is
None
:
time_infos
[
name
]
=
TimeInfo
(
name
,
interval
,
color
,
indent
)
time_infos
[
name
].
acc_time
-=
time
.
time
()
def
mark_end
(
key
,
print_min_cost
=
0.0
):
def
mark_end
(
name
):
global
time_infos
,
show_time_cost
if
not
show_time_cost
:
return
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
global
time_mark
time_infos
[
name
].
acc_time
+=
time
.
time
()
cost_time
=
(
time
.
time
()
-
time_mark
[
key
])
*
1000
if
time_infos
[
name
].
check
():
if
cost_time
>
print_min_cost
:
time_infos
[
name
].
pretty_print
()
print
(
f
"cost
{
key
}
:"
,
cost_time
)
def
calculate_time
(
show
=
False
,
min_cost_ms
=
0.0
):
def
calculate_time
(
show
=
False
,
min_cost_ms
=
0.0
):
...
...
test/srt/model/bench_llama_low_api.py
View file @
62b3812b
...
@@ -66,9 +66,9 @@ class BenchBatch:
...
@@ -66,9 +66,9 @@ class BenchBatch:
p_idx
=
prefix_req_idx
[
i
//
fork_num
].
item
()
p_idx
=
prefix_req_idx
[
i
//
fork_num
].
item
()
n_idx
=
self
.
req_pool_indices
[
i
].
item
()
n_idx
=
self
.
req_pool_indices
[
i
].
item
()
req_to_token
[
n_idx
,
:
prefix_len
]
=
req_to_token
[
p_idx
,
:
prefix_len
]
req_to_token
[
n_idx
,
:
prefix_len
]
=
req_to_token
[
p_idx
,
:
prefix_len
]
req_to_token
[
req_to_token
[
n_idx
,
prefix_len
:
prefix_len
+
extend_len
]
=
(
n_idx
,
prefix_len
:
prefix_len
+
extend_len
self
.
out_cache_loc
[
i
*
extend_len
:
(
i
+
1
)
*
extend_len
]
]
=
self
.
out_cache_loc
[
i
*
extend_len
:
(
i
+
1
)
*
extend_len
]
)
def
update_decode
(
self
,
predict_ids
,
batch_size
):
def
update_decode
(
self
,
predict_ids
,
batch_size
):
assert
predict_ids
.
shape
[
0
]
==
batch_size
assert
predict_ids
.
shape
[
0
]
==
batch_size
...
@@ -81,9 +81,9 @@ class BenchBatch:
...
@@ -81,9 +81,9 @@ class BenchBatch:
self
.
out_cache_cont_start
,
self
.
out_cache_cont_start
,
self
.
out_cache_cont_end
,
self
.
out_cache_cont_end
,
)
=
self
.
token_to_kv_pool
.
alloc_contiguous
(
batch_size
)
)
=
self
.
token_to_kv_pool
.
alloc_contiguous
(
batch_size
)
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_to_token_pool
.
req_to_token
[
self
.
req_pool_indices
,
self
.
seq_lens
]
=
(
self
.
req_pool_indices
,
self
.
seq_lens
self
.
out_cache_loc
]
=
self
.
out_cache_loc
)
self
.
seq_lens
.
add_
(
1
)
self
.
seq_lens
.
add_
(
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment