"vscode:/vscode.git/clone" did not exist on "473d5e0a4c4e4735f1c9dc9d783e0374328cca9a"
Unverified Commit 6cc38b2b authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Add more type annotations (#1237)

parent 1ece2cda
......@@ -17,6 +17,7 @@ limitations under the License.
import bisect
from contextlib import contextmanager
from typing import Callable, List
import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
......@@ -53,12 +54,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
@contextmanager
def patch_model(
model: torch.nn.Module, use_compile: bool, tp_group: "GroupCoordinator"
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
):
backup_ca_comm = None
try:
if use_compile:
if enable_compile:
_to_torch(model)
monkey_patch_vllm_all_gather()
backup_ca_comm = tp_group.ca_comm
......@@ -67,7 +68,7 @@ def patch_model(
else:
yield model.forward
finally:
if use_compile:
if enable_compile:
_to_torch(model, reverse=True)
monkey_patch_vllm_all_gather(reverse=True)
tp_group.ca_comm = backup_ca_comm
......@@ -88,7 +89,7 @@ def set_torch_compile_config():
class CudaGraphRunner:
def __init__(
self,
model_runner,
model_runner: "ModelRunner",
max_batch_size_to_capture: int,
use_torch_compile: bool,
disable_padding: bool,
......@@ -154,13 +155,13 @@ class CudaGraphRunner:
if use_torch_compile:
set_torch_compile_config()
def can_run(self, batch_size):
def can_run(self, batch_size: int):
if self.disable_padding:
return batch_size in self.graphs
else:
return batch_size <= self.max_bs
def capture(self, batch_size_list):
def capture(self, batch_size_list: List[int]):
self.batch_size_list = batch_size_list
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
......@@ -181,7 +182,7 @@ class CudaGraphRunner:
self.output_buffers[bs] = output_buffers
self.flashinfer_handlers[bs] = flashinfer_handler
def capture_one_batch_size(self, bs, forward):
def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph()
stream = self.stream
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment