Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2cea6146
"src/graph/sampling/randomwalks/randomwalks_cpu.h" did not exist on "a9dabcc769554bd3c8daff7d6b76d3104910b445"
Unverified
Commit
2cea6146
authored
May 24, 2024
by
Lianmin Zheng
Committed by
GitHub
May 24, 2024
Browse files
Improve logging & add logit cap (#471)
parent
44c998fc
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
106 additions
and
24 deletions
+106
-24
benchmark/latency_throughput/test_latency.py
benchmark/latency_throughput/test_latency.py
+1
-1
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+3
-0
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+24
-0
python/sglang/srt/layers/extend_attention.py
python/sglang/srt/layers/extend_attention.py
+17
-0
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+8
-1
python/sglang/srt/layers/token_attention.py
python/sglang/srt/layers/token_attention.py
+15
-0
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+4
-1
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+3
-4
python/sglang/srt/managers/router/model_runner.py
python/sglang/srt/managers/router/model_runner.py
+15
-13
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-1
python/sglang/utils.py
python/sglang/utils.py
+14
-1
No files found.
benchmark/latency_throughput/test_latency.py
View file @
2cea6146
...
...
@@ -30,7 +30,7 @@ if __name__ == "__main__":
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
f
"
{
a
}
,
"
,
"text"
:
f
"
The capital of France is
"
,
# "input_ids": [[2] * 256] * 196,
"sampling_params"
:
{
"temperature"
:
0
,
...
...
python/sglang/srt/constrained/fsm_cache.py
View file @
2cea6146
...
...
@@ -6,6 +6,9 @@ class FSMCache(BaseCache):
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
enable
=
True
):
super
().
__init__
(
enable
=
enable
)
if
tokenizer_path
.
endswith
(
".json"
):
return
from
importlib.metadata
import
version
if
version
(
"outlines"
)
>=
"0.0.35"
:
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
2cea6146
...
...
@@ -84,6 +84,9 @@ def get_tokenizer(
tokenizer_revision
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
if
tokenizer_name
.
endswith
(
".json"
):
return
TiktokenTokenizer
(
tokenizer_name
)
"""Gets a tokenizer for the given model name via Huggingface."""
if
is_multimodal_model
(
tokenizer_name
):
processor
=
get_processor
(
...
...
@@ -170,3 +173,24 @@ def get_processor(
**
kwargs
,
)
return
processor
class
TiktokenTokenizer
:
def
__init__
(
self
,
tokenizer_path
):
import
xlm.tokenizers.tiktoken_wrapper
as
tiktoken_wrapper
tokenizer
=
tiktoken_wrapper
.
Encoding
.
from_xtok_json
(
"xtok-json"
,
tokenizer_path
)
self
.
tokenizer
=
tokenizer
self
.
eos_token_id
=
tokenizer
.
eos_token
self
.
vocab_size
=
tokenizer
.
n_vocab
def
encode
(
self
,
x
):
return
self
.
tokenizer
.
encode
(
x
)
def
decode
(
self
,
x
):
return
self
.
tokenizer
.
decode
(
x
)
def
batch_decode
(
self
,
batch
,
skip_special_tokens
,
spaces_between_special_tokens
):
return
self
.
tokenizer
.
decode_batch
(
batch
)
def
convert_ids_to_tokens
(
self
,
index
):
return
self
.
tokenizer
.
decode_single_token_bytes
(
index
).
decode
(
"utf-8"
,
errors
=
"ignore"
)
\ No newline at end of file
python/sglang/srt/layers/extend_attention.py
View file @
2cea6146
...
...
@@ -8,6 +8,12 @@ from sglang.srt.utils import wrap_kernel_launcher
CUDA_CAPABILITY
=
torch
.
cuda
.
get_device_capability
()
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
def
_fwd_kernel
(
Q_Extend
,
...
...
@@ -39,6 +45,7 @@ def _fwd_kernel(
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
):
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
...
...
@@ -90,6 +97,10 @@ def _fwd_kernel(
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
tl
.
where
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:],
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
...
...
@@ -126,6 +137,10 @@ def _fwd_kernel(
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
)
...
...
@@ -176,6 +191,7 @@ def extend_attention_fwd(
b_seq_len_extend
,
max_len_in_batch
,
max_len_extend
,
logit_cap
=-
1
,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
...
@@ -271,6 +287,7 @@ def extend_attention_fwd(
BLOCK_N
=
BLOCK_N
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
logit_cap
=
logit_cap
,
)
cached_kernel
=
wrap_kernel_launcher
(
_fwd_kernel
)
...
...
python/sglang/srt/layers/radix_attention.py
View file @
2cea6146
import
torch
import
numpy
as
np
from
torch
import
nn
from
sglang.srt.layers.context_flashattention_nopad
import
context_attention_fwd
...
...
@@ -8,13 +9,16 @@ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
class
RadixAttention
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
head_dim
,
scaling
,
num_kv_heads
,
layer_id
):
def
__init__
(
self
,
num_heads
,
head_dim
,
scaling
,
num_kv_heads
,
layer_id
,
logit_cap
=-
1
):
super
().
__init__
()
self
.
tp_q_head_num
=
num_heads
self
.
tp_k_head_num
=
num_kv_heads
self
.
tp_v_head_num
=
num_kv_heads
self
.
head_dim
=
head_dim
self
.
layer_id
=
layer_id
self
.
logit_cap
=
logit_cap
assert
np
.
allclose
(
scaling
,
1.0
/
(
head_dim
**
0.5
))
from
sglang.srt.managers.router.model_runner
import
global_server_args_dict
...
...
@@ -38,6 +42,7 @@ class RadixAttention(nn.Module):
input_metadata
.
start_loc
,
input_metadata
.
seq_lens
,
input_metadata
.
max_seq_len
,
self
.
logit_cap
,
)
self
.
store_kv_cache
(
k
,
v
,
input_metadata
)
...
...
@@ -62,6 +67,7 @@ class RadixAttention(nn.Module):
input_metadata
.
extend_seq_lens
,
input_metadata
.
max_seq_len
,
input_metadata
.
max_extend_len
,
self
.
logit_cap
,
)
return
o
...
...
@@ -82,6 +88,7 @@ class RadixAttention(nn.Module):
input_metadata
.
max_seq_len
,
input_metadata
.
other_kv_index
,
input_metadata
.
total_num_tokens
,
self
.
logit_cap
,
)
return
o
...
...
python/sglang/srt/layers/token_attention.py
View file @
2cea6146
...
...
@@ -16,6 +16,12 @@ else:
REDUCE_TORCH_TYPE
=
torch
.
float16
@
triton
.
jit
def
tanh
(
x
):
# Tanh is just a scaled sigmoid
return
2
*
tl
.
sigmoid
(
2
*
x
)
-
1
@
triton
.
jit
def
_fwd_kernel_stage1
(
Q
,
...
...
@@ -35,6 +41,7 @@ def _fwd_kernel_stage1(
kv_group_num
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
...
...
@@ -77,6 +84,10 @@ def _fwd_kernel_stage1(
).
to
(
REDUCE_TRITON_TYPE
)
att_value
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
att_value
*=
sm_scale
if
logit_cap
>
0
:
att_value
=
logit_cap
*
tanh
(
att_value
/
logit_cap
)
off_o
=
cur_head
*
att_stride_h
+
(
cur_batch_in_all_start_index
+
offs_n
)
tl
.
store
(
Att_Out
+
off_o
,
att_value
,
mask
=
offs_n_new
<
cur_batch_end_index
)
...
...
@@ -165,6 +176,7 @@ def _token_att_m_fwd(
B_Start_Loc
,
B_Seqlen
,
max_len_in_batch
,
logit_cap
,
):
BLOCK
=
32
# shape constraints
...
...
@@ -223,6 +235,7 @@ def _token_att_m_fwd(
kv_group_num
=
kv_group_num
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
logit_cap
=
logit_cap
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
...
...
@@ -304,6 +317,7 @@ def token_attention_fwd(
max_len_in_batch
,
other_kv_index
,
total_num_tokens
,
logit_cap
=-
1
,
att_m
=
None
,
):
if
att_m
is
None
:
...
...
@@ -320,6 +334,7 @@ def token_attention_fwd(
b_start_loc
,
b_seq_len
,
max_len_in_batch
,
logit_cap
,
)
_token_softmax_reducev_fwd
(
att_m
,
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
2cea6146
import
asyncio
import
inspect
import
uvloop
import
zmq
...
...
@@ -7,7 +8,7 @@ import zmq.asyncio
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.managers.io_struct
import
BatchStrOut
,
BatchTokenIDOut
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
,
graceful_registry
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
...
@@ -85,6 +86,8 @@ def start_detokenizer_process(
port_args
:
PortArgs
,
pipe_writer
,
):
graceful_registry
(
inspect
.
currentframe
().
f_code
.
co_name
)
try
:
manager
=
DetokenizerManager
(
server_args
,
port_args
)
except
Exception
as
e
:
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
2cea6146
...
...
@@ -106,8 +106,7 @@ class ModelRpcServer:
set_random_seed
(
server_args
.
random_seed
)
# Print info
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: "
logger
.
info
(
f
"[rank=
{
self
.
tp_rank
}
] "
f
"max_total_num_token=
{
self
.
max_total_num_token
}
, "
f
"max_prefill_num_token=
{
self
.
max_prefill_num_token
}
, "
f
"context_len=
{
self
.
model_config
.
context_len
}
, "
...
...
@@ -752,7 +751,7 @@ def _init_service(port):
protocol_config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
18
00
,
"sync_request_timeout"
:
36
00
,
},
)
t
.
start
()
...
...
@@ -772,7 +771,7 @@ def start_model_process(port):
config
=
{
"allow_public_attrs"
:
True
,
"allow_pickle"
:
True
,
"sync_request_timeout"
:
18
00
,
"sync_request_timeout"
:
36
00
,
},
)
break
...
...
python/sglang/srt/managers/router/model_runner.py
View file @
2cea6146
...
...
@@ -235,8 +235,8 @@ class ModelRunner:
}
# Init torch distributed
logger
.
debug
(
"Init torch begin."
)
torch
.
cuda
.
set_device
(
self
.
tp_rank
)
logger
.
info
(
f
"[rank=
{
self
.
tp_rank
}
] Init torch begin. Avail mem=
{
get_available_gpu_memory
(
self
.
tp_rank
):.
2
f
}
GB"
)
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
world_size
=
self
.
tp_size
,
...
...
@@ -244,20 +244,22 @@ class ModelRunner:
init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
,
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
logger
.
debug
(
"Init torch end."
)
logger
.
info
(
f
"[rank=
{
self
.
tp_rank
}
] Init torch end."
)
total_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
)
if
self
.
tp_size
>
1
:
total_local_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
)
if
total_local_gpu_memory
<
total_gpu_memory
*
0.9
:
raise
ValueError
(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
total_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
)
*
(
1
<<
30
)
# logger.info(f"Before: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self
.
load_model
()
# logger.info(f"After: {get_available_gpu_memory(self.tp_rank, False):.2f} GB")
self
.
init_memory_pool
(
total_gpu_memory
)
self
.
is_multimodal_model
=
is_multimodal_model
(
self
.
model_config
)
def
load_model
(
self
):
logger
.
info
(
f
"
R
ank
{
self
.
tp_rank
}
: l
oad weight begin."
)
logger
.
info
(
f
"
[r
ank
=
{
self
.
tp_rank
}
] L
oad weight begin."
)
device_config
=
DeviceConfig
()
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
)
...
...
@@ -283,19 +285,19 @@ class ModelRunner:
parallel_config
=
None
,
scheduler_config
=
None
,
)
logger
.
info
(
f
"Rank
{
self
.
tp_rank
}
: load weight end.
{
type
(
self
.
model
)
}
"
)
logger
.
info
(
f
"[rank=
{
self
.
tp_rank
}
] Load weight end. "
f
"Type=
{
type
(
self
.
model
).
__name__
}
. "
f
"Avail mem=
{
get_available_gpu_memory
(
self
.
tp_rank
):.
2
f
}
GB"
)
def
profile_max_num_token
(
self
,
total_gpu_memory
):
available_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
)
*
(
1
<<
30
)
available_gpu_memory
=
get_available_gpu_memory
(
self
.
tp_rank
,
distributed
=
self
.
tp_size
>
1
)
head_dim
=
self
.
model_config
.
head_dim
head_num
=
self
.
model_config
.
num_key_value_heads
//
self
.
tp_size
cell_size
=
head_num
*
head_dim
*
self
.
model_config
.
num_hidden_layers
*
2
*
2
rest_memory
=
available_gpu_memory
-
total_gpu_memory
*
(
1
-
self
.
mem_fraction_static
)
max_num_token
=
int
(
rest_memory
//
cell_size
)
max_num_token
=
int
(
rest_memory
*
(
1
<<
30
)
//
cell_size
)
return
max_num_token
def
init_memory_pool
(
self
,
total_gpu_memory
):
...
...
python/sglang/srt/server.py
View file @
2cea6146
...
...
@@ -203,7 +203,6 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
time
.
sleep
(
0.5
)
try
:
requests
.
get
(
url
+
"/get_model_info"
,
timeout
=
5
,
headers
=
headers
)
success
=
True
# Set flag to True if request succeeds
break
except
requests
.
exceptions
.
RequestException
as
e
:
pass
...
...
@@ -213,7 +212,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
res
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
"
Say this is a warmup request.
"
,
"text"
:
"
The capital city of France is
"
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
...
...
python/sglang/srt/utils.py
View file @
2cea6146
...
...
@@ -92,7 +92,7 @@ def calculate_time(show=False, min_cost_ms=0.0):
return
wrapper
def
get_available_gpu_memory
(
gpu_id
,
distributed
=
Tru
e
):
def
get_available_gpu_memory
(
gpu_id
,
distributed
=
Fals
e
):
"""
Get available memory for cuda:gpu_id device.
When distributed is True, the available memory is the minimum available memory of all GPUs.
...
...
python/sglang/utils.py
View file @
2cea6146
...
...
@@ -2,7 +2,8 @@
import
base64
import
json
import
os
import
logging
import
signal
import
sys
import
threading
import
traceback
...
...
@@ -15,6 +16,9 @@ import numpy as np
import
requests
logger
=
logging
.
getLogger
(
__name__
)
def
get_exception_traceback
():
etype
,
value
,
tb
=
sys
.
exc_info
()
err_str
=
""
.
join
(
traceback
.
format_exception
(
etype
,
value
,
tb
))
...
...
@@ -247,3 +251,12 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
raise
RuntimeError
()
return
ret_value
[
0
]
def
graceful_registry
(
sub_module_name
):
def
graceful_shutdown
(
signum
,
frame
):
logger
.
info
(
f
"
{
sub_module_name
}
Received signal to shutdown. Performing graceful shutdown..."
)
if
signum
==
signal
.
SIGTERM
:
logger
.
info
(
f
"
{
sub_module_name
}
recive sigterm"
)
signal
.
signal
(
signal
.
SIGTERM
,
graceful_shutdown
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment