"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "b1de9f1ac2d0da727d96ff3ed4958c461ed704b3"
Unverified Commit 123506ee authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

make parallel loading flag a part of constants. (#12137)

parent 8c48ec05
...@@ -42,9 +42,8 @@ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer ...@@ -42,9 +42,8 @@ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod from ..quantizers.quantization_config import QuantizationMethod
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
ENV_VARS_TRUE_VALUES,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HF_PARALLEL_LOADING_FLAG, HF_ENABLE_PARALLEL_LOADING,
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
...@@ -962,7 +961,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -962,7 +961,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False) disable_mmap = kwargs.pop("disable_mmap", False)
is_parallel_loading_enabled = os.environ.get(HF_PARALLEL_LOADING_FLAG, "").upper() in ENV_VARS_TRUE_VALUES is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
if is_parallel_loading_enabled and not low_cpu_mem_usage: if is_parallel_loading_enabled and not low_cpu_mem_usage:
raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.") raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")
......
...@@ -25,8 +25,8 @@ from .constants import ( ...@@ -25,8 +25,8 @@ from .constants import (
DIFFUSERS_DYNAMIC_MODULE_NAME, DIFFUSERS_DYNAMIC_MODULE_NAME,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
GGUF_FILE_EXTENSION, GGUF_FILE_EXTENSION,
HF_ENABLE_PARALLEL_LOADING,
HF_MODULES_CACHE, HF_MODULES_CACHE,
HF_PARALLEL_LOADING_FLAG,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
MIN_PEFT_VERSION, MIN_PEFT_VERSION,
ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME,
......
...@@ -44,7 +44,7 @@ DIFFUSERS_REQUEST_TIMEOUT = 60 ...@@ -44,7 +44,7 @@ DIFFUSERS_REQUEST_TIMEOUT = 60
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
HF_PARALLEL_LOADING_FLAG = "HF_ENABLE_PARALLEL_LOADING" HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with # Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment