Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6cc38b2b
Unverified
Commit
6cc38b2b
authored
Aug 28, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 28, 2024
Browse files
[Minor] Add more type annotations (#1237)
parent
1ece2cda
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
7 deletions
+8
-7
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+8
-7
No files found.
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
6cc38b2b
...
@@ -17,6 +17,7 @@ limitations under the License.
...
@@ -17,6 +17,7 @@ limitations under the License.
import
bisect
import
bisect
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Callable
,
List
import
torch
import
torch
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
...
@@ -53,12 +54,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
...
@@ -53,12 +54,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
@
contextmanager
@
contextmanager
def
patch_model
(
def
patch_model
(
model
:
torch
.
nn
.
Module
,
us
e_compile
:
bool
,
tp_group
:
"GroupCoordinator"
model
:
torch
.
nn
.
Module
,
enabl
e_compile
:
bool
,
tp_group
:
"GroupCoordinator"
):
):
backup_ca_comm
=
None
backup_ca_comm
=
None
try
:
try
:
if
us
e_compile
:
if
enabl
e_compile
:
_to_torch
(
model
)
_to_torch
(
model
)
monkey_patch_vllm_all_gather
()
monkey_patch_vllm_all_gather
()
backup_ca_comm
=
tp_group
.
ca_comm
backup_ca_comm
=
tp_group
.
ca_comm
...
@@ -67,7 +68,7 @@ def patch_model(
...
@@ -67,7 +68,7 @@ def patch_model(
else
:
else
:
yield
model
.
forward
yield
model
.
forward
finally
:
finally
:
if
us
e_compile
:
if
enabl
e_compile
:
_to_torch
(
model
,
reverse
=
True
)
_to_torch
(
model
,
reverse
=
True
)
monkey_patch_vllm_all_gather
(
reverse
=
True
)
monkey_patch_vllm_all_gather
(
reverse
=
True
)
tp_group
.
ca_comm
=
backup_ca_comm
tp_group
.
ca_comm
=
backup_ca_comm
...
@@ -88,7 +89,7 @@ def set_torch_compile_config():
...
@@ -88,7 +89,7 @@ def set_torch_compile_config():
class
CudaGraphRunner
:
class
CudaGraphRunner
:
def
__init__
(
def
__init__
(
self
,
self
,
model_runner
,
model_runner
:
"ModelRunner"
,
max_batch_size_to_capture
:
int
,
max_batch_size_to_capture
:
int
,
use_torch_compile
:
bool
,
use_torch_compile
:
bool
,
disable_padding
:
bool
,
disable_padding
:
bool
,
...
@@ -154,13 +155,13 @@ class CudaGraphRunner:
...
@@ -154,13 +155,13 @@ class CudaGraphRunner:
if
use_torch_compile
:
if
use_torch_compile
:
set_torch_compile_config
()
set_torch_compile_config
()
def
can_run
(
self
,
batch_size
):
def
can_run
(
self
,
batch_size
:
int
):
if
self
.
disable_padding
:
if
self
.
disable_padding
:
return
batch_size
in
self
.
graphs
return
batch_size
in
self
.
graphs
else
:
else
:
return
batch_size
<=
self
.
max_bs
return
batch_size
<=
self
.
max_bs
def
capture
(
self
,
batch_size_list
):
def
capture
(
self
,
batch_size_list
:
List
[
int
]
):
self
.
batch_size_list
=
batch_size_list
self
.
batch_size_list
=
batch_size_list
with
graph_capture
()
as
graph_capture_context
:
with
graph_capture
()
as
graph_capture_context
:
self
.
stream
=
graph_capture_context
.
stream
self
.
stream
=
graph_capture_context
.
stream
...
@@ -181,7 +182,7 @@ class CudaGraphRunner:
...
@@ -181,7 +182,7 @@ class CudaGraphRunner:
self
.
output_buffers
[
bs
]
=
output_buffers
self
.
output_buffers
[
bs
]
=
output_buffers
self
.
flashinfer_handlers
[
bs
]
=
flashinfer_handler
self
.
flashinfer_handlers
[
bs
]
=
flashinfer_handler
def
capture_one_batch_size
(
self
,
bs
,
forward
):
def
capture_one_batch_size
(
self
,
bs
:
int
,
forward
:
Callable
):
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
stream
=
self
.
stream
stream
=
self
.
stream
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment