Unverified Commit db19a9d9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[DiffusionPipeline.from_pretrained] add warning when passing unused k… (#870)

[DiffusionPipeline.from_pretrained] add warning when passing unused kwargs
parent 4a76e5d4
...@@ -350,6 +350,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -350,6 +350,7 @@ class DiffusionPipeline(ConfigMixin):
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
...@@ -367,6 +368,7 @@ class DiffusionPipeline(ConfigMixin): ...@@ -367,6 +368,7 @@ class DiffusionPipeline(ConfigMixin):
pretrained_model_name_or_path, pretrained_model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
resume_download=resume_download, resume_download=resume_download,
force_download=force_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
...@@ -439,7 +441,10 @@ class DiffusionPipeline(ConfigMixin): ...@@ -439,7 +441,10 @@ class DiffusionPipeline(ConfigMixin):
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"]) expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
if len(unused_kwargs) > 0:
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
init_kwargs = {} init_kwargs = {}
......
import inspect import inspect
import logging
import os import os
import random import random
import re import re
import unittest import unittest
from distutils.util import strtobool from distutils.util import strtobool
from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
...@@ -284,3 +286,42 @@ def pytest_terminal_summary_main(tr, id): ...@@ -284,3 +286,42 @@ def pytest_terminal_summary_main(tr, id):
tr._tw = orig_writer tr._tw = orig_writer
tr.reportchars = orig_reportchars tr.reportchars = orig_reportchars
config.option.tbstyle = orig_tbstyle config.option.tbstyle = orig_tbstyle
class CaptureLogger:
"""
Args:
Context manager to capture `logging` streams
logger: 'logging` logger object
Returns:
The captured output is available via `self.out`
Example:
```python
>>> from diffusers import logging
>>> from diffusers.testing_utils import CaptureLogger
>>> msg = "Testing 1, 2, 3"
>>> logging.set_verbosity_info()
>>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
>>> with CaptureLogger(logger) as cl:
... logger.info(msg)
>>> assert cl.out, msg + "\n"
```
"""
def __init__(self, logger):
self.logger = logger
self.io = StringIO()
self.sh = logging.StreamHandler(self.io)
self.out = ""
def __enter__(self):
self.logger.addHandler(self.sh)
return self
def __exit__(self, *exc):
self.logger.removeHandler(self.sh)
self.out = self.io.getvalue()
def __repr__(self):
return f"captured: {self.out}\n"
...@@ -51,11 +51,12 @@ from diffusers import ( ...@@ -51,11 +51,12 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
UNet2DModel, UNet2DModel,
VQModel, VQModel,
logging,
) )
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import get_tests_dir from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
from packaging import version from packaging import version
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
...@@ -1473,6 +1474,15 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -1473,6 +1474,15 @@ class PipelineTesterMixin(unittest.TestCase):
# is not downloaded, but all the expected ones # is not downloaded, but all the expected ones
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy")) assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))
def test_warning_unused_kwargs(self):
model_id = "hf-internal-testing/unet-pipeline-dummy"
logger = logging.get_logger("diffusers.pipeline_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
with CaptureLogger(logger) as cap_logger:
DiffusionPipeline.from_pretrained(model_id, not_used=True, cache_dir=tmpdirname, force_download=True)
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
@property @property
def dummy_safety_checker(self): def dummy_safety_checker(self):
def check(images, *args, **kwargs): def check(images, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment