Unverified Commit ca1e4072 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

stable diffusion depth batching fix (#2757)

parent b33bd91f
......@@ -474,7 +474,8 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if depth_map.shape[0] < batch_size:
depth_map = depth_map.repeat(batch_size, 1, 1, 1)
repeat_by = batch_size // depth_map.shape[0]
depth_map = depth_map.repeat(repeat_by, 1, 1, 1)
depth_map = torch.cat([depth_map] * 2) if do_classifier_free_guidance else depth_map
return depth_map
......
......@@ -64,7 +64,7 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
test_save_load_optional_components = False
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"image"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
def get_dummy_components(self):
torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment