Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e69a2190
"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "395efbf2e9c4dda3644a46d797e64adfc838a7df"
Unverified
Commit
e69a2190
authored
Apr 22, 2025
by
Liangsheng Yin
Committed by
GitHub
Apr 21, 2025
Browse files
Enhance GPU memory settings (#5604)
parent
bf98d2e3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
31 deletions
+59
-31
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+12
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+18
-7
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+15
-24
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+14
-0
No files found.
python/sglang/srt/mem_cache/memory_pool.py
View file @
e69a2190
...
@@ -448,6 +448,18 @@ class MLATokenToKVPool(KVCache):
...
@@ -448,6 +448,18 @@ class MLATokenToKVPool(KVCache):
self
.
layer_transfer_counter
=
None
self
.
layer_transfer_counter
=
None
self
.
page_size
=
page_size
self
.
page_size
=
page_size
kv_size
=
self
.
get_kv_size_bytes
()
logger
.
info
(
f
"KV Cache is allocated. #tokens:
{
size
}
, KV size:
{
kv_size
/
GB
:.
2
f
}
GB"
)
def
get_kv_size_bytes
(
self
):
assert
hasattr
(
self
,
"kv_buffer"
)
kv_size_bytes
=
0
for
kv_cache
in
self
.
kv_buffer
:
kv_size_bytes
+=
np
.
prod
(
kv_cache
.
shape
)
*
kv_cache
.
dtype
.
itemsize
return
kv_size_bytes
# for disagg
# for disagg
def
get_contiguous_buf_infos
(
self
):
def
get_contiguous_buf_infos
(
self
):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
e69a2190
...
@@ -35,7 +35,11 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -35,7 +35,11 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode
,
ForwardMode
,
)
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.utils
import
get_available_gpu_memory
,
is_hip
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_whatever_gpu_memory_capacity
,
is_hip
,
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
@@ -132,6 +136,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
...
@@ -132,6 +136,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
if
_is_hip
:
if
_is_hip
:
capture_bs
+=
list
(
range
(
160
,
257
,
8
))
capture_bs
+=
list
(
range
(
160
,
257
,
8
))
gpu_mem
=
get_whatever_gpu_memory_capacity
()
/
1024
if
gpu_mem
is
not
None
and
gpu_mem
>
120
:
capture_bs
+=
list
(
range
(
160
,
256
,
8
))
if
max
(
capture_bs
)
>
model_runner
.
req_to_token_pool
.
size
:
if
max
(
capture_bs
)
>
model_runner
.
req_to_token_pool
.
size
:
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs.
# is very small. We add more values here to make sure we capture the maximum bs.
...
@@ -140,12 +149,13 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
...
@@ -140,12 +149,13 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
]
]
capture_bs
=
list
(
sorted
(
set
(
capture_bs
)))
capture_bs
=
list
(
sorted
(
set
(
capture_bs
)))
capture_bs
=
[
bs
assert
len
(
capture_bs
)
>
0
and
capture_bs
[
0
]
>
0
for
bs
in
capture_bs
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
model_runner
.
req_to_token_pool
.
size
]
if
bs
<=
model_runner
.
req_to_token_pool
.
size
and
bs
<=
server_args
.
cuda_graph_max_bs
if
server_args
.
cuda_graph_max_bs
:
]
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
cuda_graph_max_bs
]
compile_bs
=
(
compile_bs
=
(
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
torch_compile_max_bs
]
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
torch_compile_max_bs
]
if
server_args
.
enable_torch_compile
if
server_args
.
enable_torch_compile
...
@@ -186,6 +196,7 @@ class CudaGraphRunner:
...
@@ -186,6 +196,7 @@ class CudaGraphRunner:
# Batch sizes to capture
# Batch sizes to capture
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
self
.
capture_bs
,
self
.
compile_bs
=
get_batch_sizes_to_capture
(
model_runner
)
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_forward_mode
=
ForwardMode
.
DECODE
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
capture_hidden_mode
=
CaptureHiddenMode
.
NULL
self
.
num_tokens_per_bs
=
1
self
.
num_tokens_per_bs
=
1
...
...
python/sglang/srt/server_args.py
View file @
e69a2190
...
@@ -26,11 +26,8 @@ from sglang.srt.hf_transformers_utils import check_gguf_file
...
@@ -26,11 +26,8 @@ from sglang.srt.hf_transformers_utils import check_gguf_file
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.reasoning_parser
import
ReasoningParser
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
configure_ipv6
,
configure_ipv6
,
get_amdgpu_memory_capacity
,
get_device
,
get_device
,
get_hpu_memory_capacity
,
get_whatever_gpu_memory_capacity
,
get_nvgpu_memory_capacity
,
is_cuda
,
is_flashinfer_available
,
is_flashinfer_available
,
is_hip
,
is_hip
,
is_port_available
,
is_port_available
,
...
@@ -221,28 +218,24 @@ class ServerArgs:
...
@@ -221,28 +218,24 @@ class ServerArgs:
if
self
.
random_seed
is
None
:
if
self
.
random_seed
is
None
:
self
.
random_seed
=
random
.
randint
(
0
,
1
<<
30
)
self
.
random_seed
=
random
.
randint
(
0
,
1
<<
30
)
if
is_cuda
():
gpu_mem
=
get_whatever_gpu_memory_capacity
(
self
.
device
)
gpu_mem
=
get_nvgpu_memory_capacity
()
elif
is_hip
():
gpu_mem
=
get_amdgpu_memory_capacity
()
elif
self
.
device
==
"hpu"
:
gpu_mem
=
get_hpu_memory_capacity
()
else
:
# GPU memory is not known yet or no GPU is available.
gpu_mem
=
None
# Set mem fraction static, which depends on the tensor parallelism size
# Set mem fraction static, which depends on the tensor parallelism size
if
self
.
mem_fraction_static
is
None
:
if
self
.
mem_fraction_static
is
None
:
if
self
.
tp_size
>=
16
:
if
gpu_mem
<=
81920
:
self
.
mem_fraction_static
=
0.79
if
self
.
tp_size
>=
16
:
elif
self
.
tp_size
>=
8
:
self
.
mem_fraction_static
=
0.79
self
.
mem_fraction_static
=
0.81
elif
self
.
tp_size
>=
8
:
elif
self
.
tp_size
>=
4
:
self
.
mem_fraction_static
=
0.81
self
.
mem_fraction_static
=
0.85
elif
self
.
tp_size
>=
4
:
elif
self
.
tp_size
>=
2
:
self
.
mem_fraction_static
=
0.85
self
.
mem_fraction_static
=
0.87
elif
self
.
tp_size
>=
2
:
self
.
mem_fraction_static
=
0.87
else
:
self
.
mem_fraction_static
=
0.88
else
:
else
:
self
.
mem_fraction_static
=
0.88
# FIXME: more fine grained auto-selection polices
self
.
mem_fraction_static
=
(
gpu_mem
-
1024
*
13
)
/
gpu_mem
# Set chunked prefill size, which depends on the gpu memory capacity
# Set chunked prefill size, which depends on the gpu memory capacity
if
self
.
chunked_prefill_size
is
None
:
if
self
.
chunked_prefill_size
is
None
:
...
@@ -271,8 +264,6 @@ class ServerArgs:
...
@@ -271,8 +264,6 @@ class ServerArgs:
self
.
cuda_graph_max_bs
=
8
self
.
cuda_graph_max_bs
=
8
else
:
else
:
self
.
cuda_graph_max_bs
=
80
self
.
cuda_graph_max_bs
=
80
else
:
self
.
cuda_graph_max_bs
=
160
# Set kernel backends for hpu device
# Set kernel backends for hpu device
if
self
.
device
==
"hpu"
:
if
self
.
device
==
"hpu"
:
...
...
python/sglang/srt/utils.py
View file @
e69a2190
...
@@ -1160,6 +1160,20 @@ def get_hpu_memory_capacity():
...
@@ -1160,6 +1160,20 @@ def get_hpu_memory_capacity():
)
)
def
get_whatever_gpu_memory_capacity
(
device
:
str
=
None
):
if
is_cuda
():
gpu_mem
=
get_nvgpu_memory_capacity
()
elif
is_hip
():
gpu_mem
=
get_amdgpu_memory_capacity
()
elif
device
==
"hpu"
:
gpu_mem
=
get_hpu_memory_capacity
()
else
:
# GPU memory is not known yet or no GPU is available.
gpu_mem
=
None
return
gpu_mem
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# Copy from pytorch and OpenRLHF to allow creating multiple main groups.
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
# https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/utils/distributed_util.py
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment