Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9f662501
Unverified
Commit
9f662501
authored
Aug 08, 2024
by
Ying Sheng
Committed by
GitHub
Aug 08, 2024
Browse files
Move torch.compile configs into cuda_graph_runner.py (#993)
parent
ab787594
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
19 deletions
+15
-19
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+15
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+0
-5
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+0
-13
python/sglang/utils.py
python/sglang/utils.py
+0
-1
No files found.
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
9f662501
...
@@ -71,6 +71,18 @@ def patch_model(
...
@@ -71,6 +71,18 @@ def patch_model(
tp_group
.
ca_comm
=
backup_ca_comm
tp_group
.
ca_comm
=
backup_ca_comm
def
set_torch_compile_config
():
import
torch._dynamo.config
import
torch._inductor.config
torch
.
_inductor
.
config
.
coordinate_descent_tuning
=
True
torch
.
_inductor
.
config
.
triton
.
unique_kernel_names
=
True
torch
.
_inductor
.
config
.
fx_graph_cache
=
True
# Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
1024
class
CudaGraphRunner
:
class
CudaGraphRunner
:
def
__init__
(
self
,
model_runner
,
max_batch_size_to_capture
,
use_torch_compile
):
def
__init__
(
self
,
model_runner
,
max_batch_size_to_capture
,
use_torch_compile
):
self
.
model_runner
=
model_runner
self
.
model_runner
=
model_runner
...
@@ -112,6 +124,9 @@ class CudaGraphRunner:
...
@@ -112,6 +124,9 @@ class CudaGraphRunner:
self
.
compile_bs
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
]
if
use_torch_compile
else
[]
self
.
compile_bs
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
]
if
use_torch_compile
else
[]
if
use_torch_compile
:
set_torch_compile_config
()
def
can_run
(
self
,
batch_size
):
def
can_run
(
self
,
batch_size
):
return
batch_size
<
self
.
max_bs
return
batch_size
<
self
.
max_bs
...
...
python/sglang/srt/server.py
View file @
9f662501
...
@@ -74,7 +74,6 @@ from sglang.srt.utils import (
...
@@ -74,7 +74,6 @@ from sglang.srt.utils import (
enable_show_time_cost
,
enable_show_time_cost
,
kill_child_process
,
kill_child_process
,
maybe_set_triton_cache_manager
,
maybe_set_triton_cache_manager
,
set_torch_compile_config
,
set_ulimit
,
set_ulimit
,
)
)
from
sglang.utils
import
get_exception_traceback
from
sglang.utils
import
get_exception_traceback
...
@@ -347,10 +346,6 @@ def _set_envs_and_config(server_args: ServerArgs):
...
@@ -347,10 +346,6 @@ def _set_envs_and_config(server_args: ServerArgs):
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
maybe_set_triton_cache_manager
()
maybe_set_triton_cache_manager
()
# Set torch compile config
if
server_args
.
enable_torch_compile
:
set_torch_compile_config
()
# Set global chat template
# Set global chat template
if
server_args
.
chat_template
:
if
server_args
.
chat_template
:
# TODO: replace this with huggingface transformers template
# TODO: replace this with huggingface transformers template
...
...
python/sglang/srt/utils.py
View file @
9f662501
...
@@ -622,19 +622,6 @@ def receive_addrs(model_port_args, server_args):
...
@@ -622,19 +622,6 @@ def receive_addrs(model_port_args, server_args):
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
def
set_torch_compile_config
():
# The following configurations are for torch compile optimizations
import
torch._dynamo.config
import
torch._inductor.config
torch
.
_inductor
.
config
.
coordinate_descent_tuning
=
True
torch
.
_inductor
.
config
.
triton
.
unique_kernel_names
=
True
torch
.
_inductor
.
config
.
fx_graph_cache
=
True
# Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
256
def
set_ulimit
(
target_soft_limit
=
65535
):
def
set_ulimit
(
target_soft_limit
=
65535
):
resource_type
=
resource
.
RLIMIT_NOFILE
resource_type
=
resource
.
RLIMIT_NOFILE
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
current_soft
,
current_hard
=
resource
.
getrlimit
(
resource_type
)
...
...
python/sglang/utils.py
View file @
9f662501
...
@@ -6,7 +6,6 @@ import json
...
@@ -6,7 +6,6 @@ import json
import
logging
import
logging
import
signal
import
signal
import
sys
import
sys
import
threading
import
traceback
import
traceback
import
urllib.request
import
urllib.request
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment