Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2cea6146
Unverified
Commit
2cea6146
authored
May 24, 2024
by
Lianmin Zheng
Committed by
GitHub
May 24, 2024
Browse files
Improve logging & add logit cap (#471)
parent
44c998fc
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
106 additions
and
24 deletions
+106
-24
benchmark/latency_throughput/test_latency.py
benchmark/latency_throughput/test_latency.py
+1
-1
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+3
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+24
-0
python/sglang/srt/layers/extend_attention.py
python/sglang/srt/layers/extend_attention.py
+17
-0
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+8
-1
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+15
-0
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+4
-1
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+3
-4
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+15
-13
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-1
python/sglang/utils.py
python/sglang/utils.py
+14
-1
No files found.
benchmark/latency_throughput/test_latency.py
View file @
2cea6146
...
@@ -30,7 +30,7 @@ if __name__ == "__main__":
...
@@ -30,7 +30,7 @@ if __name__ == "__main__":
response
=
requests
.
post
(
response
=
requests
.
post
(
url
+
"/generate"
,
url
+
"/generate"
,
json
=
{
json
=
{
"text"
:
f
"
{
a
}
,
"
,
"text"
:
f
"
The capital of France is
"
,
# "input_ids": [[2] * 256] * 196,
# "input_ids": [[2] * 256] * 196,
"sampling_params"
:
{
"sampling_params"
:
{
"temperature"
:
0
,
"temperature"
:
0
,
...
...
python/sglang/srt/constrained/fsm_cache.py
View file @
2cea6146
...
@@ -6,6 +6,9 @@ class FSMCache(BaseCache):
...
@@ -6,6 +6,9 @@ class FSMCache(BaseCache):
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
enable
=
True
):
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
enable
=
True
):
super
().
__init__
(
enable
=
enable
)
super
().
__init__
(
enable
=
enable
)
if
tokenizer_path
.
endswith
(
".json"
):
return
from
importlib.metadata
import
version
from
importlib.metadata
import
version
if
version
(
"outlines"
)
>=
"0.0.35"
:
if
version
(
"outlines"
)
>=
"0.0.35"
:
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
2cea6146
...
@@ -84,6 +84,9 @@ def get_tokenizer(
...
@@ -84,6 +84,9 @@ def get_tokenizer(
tokenizer_revision
:
Optional
[
str
]
=
None
,
tokenizer_revision
:
Optional
[
str
]
=
None
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
if
tokenizer_name
.
endswith
(
".json"
):
return
TiktokenTokenizer
(
tokenizer_name
)
"""Gets a tokenizer for the given model name via Huggingface."""
"""Gets a tokenizer for the given model name via Huggingface."""
if
is_multimodal_model
(
tokenizer_name
):
if
is_multimodal_model
(
tokenizer_name
):
processor
=
get_processor
(
processor
=
get_processor
(
...
@@ -170,3 +173,24 @@ def get_processor(
...
@@ -170,3 +173,24 @@ def get_processor(
**
kwargs
,
**
kwargs
,
)
)
return
processor
return
processor
class
TiktokenTokenizer
:
def
__init__
(
self
,
tokenizer_path
):
import
xlm.tokenizers.tiktoken_wrapper
as
tiktoken_wrapper
tokenizer
=
tiktoken_wrapper
.
Encoding
.
from_xtok_json
(
"xtok-json"
,
tokenizer_path
)
self
.
tokenizer
=
tokenizer
self
.
eos_token_id
=
tokenizer
.
eos_token
self
.
vocab_size
=
tokenizer
.
n_vocab
def
encode
(
self
,
x
):
return
self
.
tokenizer
.
encode
(
x
)
def
decode
(
self
,
x
):
return
self
.
tokenizer
.
decode
(
x
)
def
batch_decode
(
self
,
batch
,
skip_special_tokens
,
spaces_between_special_tokens
):
return
self
.
tokenizer
.
decode_batch
(
batch
)
def
convert_ids_to_tokens
(
self
,
index
):
return
self
.
tokenizer
.
decode_single_token_bytes
(
index
).
decode
(
"utf-8"
,
errors
=
"ignore"
)
\ No newline at end of file
python/sglang/srt/layers/extend_attention.py
View file @
2cea6146
...
@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
...
@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
@
triton
.
jit
def
_fwd_kernel
(
def
_fwd_kernel
(
Q_Extend
,
Q_Extend
,
...
@@ -39,6 +45,7 @@ def _fwd_kernel(
...
@@ -39,6 +45,7 @@ def _fwd_kernel(
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
):
):
cur_seq
=
tl
.
program_id
(
0
)
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -90,6 +97,10 @@ def _fwd_kernel(
...
@@ -90,6 +97,10 @@ def _fwd_kernel(
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
...
@@ -126,6 +137,10 @@ def _fwd_kernel(
...
@@ -126,6 +137,10 @@ def _fwd_kernel(
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
start_n
+
offs_n
[
None
,
:]
)
)
...
@@ -176,6 +191,7 @@ def extend_attention_fwd(
...
@@ -176,6 +191,7 @@ def extend_attention_fwd(
b_seq_len_extend
,
b_seq_len_extend
,
max_len_in_batch
,
max_len_in_batch
,
max_len_extend
,
max_len_extend
,
logit_cap
=-
1
,
):
):
"""
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
@@ -271,6 +287,7 @@ def extend_attention_fwd(
...
@@ -271,6 +287,7 @@ def extend_attention_fwd(
BLOCK_N
=
BLOCK_N
,
BLOCK_N
=
BLOCK_N
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
logit_cap
=
logit_cap
,
)
)
cached_kernel
=
wrap_kernel_launcher
(
_fwd_kernel
)
cached_kernel
=
wrap_kernel_launcher
(
_fwd_kernel
)
...
...
python/sglang/srt/layers/radix_attention.py
View file @
2cea6146
import
torch
import
torch
import
numpy
as
np
from
torch
import
nn
from
torch
import
nn
from
sglang.srt.layers.context_flashattention_nopad
import
context_attention_fwd
from
sglang.srt.layers.context_flashattention_nopad
import
context_attention_fwd
...
@@ -8,13 +9,16 @@ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
...
@@ -8,13 +9,16 @@ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
class
RadixAttention
(
nn
.
Module
):
class
RadixAttention
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
head_dim
,
scaling
,
num_kv_heads
,
layer_id
):
def
__init__
(
self
,
num_heads
,
head_dim
,
scaling
,
num_kv_heads
,
layer_id
,
logit_cap
=-
1
):
super
().
__init__
()
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
self
.
tp_q_head_num
=
num_heads
self
.
tp_k_head_num
=
num_kv_heads
self
.
tp_k_head_num
=
num_kv_heads
self
.
tp_v_head_num
=
num_kv_heads
self
.
tp_v_head_num
=
num_kv_heads
self
.
head_dim
=
head_dim
self
.
head_dim
=
head_dim
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
logit_cap
=
logit_cap
assert
np
.
allclose
(
scaling
,
1.0
/
(
head_dim
**
0.5
))
from
sglang.srt.managers.router.model_runner
import
global_server_args_dict
from
sglang.srt.managers.router.model_runner
import
global_server_args_dict
...
@@ -38,6 +42,7 @@ class RadixAttention(nn.Module):
...
@@ -38,6 +42,7 @@ class RadixAttention(nn.Module):
input_metadata
.
start_loc
,
input_metadata
.
start_loc
,
input_metadata
.
seq_lens
,
input_metadata
.
seq_lens
,
input_metadata
.
max_seq_len
,
input_metadata
.
max_seq_len
,
self
.
logit_cap
,
)
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
...
@@ -62,6 +67,7 @@ class RadixAttention(nn.Module):
...
@@ -62,6 +67,7 @@ class RadixAttention(nn.Module):
input_metadata
.
extend_seq_lens
,
input_metadata
.
extend_seq_lens
,
input_metadata
.
max_seq_len
,
input_metadata
.
max_seq_len
,
input_metadata
.
max_extend_len
,
input_metadata
.
max_extend_len
,
self
.
logit_cap
,
)
)
return
o
return
o
...
@@ -82,6 +88,7 @@ class RadixAttention(nn.Module):
...
@@ -82,6 +88,7 @@ class RadixAttention(nn.Module):
input_metadata
.
max_seq_len
,
input_metadata
.
max_seq_len
,
input_metadata
.
other_kv_index
,
input_metadata
.
other_kv_index
,
input_metadata
.
total_num_tokens
,
input_metadata
.
total_num_tokens
,
self
.
logit_cap
,
)
)
return
o
return
o
...
...
python/sglang/srt/layers/token_attention.py
View file @
2cea6146
...
@@ -16,6 +16,12 @@ else:
...
@@ -16,6 +16,12 @@ else:
REDUCE_TORCH_TYPE
=
torch
.
float16
REDUCE_TORCH_TYPE
=
torch
.
float16
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
@
triton
.
jit
def
_fwd_kernel_stage1
(
def
_fwd_kernel_stage1
(
Q
,
Q
,
...
@@ -35,6 +41,7 @@ def _fwd_kernel_stage1(
...
@@ -35,6 +41,7 @@ def _fwd_kernel_stage1(
kv_group_num
:
tl
.
constexpr
,
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -77,6 +84,10 @@ def _fwd_kernel_stage1(
...
@@ -77,6 +84,10 @@ def _fwd_kernel_stage1(
).
to
(
REDUCE_TRITON_TYPE
)
).
to
(
REDUCE_TRITON_TYPE
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
*=
sm_scale
att_value
*=
sm_scale
if
logit_cap
>
0
:
att_value
=
logit_cap
*
tanh
(
att_value
/
logit_cap
)
off_o
=
cur_head
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
)
off_o
=
cur_head
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
)
tl
.
store
(
Att_Out
+
off_o
,
att_value
,
mask
=
offs_n_new
<
cur_batch_end_index
)
tl
.
store
(
Att_Out
+
off_o
,
att_value
,
mask
=
offs_n_new
<
cur_batch_end_index
)
...
@@ -165,6 +176,7 @@ def _token_att_m_fwd(
...
@@ -165,6 +176,7 @@ def _token_att_m_fwd(
B_Start_Loc
,
B_Start_Loc
,
B_Seqlen
,
B_Seqlen
,
max_len_in_batch
,
max_len_in_batch
,
logit_cap
,
):
):
BLOCK
=
32
BLOCK
=
32
# shape constraints
# shape constraints
...
@@ -223,6 +235,7 @@ def _token_att_m_fwd(
...
@@ -223,6 +235,7 @@ def _token_att_m_fwd(
kv_group_num
=
kv_group_num
,
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
)
)
...
@@ -304,6 +317,7 @@ def token_attention_fwd(
...
@@ -304,6 +317,7 @@ def token_attention_fwd(
max_len_in_batch
,
max_len_in_batch
,
other_kv_index
,
other_kv_index
,
total_num_tokens
,
total_num_tokens
,
logit_cap
=-
1
,
att_m
=
None
,
att_m
=
None
,
):
):
if
att_m
is
None
:
if
att_m
is
None
:
...
@@ -320,6 +334,7 @@ def token_attention_fwd(
...
@@ -320,6 +334,7 @@ def token_attention_fwd(
b_start_loc
,
b_start_loc
,
b_seq_len
,
b_seq_len
,
max_len_in_batch
,
max_len_in_batch
,
logit_cap
,
)
)
_token_softmax_reducev_fwd
(
_token_softmax_reducev_fwd
(
att_m
,
att_m
,
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
2cea6146
import
asyncio
import
asyncio
import
inspect
import
uvloop
import
uvloop
import
zmq
import
zmq
...
@@ -7,7 +8,7 @@ import zmq.asyncio
...
@@ -7,7 +8,7 @@ import zmq.asyncio
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
,
graceful_registry
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
@@ -85,6 +86,8 @@ def start_detokenizer_process(
...
@@ -85,6 +86,8 @@ def start_detokenizer_process(
port_args
:
PortArgs
,
port_args
:
PortArgs
,
pipe_writer
,
pipe_writer
,
):
):
graceful_registry
(
inspect
.
currentframe
().
f_code
.
co_name
)
try
:
try
:
manager
=
DetokenizerManager
(
server_args
,
port_args
)
manager
=
DetokenizerManager
(
server_args
,
port_args
)
except
Exception
as
e
:
except
Exception
as
e
:
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
2cea6146
...
@@ -106,8 +106,7 @@ class ModelRpcServer:
...
@@ -106,8 +106,7 @@ class ModelRpcServer:
set_random_seed
(
server_args
.
random_seed
)
set_random_seed
(
server_args
.
random_seed
)
# Print info
# Print info
logger
.
info
(
logger
.
info
(
f
"[rank=
{
self
.
tp_rank
}
] "
f
"Rank
{
self
.
tp_rank
}
: "
f
"max_total_num_token=
{
self
.
max_total_num_token
}
, "
f
"max_total_num_token=
{
self
.
max_total_num_token
}
, "
f
"max_prefill_num_token=
{
self
.
max_prefill_num_token
}
, "
f
"max_prefill_num_token=
{
self
.
max_prefill_num_token
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
, "
...
@@ -752,7 +751,7 @@ def _init_service(port):
...
@@ -752,7 +751,7 @@ def _init_service(port):
protocol_config
=
{
protocol_config
=
{
"allow_public_attrs"
:
True
,
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
18
00
,
"sync_request_timeout"
:
36
00
,
},
},
)
)
t
.
start
()
t
.
start
()
...
@@ -772,7 +771,7 @@ def start_model_process(port):
...
@@ -772,7 +771,7 @@ def start_model_process(port):
config
=
{
config
=
{
"allow_public_attrs"
:
True
,
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
18
00
,
"sync_request_timeout"
:
36
00
,
},
},
)
)
break
break
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
2cea6146
...
@@ -235,8 +235,8 @@ class ModelRunner:
...
@@ -235,8 +235,8 @@ class ModelRunner:
}
}
# Init torch distributed
# Init torch distributed
logger
.
debug
(
"Init torch begin."
)
torch
.
cuda
.
set_device
(
self
.
tp_rank
)
torch
.
cuda
.
set_device
(
self
.
tp_rank
)
logger
.
info
(
f
"[rank=
{
self
.
tp_rank
}
] Init torch begin. Avail mem=
{
get_available_gpu_memory
(
self
.
tp_rank
):.
2
f
}
GB"
)
torch
.
distributed
.
init_process_group
(
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
backend
=
"nccl"
,
world_size
=
self
.
tp_size
,
world_size
=
self
.
tp_size
,
...
@@ -244,20 +244,22 @@ class ModelRunner:
...
@@ -244,20 +244,22 @@ class ModelRunner:
init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
,
init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
,
)
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
logger
.
debug
(
"Init torch end."
)
logger
.
info
(
f
"[rank=
{
self
.
tp_rank
}
] Init torch end."
)
total_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
)
if
self
.
tp_size
>
1
:
total_local_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
)
if
total_local_gpu_memory
<
total_gpu_memory
*
0.9
:
raise
ValueError
(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
total_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
)
*
(
1
<<
30
)
# logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self
.
load_model
()
self
.
load_model
()
# logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self
.
init_memory_pool
(
total_gpu_memory
)
self
.
init_memory_pool
(
total_gpu_memory
)
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
def
load_model
(
self
):
def
load_model
(
self
):
logger
.
info
(
f
"
R
ank
{
self
.
tp_rank
}
: l
oad weight begin."
)
logger
.
info
(
f
"
[r
ank
=
{
self
.
tp_rank
}
] L
oad weight begin."
)
device_config
=
DeviceConfig
()
device_config
=
DeviceConfig
()
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
...
@@ -283,19 +285,19 @@ class ModelRunner:
...
@@ -283,19 +285,19 @@ class ModelRunner:
parallel_config
=
None
,
parallel_config
=
None
,
scheduler_config
=
None
,
scheduler_config
=
None
,
)
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight end.
{
type
(
self
.
model
)
}
"
)
logger
.
info
(
f
"[rank=
{
self
.
tp_rank
}
] Load weight end. "
f
"Type=
{
type
(
self
.
model
).
__name__
}
. "
f
"Avail mem=
{
get_available_gpu_memory
(
self
.
tp_rank
):.
2
f
}
GB"
)
def
profile_max_num_token
(
self
,
total_gpu_memory
):
def
profile_max_num_token
(
self
,
total_gpu_memory
):
available_gpu_memory
=
get_available_gpu_memory
(
available_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
)
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
)
*
(
1
<<
30
)
head_dim
=
self
.
model_config
.
head_dim
head_dim
=
self
.
model_config
.
head_dim
head_num
=
self
.
model_config
.
num_key_value_heads
//
self
.
tp_size
head_num
=
self
.
model_config
.
num_key_value_heads
//
self
.
tp_size
cell_size
=
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
2
cell_size
=
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
2
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
1
-
self
.
mem_fraction_static
1
-
self
.
mem_fraction_static
)
)
max_num_token
=
int
(
rest_memory
//
cell_size
)
max_num_token
=
int
(
rest_memory
*
(
1
<<
30
)
//
cell_size
)
return
max_num_token
return
max_num_token
def
init_memory_pool
(
self
,
total_gpu_memory
):
def
init_memory_pool
(
self
,
total_gpu_memory
):
...
...
python/sglang/srt/server.py
View file @
2cea6146
...
@@ -203,7 +203,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -203,7 +203,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
time
.
sleep
(
0.5
)
time
.
sleep
(
0.5
)
try
:
try
:
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
success
=
True
# Set flag to True if request succeeds
break
break
except
requests
.
exceptions
.
RequestException
as
e
:
except
requests
.
exceptions
.
RequestException
as
e
:
pass
pass
...
@@ -213,7 +212,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
...
@@ -213,7 +212,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
res
=
requests
.
post
(
res
=
requests
.
post
(
url
+
"/generate"
,
url
+
"/generate"
,
json
=
{
json
=
{
"text"
:
"
Say this is a warmup request.
"
,
"text"
:
"
The capital city of France is
"
,
"sampling_params"
:
{
"sampling_params"
:
{
"temperature"
:
0
,
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"max_new_tokens"
:
16
,
...
...
python/sglang/srt/utils.py
View file @
2cea6146
...
@@ -92,7 +92,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
...
@@ -92,7 +92,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
return
wrapper
return
wrapper
def
get_available_gpu_memory
(
gpu_id
,
distributed
=
Tru
e
):
def
get_available_gpu_memory
(
gpu_id
,
distributed
=
Fals
e
):
"""
"""
Get available memory for cuda:gpu_id device.
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
When distributed is True, the available memory is the minimum available memory of all GPUs.
...
...
python/sglang/utils.py
View file @
2cea6146
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
import
base64
import
base64
import
json
import
json
import
os
import
logging
import
signal
import
sys
import
sys
import
threading
import
threading
import
traceback
import
traceback
...
@@ -15,6 +16,9 @@ import numpy as np
...
@@ -15,6 +16,9 @@ import numpy as np
import
requests
import
requests
logger
=
logging
.
getLogger
(
__name__
)
def
get_exception_traceback
():
def
get_exception_traceback
():
etype
,
value
,
tb
=
sys
.
exc_info
()
etype
,
value
,
tb
=
sys
.
exc_info
()
err_str
=
""
.
join
(
traceback
.
format_exception
(
etype
,
value
,
tb
))
err_str
=
""
.
join
(
traceback
.
format_exception
(
etype
,
value
,
tb
))
...
@@ -247,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
...
@@ -247,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
raise
RuntimeError
()
raise
RuntimeError
()
return
ret_value
[
0
]
return
ret_value
[
0
]
def
graceful_registry
(
sub_module_name
):
def
graceful_shutdown
(
signum
,
frame
):
logger
.
info
(
f
"
{
sub_module_name
}
Received signal to shutdown. Performing graceful shutdown..."
)
if
signum
==
signal
.
SIGTERM
:
logger
.
info
(
f
"
{
sub_module_name
}
recive sigterm"
)
signal
.
signal
(
signal
.
SIGTERM
,
graceful_shutdown
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment