Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
5ea4be86
Unverified
Commit
5ea4be86
authored
Jan 20, 2023
by
Lucain
Committed by
GitHub
Jan 20, 2023
Browse files
Create repo before cloning in examples (#2047)
* Create repo before cloning in examples * code quality
parent
e5ff7554
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
38 additions
and
26 deletions
+38
-26
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+3
-2
examples/dreambooth/train_dreambooth_flax.py
examples/dreambooth/train_dreambooth_flax.py
+3
-2
examples/dreambooth/train_dreambooth_lora.py
examples/dreambooth/train_dreambooth_lora.py
+2
-2
examples/research_projects/colossalai/train_dreambooth_colossalai.py
...search_projects/colossalai/train_dreambooth_colossalai.py
+3
-2
examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
...h_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
+3
-2
examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
...ts/intel_opts/textual_inversion/textual_inversion_bf16.py
+3
-2
examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
...ulti_subject_dreambooth/train_multi_subject_dreambooth.py
+3
-2
examples/text_to_image/train_text_to_image.py
examples/text_to_image/train_text_to_image.py
+3
-2
examples/text_to_image/train_text_to_image_flax.py
examples/text_to_image/train_text_to_image_flax.py
+3
-2
examples/textual_inversion/textual_inversion.py
examples/textual_inversion/textual_inversion.py
+3
-2
examples/textual_inversion/textual_inversion_flax.py
examples/textual_inversion/textual_inversion_flax.py
+3
-2
examples/unconditional_image_generation/train_unconditional.py
...les/unconditional_image_generation/train_unconditional.py
+3
-2
examples/unconditional_image_generation/train_unconditional_ort.py
...unconditional_image_generation/train_unconditional_ort.py
+3
-2
No files found.
examples/dreambooth/train_dreambooth.py
View file @
5ea4be86
...
@@ -38,7 +38,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DCon
...
@@ -38,7 +38,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DCon
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
from
diffusers.utils
import
check_min_version
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
from
PIL
import
Image
from
PIL
import
Image
from
torchvision
import
transforms
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
...
@@ -551,7 +551,8 @@ def main(args):
...
@@ -551,7 +551,8 @@ def main(args):
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
examples/dreambooth/train_dreambooth_flax.py
View file @
5ea4be86
...
@@ -27,7 +27,7 @@ from diffusers.utils import check_min_version
...
@@ -27,7 +27,7 @@ from diffusers.utils import check_min_version
from
flax
import
jax_utils
from
flax
import
jax_utils
from
flax.training
import
train_state
from
flax.training
import
train_state
from
flax.training.common_utils
import
shard
from
flax.training.common_utils
import
shard
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
from
jax.experimental.compilation_cache
import
compilation_cache
as
cc
from
PIL
import
Image
from
PIL
import
Image
from
torchvision
import
transforms
from
torchvision
import
transforms
...
@@ -387,7 +387,8 @@ def main():
...
@@ -387,7 +387,8 @@ def main():
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
examples/dreambooth/train_dreambooth_lora.py
View file @
5ea4be86
...
@@ -599,8 +599,8 @@ def main(args):
...
@@ -599,8 +599,8 @@ def main(args):
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo_name
=
create_repo
(
repo_name
,
exist_ok
=
True
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
examples/research_projects/colossalai/train_dreambooth_colossalai.py
View file @
5ea4be86
...
@@ -20,7 +20,7 @@ from colossalai.utils import get_current_device
...
@@ -20,7 +20,7 @@ from colossalai.utils import get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
from
diffusers
import
AutoencoderKL
,
DDPMScheduler
,
DiffusionPipeline
,
UNet2DConditionModel
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
from
PIL
import
Image
from
PIL
import
Image
from
torchvision
import
transforms
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
...
@@ -420,7 +420,8 @@ def main(args):
...
@@ -420,7 +420,8 @@ def main(args):
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py
View file @
5ea4be86
...
@@ -25,7 +25,7 @@ from diffusers import (
...
@@ -25,7 +25,7 @@ from diffusers import (
)
)
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
from
diffusers.utils
import
check_min_version
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
from
PIL
import
Image
,
ImageDraw
from
PIL
import
Image
,
ImageDraw
from
torchvision
import
transforms
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
...
@@ -471,7 +471,8 @@ def main():
...
@@ -471,7 +471,8 @@ def main():
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
View file @
5ea4be86
...
@@ -21,7 +21,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusi
...
@@ -21,7 +21,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusi
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionSafetyChecker
from
diffusers.pipelines.stable_diffusion
import
StableDiffusionSafetyChecker
from
diffusers.utils
import
check_min_version
from
diffusers.utils
import
check_min_version
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from
packaging
import
version
from
packaging
import
version
...
@@ -393,7 +393,8 @@ def main():
...
@@ -393,7 +393,8 @@ def main():
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
View file @
5ea4be86
...
@@ -23,7 +23,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DCon
...
@@ -23,7 +23,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DCon
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
from
diffusers.utils
import
check_min_version
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
from
PIL
import
Image
from
PIL
import
Image
from
torchvision
import
transforms
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
...
@@ -570,7 +570,8 @@ def main(args):
...
@@ -570,7 +570,8 @@ def main(args):
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
examples/text_to_image/train_text_to_image.py
View file @
5ea4be86
...
@@ -38,7 +38,7 @@ from diffusers.optimization import get_scheduler
...
@@ -38,7 +38,7 @@ from diffusers.optimization import get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
check_min_version
from
diffusers.utils
import
check_min_version
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
from
torchvision
import
transforms
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
transformers
import
CLIPTextModel
,
CLIPTokenizer
from
transformers
import
CLIPTextModel
,
CLIPTokenizer
...
@@ -343,7 +343,8 @@ def main():
...
@@ -343,7 +343,8 @@ def main():
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
examples/text_to_image/train_text_to_image_flax.py
View file @
5ea4be86
...
@@ -27,7 +27,7 @@ from diffusers.utils import check_min_version
...
@@ -27,7 +27,7 @@ from diffusers.utils import check_min_version
from
flax
import
jax_utils
from
flax
import
jax_utils
from
flax.training
import
train_state
from
flax.training
import
train_state
from
flax.training.common_utils
import
shard
from
flax.training.common_utils
import
shard
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
from
torchvision
import
transforms
from
torchvision
import
transforms
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
,
FlaxCLIPTextModel
,
set_seed
from
transformers
import
CLIPFeatureExtractor
,
CLIPTokenizer
,
FlaxCLIPTextModel
,
set_seed
...
@@ -255,7 +255,8 @@ def main():
...
@@ -255,7 +255,8 @@ def main():
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
examples/textual_inversion/textual_inversion.py
View file @
5ea4be86
...
@@ -38,7 +38,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNe
...
@@ -38,7 +38,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNe
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.utils
import
check_min_version
from
diffusers.utils
import
check_min_version
from
diffusers.utils.import_utils
import
is_xformers_available
from
diffusers.utils.import_utils
import
is_xformers_available
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from
packaging
import
version
from
packaging
import
version
...
@@ -464,7 +464,8 @@ def main():
...
@@ -464,7 +464,8 @@ def main():
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
examples/textual_inversion/textual_inversion_flax.py
View file @
5ea4be86
...
@@ -28,7 +28,7 @@ from diffusers.utils import check_min_version
...
@@ -28,7 +28,7 @@ from diffusers.utils import check_min_version
from
flax
import
jax_utils
from
flax
import
jax_utils
from
flax.training
import
train_state
from
flax.training
import
train_state
from
flax.training.common_utils
import
shard
from
flax.training.common_utils
import
shard
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
from
packaging
import
version
from
packaging
import
version
...
@@ -372,7 +372,8 @@ def main():
...
@@ -372,7 +372,8 @@ def main():
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
examples/unconditional_image_generation/train_unconditional.py
View file @
5ea4be86
...
@@ -19,7 +19,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
...
@@ -19,7 +19,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
check_min_version
from
diffusers.utils
import
check_min_version
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
from
torchvision.transforms
import
(
from
torchvision.transforms
import
(
CenterCrop
,
CenterCrop
,
Compose
,
Compose
,
...
@@ -287,7 +287,8 @@ def main(args):
...
@@ -287,7 +287,8 @@ def main(args):
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
examples/unconditional_image_generation/train_unconditional_ort.py
View file @
5ea4be86
...
@@ -15,7 +15,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
...
@@ -15,7 +15,7 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
EMAModel
from
diffusers.training_utils
import
EMAModel
from
diffusers.utils
import
check_min_version
from
diffusers.utils
import
check_min_version
from
huggingface_hub
import
HfFolder
,
Repository
,
whoami
from
huggingface_hub
import
HfFolder
,
Repository
,
create_repo
,
whoami
from
onnxruntime.training.ortmodule
import
ORTModule
from
onnxruntime.training.ortmodule
import
ORTModule
from
torchvision.transforms
import
(
from
torchvision.transforms
import
(
CenterCrop
,
CenterCrop
,
...
@@ -371,7 +371,8 @@ def main(args):
...
@@ -371,7 +371,8 @@ def main(args):
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
repo_name
=
get_full_repo_name
(
Path
(
args
.
output_dir
).
name
,
token
=
args
.
hub_token
)
else
:
else
:
repo_name
=
args
.
hub_model_id
repo_name
=
args
.
hub_model_id
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
)
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
args
.
hub_token
)
repo
=
Repository
(
args
.
output_dir
,
clone_from
=
repo_name
,
token
=
args
.
hub_token
)
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"step_*"
not
in
gitignore
:
if
"step_*"
not
in
gitignore
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment