Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e074d84e
Unverified
Commit
e074d84e
authored
Mar 04, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 04, 2025
Browse files
[Minor] more code cleanup (#4077)
parent
4725e3f6
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
123 additions
and
31 deletions
+123
-31
python/pyproject.toml
python/pyproject.toml
+1
-0
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+6
-1
python/sglang/srt/function_call_parser.py
python/sglang/srt/function_call_parser.py
+0
-1
python/sglang/srt/layers/attention/utils.py
python/sglang/srt/layers/attention/utils.py
+39
-0
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+8
-2
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+1
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+17
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-1
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+1
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+8
-5
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+13
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+20
-7
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+5
-7
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
No files found.
python/pyproject.toml
View file @
e074d84e
...
...
@@ -40,6 +40,7 @@ runtime_common = [
"transformers==4.48.3"
,
"llguidance>=0.6.15"
]
srt
=
[
"sglang[runtime_common]"
,
"sgl-kernel==0.0.3.post6"
,
...
...
python/sglang/bench_serving.py
View file @
e074d84e
...
...
@@ -39,6 +39,7 @@ from transformers import (
)
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
)
ASSISTANT_SUFFIX
=
"Assistant:"
global
args
...
...
@@ -635,7 +636,11 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions.
prompt
=
dataset
[
i
][
0
]
if
prompt_suffix
:
prompt
=
prompt
prompt
=
(
remove_suffix
(
prompt
,
ASSISTANT_SUFFIX
)
+
prompt_suffix
+
ASSISTANT_SUFFIX
)
if
apply_chat_template
:
prompt
=
tokenizer
.
apply_chat_template
(
...
...
python/sglang/srt/function_call_parser.py
View file @
e074d84e
import
json
import
logging
import
re
from
abc
import
ABC
,
abstractmethod
from
json
import
JSONDecodeError
,
JSONDecoder
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
...
...
python/sglang/srt/layers/attention/utils.py
0 → 100644
View file @
e074d84e
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
kv_indices_ptr
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_start
=
0
kv_end
=
0
if
kv_start_idx
:
kv_start
=
tl
.
load
(
kv_start_idx
+
pid
).
to
(
tl
.
int32
)
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
kv_end
-
kv_start
data
=
tl
.
load
(
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
kv_start
+
offset
,
mask
=
mask
,
)
tl
.
store
(
kv_indices_ptr
+
kv_indices_offset
+
offset
,
data
,
mask
=
mask
)
python/sglang/srt/layers/logits_processor.py
View file @
e074d84e
...
...
@@ -33,6 +33,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
ForwardBatch
,
...
...
@@ -152,6 +153,13 @@ class LogitsMetadata:
token_ids_logprobs
=
forward_batch
.
token_ids_logprobs
,
extend_input_logprob_token_ids_gpu
=
forward_batch
.
extend_input_logprob_token_ids_gpu
,
padded_static_len
=
forward_batch
.
padded_static_len
,
global_num_tokens_gpu
=
forward_batch
.
global_num_tokens_gpu
,
dp_local_start_pos
=
forward_batch
.
dp_local_start_pos
,
dp_local_num_tokens
=
forward_batch
.
dp_local_num_tokens
,
gathered_buffer
=
forward_batch
.
gathered_buffer
,
forward_batch_gathered_buffer
=
forward_batch
.
gathered_buffer
,
global_num_tokens_for_logprob_cpu
=
forward_batch
.
global_num_tokens_for_logprob_cpu
,
global_num_tokens_for_logprob_gpu
=
forward_batch
.
global_num_tokens_for_logprob_gpu
,
)
def
compute_dp_attention_metadata
(
self
,
hidden_states
:
torch
.
Tensor
):
...
...
@@ -204,8 +212,6 @@ class LogitsProcessor(nn.Module):
):
self
.
final_logit_softcapping
=
None
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
self
.
debug_tensor_dump_output_folder
=
global_server_args_dict
.
get
(
"debug_tensor_dump_output_folder"
,
None
)
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
e074d84e
...
...
@@ -212,6 +212,7 @@ class DetokenizerManager:
rids
=
recv_obj
.
rids
,
finished_reasons
=
recv_obj
.
finished_reasons
,
output_strs
=
output_strs
,
output_ids
=
None
,
prompt_tokens
=
recv_obj
.
prompt_tokens
,
completion_tokens
=
recv_obj
.
completion_tokens
,
cached_tokens
=
recv_obj
.
cached_tokens
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
e074d84e
...
...
@@ -414,6 +414,12 @@ class BatchTokenIDOut:
class
BatchMultimodalDecodeReq
:
# The request id
rids
:
List
[
str
]
finished_reasons
:
List
[
BaseFinishReason
]
# Token counts
prompt_tokens
:
List
[
int
]
completion_tokens
:
List
[
int
]
cached_tokens
:
List
[
int
]
@
dataclass
...
...
@@ -424,6 +430,8 @@ class BatchStrOut:
finished_reasons
:
List
[
dict
]
# The output decoded strings
output_strs
:
List
[
str
]
# The token ids
output_ids
:
Optional
[
List
[
int
]]
# Token counts
prompt_tokens
:
List
[
int
]
...
...
@@ -453,6 +461,15 @@ class BatchStrOut:
class
BatchMultimodalOut
:
# The request id
rids
:
List
[
str
]
# The finish reason
finished_reasons
:
List
[
dict
]
# The outputs
outputs
:
List
[
List
[
Dict
]]
# Token counts
prompt_tokens
:
List
[
int
]
completion_tokens
:
List
[
int
]
cached_tokens
:
List
[
int
]
@
dataclass
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
e074d84e
...
...
@@ -1141,7 +1141,7 @@ async def print_exception_wrapper(func):
class
SignalHandler
:
def
__init__
(
self
,
tokenizer_manager
):
def
__init__
(
self
,
tokenizer_manager
:
TokenizerManager
):
self
.
tokenizer_manager
=
tokenizer_manager
def
signal_handler
(
self
,
signum
=
None
,
frame
=
None
):
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
e074d84e
...
...
@@ -192,7 +192,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
k_size
,
v_size
=
self
.
get_kv_size_bytes
()
logger
.
info
(
f
"KV Cache is allocated. K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB
.
"
f
"KV Cache is allocated.
#tokens:
{
size
}
,
K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB"
)
def
_create_buffers
(
self
):
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
e074d84e
...
...
@@ -238,6 +238,9 @@ class CudaGraphRunner:
),
dtype
=
self
.
model_runner
.
dtype
,
)
self
.
global_num_tokens_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
# Capture
try
:
...
...
@@ -266,9 +269,9 @@ class CudaGraphRunner:
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
enable_dp_attention
:
min_num_tokens
,
max_num_tokens
=
min
(
forward_batch
.
global_num_tokens
),
max
(
forward_batch
.
global_num_tokens
)
min_num_tokens
,
max_num_tokens
=
min
(
forward_batch
.
global_num_tokens
_cpu
)
,
max
(
forward_batch
.
global_num_tokens_cpu
)
is_bs_supported
=
forward_batch
.
can_run_dp_cuda_graph
and
(
(
min_num_tokens
==
max_num_tokens
and
max_num_tokens
in
self
.
graphs
)
if
self
.
disable_padding
...
...
@@ -360,7 +363,7 @@ class CudaGraphRunner:
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
positions
=
positions
,
global_num_tokens
=
global_num_tokens
,
global_num_tokens
_cpu
=
global_num_tokens
,
gathered_buffer
=
gathered_buffer
,
mrope_positions
=
mrope_positions
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
...
...
@@ -430,7 +433,7 @@ class CudaGraphRunner:
# Pad
if
self
.
enable_dp_attention
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
max
(
forward_batch
.
global_num_tokens
)
self
.
capture_bs
,
max
(
forward_batch
.
global_num_tokens
_cpu
)
)
else
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
e074d84e
...
...
@@ -190,7 +190,16 @@ class ForwardBatch:
attn_backend
:
AttentionBackend
=
None
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_cpu
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_gpu
:
Optional
[
torch
.
Tensor
]
=
None
# Has to be None when cuda graph is captured.
global_num_tokens_for_logprob_cpu
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_for_logprob_gpu
:
Optional
[
torch
.
Tensor
]
=
None
# for extend, local start pos and num tokens is different in logits processor
# this will be computed in get_dp_local_info
# this will be recomputed in LogitsMetadata.from_forward_batch
dp_local_start_pos
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
dp_local_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
can_run_dp_cuda_graph
:
bool
=
False
...
...
@@ -234,7 +243,6 @@ class ForwardBatch:
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
global_num_tokens
=
batch
.
global_num_tokens
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
lora_paths
=
batch
.
lora_paths
,
sampling_info
=
batch
.
sampling_info
,
...
...
@@ -248,8 +256,9 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu
=
extend_input_logprob_token_ids_gpu
,
)
if
ret
.
global_num_tokens
is
not
None
:
max_len
=
max
(
ret
.
global_num_tokens
)
if
batch
.
global_num_tokens
is
not
None
:
ret
.
global_num_tokens_cpu
=
batch
.
global_num_tokens
max_len
=
max
(
ret
.
global_num_tokens_cpu
)
ret
.
gathered_buffer
=
torch
.
zeros
(
(
max_len
*
model_runner
.
tp_size
,
model_runner
.
model_config
.
hidden_size
),
dtype
=
model_runner
.
dtype
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e074d84e
...
...
@@ -13,7 +13,6 @@
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""
import
collections
import
datetime
import
gc
import
json
...
...
@@ -269,6 +268,7 @@ class ModelRunner:
elif
self
.
device
==
"cpu"
:
backend
=
"gloo"
before_avail_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
if
not
self
.
server_args
.
enable_p2p_check
:
monkey_patch_p2p_access_check
()
...
...
@@ -299,20 +299,24 @@ class ModelRunner:
min_per_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
local_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
self
.
tp_group
=
get_tp_group
()
self
.
attention_tp_group
=
get_attention_tp_group
()
# Check memory for tensor parallelism
if
self
.
tp_size
>
1
:
local_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
if
min_per_gpu_memory
<
local_gpu_memory
*
0.9
:
raise
ValueError
(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
logger
.
info
(
f
"Init torch distributed ends. mem usage=
{
(
before_avail_memory
-
local_gpu_memory
):.
2
f
}
GB"
)
return
min_per_gpu_memory
def
load_model
(
self
):
before_avail_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
f
"Load weight begin. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
...
...
@@ -382,11 +386,13 @@ class ModelRunner:
)
self
.
dtype
=
self
.
model_config
.
dtype
after_avail_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
f
"Load weight end. "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
f
"dtype=
{
self
.
dtype
}
, "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
after_avail_memory
:.
2
f
}
GB, "
f
"mem usage=
{
(
before_avail_memory
-
after_avail_memory
):.
2
f
}
GB."
)
def
update_weights_from_disk
(
...
...
@@ -785,12 +791,15 @@ class ModelRunner:
return
tic
=
time
.
time
()
before_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
f
"Capture cuda graph begin. This can take up to several minutes. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
:.
2
f
}
GB"
f
"Capture cuda graph begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"Capture cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s. "
f
"avail mem=
{
after_mem
:.
2
f
}
GB. mem usage=
{
(
before_mem
-
after_mem
):.
2
f
}
GB."
)
def
apply_torch_tp
(
self
):
...
...
@@ -806,8 +815,12 @@ class ModelRunner:
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
def
forward_extend
(
self
,
forward_batch
:
ForwardBatch
):
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
def
forward_extend
(
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
):
if
not
skip_attn_backend_init
:
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
if
self
.
is_generation
:
if
forward_batch
.
input_embeds
is
None
:
return
self
.
model
.
forward
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
e074d84e
...
...
@@ -818,8 +818,8 @@ def all_gather(
if
world_size
==
1
:
return
input_tensor
all_lens
=
forward_batch
.
global_num_tokens
max_len
=
max
(
forward_batch
.
global_num_tokens
)
all_lens
=
forward_batch
.
global_num_tokens
_cpu
max_len
=
max
(
forward_batch
.
global_num_tokens
_cpu
)
padded_tensor
=
torch
.
nn
.
functional
.
pad
(
input_tensor
,
(
0
,
0
,
0
,
max_len
-
input_tensor
.
shape
[
0
])
...
...
python/sglang/srt/utils.py
View file @
e074d84e
...
...
@@ -741,13 +741,6 @@ def pytorch_profile(name, func, *args, data_size=-1):
return
result
def
first_rank_print
(
*
args
,
**
kwargs
):
if
torch
.
cuda
.
current_device
()
==
0
:
print
(
*
args
,
**
kwargs
)
else
:
pass
def
get_zmq_socket
(
context
:
zmq
.
Context
,
socket_type
:
zmq
.
SocketType
,
endpoint
:
str
,
bind
:
bool
):
...
...
@@ -1177,6 +1170,11 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
return
value
.
lower
()
in
(
"true"
,
"1"
)
@
lru_cache
(
maxsize
=
2
)
def
disable_request_logging
()
->
bool
:
return
get_bool_env_var
(
"SGLANG_DISABLE_REQUEST_LOGGING"
)
@
lru_cache
(
maxsize
=
8
)
def
_cuda_device_count_stateless
(
cuda_visible_devices
:
Optional
[
str
]
=
None
)
->
int
:
# Note: cuda_visible_devices is not used, but we keep it as an argument for
...
...
sgl-kernel/setup.py
View file @
e074d84e
...
...
@@ -85,6 +85,7 @@ nvcc_flags = [
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
,
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
,
"--ptxas-options=-v"
,
"--expt-relaxed-constexpr"
,
"-Xcompiler=-Wconversion"
,
"-Xcompiler=-fno-strict-aliasing"
,
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment