Unverified Commit d50e3217 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Support SD2 attention slicing (#1397)

* Support SD2 attention slicing

* Support SD2 attention slicing

* Add more copies

* Use attn_num_head_channels in blocks

* fix-copies

* Update tests

* fix imports
parent 8e2c4cd5
...@@ -404,15 +404,17 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -404,15 +404,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0: head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError( raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of " f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}" f"the number of heads used in cross_attention: {head_dims}"
) )
if slice_size is not None and slice_size > self.attn_num_head_channels: if slice_size is not None and slice_size > min(head_dims):
raise ValueError( raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to " f"slice_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}" f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
) )
for attn in self.attentions: for attn in self.attentions:
...@@ -600,15 +602,17 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -600,15 +602,17 @@ class CrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0: head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError( raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of " f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}" f"the number of heads used in cross_attention: {head_dims}"
) )
if slice_size is not None and slice_size > self.attn_num_head_channels: if slice_size is not None and slice_size > min(head_dims):
raise ValueError( raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to " f"slice_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}" f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
) )
for attn in self.attentions: for attn in self.attentions:
...@@ -1197,15 +1201,17 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1197,15 +1201,17 @@ class CrossAttnUpBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0: head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError( raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of " f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}" f"the number of heads used in cross_attention: {head_dims}"
) )
if slice_size is not None and slice_size > self.attn_num_head_channels: if slice_size is not None and slice_size > min(head_dims):
raise ValueError( raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to " f"slice_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}" f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
) )
for attn in self.attentions: for attn in self.attentions:
......
...@@ -229,15 +229,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -229,15 +229,17 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
if slice_size is not None and self.config.attention_head_dim % slice_size != 0: head_dims = self.config.attention_head_dim
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError( raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of " f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention {self.config.attention_head_dim}" f"the number of heads used in cross_attention: {head_dims}"
) )
if slice_size is not None and slice_size > self.config.attention_head_dim: if slice_size is not None and slice_size > min(head_dims):
raise ValueError( raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to " f"slice_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.config.attention_head_dim}" f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
) )
for block in self.down_blocks: for block in self.down_blocks:
......
...@@ -198,9 +198,14 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -198,9 +198,14 @@ class AltDiffusionPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`. `attention_head_dim` must be a multiple of `slice_size`.
""" """
if slice_size == "auto": if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
# speed and memory # speed and memory
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self): def disable_attention_slicing(self):
......
...@@ -193,9 +193,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -193,9 +193,14 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`. `attention_head_dim` must be a multiple of `slice_size`.
""" """
if slice_size == "auto": if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
# speed and memory # speed and memory
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self): def disable_attention_slicing(self):
......
...@@ -224,9 +224,14 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -224,9 +224,14 @@ class CycleDiffusionPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`. `attention_head_dim` must be a multiple of `slice_size`.
""" """
if slice_size == "auto": if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
# speed and memory # speed and memory
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
......
...@@ -197,9 +197,14 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -197,9 +197,14 @@ class StableDiffusionPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`. `attention_head_dim` must be a multiple of `slice_size`.
""" """
if slice_size == "auto": if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
# speed and memory # speed and memory
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self): def disable_attention_slicing(self):
......
...@@ -169,9 +169,14 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -169,9 +169,14 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`. `attention_head_dim` must be a multiple of `slice_size`.
""" """
if slice_size == "auto": if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
# speed and memory # speed and memory
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
......
...@@ -193,9 +193,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -193,9 +193,14 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`. `attention_head_dim` must be a multiple of `slice_size`.
""" """
if slice_size == "auto": if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
# speed and memory # speed and memory
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
......
...@@ -258,9 +258,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -258,9 +258,14 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`. `attention_head_dim` must be a multiple of `slice_size`.
""" """
if slice_size == "auto": if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
# speed and memory # speed and memory
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
......
...@@ -206,9 +206,14 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -206,9 +206,14 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`. `attention_head_dim` must be a multiple of `slice_size`.
""" """
if slice_size == "auto": if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
# speed and memory # speed and memory
slice_size = self.unet.config.attention_head_dim // 2 slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size) self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
......
...@@ -307,15 +307,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -307,15 +307,17 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
if slice_size is not None and self.config.attention_head_dim % slice_size != 0: head_dims = self.config.attention_head_dim
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError( raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of " f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention {self.config.attention_head_dim}" f"the number of heads used in cross_attention: {head_dims}"
) )
if slice_size is not None and slice_size > self.config.attention_head_dim: if slice_size is not None and slice_size > min(head_dims):
raise ValueError( raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to " f"slice_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.config.attention_head_dim}" f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
) )
for block in self.down_blocks: for block in self.down_blocks:
...@@ -743,15 +745,17 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -743,15 +745,17 @@ class CrossAttnDownBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0: head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError( raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of " f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}" f"the number of heads used in cross_attention: {head_dims}"
) )
if slice_size is not None and slice_size > self.attn_num_head_channels: if slice_size is not None and slice_size > min(head_dims):
raise ValueError( raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to " f"slice_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}" f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
) )
for attn in self.attentions: for attn in self.attentions:
...@@ -954,15 +958,17 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -954,15 +958,17 @@ class CrossAttnUpBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0: head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError( raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of " f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}" f"the number of heads used in cross_attention: {head_dims}"
) )
if slice_size is not None and slice_size > self.attn_num_head_channels: if slice_size is not None and slice_size > min(head_dims):
raise ValueError( raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to " f"slice_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}" f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
) )
for attn in self.attentions: for attn in self.attentions:
...@@ -1101,15 +1107,17 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1101,15 +1107,17 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0: head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError( raise ValueError(
f"Make sure slice_size {slice_size} is a divisor of " f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention {self.attn_num_head_channels}" f"the number of heads used in cross_attention: {head_dims}"
) )
if slice_size is not None and slice_size > self.attn_num_head_channels: if slice_size is not None and slice_size > min(head_dims):
raise ValueError( raise ValueError(
f"Chunk_size {slice_size} has to be smaller or equal to " f"slice_size {slice_size} has to be smaller or equal to "
f"the number of heads used in cross_attention {self.attn_num_head_channels}" f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
) )
for attn in self.attentions: for attn in self.attentions:
......
...@@ -178,9 +178,14 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -178,9 +178,14 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`. `attention_head_dim` must be a multiple of `slice_size`.
""" """
if slice_size == "auto": if slice_size == "auto":
if isinstance(self.image_unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
# speed and memory # speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2 slice_size = self.image_unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.image_unet.config.attention_head_dim)
self.image_unet.set_attention_slice(slice_size) self.image_unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
......
...@@ -108,9 +108,14 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -108,9 +108,14 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`. `attention_head_dim` must be a multiple of `slice_size`.
""" """
if slice_size == "auto": if slice_size == "auto":
if isinstance(self.image_unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
# speed and memory # speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2 slice_size = self.image_unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.image_unet.config.attention_head_dim)
self.image_unet.set_attention_slice(slice_size) self.image_unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
......
...@@ -131,9 +131,14 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -131,9 +131,14 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
`attention_head_dim` must be a multiple of `slice_size`. `attention_head_dim` must be a multiple of `slice_size`.
""" """
if slice_size == "auto": if slice_size == "auto":
if isinstance(self.image_unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between # half the attention head size is usually a good trade-off between
# speed and memory # speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2 slice_size = self.image_unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.image_unet.config.attention_head_dim)
self.image_unet.set_attention_slice(slice_size) self.image_unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment