Unverified Commit 38524f71 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Replace functools cache with lru_cache (#967)



cache was added in Python 3.9.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 744624d0
......@@ -4,19 +4,19 @@
"""Installation script."""
import functools
import glob
import os
import re
import glob
import shutil
import subprocess
import sys
from functools import cache
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple
@cache
@functools.lru_cache(maxsize=None)
def debug_build_enabled() -> bool:
"""Whether to build with a debug configuration"""
for arg in sys.argv:
......@@ -138,7 +138,7 @@ def found_pybind11() -> bool:
return False
@cache
@functools.lru_cache(maxsize=None)
def cuda_path() -> Tuple[str, str]:
"""CUDA root path and NVCC binary path as a tuple.
......
......@@ -2,12 +2,12 @@
#
# See LICENSE for license information.
from typing import List, Tuple, Union
import functools
import os
import pytest
import subprocess
import os
from dataclasses import dataclass, asdict
from functools import lru_cache
from dataclasses import asdict, dataclass
from typing import List, Tuple, Union
import torch
......@@ -46,7 +46,7 @@ te_path = os.getenv("TE_PATH", "/opt/transformerengine")
mlm_log_dir = os.path.join(te_path, "ci_logs")
@lru_cache(maxsize=1)
@functools.lru_cache(maxsize=None)
def get_parallel_configs() -> List[Tuple[int, int]]:
"""Returns valid combinations of (tp, pp)."""
sizes = [1, 2, 4]
......
......@@ -2,16 +2,16 @@
#
# See LICENSE for license information.
import math
import functools
from importlib.metadata import version
import logging
import math
import os
from importlib.metadata import version
from typing import Any, Dict, List, Tuple, Union
from pkg_resources import packaging
import pytest
import torch
import logging
from pkg_resources import packaging
from transformer_engine.common import recipe
from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init
......@@ -148,21 +148,21 @@ def _is_fused_attention_supported(
return False, backends
@functools.cache
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_available() -> bool:
"""Check if flash-attn 2.0+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2")
@functools.cache
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_1() -> bool:
"""Check if flash-attn 2.1+ is available"""
Version = packaging.version.Version
return Version(version("flash-attn")) >= Version("2.1")
@functools.cache
@functools.lru_cache(maxsize=None)
def _is_flash_attention_2_3() -> bool:
"""Check if flash-attn 2.3+ is available"""
Version = packaging.version.Version
......
......@@ -3,9 +3,10 @@
# See LICENSE for license information.
"""Utility functions for Transformer Engine modules"""
import math
import functools
import math
from typing import Any, Callable, Optional, Tuple
import torch
import transformer_engine.pytorch.cpp_extensions as ext
......@@ -243,7 +244,7 @@ def is_bf16_compatible() -> None:
return torch.cuda.get_device_capability()[0] >= 8
@functools.cache
@functools.lru_cache(maxsize=None)
def get_cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
encoded_version = ext.get_cudnn_version()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment