Unverified Commit 20478c4d authored by Simon Mo's avatar Simon Mo Committed by GitHub
Browse files

Use lru_cache for some environment detection utils (#3508)

parent 63e8b28a
...@@ -11,7 +11,7 @@ from packaging.version import parse, Version ...@@ -11,7 +11,7 @@ from packaging.version import parse, Version
import psutil import psutil
import torch import torch
import asyncio import asyncio
from functools import partial from functools import partial, lru_cache
from typing import ( from typing import (
Awaitable, Awaitable,
Callable, Callable,
...@@ -120,6 +120,7 @@ def is_hip() -> bool: ...@@ -120,6 +120,7 @@ def is_hip() -> bool:
return torch.version.hip is not None return torch.version.hip is not None
@lru_cache(maxsize=None)
def is_neuron() -> bool: def is_neuron() -> bool:
try: try:
import transformers_neuronx import transformers_neuronx
...@@ -128,6 +129,7 @@ def is_neuron() -> bool: ...@@ -128,6 +129,7 @@ def is_neuron() -> bool:
return transformers_neuronx is not None return transformers_neuronx is not None
@lru_cache(maxsize=None)
def get_max_shared_memory_bytes(gpu: int = 0) -> int: def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes.""" """Returns the maximum shared memory per thread block in bytes."""
# NOTE: This import statement should be executed lazily since # NOTE: This import statement should be executed lazily since
...@@ -151,6 +153,7 @@ def random_uuid() -> str: ...@@ -151,6 +153,7 @@ def random_uuid() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
@lru_cache(maxsize=None)
def in_wsl() -> bool: def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071 # Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()
...@@ -225,6 +228,7 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None: ...@@ -225,6 +228,7 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
@lru_cache(maxsize=None)
def get_nvcc_cuda_version() -> Optional[Version]: def get_nvcc_cuda_version() -> Optional[Version]:
cuda_home = os.environ.get('CUDA_HOME') cuda_home = os.environ.get('CUDA_HOME')
if not cuda_home: if not cuda_home:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment