"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "e592e8fccb390073a51c61a529d4a52529c44aa2"
Unverified Commit bce65cd1 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[refactor] make set_attention_slice recursive (#1532)



* make attn slice recursive

* remove set_attention_slice from blocks

* fix copies

* make enable_attention_slicing base class method of DiffusionPipeline

* fix set_attention_slice

* fix set_attention_slice

* fix copies

* add tests

* up

* up

* up

* update

* up

* uP
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent e2899989
...@@ -174,10 +174,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -174,10 +174,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.norm_out = nn.LayerNorm(inner_dim) self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
""" """
Args: Args:
...@@ -448,10 +444,6 @@ class BasicTransformerBlock(nn.Module): ...@@ -448,10 +444,6 @@ class BasicTransformerBlock(nn.Module):
f" correctly and a GPU is available: {e}" f" correctly and a GPU is available: {e}"
) )
def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available(): if not is_xformers_available():
print("Here is how to install it") print("Here is how to install it")
...@@ -534,6 +526,7 @@ class CrossAttention(nn.Module): ...@@ -534,6 +526,7 @@ class CrossAttention(nn.Module):
# for slice_size > 0 the attention score computation # for slice_size > 0 the attention score computation
# is split across the batch axis to save memory # is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice` # You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self._slice_size = None self._slice_size = None
self._use_memory_efficient_attention_xformers = False self._use_memory_efficient_attention_xformers = False
...@@ -559,6 +552,12 @@ class CrossAttention(nn.Module): ...@@ -559,6 +552,12 @@ class CrossAttention(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor return tensor
def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
self._slice_size = slice_size
def forward(self, hidden_states, context=None, mask=None): def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
......
...@@ -401,23 +401,6 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -401,23 +401,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
...@@ -595,23 +578,6 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -595,23 +578,6 @@ class CrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = () output_states = ()
...@@ -1190,25 +1156,6 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1190,25 +1156,6 @@ class CrossAttnUpBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
self.gradient_checkpointing = False
def forward( def forward(
self, self,
hidden_states, hidden_states,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -229,28 +229,69 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -229,28 +229,69 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1) self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
head_dims = self.config.attention_head_dim r"""
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims Enable sliced attention computation.
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError( When this option is enabled, the attention module will split the input tensor in slices, to compute attention
f"Make sure slice_size {slice_size} is a common divisor of " in several steps. This is useful to save some memory in exchange for a small speed decrease.
f"the number of heads used in cross_attention: {head_dims}"
) Args:
if slice_size is not None and slice_size > min(head_dims): slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
raise ValueError( When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
f"slice_size {slice_size} has to be smaller or equal to " `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
) must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_slicable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)
for block in self.down_blocks: num_slicable_layers = len(sliceable_head_dims)
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)
self.mid_block.set_attention_slice(slice_size) if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for block in self.up_blocks: for i in range(len(slice_size)):
if hasattr(block, "attentions") and block.attentions is not None: size = slice_size[i]
block.set_attention_slice(slice_size) dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
......
...@@ -839,3 +839,34 @@ class DiffusionPipeline(ConfigMixin): ...@@ -839,3 +839,34 @@ class DiffusionPipeline(ConfigMixin):
module = getattr(self, module_name) module = getattr(self, module_name)
if isinstance(module, torch.nn.Module): if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module) fn_recursive_set_mem_eff(module)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
self.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def set_attention_slice(self, slice_size: Optional[int]):
module_names, _, _ = self.extract_init_dict(dict(self.config))
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size)
...@@ -166,38 +166,6 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -166,38 +166,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_vae_slicing(self): def enable_vae_slicing(self):
r""" r"""
Enable sliced VAE decoding. Enable sliced VAE decoding.
......
...@@ -179,38 +179,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -179,38 +179,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
......
...@@ -209,40 +209,6 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -209,40 +209,6 @@ class CycleDiffusionPipeline(DiffusionPipeline):
) )
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
......
...@@ -165,38 +165,6 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -165,38 +165,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_vae_slicing(self): def enable_vae_slicing(self):
r""" r"""
Enable sliced VAE decoding. Enable sliced VAE decoding.
......
...@@ -134,40 +134,6 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -134,40 +134,6 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
......
...@@ -178,40 +178,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -178,40 +178,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
......
...@@ -243,40 +243,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -243,40 +243,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
......
...@@ -191,40 +191,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -191,40 +191,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
......
...@@ -92,40 +92,6 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -92,40 +92,6 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
) )
self.register_to_config(max_noise_level=max_noise_level) self.register_to_config(max_noise_level=max_noise_level)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.unet.config.attention_head_dim)
self.unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
......
...@@ -182,33 +182,6 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -182,33 +182,6 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
""" """
self._safety_text_concept = concept self._safety_text_concept = concept
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self): def enable_sequential_cpu_offload(self):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
......
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -307,28 +307,69 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -307,28 +307,69 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1) self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
head_dims = self.config.attention_head_dim r"""
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims Enable sliced attention computation.
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError( When this option is enabled, the attention module will split the input tensor in slices, to compute attention
f"Make sure slice_size {slice_size} is a common divisor of " in several steps. This is useful to save some memory in exchange for a small speed decrease.
f"the number of heads used in cross_attention: {head_dims}"
) Args:
if slice_size is not None and slice_size > min(head_dims): slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_slicable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)
num_slicable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError( raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to " f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
) )
for block in self.down_blocks: for i in range(len(slice_size)):
if hasattr(block, "attentions") and block.attentions is not None: size = slice_size[i]
block.set_attention_slice(slice_size) dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
self.mid_block.set_attention_slice(slice_size) for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
for block in self.up_blocks: reversed_slice_size = list(reversed(slice_size))
if hasattr(block, "attentions") and block.attentions is not None: for module in self.children():
block.set_attention_slice(slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)):
...@@ -739,23 +780,6 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -739,23 +780,6 @@ class CrossAttnDownBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = () output_states = ()
...@@ -948,25 +972,6 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -948,25 +972,6 @@ class CrossAttnUpBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
self.gradient_checkpointing = False
def forward( def forward(
self, self,
hidden_states, hidden_states,
...@@ -1092,23 +1097,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1092,23 +1097,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
self.attentions = nn.ModuleList(attentions) self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
......
...@@ -80,34 +80,6 @@ class VersatileDiffusionPipeline(DiffusionPipeline): ...@@ -80,34 +80,6 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2
self.image_unet.set_attention_slice(slice_size)
self.text_unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
@torch.no_grad() @torch.no_grad()
def image_variation( def image_variation(
self, self,
......
...@@ -147,40 +147,6 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -147,40 +147,6 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
self.image_unet.register_to_config(dual_cross_attention=False) self.image_unet.register_to_config(dual_cross_attention=False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.image_unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.image_unet.config.attention_head_dim)
self.image_unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
......
...@@ -73,40 +73,6 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -73,40 +73,6 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.image_unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.image_unet.config.attention_head_dim)
self.image_unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
......
...@@ -98,40 +98,6 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -98,40 +98,6 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
def remove_unused_weights(self): def remove_unused_weights(self):
self.register_modules(text_unet=None) self.register_modules(text_unet=None)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
`attention_head_dim` must be a multiple of `slice_size`.
"""
if slice_size == "auto":
if isinstance(self.image_unet.config.attention_head_dim, int):
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.image_unet.config.attention_head_dim // 2
else:
# if `attention_head_dim` is a list, take the smallest head size
slice_size = min(self.image_unet.config.attention_head_dim)
self.image_unet.set_attention_slice(slice_size)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
def disable_attention_slicing(self):
r"""
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def enable_sequential_cpu_offload(self, gpu_id=0): def enable_sequential_cpu_offload(self, gpu_id=0):
r""" r"""
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
......
...@@ -334,6 +334,48 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -334,6 +334,48 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
expected_shape = inputs_dict["sample"].shape expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model.set_attention_slice("auto")
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
model.set_attention_slice("max")
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
model.set_attention_slice(2)
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
def test_model_slicable_head_dim(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
def check_slicable_dim_attr(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
assert isinstance(module.sliceable_head_dim, int)
for child in module.children():
check_slicable_dim_attr(child)
# retrieve number of attention layers
for module in model.children():
check_slicable_dim_attr(module)
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DModel model_class = UNet2DModel
...@@ -479,6 +521,84 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -479,6 +521,84 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
return model return model
def test_set_attention_slice_auto(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
unet = self.get_unet_model()
unet.set_attention_slice("auto")
latents = self.get_latents(33)
encoder_hidden_states = self.get_encoder_hidden_states(33)
timestep = 1
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
def test_set_attention_slice_max(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
unet = self.get_unet_model()
unet.set_attention_slice("max")
latents = self.get_latents(33)
encoder_hidden_states = self.get_encoder_hidden_states(33)
timestep = 1
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
def test_set_attention_slice_int(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
unet = self.get_unet_model()
unet.set_attention_slice(2)
latents = self.get_latents(33)
encoder_hidden_states = self.get_encoder_hidden_states(33)
timestep = 1
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
def test_set_attention_slice_list(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# there are 32 slicable layers
slice_list = 16 * [2, 3]
unet = self.get_unet_model()
unet.set_attention_slice(slice_list)
latents = self.get_latents(33)
encoder_hidden_states = self.get_encoder_hidden_states(33)
timestep = 1
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False): def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
dtype = torch.float16 if fp16 else torch.float32 dtype = torch.float16 if fp16 else torch.float32
hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype) hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
...@@ -500,6 +620,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -500,6 +620,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
latents = self.get_latents(seed) latents = self.get_latents(seed)
encoder_hidden_states = self.get_encoder_hidden_states(seed) encoder_hidden_states = self.get_encoder_hidden_states(seed)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad(): with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
...@@ -526,6 +648,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -526,6 +648,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
latents = self.get_latents(seed, fp16=True) latents = self.get_latents(seed, fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad(): with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
...@@ -552,6 +676,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -552,6 +676,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
latents = self.get_latents(seed) latents = self.get_latents(seed)
encoder_hidden_states = self.get_encoder_hidden_states(seed) encoder_hidden_states = self.get_encoder_hidden_states(seed)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad(): with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
...@@ -578,6 +704,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -578,6 +704,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
latents = self.get_latents(seed, fp16=True) latents = self.get_latents(seed, fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad(): with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
...@@ -604,6 +732,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -604,6 +732,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
latents = self.get_latents(seed, shape=(4, 9, 64, 64)) latents = self.get_latents(seed, shape=(4, 9, 64, 64))
encoder_hidden_states = self.get_encoder_hidden_states(seed) encoder_hidden_states = self.get_encoder_hidden_states(seed)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad(): with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
...@@ -630,6 +760,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -630,6 +760,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True) latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad(): with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
...@@ -656,6 +788,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase): ...@@ -656,6 +788,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True) latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True) encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad(): with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment