Unverified Commit e780c05c authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Chore] add set_default_attn_processor to pixart. (#9196)

add set_default_attn_processor to pixart.
parent e649678b
......@@ -19,7 +19,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ..attention import BasicTransformerBlock
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
......@@ -247,6 +247,14 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
"""
self.set_attn_processor(AttnProcessor())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment