Unverified Commit 8d6487f3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Fix some failing tests (#1041)

* up

* up

* up

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

* Apply suggestions from code review
parent d2d9764f
...@@ -662,6 +662,8 @@ class LDMBertEncoder(LDMBertPreTrainedModel): ...@@ -662,6 +662,8 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
class LDMBertModel(LDMBertPreTrainedModel): class LDMBertModel(LDMBertPreTrainedModel):
_no_split_modules = []
def __init__(self, config: LDMBertConfig): def __init__(self, config: LDMBertConfig):
super().__init__(config) super().__init__(config)
self.model = LDMBertEncoder(config) self.model = LDMBertEncoder(config)
......
...@@ -208,7 +208,6 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -208,7 +208,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`. (nsfw) content, according to the `safety_checker`.
""" """
if isinstance(prompt, str): if isinstance(prompt, str):
batch_size = 1 batch_size = 1
elif isinstance(prompt, list): elif isinstance(prompt, list):
......
...@@ -740,7 +740,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -740,7 +740,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
start_time = time.time() start_time = time.time()
pipeline_normal_load = StableDiffusionPipeline.from_pretrained( pipeline_normal_load = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float16, device_map="auto" pipeline_id, revision="fp16", torch_dtype=torch.float16
) )
pipeline_normal_load.to(torch_device) pipeline_normal_load.to(torch_device)
normal_load_time = time.time() - start_time normal_load_time = time.time() - start_time
...@@ -761,9 +761,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): ...@@ -761,9 +761,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
pipeline_id = "CompVis/stable-diffusion-v1-4" pipeline_id = "CompVis/stable-diffusion-v1-4"
prompt = "Andromeda galaxy in a bottle" prompt = "Andromeda galaxy in a bottle"
pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(pipeline_id, revision="fp16", torch_dtype=torch.float16)
pipeline_id, revision="fp16", torch_dtype=torch.float16, device_map="auto"
)
pipeline.enable_attention_slicing(1) pipeline.enable_attention_slicing(1)
pipeline.enable_sequential_cpu_offload() pipeline.enable_sequential_cpu_offload()
......
...@@ -77,6 +77,7 @@ class CustomPipelineTests(unittest.TestCase): ...@@ -77,6 +77,7 @@ class CustomPipelineTests(unittest.TestCase):
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline" "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
) )
pipeline = pipeline.to(torch_device)
# NOTE that `"CustomPipeline"` is not a class that is defined in this library, but solely on the Hub # NOTE that `"CustomPipeline"` is not a class that is defined in this library, but solely on the Hub
# under https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L24 # under https://huggingface.co/hf-internal-testing/diffusers-dummy-pipeline/blob/main/pipeline.py#L24
assert pipeline.__class__.__name__ == "CustomPipeline" assert pipeline.__class__.__name__ == "CustomPipeline"
...@@ -85,6 +86,7 @@ class CustomPipelineTests(unittest.TestCase): ...@@ -85,6 +86,7 @@ class CustomPipelineTests(unittest.TestCase):
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline" "google/ddpm-cifar10-32", custom_pipeline="hf-internal-testing/diffusers-dummy-pipeline"
) )
pipeline = pipeline.to(torch_device)
images, output_str = pipeline(num_inference_steps=2, output_type="np") images, output_str = pipeline(num_inference_steps=2, output_type="np")
assert images[0].shape == (1, 32, 32, 3) assert images[0].shape == (1, 32, 32, 3)
...@@ -96,6 +98,7 @@ class CustomPipelineTests(unittest.TestCase): ...@@ -96,6 +98,7 @@ class CustomPipelineTests(unittest.TestCase):
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
"google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path "google/ddpm-cifar10-32", custom_pipeline=local_custom_pipeline_path
) )
pipeline = pipeline.to(torch_device)
images, output_str = pipeline(num_inference_steps=2, output_type="np") images, output_str = pipeline(num_inference_steps=2, output_type="np")
assert pipeline.__class__.__name__ == "CustomLocalPipeline" assert pipeline.__class__.__name__ == "CustomLocalPipeline"
...@@ -109,7 +112,7 @@ class CustomPipelineTests(unittest.TestCase): ...@@ -109,7 +112,7 @@ class CustomPipelineTests(unittest.TestCase):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id, device_map="auto") feature_extractor = CLIPFeatureExtractor.from_pretrained(clip_model_id, device_map="auto")
clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16, device_map="auto") clip_model = CLIPModel.from_pretrained(clip_model_id, torch_dtype=torch.float16)
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", "CompVis/stable-diffusion-v1-4",
...@@ -380,10 +383,11 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -380,10 +383,11 @@ class PipelineSlowTests(unittest.TestCase):
scheduler = DDPMScheduler(num_train_timesteps=10) scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto") ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
ddpm.to(torch_device) ddpm = ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None) ddpm.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto") ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
ddpm_from_hub.to(torch_device) ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub.set_progress_bar_config(disable=None) ddpm_from_hub.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -404,11 +408,11 @@ class PipelineSlowTests(unittest.TestCase): ...@@ -404,11 +408,11 @@ class PipelineSlowTests(unittest.TestCase):
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained( ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(
model_path, unet=unet, scheduler=scheduler, device_map="auto" model_path, unet=unet, scheduler=scheduler, device_map="auto"
) )
ddpm_from_hub_custom_model.to(torch_device) ddpm_from_hub_custom_model = ddpm_from_hub_custom_model.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto") ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, device_map="auto")
ddpm_from_hub.to(torch_device) ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None) ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment