"...text-generation-inference.git" did not exist on "f171bdc82313e58bf82a05fee1d5c01d2f3f2be0"
Unverified Commit a816a87a authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor] Making the xformers mem-efficient attention activation recursive (#1493)



* Moving the mem efficiient attention activation to the top + recursive

* black, too bad there's no pre-commit ?
Co-authored-by: default avatarBenjamin Lefaudeux <benjamin@photoroom.com>
parent f21415d1
...@@ -488,24 +488,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline): ...@@ -488,24 +488,6 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
......
...@@ -106,24 +106,6 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -106,24 +106,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
sampling = getattr(library, "sampling") sampling = getattr(library, "sampling")
self.sampler = getattr(sampling, scheduler_type) self.sampler = getattr(sampling, scheduler_type)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
......
...@@ -183,24 +183,6 @@ class TextInpainting(DiffusionPipeline): ...@@ -183,24 +183,6 @@ class TextInpainting(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, self,
......
...@@ -246,10 +246,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -246,10 +246,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
return Transformer2DModelOutput(sample=output) return Transformer2DModelOutput(sample=output)
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.transformer_blocks:
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
""" """
...@@ -414,7 +410,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -414,7 +410,7 @@ class BasicTransformerBlock(nn.Module):
# if xformers is installed try to use memory_efficient_attention by default # if xformers is installed try to use memory_efficient_attention by default
if is_xformers_available(): if is_xformers_available():
try: try:
self._set_use_memory_efficient_attention_xformers(True) self.set_use_memory_efficient_attention_xformers(True)
except Exception as e: except Exception as e:
warnings.warn( warnings.warn(
"Could not enable memory efficient attention. Make sure xformers is installed" "Could not enable memory efficient attention. Make sure xformers is installed"
...@@ -425,7 +421,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -425,7 +421,7 @@ class BasicTransformerBlock(nn.Module):
self.attn1._slice_size = slice_size self.attn1._slice_size = slice_size
self.attn2._slice_size = slice_size self.attn2._slice_size = slice_size
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available(): if not is_xformers_available():
print("Here is how to install it") print("Here is how to install it")
raise ModuleNotFoundError( raise ModuleNotFoundError(
...@@ -835,11 +831,3 @@ class DualTransformer2DModel(nn.Module): ...@@ -835,11 +831,3 @@ class DualTransformer2DModel(nn.Module):
return (output_states,) return (output_states,)
return Transformer2DModelOutput(sample=output_states) return Transformer2DModelOutput(sample=output_states)
def _set_attention_slice(self, slice_size):
for transformer in self.transformers:
transformer._set_attention_slice(slice_size)
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for transformer in self.transformers:
transformer._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
...@@ -418,10 +418,6 @@ class UNetMidBlock2DCrossAttn(nn.Module): ...@@ -418,10 +418,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
for attn in self.attentions: for attn in self.attentions:
attn._set_attention_slice(slice_size) attn._set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
...@@ -616,10 +612,6 @@ class CrossAttnDownBlock2D(nn.Module): ...@@ -616,10 +612,6 @@ class CrossAttnDownBlock2D(nn.Module):
for attn in self.attentions: for attn in self.attentions:
attn._set_attention_slice(slice_size) attn._set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = () output_states = ()
...@@ -1217,10 +1209,6 @@ class CrossAttnUpBlock2D(nn.Module): ...@@ -1217,10 +1209,6 @@ class CrossAttnUpBlock2D(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward( def forward(
self, self,
hidden_states, hidden_states,
......
...@@ -252,17 +252,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): ...@@ -252,17 +252,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if hasattr(block, "attentions") and block.attentions is not None: if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size) block.set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
......
...@@ -789,3 +789,38 @@ class DiffusionPipeline(ConfigMixin): ...@@ -789,3 +789,38 @@ class DiffusionPipeline(ConfigMixin):
def set_progress_bar_config(self, **kwargs): def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs self._progress_bar_config = kwargs
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.set_use_memory_efficient_attention_xformers(False)
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
for child in module.children():
fn_recursive_set_mem_eff(child)
module_names, _, _ = self.extract_init_dict(dict(self.config))
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)
...@@ -166,24 +166,6 @@ class AltDiffusionPipeline(DiffusionPipeline): ...@@ -166,24 +166,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
......
...@@ -251,24 +251,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -251,24 +251,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
......
...@@ -285,26 +285,6 @@ class CycleDiffusionPipeline(DiffusionPipeline): ...@@ -285,26 +285,6 @@ class CycleDiffusionPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r""" r"""
......
...@@ -165,24 +165,6 @@ class StableDiffusionPipeline(DiffusionPipeline): ...@@ -165,24 +165,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
......
...@@ -134,26 +134,6 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -134,26 +134,6 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
......
...@@ -254,26 +254,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): ...@@ -254,26 +254,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
return torch.device(module._hf_hook.execution_device) return torch.device(module._hf_hook.execution_device)
return self.device return self.device
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
r""" r"""
......
...@@ -300,26 +300,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): ...@@ -300,26 +300,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# fix by only offloading self.safety_checker for now # fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device) cpu_offload(self.safety_checker.vision_model, device)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self): def _execution_device(self):
......
...@@ -248,26 +248,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): ...@@ -248,26 +248,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
# fix by only offloading self.safety_checker for now # fix by only offloading self.safety_checker for now
cpu_offload(self.safety_checker.vision_model, device) cpu_offload(self.safety_checker.vision_model, device)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self): def _execution_device(self):
......
...@@ -143,26 +143,6 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): ...@@ -143,26 +143,6 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
if cpu_offloaded_model is not None: if cpu_offloaded_model is not None:
cpu_offload(cpu_offloaded_model, device) cpu_offload(cpu_offloaded_model, device)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
@property @property
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
def _execution_device(self): def _execution_device(self):
......
...@@ -182,24 +182,6 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -182,24 +182,6 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
""" """
self._safety_text_concept = concept self._safety_text_concept = concept
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.unet.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.unet.set_use_memory_efficient_attention_xformers(False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
Enable sliced attention computation. Enable sliced attention computation.
......
...@@ -330,17 +330,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -330,17 +330,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if hasattr(block, "attentions") and block.attentions is not None: if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size) block.set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.down_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
for block in self.up_blocks:
if hasattr(block, "attentions") and block.attentions is not None:
block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)): if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)):
module.gradient_checkpointing = value module.gradient_checkpointing = value
...@@ -761,10 +750,6 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -761,10 +750,6 @@ class CrossAttnDownBlockFlat(nn.Module):
for attn in self.attentions: for attn in self.attentions:
attn._set_attention_slice(slice_size) attn._set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = () output_states = ()
...@@ -976,10 +961,6 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -976,10 +961,6 @@ class CrossAttnUpBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward( def forward(
self, self,
hidden_states, hidden_states,
...@@ -1122,10 +1103,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -1122,10 +1103,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
for attn in self.attentions: for attn in self.attentions:
attn._set_attention_slice(slice_size) attn._set_attention_slice(slice_size)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for attn in self.attentions:
attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None): def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
......
...@@ -147,26 +147,6 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -147,26 +147,6 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
self.image_unet.register_to_config(dual_cross_attention=False) self.image_unet.register_to_config(dual_cross_attention=False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.image_unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.image_unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
......
...@@ -73,26 +73,6 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -73,26 +73,6 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention with unet->image_unet
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.image_unet.set_use_memory_efficient_attention_xformers(True)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_xformers_memory_efficient_attention with unet->image_unet
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.image_unet.set_use_memory_efficient_attention_xformers(False)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r""" r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment