Unverified Commit ffb3d553 authored by Xinyu Chen's avatar Xinyu Chen Committed by GitHub
Browse files

[Model Runner V2] Init cuda graph pool when necessary (#33217)


Signed-off-by: default avatarXinyu Chen <xinyu1.chen@intel.com>
parent fa7e0bfa
...@@ -45,7 +45,9 @@ class CudaGraphManager: ...@@ -45,7 +45,9 @@ class CudaGraphManager:
) )
self.graphs: dict[int, torch.cuda.CUDAGraph] = {} self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle() self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None self.hidden_states: torch.Tensor | None = None
def needs_capture(self) -> bool: def needs_capture(self) -> bool:
......
...@@ -44,7 +44,9 @@ class EagleCudaGraphManager: ...@@ -44,7 +44,9 @@ class EagleCudaGraphManager:
) )
self.graphs: dict[int, torch.cuda.CUDAGraph] = {} self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle() self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle()
def get_cudagraph_size(self, num_tokens: int) -> int | None: def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens) return self.cudagraph_sizes.get(num_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment