Unverified Commit 6cc38b2b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Add more type annotations (#1237)

parent 1ece2cda
...@@ -17,6 +17,7 @@ limitations under the License. ...@@ -17,6 +17,7 @@ limitations under the License.
import bisect import bisect
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable, List
import torch import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer import BatchDecodeWithPagedKVCacheWrapper
...@@ -53,12 +54,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): ...@@ -53,12 +54,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
@contextmanager @contextmanager
def patch_model( def patch_model(
model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator" model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
): ):
backup_ca_comm = None backup_ca_comm = None
try: try:
if use_compile: if enable_compile:
_to_torch(model) _to_torch(model)
monkey_patch_vllm_all_gather() monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm backup_ca_comm = tp_group.ca_comm
...@@ -67,7 +68,7 @@ def patch_model( ...@@ -67,7 +68,7 @@ def patch_model(
else: else:
yield model.forward yield model.forward
finally: finally:
if use_compile: if enable_compile:
_to_torch(model, reverse=True) _to_torch(model, reverse=True)
monkey_patch_vllm_all_gather(reverse=True) monkey_patch_vllm_all_gather(reverse=True)
tp_group.ca_comm = backup_ca_comm tp_group.ca_comm = backup_ca_comm
...@@ -88,7 +89,7 @@ def set_torch_compile_config(): ...@@ -88,7 +89,7 @@ def set_torch_compile_config():
class CudaGraphRunner: class CudaGraphRunner:
def __init__( def __init__(
self, self,
model_runner, model_runner: "ModelRunner",
max_batch_size_to_capture: int, max_batch_size_to_capture: int,
use_torch_compile: bool, use_torch_compile: bool,
disable_padding: bool, disable_padding: bool,
...@@ -154,13 +155,13 @@ class CudaGraphRunner: ...@@ -154,13 +155,13 @@ class CudaGraphRunner:
if use_torch_compile: if use_torch_compile:
set_torch_compile_config() set_torch_compile_config()
def can_run(self, batch_size): def can_run(self, batch_size: int):
if self.disable_padding: if self.disable_padding:
return batch_size in self.graphs return batch_size in self.graphs
else: else:
return batch_size <= self.max_bs return batch_size <= self.max_bs
def capture(self, batch_size_list): def capture(self, batch_size_list: List[int]):
self.batch_size_list = batch_size_list self.batch_size_list = batch_size_list
with graph_capture() as graph_capture_context: with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream self.stream = graph_capture_context.stream
...@@ -181,7 +182,7 @@ class CudaGraphRunner: ...@@ -181,7 +182,7 @@ class CudaGraphRunner:
self.output_buffers[bs] = output_buffers self.output_buffers[bs] = output_buffers
self.flashinfer_handlers[bs] = flashinfer_handler self.flashinfer_handlers[bs] = flashinfer_handler
def capture_one_batch_size(self, bs, forward): def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
stream = self.stream stream = self.stream
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment