Unverified Commit 7212f35d authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[single file] enable telemetry for single file loading when using GGUF. (#11284)

* enable telemetry for single file loading when using GGUF.

* quality
parent 3252d7ad
...@@ -21,6 +21,7 @@ import torch ...@@ -21,6 +21,7 @@ import torch
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import ( from .single_file_utils import (
...@@ -260,6 +261,11 @@ class FromOriginalModelMixin: ...@@ -260,6 +261,11 @@ class FromOriginalModelMixin:
device = kwargs.pop("device", None) device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False) disable_mmap = kwargs.pop("disable_mmap", False)
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if quantization_config is not None:
user_agent["quant"] = quantization_config.quant_method.value
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32 torch_dtype = torch.float32
logger.warning( logger.warning(
...@@ -278,6 +284,7 @@ class FromOriginalModelMixin: ...@@ -278,6 +284,7 @@ class FromOriginalModelMixin:
local_files_only=local_files_only, local_files_only=local_files_only,
revision=revision, revision=revision,
disable_mmap=disable_mmap, disable_mmap=disable_mmap,
user_agent=user_agent,
) )
if quantization_config is not None: if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
......
...@@ -405,13 +405,16 @@ def load_single_file_checkpoint( ...@@ -405,13 +405,16 @@ def load_single_file_checkpoint(
local_files_only=None, local_files_only=None,
revision=None, revision=None,
disable_mmap=False, disable_mmap=False,
user_agent=None,
): ):
if user_agent is None:
user_agent = {"file_type": "single_file", "framework": "pytorch"}
if os.path.isfile(pretrained_model_link_or_path): if os.path.isfile(pretrained_model_link_or_path):
pretrained_model_link_or_path = pretrained_model_link_or_path pretrained_model_link_or_path = pretrained_model_link_or_path
else: else:
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
user_agent = {"file_type": "single_file", "framework": "pytorch"}
pretrained_model_link_or_path = _get_model_file( pretrained_model_link_or_path = _get_model_file(
repo_id, repo_id,
weights_name=weights_name, weights_name=weights_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment