"src/vscode:/vscode.git/clone" did not exist on "a75846379afc0557352be7df76780c3ac4aa113e"
Unverified Commit b01df48c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Fix] Adjust default chunked prefill size and cuda graph max bs according to...

[Fix] Adjust default chunked prefill size and cuda graph max bs according to GPU memory capacity (#2044)
parent c29b98e0
...@@ -22,7 +22,12 @@ import random ...@@ -22,7 +22,12 @@ import random
import tempfile import tempfile
from typing import List, Optional from typing import List, Optional
from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available from sglang.srt.utils import (
get_gpu_memory_capacity,
is_flashinfer_available,
is_ipv6,
is_port_available,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -143,6 +148,9 @@ class ServerArgs: ...@@ -143,6 +148,9 @@ class ServerArgs:
# Disable chunked prefill # Disable chunked prefill
self.chunked_prefill_size = None self.chunked_prefill_size = None
if self.random_seed is None:
self.random_seed = random.randint(0, 1 << 30)
# Mem fraction depends on the tensor parallelism size # Mem fraction depends on the tensor parallelism size
if self.mem_fraction_static is None: if self.mem_fraction_static is None:
if self.tp_size >= 16: if self.tp_size >= 16:
...@@ -156,8 +164,14 @@ class ServerArgs: ...@@ -156,8 +164,14 @@ class ServerArgs:
else: else:
self.mem_fraction_static = 0.88 self.mem_fraction_static = 0.88
if self.random_seed is None: # Adjust for GPUs with small memory capacities
self.random_seed = random.randint(0, 1 << 30) gpu_mem = get_gpu_memory_capacity()
if gpu_mem < 25000:
logger.warning(
"Automatically adjust --chunked-prefill-size for small GPUs."
)
self.chunked_prefill_size //= 4 # make it 2048
self.cuda_graph_max_bs = 4
# Deprecation warnings # Deprecation warnings
if self.disable_flashinfer: if self.disable_flashinfer:
......
...@@ -27,6 +27,7 @@ import resource ...@@ -27,6 +27,7 @@ import resource
import shutil import shutil
import signal import signal
import socket import socket
import subprocess
import tempfile import tempfile
import time import time
import warnings import warnings
...@@ -791,3 +792,35 @@ def add_prometheus_middleware(app): ...@@ -791,3 +792,35 @@ def add_prometheus_middleware(app):
# Workaround for 307 Redirect for /metrics # Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$") metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route) app.routes.append(metrics_route)
def get_gpu_memory_capacity():
try:
# Run nvidia-smi and capture the output
result = subprocess.run(
["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"nvidia-smi error: {result.stderr.strip()}")
# Parse the output to extract memory values
memory_values = [
float(mem)
for mem in result.stdout.strip().split("\n")
if re.match(r"^\d+(\.\d+)?$", mem.strip())
]
if not memory_values:
raise ValueError("No GPU memory values found.")
# Return the minimum memory value
return min(memory_values)
except FileNotFoundError:
raise RuntimeError(
"nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment