Unverified Commit f84a472a authored by Sage Moore's avatar Sage Moore Committed by GitHub
Browse files

Suppress benign cuBLAS warning when capturing cudagraphs with DBO (#25596)


Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
parent 54e42b72
...@@ -104,6 +104,7 @@ class UBatchWrapper: ...@@ -104,6 +104,7 @@ class UBatchWrapper:
self.graph_pool = current_platform.get_global_graph_pool() self.graph_pool = current_platform.get_global_graph_pool()
self.sm_control = self._create_sm_control_context(vllm_config) self.sm_control = self._create_sm_control_context(vllm_config)
self.device = device
@staticmethod @staticmethod
def _create_sm_control_context(vllm_config: VllmConfig): def _create_sm_control_context(vllm_config: VllmConfig):
...@@ -168,6 +169,7 @@ class UBatchWrapper: ...@@ -168,6 +169,7 @@ class UBatchWrapper:
@torch.inference_mode() @torch.inference_mode()
def _capture_ubatch_thread(results, ubatch_metadata): def _capture_ubatch_thread(results, ubatch_metadata):
torch.cuda.set_device(self.device)
ubatch_context = ubatch_metadata.context ubatch_context = ubatch_metadata.context
with torch.cuda.stream(ubatch_context.compute_stream): with torch.cuda.stream(ubatch_context.compute_stream):
_ = torch.cuda.current_blas_handle() _ = torch.cuda.current_blas_handle()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment