Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e074d84e
Unverified
Commit
e074d84e
authored
Mar 04, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 04, 2025
Browse files
[Minor] more code cleanup (#4077)
parent
4725e3f6
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
123 additions
and
31 deletions
+123
-31
python/pyproject.toml
python/pyproject.toml
+1
-0
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+6
-1
python/sglang/srt/function_call_parser.py
python/sglang/srt/function_call_parser.py
+0
-1
python/sglang/srt/layers/attention/utils.py
python/sglang/srt/layers/attention/utils.py
+39
-0
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+8
-2
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+1
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+17
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-1
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+1
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+8
-5
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+13
-4
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+20
-7
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+5
-7
sgl-kernel/setup.py
sgl-kernel/setup.py
+1
-0
No files found.
python/pyproject.toml
View file @
e074d84e
...
@@ -40,6 +40,7 @@ runtime_common = [
...
@@ -40,6 +40,7 @@ runtime_common = [
"transformers==4.48.3"
,
"transformers==4.48.3"
,
"llguidance>=0.6.15"
"llguidance>=0.6.15"
]
]
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"sglang[runtime_common]"
,
"sgl-kernel==0.0.3.post6"
,
"sgl-kernel==0.0.3.post6"
,
...
...
python/sglang/bench_serving.py
View file @
e074d84e
...
@@ -39,6 +39,7 @@ from transformers import (
...
@@ -39,6 +39,7 @@ from transformers import (
)
)
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
)
AIOHTTP_TIMEOUT
=
aiohttp
.
ClientTimeout
(
total
=
6
*
60
*
60
)
ASSISTANT_SUFFIX
=
"Assistant:"
global
args
global
args
...
@@ -635,7 +636,11 @@ def sample_sharegpt_requests(
...
@@ -635,7 +636,11 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions.
# Tokenize the prompts and completions.
prompt
=
dataset
[
i
][
0
]
prompt
=
dataset
[
i
][
0
]
if
prompt_suffix
:
if
prompt_suffix
:
prompt
=
prompt
prompt
=
(
remove_suffix
(
prompt
,
ASSISTANT_SUFFIX
)
+
prompt_suffix
+
ASSISTANT_SUFFIX
)
if
apply_chat_template
:
if
apply_chat_template
:
prompt
=
tokenizer
.
apply_chat_template
(
prompt
=
tokenizer
.
apply_chat_template
(
...
...
python/sglang/srt/function_call_parser.py
View file @
e074d84e
import
json
import
json
import
logging
import
logging
import
re
import
re
from
abc
import
ABC
,
abstractmethod
from
json
import
JSONDecodeError
,
JSONDecoder
from
json
import
JSONDecodeError
,
JSONDecoder
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
...
...
python/sglang/srt/layers/attention/utils.py
0 → 100644
View file @
e074d84e
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
create_flashinfer_kv_indices_triton
(
req_to_token_ptr
,
# [max_batch, max_context_len]
req_pool_indices_ptr
,
page_kernel_lens_ptr
,
kv_indptr
,
kv_start_idx
,
kv_indices_ptr
,
req_to_token_ptr_stride
:
tl
.
constexpr
,
):
BLOCK_SIZE
:
tl
.
constexpr
=
512
pid
=
tl
.
program_id
(
axis
=
0
)
req_pool_index
=
tl
.
load
(
req_pool_indices_ptr
+
pid
)
kv_indices_offset
=
tl
.
load
(
kv_indptr
+
pid
)
kv_start
=
0
kv_end
=
0
if
kv_start_idx
:
kv_start
=
tl
.
load
(
kv_start_idx
+
pid
).
to
(
tl
.
int32
)
kv_end
=
kv_start
kv_end
+=
tl
.
load
(
page_kernel_lens_ptr
+
pid
).
to
(
tl
.
int32
)
num_loop
=
tl
.
cdiv
(
kv_end
-
kv_start
,
BLOCK_SIZE
)
for
i
in
range
(
num_loop
):
offset
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
i
*
BLOCK_SIZE
mask
=
offset
<
kv_end
-
kv_start
data
=
tl
.
load
(
req_to_token_ptr
+
req_pool_index
*
req_to_token_ptr_stride
+
kv_start
+
offset
,
mask
=
mask
,
)
tl
.
store
(
kv_indices_ptr
+
kv_indices_offset
+
offset
,
data
,
mask
=
mask
)
python/sglang/srt/layers/logits_processor.py
View file @
e074d84e
...
@@ -33,6 +33,7 @@ from sglang.srt.layers.dp_attention import (
...
@@ -33,6 +33,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size
,
get_attention_dp_size
,
)
)
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
(
from
sglang.srt.model_executor.forward_batch_info
import
(
CaptureHiddenMode
,
CaptureHiddenMode
,
ForwardBatch
,
ForwardBatch
,
...
@@ -152,6 +153,13 @@ class LogitsMetadata:
...
@@ -152,6 +153,13 @@ class LogitsMetadata:
token_ids_logprobs
=
forward_batch
.
token_ids_logprobs
,
token_ids_logprobs
=
forward_batch
.
token_ids_logprobs
,
extend_input_logprob_token_ids_gpu
=
forward_batch
.
extend_input_logprob_token_ids_gpu
,
extend_input_logprob_token_ids_gpu
=
forward_batch
.
extend_input_logprob_token_ids_gpu
,
padded_static_len
=
forward_batch
.
padded_static_len
,
padded_static_len
=
forward_batch
.
padded_static_len
,
global_num_tokens_gpu
=
forward_batch
.
global_num_tokens_gpu
,
dp_local_start_pos
=
forward_batch
.
dp_local_start_pos
,
dp_local_num_tokens
=
forward_batch
.
dp_local_num_tokens
,
gathered_buffer
=
forward_batch
.
gathered_buffer
,
forward_batch_gathered_buffer
=
forward_batch
.
gathered_buffer
,
global_num_tokens_for_logprob_cpu
=
forward_batch
.
global_num_tokens_for_logprob_cpu
,
global_num_tokens_for_logprob_gpu
=
forward_batch
.
global_num_tokens_for_logprob_gpu
,
)
)
def
compute_dp_attention_metadata
(
self
,
hidden_states
:
torch
.
Tensor
):
def
compute_dp_attention_metadata
(
self
,
hidden_states
:
torch
.
Tensor
):
...
@@ -204,8 +212,6 @@ class LogitsProcessor(nn.Module):
...
@@ -204,8 +212,6 @@ class LogitsProcessor(nn.Module):
):
):
self
.
final_logit_softcapping
=
None
self
.
final_logit_softcapping
=
None
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
self
.
debug_tensor_dump_output_folder
=
global_server_args_dict
.
get
(
self
.
debug_tensor_dump_output_folder
=
global_server_args_dict
.
get
(
"debug_tensor_dump_output_folder"
,
None
"debug_tensor_dump_output_folder"
,
None
)
)
...
...
python/sglang/srt/managers/detokenizer_manager.py
View file @
e074d84e
...
@@ -212,6 +212,7 @@ class DetokenizerManager:
...
@@ -212,6 +212,7 @@ class DetokenizerManager:
rids
=
recv_obj
.
rids
,
rids
=
recv_obj
.
rids
,
finished_reasons
=
recv_obj
.
finished_reasons
,
finished_reasons
=
recv_obj
.
finished_reasons
,
output_strs
=
output_strs
,
output_strs
=
output_strs
,
output_ids
=
None
,
prompt_tokens
=
recv_obj
.
prompt_tokens
,
prompt_tokens
=
recv_obj
.
prompt_tokens
,
completion_tokens
=
recv_obj
.
completion_tokens
,
completion_tokens
=
recv_obj
.
completion_tokens
,
cached_tokens
=
recv_obj
.
cached_tokens
,
cached_tokens
=
recv_obj
.
cached_tokens
,
...
...
python/sglang/srt/managers/io_struct.py
View file @
e074d84e
...
@@ -414,6 +414,12 @@ class BatchTokenIDOut:
...
@@ -414,6 +414,12 @@ class BatchTokenIDOut:
class
BatchMultimodalDecodeReq
:
class
BatchMultimodalDecodeReq
:
# The request id
# The request id
rids
:
List
[
str
]
rids
:
List
[
str
]
finished_reasons
:
List
[
BaseFinishReason
]
# Token counts
prompt_tokens
:
List
[
int
]
completion_tokens
:
List
[
int
]
cached_tokens
:
List
[
int
]
@
dataclass
@
dataclass
...
@@ -424,6 +430,8 @@ class BatchStrOut:
...
@@ -424,6 +430,8 @@ class BatchStrOut:
finished_reasons
:
List
[
dict
]
finished_reasons
:
List
[
dict
]
# The output decoded strings
# The output decoded strings
output_strs
:
List
[
str
]
output_strs
:
List
[
str
]
# The token ids
output_ids
:
Optional
[
List
[
int
]]
# Token counts
# Token counts
prompt_tokens
:
List
[
int
]
prompt_tokens
:
List
[
int
]
...
@@ -453,6 +461,15 @@ class BatchStrOut:
...
@@ -453,6 +461,15 @@ class BatchStrOut:
class
BatchMultimodalOut
:
class
BatchMultimodalOut
:
# The request id
# The request id
rids
:
List
[
str
]
rids
:
List
[
str
]
# The finish reason
finished_reasons
:
List
[
dict
]
# The outputs
outputs
:
List
[
List
[
Dict
]]
# Token counts
prompt_tokens
:
List
[
int
]
completion_tokens
:
List
[
int
]
cached_tokens
:
List
[
int
]
@
dataclass
@
dataclass
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
e074d84e
...
@@ -1141,7 +1141,7 @@ async def print_exception_wrapper(func):
...
@@ -1141,7 +1141,7 @@ async def print_exception_wrapper(func):
class
SignalHandler
:
class
SignalHandler
:
def
__init__
(
self
,
tokenizer_manager
):
def
__init__
(
self
,
tokenizer_manager
:
TokenizerManager
):
self
.
tokenizer_manager
=
tokenizer_manager
self
.
tokenizer_manager
=
tokenizer_manager
def
signal_handler
(
self
,
signum
=
None
,
frame
=
None
):
def
signal_handler
(
self
,
signum
=
None
,
frame
=
None
):
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
e074d84e
...
@@ -192,7 +192,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
...
@@ -192,7 +192,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
k_size
,
v_size
=
self
.
get_kv_size_bytes
()
k_size
,
v_size
=
self
.
get_kv_size_bytes
()
logger
.
info
(
logger
.
info
(
f
"KV Cache is allocated. K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB
.
"
f
"KV Cache is allocated.
#tokens:
{
size
}
,
K size:
{
k_size
/
GB
:.
2
f
}
GB, V size:
{
v_size
/
GB
:.
2
f
}
GB"
)
)
def
_create_buffers
(
self
):
def
_create_buffers
(
self
):
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
e074d84e
...
@@ -238,6 +238,9 @@ class CudaGraphRunner:
...
@@ -238,6 +238,9 @@ class CudaGraphRunner:
),
),
dtype
=
self
.
model_runner
.
dtype
,
dtype
=
self
.
model_runner
.
dtype
,
)
)
self
.
global_num_tokens_gpu
=
torch
.
zeros
(
(
self
.
dp_size
,),
dtype
=
torch
.
int32
)
# Capture
# Capture
try
:
try
:
...
@@ -266,9 +269,9 @@ class CudaGraphRunner:
...
@@ -266,9 +269,9 @@ class CudaGraphRunner:
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
def
can_run
(
self
,
forward_batch
:
ForwardBatch
):
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
min_num_tokens
,
max_num_tokens
=
min
(
forward_batch
.
global_num_tokens
),
max
(
min_num_tokens
,
max_num_tokens
=
min
(
forward_batch
.
global_num_tokens
forward_batch
.
global_num_tokens
_cpu
)
)
,
max
(
forward_batch
.
global_num_tokens_cpu
)
is_bs_supported
=
forward_batch
.
can_run_dp_cuda_graph
and
(
is_bs_supported
=
forward_batch
.
can_run_dp_cuda_graph
and
(
(
min_num_tokens
==
max_num_tokens
and
max_num_tokens
in
self
.
graphs
)
(
min_num_tokens
==
max_num_tokens
and
max_num_tokens
in
self
.
graphs
)
if
self
.
disable_padding
if
self
.
disable_padding
...
@@ -360,7 +363,7 @@ class CudaGraphRunner:
...
@@ -360,7 +363,7 @@ class CudaGraphRunner:
encoder_lens
=
encoder_lens
,
encoder_lens
=
encoder_lens
,
return_logprob
=
False
,
return_logprob
=
False
,
positions
=
positions
,
positions
=
positions
,
global_num_tokens
=
global_num_tokens
,
global_num_tokens
_cpu
=
global_num_tokens
,
gathered_buffer
=
gathered_buffer
,
gathered_buffer
=
gathered_buffer
,
mrope_positions
=
mrope_positions
,
mrope_positions
=
mrope_positions
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
...
@@ -430,7 +433,7 @@ class CudaGraphRunner:
...
@@ -430,7 +433,7 @@ class CudaGraphRunner:
# Pad
# Pad
if
self
.
enable_dp_attention
:
if
self
.
enable_dp_attention
:
index
=
bisect
.
bisect_left
(
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
max
(
forward_batch
.
global_num_tokens
)
self
.
capture_bs
,
max
(
forward_batch
.
global_num_tokens
_cpu
)
)
)
else
:
else
:
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
index
=
bisect
.
bisect_left
(
self
.
capture_bs
,
raw_bs
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
e074d84e
...
@@ -190,7 +190,16 @@ class ForwardBatch:
...
@@ -190,7 +190,16 @@ class ForwardBatch:
attn_backend
:
AttentionBackend
=
None
attn_backend
:
AttentionBackend
=
None
# For DP attention
# For DP attention
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_cpu
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_gpu
:
Optional
[
torch
.
Tensor
]
=
None
# Has to be None when cuda graph is captured.
global_num_tokens_for_logprob_cpu
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_for_logprob_gpu
:
Optional
[
torch
.
Tensor
]
=
None
# for extend, local start pos and num tokens is different in logits processor
# this will be computed in get_dp_local_info
# this will be recomputed in LogitsMetadata.from_forward_batch
dp_local_start_pos
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
dp_local_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
can_run_dp_cuda_graph
:
bool
=
False
can_run_dp_cuda_graph
:
bool
=
False
...
@@ -234,7 +243,6 @@ class ForwardBatch:
...
@@ -234,7 +243,6 @@ class ForwardBatch:
return_logprob
=
batch
.
return_logprob
,
return_logprob
=
batch
.
return_logprob
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
global_num_tokens
=
batch
.
global_num_tokens
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
lora_paths
=
batch
.
lora_paths
,
lora_paths
=
batch
.
lora_paths
,
sampling_info
=
batch
.
sampling_info
,
sampling_info
=
batch
.
sampling_info
,
...
@@ -248,8 +256,9 @@ class ForwardBatch:
...
@@ -248,8 +256,9 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu
=
extend_input_logprob_token_ids_gpu
,
extend_input_logprob_token_ids_gpu
=
extend_input_logprob_token_ids_gpu
,
)
)
if
ret
.
global_num_tokens
is
not
None
:
if
batch
.
global_num_tokens
is
not
None
:
max_len
=
max
(
ret
.
global_num_tokens
)
ret
.
global_num_tokens_cpu
=
batch
.
global_num_tokens
max_len
=
max
(
ret
.
global_num_tokens_cpu
)
ret
.
gathered_buffer
=
torch
.
zeros
(
ret
.
gathered_buffer
=
torch
.
zeros
(
(
max_len
*
model_runner
.
tp_size
,
model_runner
.
model_config
.
hidden_size
),
(
max_len
*
model_runner
.
tp_size
,
model_runner
.
model_config
.
hidden_size
),
dtype
=
model_runner
.
dtype
,
dtype
=
model_runner
.
dtype
,
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
e074d84e
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# ==============================================================================
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""
"""ModelRunner runs the forward passes of the models."""
import
collections
import
datetime
import
datetime
import
gc
import
gc
import
json
import
json
...
@@ -269,6 +268,7 @@ class ModelRunner:
...
@@ -269,6 +268,7 @@ class ModelRunner:
elif
self
.
device
==
"cpu"
:
elif
self
.
device
==
"cpu"
:
backend
=
"gloo"
backend
=
"gloo"
before_avail_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
if
not
self
.
server_args
.
enable_p2p_check
:
if
not
self
.
server_args
.
enable_p2p_check
:
monkey_patch_p2p_access_check
()
monkey_patch_p2p_access_check
()
...
@@ -299,20 +299,24 @@ class ModelRunner:
...
@@ -299,20 +299,24 @@ class ModelRunner:
min_per_gpu_memory
=
get_available_gpu_memory
(
min_per_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
device
,
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
)
local_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
self
.
tp_group
=
get_tp_group
()
self
.
tp_group
=
get_tp_group
()
self
.
attention_tp_group
=
get_attention_tp_group
()
self
.
attention_tp_group
=
get_attention_tp_group
()
# Check memory for tensor parallelism
# Check memory for tensor parallelism
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
local_gpu_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
if
min_per_gpu_memory
<
local_gpu_memory
*
0.9
:
if
min_per_gpu_memory
<
local_gpu_memory
*
0.9
:
raise
ValueError
(
raise
ValueError
(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
)
logger
.
info
(
f
"Init torch distributed ends. mem usage=
{
(
before_avail_memory
-
local_gpu_memory
):.
2
f
}
GB"
)
return
min_per_gpu_memory
return
min_per_gpu_memory
def
load_model
(
self
):
def
load_model
(
self
):
before_avail_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
logger
.
info
(
f
"Load weight begin. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"Load weight begin. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
)
)
...
@@ -382,11 +386,13 @@ class ModelRunner:
...
@@ -382,11 +386,13 @@ class ModelRunner:
)
)
self
.
dtype
=
self
.
model_config
.
dtype
self
.
dtype
=
self
.
model_config
.
dtype
after_avail_memory
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
logger
.
info
(
f
"Load weight end. "
f
"Load weight end. "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
f
"type=
{
type
(
self
.
model
).
__name__
}
, "
f
"dtype=
{
self
.
dtype
}
, "
f
"dtype=
{
self
.
dtype
}
, "
f
"avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
after_avail_memory
:.
2
f
}
GB, "
f
"mem usage=
{
(
before_avail_memory
-
after_avail_memory
):.
2
f
}
GB."
)
)
def
update_weights_from_disk
(
def
update_weights_from_disk
(
...
@@ -785,12 +791,15 @@ class ModelRunner:
...
@@ -785,12 +791,15 @@ class ModelRunner:
return
return
tic
=
time
.
time
()
tic
=
time
.
time
()
before_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
logger
.
info
(
f
"Capture cuda graph begin. This can take up to several minutes. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
:.
2
f
}
GB"
f
"Capture cuda graph begin. This can take up to several minutes. avail mem=
{
before_mem
:.
2
f
}
GB"
)
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
)
after_mem
=
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
)
logger
.
info
(
logger
.
info
(
f
"Capture cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
f
"Capture cuda graph end. Time elapsed:
{
time
.
time
()
-
tic
:.
2
f
}
s. "
f
"avail mem=
{
after_mem
:.
2
f
}
GB. mem usage=
{
(
before_mem
-
after_mem
):.
2
f
}
GB."
)
)
def
apply_torch_tp
(
self
):
def
apply_torch_tp
(
self
):
...
@@ -806,8 +815,12 @@ class ModelRunner:
...
@@ -806,8 +815,12 @@ class ModelRunner:
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
)
def
forward_extend
(
self
,
forward_batch
:
ForwardBatch
):
def
forward_extend
(
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
self
,
forward_batch
:
ForwardBatch
,
skip_attn_backend_init
:
bool
=
False
):
if
not
skip_attn_backend_init
:
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
if
self
.
is_generation
:
if
self
.
is_generation
:
if
forward_batch
.
input_embeds
is
None
:
if
forward_batch
.
input_embeds
is
None
:
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
e074d84e
...
@@ -818,8 +818,8 @@ def all_gather(
...
@@ -818,8 +818,8 @@ def all_gather(
if
world_size
==
1
:
if
world_size
==
1
:
return
input_tensor
return
input_tensor
all_lens
=
forward_batch
.
global_num_tokens
all_lens
=
forward_batch
.
global_num_tokens
_cpu
max_len
=
max
(
forward_batch
.
global_num_tokens
)
max_len
=
max
(
forward_batch
.
global_num_tokens
_cpu
)
padded_tensor
=
torch
.
nn
.
functional
.
pad
(
padded_tensor
=
torch
.
nn
.
functional
.
pad
(
input_tensor
,
(
0
,
0
,
0
,
max_len
-
input_tensor
.
shape
[
0
])
input_tensor
,
(
0
,
0
,
0
,
max_len
-
input_tensor
.
shape
[
0
])
...
...
python/sglang/srt/utils.py
View file @
e074d84e
...
@@ -741,13 +741,6 @@ def pytorch_profile(name, func, *args, data_size=-1):
...
@@ -741,13 +741,6 @@ def pytorch_profile(name, func, *args, data_size=-1):
return
result
return
result
def
first_rank_print
(
*
args
,
**
kwargs
):
if
torch
.
cuda
.
current_device
()
==
0
:
print
(
*
args
,
**
kwargs
)
else
:
pass
def
get_zmq_socket
(
def
get_zmq_socket
(
context
:
zmq
.
Context
,
socket_type
:
zmq
.
SocketType
,
endpoint
:
str
,
bind
:
bool
context
:
zmq
.
Context
,
socket_type
:
zmq
.
SocketType
,
endpoint
:
str
,
bind
:
bool
):
):
...
@@ -1177,6 +1170,11 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
...
@@ -1177,6 +1170,11 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
return
value
.
lower
()
in
(
"true"
,
"1"
)
return
value
.
lower
()
in
(
"true"
,
"1"
)
@
lru_cache
(
maxsize
=
2
)
def
disable_request_logging
()
->
bool
:
return
get_bool_env_var
(
"SGLANG_DISABLE_REQUEST_LOGGING"
)
@
lru_cache
(
maxsize
=
8
)
@
lru_cache
(
maxsize
=
8
)
def
_cuda_device_count_stateless
(
cuda_visible_devices
:
Optional
[
str
]
=
None
)
->
int
:
def
_cuda_device_count_stateless
(
cuda_visible_devices
:
Optional
[
str
]
=
None
)
->
int
:
# Note: cuda_visible_devices is not used, but we keep it as an argument for
# Note: cuda_visible_devices is not used, but we keep it as an argument for
...
...
sgl-kernel/setup.py
View file @
e074d84e
...
@@ -85,6 +85,7 @@ nvcc_flags = [
...
@@ -85,6 +85,7 @@ nvcc_flags = [
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
,
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1"
,
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
,
"-DCUTLASS_DEBUG_TRACE_LEVEL=0"
,
"--ptxas-options=-v"
,
"--ptxas-options=-v"
,
"--expt-relaxed-constexpr"
,
"-Xcompiler=-Wconversion"
,
"-Xcompiler=-Wconversion"
,
"-Xcompiler=-fno-strict-aliasing"
,
"-Xcompiler=-fno-strict-aliasing"
,
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment