"googlemock/include/gmock/vscode:/vscode.git/clone" did not exist on "e26771776b26b690c98edd99a949ae6c8638e03b"
Unverified Commit 38524f71 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Replace functools cache with lru_cache (#967)



cache was added in Python 3.9.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 744624d0
...@@ -4,19 +4,19 @@ ...@@ -4,19 +4,19 @@
"""Installation script.""" """Installation script."""
import functools
import glob
import os import os
import re import re
import glob
import shutil import shutil
import subprocess import subprocess
import sys import sys
from functools import cache
from pathlib import Path from pathlib import Path
from subprocess import CalledProcessError from subprocess import CalledProcessError
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@cache @functools.lru_cache(maxsize=None)
def debug_build_enabled() -> bool: def debug_build_enabled() -> bool:
"""Whether to build with a debug configuration""" """Whether to build with a debug configuration"""
for arg in sys.argv: for arg in sys.argv:
...@@ -138,7 +138,7 @@ def found_pybind11() -> bool: ...@@ -138,7 +138,7 @@ def found_pybind11() -> bool:
return False return False
@cache @functools.lru_cache(maxsize=None)
def cuda_path() -> Tuple[str, str]: def cuda_path() -> Tuple[str, str]:
"""CUDA root path and NVCC binary path as a tuple. """CUDA root path and NVCC binary path as a tuple.
......
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import List, Tuple, Union import functools
import os
import pytest import pytest
import subprocess import subprocess
import os from dataclasses import asdict, dataclass
from dataclasses import dataclass, asdict from typing import List, Tuple, Union
from functools import lru_cache
import torch import torch
...@@ -46,7 +46,7 @@ te_path = os.getenv("TE_PATH", "/opt/transformerengine") ...@@ -46,7 +46,7 @@ te_path = os.getenv("TE_PATH", "/opt/transformerengine")
mlm_log_dir = os.path.join(te_path, "ci_logs") mlm_log_dir = os.path.join(te_path, "ci_logs")
@lru_cache(maxsize=1) @functools.lru_cache(maxsize=None)
def get_parallel_configs() -> List[Tuple[int, int]]: def get_parallel_configs() -> List[Tuple[int, int]]:
"""Returns valid combinations of (tp, pp).""" """Returns valid combinations of (tp, pp)."""
sizes = [1, 2, 4] sizes = [1, 2, 4]
......
...@@ -2,16 +2,16 @@ ...@@ -2,16 +2,16 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
import math
import functools import functools
from importlib.metadata import version import logging
import math
import os import os
from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from pkg_resources import packaging
import pytest import pytest
import torch import torch
import logging from pkg_resources import packaging
from transformer_engine.common import recipe from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
...@@ -148,21 +148,21 @@ def _is_fused_attention_supported( ...@@ -148,21 +148,21 @@ def _is_fused_attention_supported(
return False, backends return False, backends
@functools.cache @functools.lru_cache(maxsize=None)
def _is_flash_attention_2_available() -> bool: def _is_flash_attention_2_available() -> bool:
"""Check if flash-attn 2.0+ is available""" """Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2") return Version(version("flash-attn")) >= Version("2")
@functools.cache @functools.lru_cache(maxsize=None)
def _is_flash_attention_2_1() -> bool: def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.1+ is available""" """Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1") return Version(version("flash-attn")) >= Version("2.1")
@functools.cache @functools.lru_cache(maxsize=None)
def _is_flash_attention_2_3() -> bool: def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available""" """Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version Version = packaging.version.Version
......
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Utility functions for Transformer Engine modules""" """Utility functions for Transformer Engine modules"""
import math
import functools import functools
import math
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
import torch import torch
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
...@@ -243,7 +244,7 @@ def is_bf16_compatible() -> None: ...@@ -243,7 +244,7 @@ def is_bf16_compatible() -> None:
return torch.cuda.get_device_capability()[0] >= 8 return torch.cuda.get_device_capability()[0] >= 8
@functools.cache @functools.lru_cache(maxsize=None)
def get_cudnn_version() -> Tuple[int, int, int]: def get_cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)""" """Runtime cuDNN version (major, minor, patch)"""
encoded_version = ext.get_cudnn_version() encoded_version = ext.get_cudnn_version()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment