Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
edc154da
Unverified
Commit
edc154da
authored
Apr 09, 2025
by
Dhruv Nair
Committed by
GitHub
Apr 09, 2025
Browse files
Update Ruff to latest Version (#10919)
* update * update * update * update
parent
552cd320
Changes
200
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
49 additions
and
48 deletions
+49
-48
examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
...nsistency_distillation/train_lcm_distill_lora_sdxl_wds.py
+1
-1
examples/custom_diffusion/retrieve.py
examples/custom_diffusion/retrieve.py
+5
-3
examples/custom_diffusion/train_custom_diffusion.py
examples/custom_diffusion/train_custom_diffusion.py
+12
-12
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+1
-1
examples/dreambooth/train_dreambooth_lora.py
examples/dreambooth/train_dreambooth_lora.py
+1
-1
examples/dreambooth/train_dreambooth_lora_flux.py
examples/dreambooth/train_dreambooth_lora_flux.py
+1
-1
examples/dreambooth/train_dreambooth_lora_lumina2.py
examples/dreambooth/train_dreambooth_lora_lumina2.py
+1
-1
examples/dreambooth/train_dreambooth_lora_sana.py
examples/dreambooth/train_dreambooth_lora_sana.py
+1
-1
examples/dreambooth/train_dreambooth_lora_sd3.py
examples/dreambooth/train_dreambooth_lora_sd3.py
+1
-1
examples/dreambooth/train_dreambooth_lora_sdxl.py
examples/dreambooth/train_dreambooth_lora_sdxl.py
+2
-2
examples/flux-control/train_control_lora_flux.py
examples/flux-control/train_control_lora_flux.py
+4
-4
examples/model_search/pipeline_easy.py
examples/model_search/pipeline_easy.py
+3
-3
examples/research_projects/anytext/anytext.py
examples/research_projects/anytext/anytext.py
+3
-3
examples/research_projects/anytext/ocr_recog/RecSVTR.py
examples/research_projects/anytext/ocr_recog/RecSVTR.py
+3
-3
examples/research_projects/colossalai/train_dreambooth_colossalai.py
...search_projects/colossalai/train_dreambooth_colossalai.py
+1
-1
examples/research_projects/controlnet/train_controlnet_webdataset.py
...search_projects/controlnet/train_controlnet_webdataset.py
+3
-4
examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
...es/research_projects/diffusion_dpo/train_diffusion_dpo.py
+1
-1
examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
...search_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
+1
-1
examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
...projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
+2
-2
examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
...ects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
+2
-2
No files found.
examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
View file @
edc154da
...
@@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
...
@@ -95,7 +95,7 @@ def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter
# Set alpha parameter
# Set alpha parameter
if
"lora_down"
in
kohya_key
:
if
"lora_down"
in
kohya_key
:
alpha_key
=
f
'
{
kohya_key
.
split
(
"."
)[
0
]
}
.alpha
'
alpha_key
=
f
"
{
kohya_key
.
split
(
'.'
)[
0
]
}
.alpha
"
kohya_ss_state_dict
[
alpha_key
]
=
torch
.
tensor
(
module
.
peft_config
[
adapter_name
].
lora_alpha
).
to
(
dtype
)
kohya_ss_state_dict
[
alpha_key
]
=
torch
.
tensor
(
module
.
peft_config
[
adapter_name
].
lora_alpha
).
to
(
dtype
)
return
kohya_ss_state_dict
return
kohya_ss_state_dict
...
...
examples/custom_diffusion/retrieve.py
View file @
edc154da
...
@@ -50,9 +50,11 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
...
@@ -50,9 +50,11 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
total
=
0
total
=
0
pbar
=
tqdm
(
desc
=
"downloading real regularization images"
,
total
=
num_class_images
)
pbar
=
tqdm
(
desc
=
"downloading real regularization images"
,
total
=
num_class_images
)
with
open
(
f
"
{
class_data_dir
}
/caption.txt"
,
"w"
)
as
f1
,
open
(
f
"
{
class_data_dir
}
/urls.txt"
,
"w"
)
as
f2
,
open
(
with
(
f
"
{
class_data_dir
}
/images.txt"
,
"w"
open
(
f
"
{
class_data_dir
}
/caption.txt"
,
"w"
)
as
f1
,
)
as
f3
:
open
(
f
"
{
class_data_dir
}
/urls.txt"
,
"w"
)
as
f2
,
open
(
f
"
{
class_data_dir
}
/images.txt"
,
"w"
)
as
f3
,
):
while
total
<
num_class_images
:
while
total
<
num_class_images
:
images
=
class_images
[
count
]
images
=
class_images
[
count
]
count
+=
1
count
+=
1
...
...
examples/custom_diffusion/train_custom_diffusion.py
View file @
edc154da
...
@@ -731,18 +731,18 @@ def main(args):
...
@@ -731,18 +731,18 @@ def main(args):
if
not
class_images_dir
.
exists
():
if
not
class_images_dir
.
exists
():
class_images_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
class_images_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
args
.
real_prior
:
if
args
.
real_prior
:
assert
(
assert
(
class_images_dir
/
"images"
).
exists
(),
(
class_images_dir
/
"
images
"
f
'Please run: python retrieve.py --class_prompt "
{
concept
[
"class_prompt"
]
}
" --class_data_dir
{
class_images_dir
}
--num_class_images
{
args
.
num_class_
images
}
'
)
.
exists
(),
f
"Please run: python retrieve.py --class_prompt
\"
{
concept
[
'class_prompt'
]
}
\"
--class_data_dir
{
class_images_dir
}
--num_class_images
{
args
.
num_class_images
}
"
)
assert
(
assert
len
(
list
((
class_images_dir
/
"images"
).
iterdir
()))
==
args
.
num_class_images
,
(
len
(
list
((
class_images_dir
/
"images"
).
iterdir
()))
==
args
.
num_class_images
f
'Please run: python retrieve.py --class_prompt "
{
concept
[
"class_prompt"
]
}
" --class_data_dir
{
class_images_dir
}
--num_class_images
{
args
.
num_class_images
}
'
)
,
f
"Please run: python retrieve.py --class_prompt
\"
{
concept
[
'class_prompt'
]
}
\"
--class_data_dir
{
class_images_dir
}
--num_class_images
{
args
.
num_class_images
}
"
)
assert
(
assert
(
class_images_dir
/
"caption.txt"
).
exists
(),
(
class_images_dir
/
"caption.txt"
f
'Please run: python retrieve.py --class_prompt "
{
concept
[
"class_prompt"
]
}
" --class_data_dir
{
class_images_dir
}
--num_class_images
{
args
.
num_class_images
}
'
)
.
exists
(),
f
"Please run: python retrieve.py --class_prompt
\"
{
concept
[
'class_prompt'
]
}
\"
--class_data_dir
{
class_images_dir
}
--num_class_images
{
args
.
num_class_images
}
"
)
assert
(
assert
(
class_images_dir
/
"images.txt"
).
exists
(),
(
class_images_dir
/
"images.txt"
f
'Please run: python retrieve.py --class_prompt "
{
concept
[
"class_prompt"
]
}
" --class_data_dir
{
class_images_dir
}
--num_class_images
{
args
.
num_class_images
}
'
)
.
exists
(),
f
"Please run: python retrieve.py --class_prompt
\"
{
concept
[
'class_prompt'
]
}
\"
--class_data_dir
{
class_images_dir
}
--num_class_images
{
args
.
num_class_images
}
"
)
concept
[
"class_prompt"
]
=
os
.
path
.
join
(
class_images_dir
,
"caption.txt"
)
concept
[
"class_prompt"
]
=
os
.
path
.
join
(
class_images_dir
,
"caption.txt"
)
concept
[
"class_data_dir"
]
=
os
.
path
.
join
(
class_images_dir
,
"images.txt"
)
concept
[
"class_data_dir"
]
=
os
.
path
.
join
(
class_images_dir
,
"images.txt"
)
args
.
concepts_list
[
i
]
=
concept
args
.
concepts_list
[
i
]
=
concept
...
...
examples/dreambooth/train_dreambooth.py
View file @
edc154da
...
@@ -1014,7 +1014,7 @@ def main(args):
...
@@ -1014,7 +1014,7 @@ def main(args):
if
args
.
train_text_encoder
and
unwrap_model
(
text_encoder
).
dtype
!=
torch
.
float32
:
if
args
.
train_text_encoder
and
unwrap_model
(
text_encoder
).
dtype
!=
torch
.
float32
:
raise
ValueError
(
raise
ValueError
(
f
"Text encoder loaded as datatype
{
unwrap_model
(
text_encoder
).
dtype
}
.
"
f
"
{
low_precision_error_string
}
"
f
"Text encoder loaded as datatype
{
unwrap_model
(
text_encoder
).
dtype
}
.
{
low_precision_error_string
}
"
)
)
# Enable TF32 for faster training on Ampere GPUs,
# Enable TF32 for faster training on Ampere GPUs,
...
...
examples/dreambooth/train_dreambooth_lora.py
View file @
edc154da
...
@@ -982,7 +982,7 @@ def main(args):
...
@@ -982,7 +982,7 @@ def main(args):
lora_state_dict
,
network_alphas
=
StableDiffusionLoraLoaderMixin
.
lora_state_dict
(
input_dir
)
lora_state_dict
,
network_alphas
=
StableDiffusionLoraLoaderMixin
.
lora_state_dict
(
input_dir
)
unet_state_dict
=
{
f
'
{
k
.
replace
(
"
unet.
"
,
""
)
}
'
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"unet."
)}
unet_state_dict
=
{
f
"
{
k
.
replace
(
'
unet.
'
,
''
)
}
"
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"unet."
)}
unet_state_dict
=
convert_unet_state_dict_to_peft
(
unet_state_dict
)
unet_state_dict
=
convert_unet_state_dict_to_peft
(
unet_state_dict
)
incompatible_keys
=
set_peft_model_state_dict
(
unet_
,
unet_state_dict
,
adapter_name
=
"default"
)
incompatible_keys
=
set_peft_model_state_dict
(
unet_
,
unet_state_dict
,
adapter_name
=
"default"
)
...
...
examples/dreambooth/train_dreambooth_lora_flux.py
View file @
edc154da
...
@@ -1294,7 +1294,7 @@ def main(args):
...
@@ -1294,7 +1294,7 @@ def main(args):
lora_state_dict
=
FluxPipeline
.
lora_state_dict
(
input_dir
)
lora_state_dict
=
FluxPipeline
.
lora_state_dict
(
input_dir
)
transformer_state_dict
=
{
transformer_state_dict
=
{
f
'
{
k
.
replace
(
"
transformer.
"
,
""
)
}
'
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)
f
"
{
k
.
replace
(
'
transformer.
'
,
''
)
}
"
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)
}
}
transformer_state_dict
=
convert_unet_state_dict_to_peft
(
transformer_state_dict
)
transformer_state_dict
=
convert_unet_state_dict_to_peft
(
transformer_state_dict
)
incompatible_keys
=
set_peft_model_state_dict
(
transformer_
,
transformer_state_dict
,
adapter_name
=
"default"
)
incompatible_keys
=
set_peft_model_state_dict
(
transformer_
,
transformer_state_dict
,
adapter_name
=
"default"
)
...
...
examples/dreambooth/train_dreambooth_lora_lumina2.py
View file @
edc154da
...
@@ -1053,7 +1053,7 @@ def main(args):
...
@@ -1053,7 +1053,7 @@ def main(args):
lora_state_dict
=
Lumina2Text2ImgPipeline
.
lora_state_dict
(
input_dir
)
lora_state_dict
=
Lumina2Text2ImgPipeline
.
lora_state_dict
(
input_dir
)
transformer_state_dict
=
{
transformer_state_dict
=
{
f
'
{
k
.
replace
(
"
transformer.
"
,
""
)
}
'
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)
f
"
{
k
.
replace
(
'
transformer.
'
,
''
)
}
"
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)
}
}
transformer_state_dict
=
convert_unet_state_dict_to_peft
(
transformer_state_dict
)
transformer_state_dict
=
convert_unet_state_dict_to_peft
(
transformer_state_dict
)
incompatible_keys
=
set_peft_model_state_dict
(
transformer_
,
transformer_state_dict
,
adapter_name
=
"default"
)
incompatible_keys
=
set_peft_model_state_dict
(
transformer_
,
transformer_state_dict
,
adapter_name
=
"default"
)
...
...
examples/dreambooth/train_dreambooth_lora_sana.py
View file @
edc154da
...
@@ -1064,7 +1064,7 @@ def main(args):
...
@@ -1064,7 +1064,7 @@ def main(args):
lora_state_dict
=
SanaPipeline
.
lora_state_dict
(
input_dir
)
lora_state_dict
=
SanaPipeline
.
lora_state_dict
(
input_dir
)
transformer_state_dict
=
{
transformer_state_dict
=
{
f
'
{
k
.
replace
(
"
transformer.
"
,
""
)
}
'
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)
f
"
{
k
.
replace
(
'
transformer.
'
,
''
)
}
"
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)
}
}
transformer_state_dict
=
convert_unet_state_dict_to_peft
(
transformer_state_dict
)
transformer_state_dict
=
convert_unet_state_dict_to_peft
(
transformer_state_dict
)
incompatible_keys
=
set_peft_model_state_dict
(
transformer_
,
transformer_state_dict
,
adapter_name
=
"default"
)
incompatible_keys
=
set_peft_model_state_dict
(
transformer_
,
transformer_state_dict
,
adapter_name
=
"default"
)
...
...
examples/dreambooth/train_dreambooth_lora_sd3.py
View file @
edc154da
...
@@ -1355,7 +1355,7 @@ def main(args):
...
@@ -1355,7 +1355,7 @@ def main(args):
lora_state_dict
=
StableDiffusion3Pipeline
.
lora_state_dict
(
input_dir
)
lora_state_dict
=
StableDiffusion3Pipeline
.
lora_state_dict
(
input_dir
)
transformer_state_dict
=
{
transformer_state_dict
=
{
f
'
{
k
.
replace
(
"
transformer.
"
,
""
)
}
'
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)
f
"
{
k
.
replace
(
'
transformer.
'
,
''
)
}
"
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)
}
}
transformer_state_dict
=
convert_unet_state_dict_to_peft
(
transformer_state_dict
)
transformer_state_dict
=
convert_unet_state_dict_to_peft
(
transformer_state_dict
)
incompatible_keys
=
set_peft_model_state_dict
(
transformer_
,
transformer_state_dict
,
adapter_name
=
"default"
)
incompatible_keys
=
set_peft_model_state_dict
(
transformer_
,
transformer_state_dict
,
adapter_name
=
"default"
)
...
...
examples/dreambooth/train_dreambooth_lora_sdxl.py
View file @
edc154da
...
@@ -118,7 +118,7 @@ def save_model_card(
...
@@ -118,7 +118,7 @@ def save_model_card(
)
)
model_description
=
f
"""
model_description
=
f
"""
#
{
'
SDXL
'
if
'
playground
'
not
in
base_model
else
'
Playground
'
}
LoRA DreamBooth -
{
repo_id
}
#
{
"
SDXL
"
if
"
playground
"
not
in
base_model
else
"
Playground
"
}
LoRA DreamBooth -
{
repo_id
}
<Gallery />
<Gallery />
...
@@ -1286,7 +1286,7 @@ def main(args):
...
@@ -1286,7 +1286,7 @@ def main(args):
lora_state_dict
,
network_alphas
=
StableDiffusionLoraLoaderMixin
.
lora_state_dict
(
input_dir
)
lora_state_dict
,
network_alphas
=
StableDiffusionLoraLoaderMixin
.
lora_state_dict
(
input_dir
)
unet_state_dict
=
{
f
'
{
k
.
replace
(
"
unet.
"
,
""
)
}
'
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"unet."
)}
unet_state_dict
=
{
f
"
{
k
.
replace
(
'
unet.
'
,
''
)
}
"
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"unet."
)}
unet_state_dict
=
convert_unet_state_dict_to_peft
(
unet_state_dict
)
unet_state_dict
=
convert_unet_state_dict_to_peft
(
unet_state_dict
)
incompatible_keys
=
set_peft_model_state_dict
(
unet_
,
unet_state_dict
,
adapter_name
=
"default"
)
incompatible_keys
=
set_peft_model_state_dict
(
unet_
,
unet_state_dict
,
adapter_name
=
"default"
)
if
incompatible_keys
is
not
None
:
if
incompatible_keys
is
not
None
:
...
...
examples/flux-control/train_control_lora_flux.py
View file @
edc154da
...
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
...
@@ -91,9 +91,9 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
torch_dtype
=
weight_dtype
,
torch_dtype
=
weight_dtype
,
)
)
pipeline
.
load_lora_weights
(
args
.
output_dir
)
pipeline
.
load_lora_weights
(
args
.
output_dir
)
assert
(
assert
pipeline
.
transformer
.
config
.
in_channels
==
initial_channels
*
2
,
(
pipeline
.
transformer
.
config
.
in_channels
==
initial_channels
*
2
f
"
{
pipeline
.
transformer
.
config
.
in_channels
=
}
"
)
,
f
"
{
pipeline
.
transformer
.
config
.
in_channels
=
}
"
)
pipeline
.
to
(
accelerator
.
device
)
pipeline
.
to
(
accelerator
.
device
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
pipeline
.
set_progress_bar_config
(
disable
=
True
)
...
@@ -954,7 +954,7 @@ def main(args):
...
@@ -954,7 +954,7 @@ def main(args):
lora_state_dict
=
FluxControlPipeline
.
lora_state_dict
(
input_dir
)
lora_state_dict
=
FluxControlPipeline
.
lora_state_dict
(
input_dir
)
transformer_lora_state_dict
=
{
transformer_lora_state_dict
=
{
f
'
{
k
.
replace
(
"
transformer.
"
,
""
)
}
'
:
v
f
"
{
k
.
replace
(
'
transformer.
'
,
''
)
}
"
:
v
for
k
,
v
in
lora_state_dict
.
items
()
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"transformer."
)
and
"lora"
in
k
if
k
.
startswith
(
"transformer."
)
and
"lora"
in
k
}
}
...
...
examples/model_search/pipeline_easy.py
View file @
edc154da
...
@@ -1081,9 +1081,9 @@ class AutoConfig:
...
@@ -1081,9 +1081,9 @@ class AutoConfig:
f
"textual_inversion_path:
{
search_word
}
->
{
textual_inversion_path
.
model_status
.
site_url
}
"
f
"textual_inversion_path:
{
search_word
}
->
{
textual_inversion_path
.
model_status
.
site_url
}
"
)
)
pretrained_model_name_or_paths
[
pretrained_model_name_or_paths
[
pretrained_model_name_or_paths
.
index
(
search_word
)]
=
(
pretrained_model_name_or_paths
.
index
(
search_word
)
textual_inversion_path
.
model_path
]
=
textual_inversion_path
.
model_path
)
self
.
load_textual_inversion
(
self
.
load_textual_inversion
(
pretrained_model_name_or_paths
,
token
=
tokens
,
tokenizer
=
tokenizer
,
text_encoder
=
text_encoder
,
**
kwargs
pretrained_model_name_or_paths
,
token
=
tokens
,
tokenizer
=
tokenizer
,
text_encoder
=
text_encoder
,
**
kwargs
...
...
examples/research_projects/anytext/anytext.py
View file @
edc154da
...
@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
...
@@ -187,9 +187,9 @@ def get_clip_token_for_string(tokenizer, string):
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
)
)
tokens
=
batch_encoding
[
"input_ids"
]
tokens
=
batch_encoding
[
"input_ids"
]
assert
(
assert
torch
.
count_nonzero
(
tokens
-
49407
)
==
2
,
(
torch
.
count_nonzero
(
tokens
-
49407
)
==
2
f
"String '
{
string
}
' maps to more than a single token. Please use another string"
)
,
f
"String '
{
string
}
' maps to more than a single token. Please use another string"
)
return
tokens
[
0
,
1
]
return
tokens
[
0
,
1
]
...
...
examples/research_projects/anytext/ocr_recog/RecSVTR.py
View file @
edc154da
...
@@ -312,9 +312,9 @@ class PatchEmbed(nn.Module):
...
@@ -312,9 +312,9 @@ class PatchEmbed(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
B
,
C
,
H
,
W
=
x
.
shape
assert
(
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
(
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
]
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
)
,
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
)
x
=
self
.
proj
(
x
).
flatten
(
2
).
permute
(
0
,
2
,
1
)
x
=
self
.
proj
(
x
).
flatten
(
2
).
permute
(
0
,
2
,
1
)
return
x
return
x
...
...
examples/research_projects/colossalai/train_dreambooth_colossalai.py
View file @
edc154da
...
@@ -619,7 +619,7 @@ def main(args):
...
@@ -619,7 +619,7 @@ def main(args):
optimizer
.
step
()
optimizer
.
step
()
lr_scheduler
.
step
()
lr_scheduler
.
step
()
logger
.
info
(
f
"max GPU_mem cost is
{
torch
.
cuda
.
max_memory_allocated
()
/
2
**
20
}
MB"
,
ranks
=
[
0
])
logger
.
info
(
f
"max GPU_mem cost is
{
torch
.
cuda
.
max_memory_allocated
()
/
2
**
20
}
MB"
,
ranks
=
[
0
])
# Checks if the accelerator has performed an optimization step behind the scenes
# Checks if the accelerator has performed an optimization step behind the scenes
progress_bar
.
update
(
1
)
progress_bar
.
update
(
1
)
global_step
+=
1
global_step
+=
1
...
...
examples/research_projects/controlnet/train_controlnet_webdataset.py
View file @
edc154da
...
@@ -803,21 +803,20 @@ def parse_args(input_args=None):
...
@@ -803,21 +803,20 @@ def parse_args(input_args=None):
"--control_type"
,
"--control_type"
,
type
=
str
,
type
=
str
,
default
=
"canny"
,
default
=
"canny"
,
help
=
(
"The type of controlnet conditioning image to use. One of `canny`, `depth`
"
"
Defaults to `canny`."
),
help
=
(
"The type of controlnet conditioning image to use. One of `canny`, `depth` Defaults to `canny`."
),
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--transformer_layers_per_block"
,
"--transformer_layers_per_block"
,
type
=
str
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
(
"The number of layers per block in the transformer. If None, defaults to
"
"
`args.transformer_layers`."
),
help
=
(
"The number of layers per block in the transformer. If None, defaults to `args.transformer_layers`."
),
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--old_style_controlnet"
,
"--old_style_controlnet"
,
action
=
"store_true"
,
action
=
"store_true"
,
default
=
False
,
default
=
False
,
help
=
(
help
=
(
"Use the old style controlnet, which is a single transformer layer with"
"Use the old style controlnet, which is a single transformer layer with a single head. Defaults to False."
" a single head. Defaults to False."
),
),
)
)
...
...
examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
View file @
edc154da
...
@@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
...
@@ -86,7 +86,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def
log_validation
(
args
,
unet
,
accelerator
,
weight_dtype
,
epoch
,
is_final_validation
=
False
):
def
log_validation
(
args
,
unet
,
accelerator
,
weight_dtype
,
epoch
,
is_final_validation
=
False
):
logger
.
info
(
f
"Running validation...
\n
Generating images with prompts:
\n
"
f
"
{
VALIDATION_PROMPTS
}
."
)
logger
.
info
(
f
"Running validation...
\n
Generating images with prompts:
\n
{
VALIDATION_PROMPTS
}
."
)
# create pipeline
# create pipeline
pipeline
=
DiffusionPipeline
.
from_pretrained
(
pipeline
=
DiffusionPipeline
.
from_pretrained
(
...
...
examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
View file @
edc154da
...
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
...
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def
log_validation
(
args
,
unet
,
vae
,
accelerator
,
weight_dtype
,
epoch
,
is_final_validation
=
False
):
def
log_validation
(
args
,
unet
,
vae
,
accelerator
,
weight_dtype
,
epoch
,
is_final_validation
=
False
):
logger
.
info
(
f
"Running validation...
\n
Generating images with prompts:
\n
"
f
"
{
VALIDATION_PROMPTS
}
."
)
logger
.
info
(
f
"Running validation...
\n
Generating images with prompts:
\n
{
VALIDATION_PROMPTS
}
."
)
if
is_final_validation
:
if
is_final_validation
:
if
args
.
mixed_precision
==
"fp16"
:
if
args
.
mixed_precision
==
"fp16"
:
...
...
examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
View file @
edc154da
...
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
...
@@ -91,7 +91,7 @@ def import_model_class_from_model_name_or_path(
def
log_validation
(
args
,
unet
,
vae
,
accelerator
,
weight_dtype
,
epoch
,
is_final_validation
=
False
):
def
log_validation
(
args
,
unet
,
vae
,
accelerator
,
weight_dtype
,
epoch
,
is_final_validation
=
False
):
logger
.
info
(
f
"Running validation...
\n
Generating images with prompts:
\n
"
f
"
{
VALIDATION_PROMPTS
}
."
)
logger
.
info
(
f
"Running validation...
\n
Generating images with prompts:
\n
{
VALIDATION_PROMPTS
}
."
)
if
is_final_validation
:
if
is_final_validation
:
if
args
.
mixed_precision
==
"fp16"
:
if
args
.
mixed_precision
==
"fp16"
:
...
@@ -683,7 +683,7 @@ def main(args):
...
@@ -683,7 +683,7 @@ def main(args):
lora_state_dict
,
network_alphas
=
StableDiffusionXLLoraLoaderMixin
.
lora_state_dict
(
input_dir
)
lora_state_dict
,
network_alphas
=
StableDiffusionXLLoraLoaderMixin
.
lora_state_dict
(
input_dir
)
unet_state_dict
=
{
f
'
{
k
.
replace
(
"
unet.
"
,
""
)
}
'
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"unet."
)}
unet_state_dict
=
{
f
"
{
k
.
replace
(
'
unet.
'
,
''
)
}
"
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"unet."
)}
unet_state_dict
=
convert_unet_state_dict_to_peft
(
unet_state_dict
)
unet_state_dict
=
convert_unet_state_dict_to_peft
(
unet_state_dict
)
incompatible_keys
=
set_peft_model_state_dict
(
unet_
,
unet_state_dict
,
adapter_name
=
"default"
)
incompatible_keys
=
set_peft_model_state_dict
(
unet_
,
unet_state_dict
,
adapter_name
=
"default"
)
if
incompatible_keys
is
not
None
:
if
incompatible_keys
is
not
None
:
...
...
examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
View file @
edc154da
...
@@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path(
...
@@ -89,7 +89,7 @@ def import_model_class_from_model_name_or_path(
def
log_validation
(
args
,
unet
,
vae
,
accelerator
,
weight_dtype
,
epoch
,
is_final_validation
=
False
):
def
log_validation
(
args
,
unet
,
vae
,
accelerator
,
weight_dtype
,
epoch
,
is_final_validation
=
False
):
logger
.
info
(
f
"Running validation...
\n
Generating images with prompts:
\n
"
f
"
{
VALIDATION_PROMPTS
}
."
)
logger
.
info
(
f
"Running validation...
\n
Generating images with prompts:
\n
{
VALIDATION_PROMPTS
}
."
)
if
is_final_validation
:
if
is_final_validation
:
if
args
.
mixed_precision
==
"fp16"
:
if
args
.
mixed_precision
==
"fp16"
:
...
@@ -790,7 +790,7 @@ def main(args):
...
@@ -790,7 +790,7 @@ def main(args):
lora_state_dict
,
network_alphas
=
StableDiffusionXLLoraLoaderMixin
.
lora_state_dict
(
input_dir
)
lora_state_dict
,
network_alphas
=
StableDiffusionXLLoraLoaderMixin
.
lora_state_dict
(
input_dir
)
unet_state_dict
=
{
f
'
{
k
.
replace
(
"
unet.
"
,
""
)
}
'
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"unet."
)}
unet_state_dict
=
{
f
"
{
k
.
replace
(
'
unet.
'
,
''
)
}
"
:
v
for
k
,
v
in
lora_state_dict
.
items
()
if
k
.
startswith
(
"unet."
)}
unet_state_dict
=
convert_unet_state_dict_to_peft
(
unet_state_dict
)
unet_state_dict
=
convert_unet_state_dict_to_peft
(
unet_state_dict
)
incompatible_keys
=
set_peft_model_state_dict
(
unet_
,
unet_state_dict
,
adapter_name
=
"default"
)
incompatible_keys
=
set_peft_model_state_dict
(
unet_
,
unet_state_dict
,
adapter_name
=
"default"
)
if
incompatible_keys
is
not
None
:
if
incompatible_keys
is
not
None
:
...
...
Prev
1
2
3
4
5
6
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment