Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1b355479
Unverified
Commit
1b355479
authored
Mar 11, 2024
by
Liangsheng Yin
Committed by
GitHub
Mar 11, 2024
Browse files
Organize `server_args` (#277)
parent
faba293a
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
92 additions
and
34 deletions
+92
-34
docs/flashinfer.md
docs/flashinfer.md
+2
-2
python/sglang/api.py
python/sglang/api.py
+11
-0
python/sglang/backend/base_backend.py
python/sglang/backend/base_backend.py
+6
-0
python/sglang/backend/runtime_endpoint.py
python/sglang/backend/runtime_endpoint.py
+16
-0
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+2
-4
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+2
-2
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+10
-4
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+5
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+2
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+6
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+24
-9
test/srt/model/bench_llama_low_api.py
test/srt/model/bench_llama_low_api.py
+6
-6
No files found.
docs/flashinfer.md
View file @
1b355479
...
...
@@ -16,10 +16,10 @@ please build it from source (the compilation takes a long time).
### Run a Server With Flashinfer Mode
Add
`--
model-mode
flashinfer`
argument to enable flashinfer when launching a server.
Add
`--
enable-
flashinfer`
argument to enable flashinfer when launching a server.
Example:
```
bash
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-2-7b-chat-hf
--port
30000
--
model-mode
flashinfer
python
-m
sglang.launch_server
--model-path
meta-llama/Llama-2-7b-chat-hf
--port
30000
--
enable-
flashinfer
```
python/sglang/api.py
View file @
1b355479
...
...
@@ -43,6 +43,17 @@ def Runtime(*args, **kwargs):
def
set_default_backend
(
backend
:
BaseBackend
):
global_config
.
default_backend
=
backend
def
flush_cache
(
backend
:
BaseBackend
=
None
):
backend
=
backend
or
global_config
.
default_backend
if
backend
is
None
:
return
False
return
backend
.
flush_cache
()
def
get_server_args
(
backend
:
BaseBackend
=
None
):
backend
=
backend
or
global_config
.
default_backend
if
backend
is
None
:
return
None
return
backend
.
get_server_args
()
def
gen
(
name
:
Optional
[
str
]
=
None
,
...
...
python/sglang/backend/base_backend.py
View file @
1b355479
...
...
@@ -72,3 +72,9 @@ class BaseBackend:
def
shutdown
(
self
):
pass
def
flush_cache
(
self
):
pass
def
get_server_args
(
self
):
pass
python/sglang/backend/runtime_endpoint.py
View file @
1b355479
...
...
@@ -35,6 +35,22 @@ class RuntimeEndpoint(BaseBackend):
def
get_model_name
(
self
):
return
self
.
model_info
[
"model_path"
]
def
flush_cache
(
self
):
res
=
http_request
(
self
.
base_url
+
"/flush_cache"
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
)
return
res
.
status_code
==
200
def
get_server_args
(
self
):
res
=
http_request
(
self
.
base_url
+
"/get_server_args"
,
auth_token
=
self
.
auth_token
,
verify
=
self
.
verify
,
)
return
res
.
json
()
def
get_chat_template
(
self
):
return
self
.
chat_template
...
...
python/sglang/srt/layers/radix_attention.py
View file @
1b355479
...
...
@@ -15,11 +15,9 @@ class RadixAttention(nn.Module):
self
.
head_dim
=
head_dim
self
.
layer_id
=
layer_id
from
sglang.srt.managers.router.model_runner
import
global_server_args
from
sglang.srt.managers.router.model_runner
import
global_server_args
_dict
self
.
use_flashinfer
=
"flashinfer"
in
global_server_args
.
model_mode
if
self
.
use_flashinfer
:
if
global_server_args_dict
[
"enable_flashinfer"
]:
self
.
prefill_forward
=
self
.
prefill_forward_flashinfer
self
.
extend_forward
=
self
.
prefill_forward_flashinfer
self
.
decode_forward
=
self
.
decode_forward_flashinfer
...
...
python/sglang/srt/layers/token_attention.py
View file @
1b355479
...
...
@@ -4,10 +4,10 @@
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.managers.router.model_runner
import
global_server_args
from
sglang.srt.managers.router.model_runner
import
global_server_args
_dict
from
sglang.srt.utils
import
wrap_kernel_launcher
if
global_server_args
.
attention_reduce_in_fp32
:
if
global_server_args
_dict
[
"
attention_reduce_in_fp32
"
]
:
REDUCE_TRITON_TYPE
=
tl
.
float32
REDUCE_TORCH_TYPE
=
torch
.
float32
else
:
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
1b355479
...
...
@@ -46,7 +46,6 @@ class ModelRpcServer(rpyc.Service):
server_args
,
port_args
=
[
obtain
(
x
)
for
x
in
[
server_args
,
port_args
]]
# Copy arguments
self
.
model_mode
=
server_args
.
model_mode
self
.
tp_rank
=
tp_rank
self
.
tp_size
=
server_args
.
tp_size
self
.
schedule_heuristic
=
server_args
.
schedule_heuristic
...
...
@@ -61,15 +60,22 @@ class ModelRpcServer(rpyc.Service):
server_args
.
trust_remote_code
,
context_length
=
server_args
.
context_length
,
)
# for model end global settings
server_args_dict
=
{
"enable_flashinfer"
:
server_args
.
enable_flashinfer
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
}
self
.
model_runner
=
ModelRunner
(
model_config
=
self
.
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
tp_rank
=
tp_rank
,
tp_size
=
server_args
.
tp_size
,
nccl_port
=
port_args
.
nccl_port
,
server_args
=
server_args
,
load_format
=
server_args
.
load_format
,
trust_remote_code
=
server_args
.
trust_remote_code
,
server_args_dict
=
server_args_dict
,
)
if
is_multimodal_model
(
server_args
.
model_path
):
self
.
processor
=
get_processor
(
...
...
@@ -104,11 +110,11 @@ class ModelRpcServer(rpyc.Service):
f
"max_total_num_token=
{
self
.
max_total_num_token
}
, "
f
"max_prefill_num_token=
{
self
.
max_prefill_num_token
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
, "
f
"model_mode=
{
self
.
model_mode
}
"
)
logger
.
info
(
server_args
.
get_optional_modes_logging
())
# Init cache
self
.
tree_cache
=
RadixCache
(
disable
=
"no-cache"
in
self
.
model_mod
e
)
self
.
tree_cache
=
RadixCache
(
server_args
.
disable_radix_cach
e
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
scheduler
=
Scheduler
(
self
.
schedule_heuristic
,
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
1b355479
...
...
@@ -23,7 +23,7 @@ logger = logging.getLogger("model_runner")
# for server args in model endpoints
global_server_args
=
None
global_server_args
_dict
:
dict
=
None
@
lru_cache
()
...
...
@@ -222,7 +222,7 @@ class InputMetadata:
if
forward_mode
==
ForwardMode
.
EXTEND
:
ret
.
init_extend_args
()
if
"flashinfer"
in
global_server_args
.
model_mode
:
if
global_server_args
_dict
[
"enable_flashinfer"
]
:
ret
.
init_flashinfer_args
(
tp_size
)
return
ret
...
...
@@ -236,9 +236,9 @@ class ModelRunner:
tp_rank
,
tp_size
,
nccl_port
,
server_args
,
load_format
=
"auto"
,
trust_remote_code
=
True
,
server_args_dict
:
dict
=
{},
):
self
.
model_config
=
model_config
self
.
mem_fraction_static
=
mem_fraction_static
...
...
@@ -248,8 +248,8 @@ class ModelRunner:
self
.
load_format
=
load_format
self
.
trust_remote_code
=
trust_remote_code
global
global_server_args
global_server_args
=
server_args
global
global_server_args
_dict
global_server_args
_dict
=
server_args
_dict
# Init torch distributed
torch
.
cuda
.
set_device
(
self
.
tp_rank
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
1b355479
...
...
@@ -82,6 +82,8 @@ class TokenizerManager:
server_args
:
ServerArgs
,
port_args
:
PortArgs
,
):
self
.
server_args
=
server_args
context
=
zmq
.
asyncio
.
Context
(
2
)
self
.
recv_from_detokenizer
=
context
.
socket
(
zmq
.
PULL
)
self
.
recv_from_detokenizer
.
bind
(
f
"tcp://127.0.0.1:
{
port_args
.
tokenizer_port
}
"
)
...
...
python/sglang/srt/server.py
View file @
1b355479
"""SRT: SGLang Runtime"""
import
asyncio
import
dataclasses
import
json
import
multiprocessing
as
mp
import
os
...
...
@@ -86,6 +87,11 @@ async def get_model_info():
return
result
@
app
.
get
(
"/get_server_args"
)
async
def
get_server_args
():
return
dataclasses
.
asdict
(
tokenizer_manager
.
server_args
)
@
app
.
get
(
"/flush_cache"
)
async
def
flush_cache
():
await
tokenizer_manager
.
flush_cache
()
...
...
@@ -548,7 +554,6 @@ class Runtime:
max_prefill_num_token
:
int
=
ServerArgs
.
max_prefill_num_token
,
context_length
:
int
=
ServerArgs
.
context_length
,
tp_size
:
int
=
1
,
model_mode
:
List
[
str
]
=
(),
schedule_heuristic
:
str
=
"lpm"
,
attention_reduce_in_fp32
:
bool
=
False
,
random_seed
:
int
=
42
,
...
...
@@ -571,7 +576,6 @@ class Runtime:
max_prefill_num_token
=
max_prefill_num_token
,
context_length
=
context_length
,
tp_size
=
tp_size
,
model_mode
=
model_mode
,
schedule_heuristic
=
schedule_heuristic
,
attention_reduce_in_fp32
=
attention_reduce_in_fp32
,
random_seed
=
random_seed
,
...
...
python/sglang/srt/server_args.py
View file @
1b355479
...
...
@@ -18,7 +18,6 @@ class ServerArgs:
max_prefill_num_token
:
Optional
[
int
]
=
None
context_length
:
Optional
[
int
]
=
None
tp_size
:
int
=
1
model_mode
:
List
[
str
]
=
()
schedule_heuristic
:
str
=
"lpm"
schedule_conservativeness
:
float
=
1.0
attention_reduce_in_fp32
:
bool
=
False
...
...
@@ -27,6 +26,10 @@ class ServerArgs:
disable_log_stats
:
bool
=
False
log_stats_interval
:
int
=
10
log_level
:
str
=
"info"
# optional modes
disable_radix_cache
:
bool
=
False
enable_flashinfer
:
bool
=
False
disable_regex_jump_forward
:
bool
=
False
disable_disk_cache
:
bool
=
False
...
...
@@ -131,14 +134,6 @@ class ServerArgs:
default
=
ServerArgs
.
tp_size
,
help
=
"Tensor parallelism degree."
,
)
parser
.
add_argument
(
"--model-mode"
,
type
=
str
,
default
=
[],
nargs
=
"+"
,
choices
=
[
"flashinfer"
,
"no-cache"
],
help
=
"Model mode: [flashinfer, no-cache]"
,
)
parser
.
add_argument
(
"--schedule-heuristic"
,
type
=
str
,
...
...
@@ -185,6 +180,17 @@ class ServerArgs:
default
=
ServerArgs
.
log_stats_interval
,
help
=
"Log stats interval in second."
,
)
# optional modes
parser
.
add_argument
(
"--disable-radix-cache"
,
action
=
"store_true"
,
help
=
"Disable RadixAttention"
,
)
parser
.
add_argument
(
"--enable-flashinfer"
,
action
=
"store_true"
,
help
=
"Enable flashinfer inference kernels"
,
)
parser
.
add_argument
(
"--disable-regex-jump-forward"
,
action
=
"store_true"
,
...
...
@@ -204,6 +210,15 @@ class ServerArgs:
def
url
(
self
):
return
f
"http://
{
self
.
host
}
:
{
self
.
port
}
"
def
get_optional_modes_logging
(
self
):
return
(
f
"disable_radix_cache=
{
self
.
disable_radix_cache
}
, "
f
"enable_flashinfer=
{
self
.
enable_flashinfer
}
, "
f
"disable_regex_jump_forward=
{
self
.
disable_regex_jump_forward
}
, "
f
"disable_disk_cache=
{
self
.
disable_disk_cache
}
, "
f
"attention_reduce_in_fp32=
{
self
.
attention_reduce_in_fp32
}
"
)
@
dataclasses
.
dataclass
class
PortArgs
:
...
...
test/srt/model/bench_llama_low_api.py
View file @
1b355479
...
...
@@ -151,7 +151,7 @@ def bench_generate_worker(
shared_len
,
unique_len
,
decode_len
,
model_mode
,
server_args_dict
,
):
assert
unique_num
%
shared_num
==
0
...
...
@@ -162,7 +162,7 @@ def bench_generate_worker(
tp_rank
=
tp_rank
,
tp_size
=
tp_size
,
nccl_port
=
28888
,
model_mode
=
model_mode
,
server_args_dict
=
server_args_dict
,
)
batch
=
BenchBatch
(
model_runner
)
...
...
@@ -227,7 +227,7 @@ def bench_generate(
shared_len
,
unique_len
,
decode_len
,
model_mode
,
server_args_dict
,
):
print
(
f
"tp_size:
{
tp_size
}
, "
...
...
@@ -236,7 +236,7 @@ def bench_generate(
f
"shared_len:
{
shared_len
}
, "
f
"unique_len:
{
unique_len
}
, "
f
"decode_len:
{
decode_len
}
, "
f
"
model_mode:
{
model_mode
}
"
f
"
server_args:
{
server_args_dict
}
"
)
workers
=
[]
for
tp_rank
in
range
(
tp_size
):
...
...
@@ -251,7 +251,7 @@ def bench_generate(
shared_len
,
unique_len
,
decode_len
,
model_mode
,
server_args_dict
,
),
)
proc
.
start
()
...
...
@@ -270,5 +270,5 @@ if __name__ == "__main__":
shared_len
=
256
,
unique_len
=
256
,
decode_len
=
8
,
model_mode
=
[]
,
server_args_dict
=
{}
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment