Unverified Commit 7484e1fc authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Add cache to cuda get_device_capability (#19436)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent a2142f01
...@@ -6,7 +6,7 @@ pynvml. However, it should not initialize cuda context. ...@@ -6,7 +6,7 @@ pynvml. However, it should not initialize cuda context.
import os import os
from datetime import timedelta from datetime import timedelta
from functools import wraps from functools import cache, wraps
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
import torch import torch
...@@ -389,6 +389,7 @@ class CudaPlatformBase(Platform): ...@@ -389,6 +389,7 @@ class CudaPlatformBase(Platform):
class NvmlCudaPlatform(CudaPlatformBase): class NvmlCudaPlatform(CudaPlatformBase):
@classmethod @classmethod
@cache
@with_nvml_context @with_nvml_context
def get_device_capability(cls, def get_device_capability(cls,
device_id: int = 0 device_id: int = 0
...@@ -486,6 +487,7 @@ class NvmlCudaPlatform(CudaPlatformBase): ...@@ -486,6 +487,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
class NonNvmlCudaPlatform(CudaPlatformBase): class NonNvmlCudaPlatform(CudaPlatformBase):
@classmethod @classmethod
@cache
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
major, minor = torch.cuda.get_device_capability(device_id) major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor) return DeviceCapability(major=major, minor=minor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment