Unverified Commit 1606eb99 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Fix pipelines user_agent, ignore CI requests (#1058)

* Fix pipelines user_agent, ignore CI requests

* fix circular import

* N/A versions

* N/A versions
parent 82d56cf1
...@@ -10,6 +10,7 @@ concurrency: ...@@ -10,6 +10,7 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
env: env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8 OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8 MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 60 PYTEST_TIMEOUT: 60
......
...@@ -6,6 +6,7 @@ on: ...@@ -6,6 +6,7 @@ on:
- main - main
env: env:
DIFFUSERS_IS_CI: yes
HF_HOME: /mnt/cache HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8 OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8 MKL_NUM_THREADS: 8
......
...@@ -16,13 +16,25 @@ ...@@ -16,13 +16,25 @@
import os import os
import shutil import shutil
import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Dict, Optional, Union
from uuid import uuid4
from huggingface_hub import HfFolder, Repository, whoami from huggingface_hub import HfFolder, Repository, whoami
from .pipeline_utils import DiffusionPipeline from . import __version__
from .utils import deprecate, is_modelcards_available, logging from .utils import ENV_VARS_TRUE_VALUES, deprecate, logging
from .utils.import_utils import (
_flax_version,
_jax_version,
_onnxruntime_version,
_torch_version,
is_flax_available,
is_modelcards_available,
is_onnx_available,
is_torch_available,
)
if is_modelcards_available(): if is_modelcards_available():
...@@ -33,6 +45,32 @@ logger = logging.get_logger(__name__) ...@@ -33,6 +45,32 @@ logger = logging.get_logger(__name__)
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md" MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
SESSION_ID = uuid4().hex
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", "").upper() in ENV_VARS_TRUE_VALUES
def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
"""
Formats a user-agent string with basic info about a request.
"""
ua = f"diffusers/{__version__}; python/{sys.version.split()[0]}; session_id/{SESSION_ID}"
if DISABLE_TELEMETRY:
return ua + "; telemetry/off"
if is_torch_available():
ua += f"; torch/{_torch_version}"
if is_flax_available():
ua += f"; jax/{_jax_version}"
ua += f"; flax/{_flax_version}"
if is_onnx_available():
ua += f"; onnxruntime/{_onnxruntime_version}"
# CI will set this value to True
if os.environ.get("DIFFUSERS_IS_CI", "").upper() in ENV_VARS_TRUE_VALUES:
ua += "; is_ci/true"
if isinstance(user_agent, dict):
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
return ua
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
...@@ -101,7 +139,7 @@ def init_git_repo(args, at_init: bool = False): ...@@ -101,7 +139,7 @@ def init_git_repo(args, at_init: bool = False):
def push_to_hub( def push_to_hub(
args, args,
pipeline: DiffusionPipeline, pipeline,
repo: Repository, repo: Repository,
commit_message: Optional[str] = "End of training", commit_message: Optional[str] = "End of training",
blocking: bool = True, blocking: bool = True,
......
...@@ -29,6 +29,7 @@ from PIL import Image ...@@ -29,6 +29,7 @@ from PIL import Image
from tqdm.auto import tqdm from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .hub_utils import http_user_agent
from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin from .modeling_flax_utils import FLAX_WEIGHTS_NAME, FlaxModelMixin
from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin from .schedulers.scheduling_utils_flax import SCHEDULER_CONFIG_NAME, FlaxSchedulerMixin
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging from .utils import CONFIG_NAME, DIFFUSERS_CACHE, BaseOutput, is_transformers_available, logging
...@@ -301,6 +302,13 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -301,6 +302,13 @@ class FlaxDiffusionPipeline(ConfigMixin):
allow_patterns = [os.path.join(k, "*") for k in folder_names] allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name] allow_patterns += [FLAX_WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, cls.config_name]
if cls != FlaxDiffusionPipeline:
requested_pipeline_class = cls.__name__
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"pipeline_class": requested_pipeline_class}
user_agent = http_user_agent(user_agent)
# download all allow_patterns # download all allow_patterns
cached_folder = snapshot_download( cached_folder = snapshot_download(
pretrained_model_name_or_path, pretrained_model_name_or_path,
...@@ -311,6 +319,7 @@ class FlaxDiffusionPipeline(ConfigMixin): ...@@ -311,6 +319,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision, revision=revision,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
user_agent=user_agent,
) )
else: else:
cached_folder = pretrained_model_name_or_path cached_folder = pretrained_model_name_or_path
......
...@@ -30,9 +30,9 @@ from packaging import version ...@@ -30,9 +30,9 @@ from packaging import version
from PIL import Image from PIL import Image
from tqdm.auto import tqdm from tqdm.auto import tqdm
from . import __version__
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import http_user_agent
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import ( from .utils import (
CONFIG_NAME, CONFIG_NAME,
...@@ -398,10 +398,14 @@ class DiffusionPipeline(ConfigMixin): ...@@ -398,10 +398,14 @@ class DiffusionPipeline(ConfigMixin):
if custom_pipeline is not None: if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME] allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
requested_pipeline_class = config_dict.get("_class_name", cls.__name__) if cls != DiffusionPipeline:
user_agent = {"diffusers": __version__, "pipeline_class": requested_pipeline_class} requested_pipeline_class = cls.__name__
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"pipeline_class": requested_pipeline_class}
if custom_pipeline is not None: if custom_pipeline is not None:
user_agent["custom_pipeline"] = custom_pipeline user_agent["custom_pipeline"] = custom_pipeline
user_agent = http_user_agent(user_agent)
# download all allow_patterns # download all allow_patterns
cached_folder = snapshot_download( cached_folder = snapshot_download(
......
...@@ -90,7 +90,8 @@ else: ...@@ -90,7 +90,8 @@ else:
logger.info("Disabling Tensorflow because USE_TORCH is set") logger.info("Disabling Tensorflow because USE_TORCH is set")
_tf_available = False _tf_available = False
_jax_version = "N/A"
_flax_version = "N/A"
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
_flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None _flax_available = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("flax") is not None
if _flax_available: if _flax_available:
...@@ -136,6 +137,7 @@ except importlib_metadata.PackageNotFoundError: ...@@ -136,6 +137,7 @@ except importlib_metadata.PackageNotFoundError:
_modelcards_available = False _modelcards_available = False
_onnxruntime_version = "N/A"
_onnx_available = importlib.util.find_spec("onnxruntime") is not None _onnx_available = importlib.util.find_spec("onnxruntime") is not None
if _onnx_available: if _onnx_available:
candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino") candidates = ("onnxruntime", "onnxruntime-gpu", "onnxruntime-directml", "onnxruntime-openvino")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment