Unverified Commit cf4b97b2 authored by Charles's avatar Charles Committed by GitHub
Browse files

[perf] Cache version checks (#12399)

parent 7f3e9b86
...@@ -21,6 +21,7 @@ import operator as op ...@@ -21,6 +21,7 @@ import operator as op
import os import os
import sys import sys
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from functools import lru_cache as cache
from itertools import chain from itertools import chain
from types import ModuleType from types import ModuleType
from typing import Any, Tuple, Union from typing import Any, Tuple, Union
...@@ -673,6 +674,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re ...@@ -673,6 +674,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
@cache
def is_torch_version(operation: str, version: str): def is_torch_version(operation: str, version: str):
""" """
Compares the current PyTorch version to a given reference with an operation. Compares the current PyTorch version to a given reference with an operation.
...@@ -686,6 +688,7 @@ def is_torch_version(operation: str, version: str): ...@@ -686,6 +688,7 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version) return compare_versions(parse(_torch_version), operation, version)
@cache
def is_torch_xla_version(operation: str, version: str): def is_torch_xla_version(operation: str, version: str):
""" """
Compares the current torch_xla version to a given reference with an operation. Compares the current torch_xla version to a given reference with an operation.
...@@ -701,6 +704,7 @@ def is_torch_xla_version(operation: str, version: str): ...@@ -701,6 +704,7 @@ def is_torch_xla_version(operation: str, version: str):
return compare_versions(parse(_torch_xla_version), operation, version) return compare_versions(parse(_torch_xla_version), operation, version)
@cache
def is_transformers_version(operation: str, version: str): def is_transformers_version(operation: str, version: str):
""" """
Compares the current Transformers version to a given reference with an operation. Compares the current Transformers version to a given reference with an operation.
...@@ -716,6 +720,7 @@ def is_transformers_version(operation: str, version: str): ...@@ -716,6 +720,7 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version) return compare_versions(parse(_transformers_version), operation, version)
@cache
def is_hf_hub_version(operation: str, version: str): def is_hf_hub_version(operation: str, version: str):
""" """
Compares the current Hugging Face Hub version to a given reference with an operation. Compares the current Hugging Face Hub version to a given reference with an operation.
...@@ -731,6 +736,7 @@ def is_hf_hub_version(operation: str, version: str): ...@@ -731,6 +736,7 @@ def is_hf_hub_version(operation: str, version: str):
return compare_versions(parse(_hf_hub_version), operation, version) return compare_versions(parse(_hf_hub_version), operation, version)
@cache
def is_accelerate_version(operation: str, version: str): def is_accelerate_version(operation: str, version: str):
""" """
Compares the current Accelerate version to a given reference with an operation. Compares the current Accelerate version to a given reference with an operation.
...@@ -746,6 +752,7 @@ def is_accelerate_version(operation: str, version: str): ...@@ -746,6 +752,7 @@ def is_accelerate_version(operation: str, version: str):
return compare_versions(parse(_accelerate_version), operation, version) return compare_versions(parse(_accelerate_version), operation, version)
@cache
def is_peft_version(operation: str, version: str): def is_peft_version(operation: str, version: str):
""" """
Compares the current PEFT version to a given reference with an operation. Compares the current PEFT version to a given reference with an operation.
...@@ -761,6 +768,7 @@ def is_peft_version(operation: str, version: str): ...@@ -761,6 +768,7 @@ def is_peft_version(operation: str, version: str):
return compare_versions(parse(_peft_version), operation, version) return compare_versions(parse(_peft_version), operation, version)
@cache
def is_bitsandbytes_version(operation: str, version: str): def is_bitsandbytes_version(operation: str, version: str):
""" """
Args: Args:
...@@ -775,6 +783,7 @@ def is_bitsandbytes_version(operation: str, version: str): ...@@ -775,6 +783,7 @@ def is_bitsandbytes_version(operation: str, version: str):
return compare_versions(parse(_bitsandbytes_version), operation, version) return compare_versions(parse(_bitsandbytes_version), operation, version)
@cache
def is_gguf_version(operation: str, version: str): def is_gguf_version(operation: str, version: str):
""" """
Compares the current Accelerate version to a given reference with an operation. Compares the current Accelerate version to a given reference with an operation.
...@@ -790,6 +799,7 @@ def is_gguf_version(operation: str, version: str): ...@@ -790,6 +799,7 @@ def is_gguf_version(operation: str, version: str):
return compare_versions(parse(_gguf_version), operation, version) return compare_versions(parse(_gguf_version), operation, version)
@cache
def is_torchao_version(operation: str, version: str): def is_torchao_version(operation: str, version: str):
""" """
Compares the current torchao version to a given reference with an operation. Compares the current torchao version to a given reference with an operation.
...@@ -805,6 +815,7 @@ def is_torchao_version(operation: str, version: str): ...@@ -805,6 +815,7 @@ def is_torchao_version(operation: str, version: str):
return compare_versions(parse(_torchao_version), operation, version) return compare_versions(parse(_torchao_version), operation, version)
@cache
def is_k_diffusion_version(operation: str, version: str): def is_k_diffusion_version(operation: str, version: str):
""" """
Compares the current k-diffusion version to a given reference with an operation. Compares the current k-diffusion version to a given reference with an operation.
...@@ -820,6 +831,7 @@ def is_k_diffusion_version(operation: str, version: str): ...@@ -820,6 +831,7 @@ def is_k_diffusion_version(operation: str, version: str):
return compare_versions(parse(_k_diffusion_version), operation, version) return compare_versions(parse(_k_diffusion_version), operation, version)
@cache
def is_optimum_quanto_version(operation: str, version: str): def is_optimum_quanto_version(operation: str, version: str):
""" """
Compares the current Accelerate version to a given reference with an operation. Compares the current Accelerate version to a given reference with an operation.
...@@ -835,6 +847,7 @@ def is_optimum_quanto_version(operation: str, version: str): ...@@ -835,6 +847,7 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version) return compare_versions(parse(_optimum_quanto_version), operation, version)
@cache
def is_nvidia_modelopt_version(operation: str, version: str): def is_nvidia_modelopt_version(operation: str, version: str):
""" """
Compares the current Nvidia ModelOpt version to a given reference with an operation. Compares the current Nvidia ModelOpt version to a given reference with an operation.
...@@ -850,6 +863,7 @@ def is_nvidia_modelopt_version(operation: str, version: str): ...@@ -850,6 +863,7 @@ def is_nvidia_modelopt_version(operation: str, version: str):
return compare_versions(parse(_nvidia_modelopt_version), operation, version) return compare_versions(parse(_nvidia_modelopt_version), operation, version)
@cache
def is_xformers_version(operation: str, version: str): def is_xformers_version(operation: str, version: str):
""" """
Compares the current xformers version to a given reference with an operation. Compares the current xformers version to a given reference with an operation.
...@@ -865,6 +879,7 @@ def is_xformers_version(operation: str, version: str): ...@@ -865,6 +879,7 @@ def is_xformers_version(operation: str, version: str):
return compare_versions(parse(_xformers_version), operation, version) return compare_versions(parse(_xformers_version), operation, version)
@cache
def is_sageattention_version(operation: str, version: str): def is_sageattention_version(operation: str, version: str):
""" """
Compares the current sageattention version to a given reference with an operation. Compares the current sageattention version to a given reference with an operation.
...@@ -880,6 +895,7 @@ def is_sageattention_version(operation: str, version: str): ...@@ -880,6 +895,7 @@ def is_sageattention_version(operation: str, version: str):
return compare_versions(parse(_sageattention_version), operation, version) return compare_versions(parse(_sageattention_version), operation, version)
@cache
def is_flash_attn_version(operation: str, version: str): def is_flash_attn_version(operation: str, version: str):
""" """
Compares the current flash-attention version to a given reference with an operation. Compares the current flash-attention version to a given reference with an operation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment