Unverified Commit ca1e4072 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

stable diffusion depth batching fix (#2757)

parent b33bd91f
...@@ -474,7 +474,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): ...@@ -474,7 +474,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if depth_map.shape[0] < batch_size: if depth_map.shape[0] < batch_size:
depth_map = depth_map.repeat(batch_size, 1, 1, 1) repeat_by = batch_size // depth_map.shape[0]
depth_map = depth_map.repeat(repeat_by, 1, 1, 1)
depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map
return depth_map return depth_map
......
...@@ -64,7 +64,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te ...@@ -64,7 +64,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
test_save_load_optional_components = False test_save_load_optional_components = False
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"image"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment