Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
e780c05c
Unverified
Commit
e780c05c
authored
Aug 16, 2024
by
Sayak Paul
Committed by
GitHub
Aug 16, 2024
Browse files
[Chore] add set_default_attn_processor to pixart. (#9196)
add set_default_attn_processor to pixart.
parent
e649678b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
1 deletion
+9
-1
src/diffusers/models/transformers/pixart_transformer_2d.py
src/diffusers/models/transformers/pixart_transformer_2d.py
+9
-1
No files found.
src/diffusers/models/transformers/pixart_transformer_2d.py
View file @
e780c05c
...
@@ -19,7 +19,7 @@ from torch import nn
...
@@ -19,7 +19,7 @@ from torch import nn
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...configuration_utils
import
ConfigMixin
,
register_to_config
from
...utils
import
is_torch_version
,
logging
from
...utils
import
is_torch_version
,
logging
from
..attention
import
BasicTransformerBlock
from
..attention
import
BasicTransformerBlock
from
..attention_processor
import
Attention
,
AttentionProcessor
,
FusedAttnProcessor2_0
from
..attention_processor
import
Attention
,
AttentionProcessor
,
AttnProcessor
,
FusedAttnProcessor2_0
from
..embeddings
import
PatchEmbed
,
PixArtAlphaTextProjection
from
..embeddings
import
PatchEmbed
,
PixArtAlphaTextProjection
from
..modeling_outputs
import
Transformer2DModelOutput
from
..modeling_outputs
import
Transformer2DModelOutput
from
..modeling_utils
import
ModelMixin
from
..modeling_utils
import
ModelMixin
...
@@ -247,6 +247,14 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
...
@@ -247,6 +247,14 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
for
name
,
module
in
self
.
named_children
():
for
name
,
module
in
self
.
named_children
():
fn_recursive_attn_processor
(
name
,
module
,
processor
)
fn_recursive_attn_processor
(
name
,
module
,
processor
)
def
set_default_attn_processor
(
self
):
"""
Disables custom attention processors and sets the default attention implementation.
Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
"""
self
.
set_attn_processor
(
AttnProcessor
())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def
fuse_qkv_projections
(
self
):
def
fuse_qkv_projections
(
self
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment