Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a0e58740
"vscode:/vscode.git/clone" did not exist on "d96cbacacdb1e0a49ab436c3f647d002e0f411bf"
Unverified
Commit
a0e58740
authored
Nov 27, 2024
by
Lianmin Zheng
Committed by
GitHub
Nov 27, 2024
Browse files
Use an env var SGLANG_SET_CPU_AFFINITY to set cpu affinity; turn it off by default (#2217)
parent
37c8a576
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
23 additions
and
20 deletions
+23
-20
python/sglang/bench_one_batch_server.py
python/sglang/bench_one_batch_server.py
+2
-2
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+2
-6
python/sglang/srt/layers/attention/flashinfer_backend.py
python/sglang/srt/layers/attention/flashinfer_backend.py
+3
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+5
-3
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+9
-4
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+2
-2
No files found.
python/sglang/bench_one_batch_server.py
View file @
a0e58740
...
...
@@ -5,9 +5,9 @@ This script launches a server and uses the HTTP interface.
It accepts server arguments (the same as launch_server.py) and benchmark arguments (e.g., batch size, input lengths).
Usage:
python3 -m sglang.bench_
server_latency
--model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_
one_batch_server
--model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
python3 -m sglang.bench_
server_latency
--model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
python3 -m sglang.bench_
one_batch_server
--model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
"""
import
argparse
...
...
python/sglang/srt/configs/model_config.py
View file @
a0e58740
...
...
@@ -14,13 +14,13 @@
import
json
import
logging
import
os
from
enum
import
IntEnum
,
auto
from
typing
import
List
,
Optional
from
transformers
import
PretrainedConfig
from
sglang.srt.hf_transformers_utils
import
get_config
,
get_context_length
from
sglang.srt.utils
import
get_bool_env_var
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -59,13 +59,9 @@ class ModelConfig:
# Derive context length
derived_context_len
=
get_context_length
(
self
.
hf_text_config
)
allow_long_context
=
os
.
environ
.
get
(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"
,
None
)
if
context_length
is
not
None
:
if
context_length
>
derived_context_len
:
if
allow_long_context
:
if
get_bool_env_var
(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"
)
:
logger
.
warning
(
f
"Warning: User-specified context_length (
{
context_length
}
) is greater than the derived context_length (
{
derived_context_len
}
). "
f
"This may lead to incorrect model outputs or CUDA errors."
...
...
python/sglang/srt/layers/attention/flashinfer_backend.py
View file @
a0e58740
...
...
@@ -18,7 +18,7 @@ import triton.language as tl
from
sglang.global_config
import
global_config
from
sglang.srt.layers.attention
import
AttentionBackend
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
is_flashinfer_available
from
sglang.srt.utils
import
get_bool_env_var
,
is_flashinfer_available
if
TYPE_CHECKING
:
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -47,8 +47,8 @@ class FlashInferAttnBackend(AttentionBackend):
# Parse constants
if
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
in
os
.
environ
:
self
.
decode_use_tensor_cores
=
(
os
.
environ
[
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
].
lower
()
==
"true"
self
.
decode_use_tensor_cores
=
get_bool_env_var
(
"SGLANG_FLASHINFER_USE_TENSOR_CORE"
)
else
:
if
not
_grouped_size_compiled_for_decode_kernels
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
a0e58740
...
...
@@ -71,9 +71,10 @@ from sglang.srt.utils import (
broadcast_pyobj
,
configure_logger
,
crash_on_warnings
,
get_bool_env_var
,
get_zmq_socket
,
gpu_proc_affinity
,
kill_parent_process
,
set_gpu_proc_affinity
,
set_random_seed
,
suppress_other_loggers
,
)
...
...
@@ -82,7 +83,7 @@ from sglang.utils import get_exception_traceback
logger
=
logging
.
getLogger
(
__name__
)
# Test retract decode
test_retract
=
os
.
getenv
(
"SGLANG_TEST_RETRACT"
,
"false"
).
lower
()
==
"true"
test_retract
=
get_bool_env_var
(
"SGLANG_TEST_RETRACT"
)
class
Scheduler
:
...
...
@@ -1405,7 +1406,8 @@ def run_scheduler_process(
pipe_writer
,
):
# set cpu affinity to this gpu process
gpu_proc_affinity
(
server_args
.
tp_size
,
server_args
.
nnodes
,
gpu_id
)
if
get_bool_env_var
(
"SGLANG_SET_CPU_AFFINITY"
):
set_gpu_proc_affinity
(
server_args
.
tp_size
,
server_args
.
nnodes
,
gpu_id
)
# [For Router] if env var "DP_RANK" exist, set dp_rank to the value of the env var
if
dp_rank
is
None
and
"DP_RANK"
in
os
.
environ
:
...
...
python/sglang/srt/utils.py
View file @
a0e58740
...
...
@@ -72,7 +72,7 @@ def is_flashinfer_available():
Check whether flashinfer is available.
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
"""
if
os
.
environ
.
get
(
"SGLANG_IS_FLASHINFER_AVAILABLE"
,
"true"
)
==
"false"
:
if
get_bool_env_var
(
"SGLANG_IS_FLASHINFER_AVAILABLE"
,
default
=
"true"
)
:
return
False
return
torch
.
cuda
.
is_available
()
and
not
is_hip
()
...
...
@@ -626,7 +626,7 @@ def add_api_key_middleware(app, api_key: str):
def
prepare_model_and_tokenizer
(
model_path
:
str
,
tokenizer_path
:
str
):
if
"SGLANG_USE_MODELSCOPE"
in
os
.
environ
:
if
get_bool_env_var
(
"SGLANG_USE_MODELSCOPE"
)
:
if
not
os
.
path
.
exists
(
model_path
):
from
modelscope
import
snapshot_download
...
...
@@ -931,7 +931,7 @@ def get_nvgpu_memory_capacity():
def
crash_on_warnings
():
# Crash on warning if we are running CI tests
return
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
).
lower
()
==
"true"
return
get_bool_env_var
(
"SGLANG_IS_IN_CI"
)
def
get_device_name
(
device_id
:
int
=
0
)
->
str
:
...
...
@@ -990,7 +990,7 @@ def direct_register_custom_op(
my_lib
.
_register_fake
(
op_name
,
fake_impl
)
def
gpu_proc_affinity
(
def
set_
gpu_proc_affinity
(
tp_size
:
int
,
nnodes
:
int
,
gpu_id
:
int
,
...
...
@@ -1022,3 +1022,8 @@ def gpu_proc_affinity(
# set cpu_affinity to current process
p
.
cpu_affinity
(
bind_cpu_ids
)
logger
.
info
(
f
"Process
{
pid
}
gpu_id
{
gpu_id
}
is running on CPUs:
{
p
.
cpu_affinity
()
}
"
)
def
get_bool_env_var
(
name
:
str
,
default
:
str
=
"false"
)
->
bool
:
value
=
os
.
getenv
(
name
,
default
)
return
value
.
lower
()
in
(
"true"
,
"1"
)
python/sglang/test/test_utils.py
View file @
a0e58740
...
...
@@ -22,7 +22,7 @@ from sglang.bench_serving import run_benchmark
from
sglang.global_config
import
global_config
from
sglang.lang.backend.openai
import
OpenAI
from
sglang.lang.backend.runtime_endpoint
import
RuntimeEndpoint
from
sglang.srt.utils
import
kill_child_process
from
sglang.srt.utils
import
get_bool_env_var
,
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.utils
import
get_exception_traceback
...
...
@@ -44,7 +44,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8
def
is_in_ci
():
"""Return whether it is in CI runner."""
return
os
.
getenv
(
"SGLANG_IS_IN_CI"
,
"false"
).
lower
()
==
"true"
return
get_bool_env_var
(
"SGLANG_IS_IN_CI"
)
if
is_in_ci
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment