Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
5e5ce13e
Unverified
Commit
5e5ce13e
authored
Mar 03, 2023
by
Alex McKinney
Committed by
GitHub
Mar 03, 2023
Browse files
adds `xformers` support to `train_unconditional.py` (#2520)
parent
7f0f7e1e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
0 deletions
+17
-0
examples/unconditional_image_generation/train_unconditional.py
...les/unconditional_image_generation/train_unconditional.py
+17
-0
No files found.
examples/unconditional_image_generation/train_unconditional.py
View file @
5e5ce13e
...
...
@@ -24,6 +24,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
check_min_version
,
is_accelerate_version
,
is_tensorboard_available
,
is_wandb_available
from
diffusers.utils.import_utils
import
is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...
...
@@ -259,6 +260,9 @@ def parse_args():
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser
.
add_argument
(
"--enable_xformers_memory_efficient_attention"
,
action
=
"store_true"
,
help
=
"Whether or not to use xformers."
)
args
=
parser
.
parse_args
()
env_local_rank
=
int
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
-
1
))
...
...
@@ -410,6 +414,19 @@ def main(args):
model_config
=
model
.
config
,
)
if
args
.
enable_xformers_memory_efficient_attention
:
if
is_xformers_available
():
import
xformers
xformers_version
=
version
.
parse
(
xformers
.
__version__
)
if
xformers_version
==
version
.
parse
(
"0.0.16"
):
logger
.
warn
(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
model
.
enable_xformers_memory_efficient_attention
()
else
:
raise
ValueError
(
"xformers is not available. Make sure it is installed correctly"
)
# Initialize the scheduler
accepts_prediction_type
=
"prediction_type"
in
set
(
inspect
.
signature
(
DDPMScheduler
.
__init__
).
parameters
.
keys
())
if
accepts_prediction_type
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment