Unverified Commit 7484e1fc authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Add cache to cuda get_device_capability (#19436)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent a2142f01
......@@ -6,7 +6,7 @@ pynvml. However, it should not initialize cuda context.
import os
from datetime import timedelta
from functools import wraps
from functools import cache, wraps
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union
import torch
......@@ -389,6 +389,7 @@ class CudaPlatformBase(Platform):
class NvmlCudaPlatform(CudaPlatformBase):
@classmethod
@cache
@with_nvml_context
def get_device_capability(cls,
device_id: int = 0
......@@ -486,6 +487,7 @@ class NvmlCudaPlatform(CudaPlatformBase):
class NonNvmlCudaPlatform(CudaPlatformBase):
@classmethod
@cache
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
major, minor = torch.cuda.get_device_capability(device_id)
return DeviceCapability(major=major, minor=minor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment