Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
6b68afd8
Unverified
Commit
6b68afd8
authored
Dec 09, 2022
by
Patrick von Platen
Committed by
GitHub
Dec 09, 2022
Browse files
do not automatically enable xformers (#1640)
* do not automatically enable xformers * uP
parent
63c49449
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
11 deletions
+30
-11
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+10
-0
examples/text_to_image/train_text_to_image.py
examples/text_to_image/train_text_to_image.py
+10
-0
examples/textual_inversion/textual_inversion.py
examples/textual_inversion/textual_inversion.py
+10
-0
src/diffusers/models/attention.py
src/diffusers/models/attention.py
+0
-11
No files found.
examples/dreambooth/train_dreambooth.py
View file @
6b68afd8
...
@@ -17,6 +17,7 @@ from accelerate.utils import set_seed
...
@@ -17,6 +17,7 @@ from accelerate.utils import set_seed
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
from
diffusers.utils
import
check_min_version
from
diffusers.utils.import_utils
import
is_xformers_available
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
PIL
import
Image
from
PIL
import
Image
from
torchvision
import
transforms
from
torchvision
import
transforms
...
@@ -488,6 +489,15 @@ def main(args):
...
@@ -488,6 +489,15 @@ def main(args):
revision
=
args
.
revision
,
revision
=
args
.
revision
,
)
)
if
is_xformers_available
():
try
:
unet
.
enable_xformers_memory_efficient_attention
(
True
)
except
Exception
as
e
:
logger
.
warning
(
"Could not enable memory efficient attention. Make sure xformers is installed"
f
" correctly and a GPU is available:
{
e
}
"
)
vae
.
requires_grad_
(
False
)
vae
.
requires_grad_
(
False
)
if
not
args
.
train_text_encoder
:
if
not
args
.
train_text_encoder
:
text_encoder
.
requires_grad_
(
False
)
text_encoder
.
requires_grad_
(
False
)
...
...
examples/text_to_image/train_text_to_image.py
View file @
6b68afd8
...
@@ -18,6 +18,7 @@ from datasets import load_dataset
...
@@ -18,6 +18,7 @@ from datasets import load_dataset
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
StableDiffusionPipeline
,
UNet2DConditionModel
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
StableDiffusionPipeline
,
UNet2DConditionModel
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
from
diffusers.utils
import
check_min_version
from
diffusers.utils.import_utils
import
is_xformers_available
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
torchvision
import
transforms
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
...
@@ -364,6 +365,15 @@ def main():
...
@@ -364,6 +365,15 @@ def main():
revision
=
args
.
revision
,
revision
=
args
.
revision
,
)
)
if
is_xformers_available
():
try
:
unet
.
enable_xformers_memory_efficient_attention
(
True
)
except
Exception
as
e
:
logger
.
warning
(
"Could not enable memory efficient attention. Make sure xformers is installed"
f
" correctly and a GPU is available:
{
e
}
"
)
# Freeze vae and text_encoder
# Freeze vae and text_encoder
vae
.
requires_grad_
(
False
)
vae
.
requires_grad_
(
False
)
text_encoder
.
requires_grad_
(
False
)
text_encoder
.
requires_grad_
(
False
)
...
...
examples/textual_inversion/textual_inversion.py
View file @
6b68afd8
...
@@ -20,6 +20,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusi
...
@@ -20,6 +20,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusi
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionSafetyChecker
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionSafetyChecker
from
diffusers.utils
import
check_min_version
from
diffusers.utils
import
check_min_version
from
diffusers.utils.import_utils
import
is_xformers_available
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
...
@@ -439,6 +440,15 @@ def main():
...
@@ -439,6 +440,15 @@ def main():
revision
=
args
.
revision
,
revision
=
args
.
revision
,
)
)
if
is_xformers_available
():
try
:
unet
.
enable_xformers_memory_efficient_attention
(
True
)
except
Exception
as
e
:
logger
.
warning
(
"Could not enable memory efficient attention. Make sure xformers is installed"
f
" correctly and a GPU is available:
{
e
}
"
)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder
.
resize_token_embeddings
(
len
(
tokenizer
))
text_encoder
.
resize_token_embeddings
(
len
(
tokenizer
))
...
...
src/diffusers/models/attention.py
View file @
6b68afd8
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
math
import
math
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
Optional
...
@@ -447,16 +446,6 @@ class BasicTransformerBlock(nn.Module):
...
@@ -447,16 +446,6 @@ class BasicTransformerBlock(nn.Module):
# 3. Feed-forward
# 3. Feed-forward
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
self
.
norm3
=
nn
.
LayerNorm
(
dim
)
# if xformers is installed try to use memory_efficient_attention by default
if
is_xformers_available
():
try
:
self
.
set_use_memory_efficient_attention_xformers
(
True
)
except
Exception
as
e
:
warnings
.
warn
(
"Could not enable memory efficient attention. Make sure xformers is installed"
f
" correctly and a GPU is available:
{
e
}
"
)
def
set_use_memory_efficient_attention_xformers
(
self
,
use_memory_efficient_attention_xformers
:
bool
):
def
set_use_memory_efficient_attention_xformers
(
self
,
use_memory_efficient_attention_xformers
:
bool
):
if
not
is_xformers_available
():
if
not
is_xformers_available
():
print
(
"Here is how to install it"
)
print
(
"Here is how to install it"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment