Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
4edde134
Unverified
Commit
4edde134
authored
Jun 18, 2024
by
Sayak Paul
Committed by
GitHub
Jun 18, 2024
Browse files
[SD3 training] refactor the density and weighting utilities. (#8591)
refactor the density and weighting utilities.
parent
074a7cc3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
48 deletions
+89
-48
examples/dreambooth/train_dreambooth_lora_sd3.py
examples/dreambooth/train_dreambooth_lora_sd3.py
+27
-25
examples/dreambooth/train_dreambooth_sd3.py
examples/dreambooth/train_dreambooth_sd3.py
+23
-23
src/diffusers/training_utils.py
src/diffusers/training_utils.py
+39
-0
No files found.
examples/dreambooth/train_dreambooth_lora_sd3.py
View file @
4edde134
...
@@ -53,7 +53,11 @@ from diffusers import (
...
@@ -53,7 +53,11 @@ from diffusers import (
StableDiffusion3Pipeline
,
StableDiffusion3Pipeline
,
)
)
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
cast_training_params
from
diffusers.training_utils
import
(
cast_training_params
,
compute_density_for_timestep_sampling
,
compute_loss_weighting_for_sd3
,
)
from
diffusers.utils
import
(
from
diffusers.utils
import
(
check_min_version
,
check_min_version
,
convert_unet_state_dict_to_peft
,
convert_unet_state_dict_to_peft
,
...
@@ -473,11 +477,20 @@ def parse_args(input_args=None):
...
@@ -473,11 +477,20 @@ def parse_args(input_args=None):
),
),
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--weighting_scheme"
,
type
=
str
,
default
=
"logit_normal"
,
choices
=
[
"sigma_sqrt"
,
"logit_normal"
,
"mode"
]
"--weighting_scheme"
,
type
=
str
,
default
=
"sigma_sqrt"
,
choices
=
[
"sigma_sqrt"
,
"logit_normal"
,
"mode"
,
"cosmap"
]
)
parser
.
add_argument
(
"--logit_mean"
,
type
=
float
,
default
=
0.0
,
help
=
"mean to use when using the `'logit_normal'` weighting scheme."
)
parser
.
add_argument
(
"--logit_std"
,
type
=
float
,
default
=
1.0
,
help
=
"std to use when using the `'logit_normal'` weighting scheme."
)
parser
.
add_argument
(
"--mode_scale"
,
type
=
float
,
default
=
1.29
,
help
=
"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`."
,
)
)
parser
.
add_argument
(
"--logit_mean"
,
type
=
float
,
default
=
0.0
)
parser
.
add_argument
(
"--logit_std"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--mode_scale"
,
type
=
float
,
default
=
1.29
)
parser
.
add_argument
(
parser
.
add_argument
(
"--optimizer"
,
"--optimizer"
,
type
=
str
,
type
=
str
,
...
@@ -1477,16 +1490,13 @@ def main(args):
...
@@ -1477,16 +1490,13 @@ def main(args):
# Sample a random timestep for each image
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
# for weighting schemes where we sample timesteps non-uniformly
if
args
.
weighting_scheme
==
"logit_normal"
:
u
=
compute_density_for_timestep_sampling
(
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
weighting_scheme
=
args
.
weighting_scheme
,
u
=
torch
.
normal
(
mean
=
args
.
logit_mean
,
std
=
args
.
logit_std
,
size
=
(
bsz
,),
device
=
"cpu"
)
batch_size
=
bsz
,
u
=
torch
.
nn
.
functional
.
sigmoid
(
u
)
logit_mean
=
args
.
logit_mean
,
elif
args
.
weighting_scheme
==
"mode"
:
logit_std
=
args
.
logit_std
,
u
=
torch
.
rand
(
size
=
(
bsz
,),
device
=
"cpu"
)
mode_scale
=
args
.
mode_scale
,
u
=
1
-
u
-
args
.
mode_scale
*
(
torch
.
cos
(
math
.
pi
*
u
/
2
)
**
2
-
1
+
u
)
)
else
:
u
=
torch
.
rand
(
size
=
(
bsz
,),
device
=
"cpu"
)
indices
=
(
u
*
noise_scheduler_copy
.
config
.
num_train_timesteps
).
long
()
indices
=
(
u
*
noise_scheduler_copy
.
config
.
num_train_timesteps
).
long
()
timesteps
=
noise_scheduler_copy
.
timesteps
[
indices
].
to
(
device
=
model_input
.
device
)
timesteps
=
noise_scheduler_copy
.
timesteps
[
indices
].
to
(
device
=
model_input
.
device
)
...
@@ -1507,19 +1517,11 @@ def main(args):
...
@@ -1507,19 +1517,11 @@ def main(args):
# Preconditioning of the model outputs.
# Preconditioning of the model outputs.
model_pred
=
model_pred
*
(
-
sigmas
)
+
noisy_model_input
model_pred
=
model_pred
*
(
-
sigmas
)
+
noisy_model_input
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
# these weighting schemes use a uniform timestep sampling
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
# and instead post-weight the loss
if
args
.
weighting_scheme
==
"sigma_sqrt"
:
weighting
=
compute_loss_weighting_for_sd3
(
weighting_scheme
=
args
.
weighting_scheme
,
sigmas
=
sigmas
)
weighting
=
(
sigmas
**-
2.0
).
float
()
elif
args
.
weighting_scheme
==
"cosmap"
:
bot
=
1
-
2
*
sigmas
+
2
*
sigmas
**
2
weighting
=
2
/
(
math
.
pi
*
bot
)
else
:
weighting
=
torch
.
ones_like
(
sigmas
)
# simplified flow matching aka 0-rectified flow matching loss
# flow matching loss
# target = model_input - noise
target
=
model_input
target
=
model_input
if
args
.
with_prior_preservation
:
if
args
.
with_prior_preservation
:
...
...
examples/dreambooth/train_dreambooth_sd3.py
View file @
4edde134
...
@@ -51,6 +51,7 @@ from diffusers import (
...
@@ -51,6 +51,7 @@ from diffusers import (
StableDiffusion3Pipeline
,
StableDiffusion3Pipeline
,
)
)
from
diffusers.optimization
import
get_scheduler
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
compute_density_for_timestep_sampling
,
compute_loss_weighting_for_sd3
from
diffusers.utils
import
(
from
diffusers.utils
import
(
check_min_version
,
check_min_version
,
is_wandb_available
,
is_wandb_available
,
...
@@ -471,11 +472,20 @@ def parse_args(input_args=None):
...
@@ -471,11 +472,20 @@ def parse_args(input_args=None):
),
),
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--weighting_scheme"
,
type
=
str
,
default
=
"logit_normal"
,
choices
=
[
"sigma_sqrt"
,
"logit_normal"
,
"mode"
]
"--weighting_scheme"
,
type
=
str
,
default
=
"sigma_sqrt"
,
choices
=
[
"sigma_sqrt"
,
"logit_normal"
,
"mode"
,
"cosmap"
]
)
parser
.
add_argument
(
"--logit_mean"
,
type
=
float
,
default
=
0.0
,
help
=
"mean to use when using the `'logit_normal'` weighting scheme."
)
parser
.
add_argument
(
"--logit_std"
,
type
=
float
,
default
=
1.0
,
help
=
"std to use when using the `'logit_normal'` weighting scheme."
)
parser
.
add_argument
(
"--mode_scale"
,
type
=
float
,
default
=
1.29
,
help
=
"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`."
,
)
)
parser
.
add_argument
(
"--logit_mean"
,
type
=
float
,
default
=
0.0
)
parser
.
add_argument
(
"--logit_std"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--mode_scale"
,
type
=
float
,
default
=
1.29
)
parser
.
add_argument
(
parser
.
add_argument
(
"--optimizer"
,
"--optimizer"
,
type
=
str
,
type
=
str
,
...
@@ -1541,16 +1551,13 @@ def main(args):
...
@@ -1541,16 +1551,13 @@ def main(args):
# Sample a random timestep for each image
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
# for weighting schemes where we sample timesteps non-uniformly
if
args
.
weighting_scheme
==
"logit_normal"
:
u
=
compute_density_for_timestep_sampling
(
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
weighting_scheme
=
args
.
weighting_scheme
,
u
=
torch
.
normal
(
mean
=
args
.
logit_mean
,
std
=
args
.
logit_std
,
size
=
(
bsz
,),
device
=
"cpu"
)
batch_size
=
bsz
,
u
=
torch
.
nn
.
functional
.
sigmoid
(
u
)
logit_mean
=
args
.
logit_mean
,
elif
args
.
weighting_scheme
==
"mode"
:
logit_std
=
args
.
logit_std
,
u
=
torch
.
rand
(
size
=
(
bsz
,),
device
=
"cpu"
)
mode_scale
=
args
.
mode_scale
,
u
=
1
-
u
-
args
.
mode_scale
*
(
torch
.
cos
(
math
.
pi
*
u
/
2
)
**
2
-
1
+
u
)
)
else
:
u
=
torch
.
rand
(
size
=
(
bsz
,),
device
=
"cpu"
)
indices
=
(
u
*
noise_scheduler_copy
.
config
.
num_train_timesteps
).
long
()
indices
=
(
u
*
noise_scheduler_copy
.
config
.
num_train_timesteps
).
long
()
timesteps
=
noise_scheduler_copy
.
timesteps
[
indices
].
to
(
device
=
model_input
.
device
)
timesteps
=
noise_scheduler_copy
.
timesteps
[
indices
].
to
(
device
=
model_input
.
device
)
...
@@ -1587,16 +1594,9 @@ def main(args):
...
@@ -1587,16 +1594,9 @@ def main(args):
model_pred
=
model_pred
*
(
-
sigmas
)
+
noisy_model_input
model_pred
=
model_pred
*
(
-
sigmas
)
+
noisy_model_input
# these weighting schemes use a uniform timestep sampling
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
# and instead post-weight the loss
if
args
.
weighting_scheme
==
"sigma_sqrt"
:
weighting
=
compute_loss_weighting_for_sd3
(
weighting_scheme
=
args
.
weighting_scheme
,
sigmas
=
sigmas
)
weighting
=
(
sigmas
**-
2.0
).
float
()
elif
args
.
weighting_scheme
==
"cosmap"
:
bot
=
1
-
2
*
sigmas
+
2
*
sigmas
**
2
weighting
=
2
/
(
math
.
pi
*
bot
)
else
:
weighting
=
torch
.
ones_like
(
sigmas
)
# simplified flow matching aka 0-rectified flow matching loss
# flow matching loss
# target = model_input - noise
target
=
model_input
target
=
model_input
if
args
.
with_prior_preservation
:
if
args
.
with_prior_preservation
:
...
...
src/diffusers/training_utils.py
View file @
4edde134
import
contextlib
import
contextlib
import
copy
import
copy
import
math
import
random
import
random
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
...
@@ -220,6 +221,44 @@ def _set_state_dict_into_text_encoder(
...
@@ -220,6 +221,44 @@ def _set_state_dict_into_text_encoder(
set_peft_model_state_dict
(
text_encoder
,
text_encoder_state_dict
,
adapter_name
=
"default"
)
set_peft_model_state_dict
(
text_encoder
,
text_encoder_state_dict
,
adapter_name
=
"default"
)
def
compute_density_for_timestep_sampling
(
weighting_scheme
:
str
,
batch_size
:
int
,
logit_mean
:
float
=
None
,
logit_std
:
float
=
None
,
mode_scale
:
float
=
None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if
weighting_scheme
==
"logit_normal"
:
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u
=
torch
.
normal
(
mean
=
logit_mean
,
std
=
logit_std
,
size
=
(
batch_size
,),
device
=
"cpu"
)
u
=
torch
.
nn
.
functional
.
sigmoid
(
u
)
elif
weighting_scheme
==
"mode"
:
u
=
torch
.
rand
(
size
=
(
batch_size
,),
device
=
"cpu"
)
u
=
1
-
u
-
mode_scale
*
(
torch
.
cos
(
math
.
pi
*
u
/
2
)
**
2
-
1
+
u
)
else
:
u
=
torch
.
rand
(
size
=
(
batch_size
,),
device
=
"cpu"
)
return
u
def
compute_loss_weighting_for_sd3
(
weighting_scheme
:
str
,
sigmas
=
None
):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if
weighting_scheme
==
"sigma_sqrt"
:
weighting
=
(
sigmas
**-
2.0
).
float
()
elif
weighting_scheme
==
"cosmap"
:
bot
=
1
-
2
*
sigmas
+
2
*
sigmas
**
2
weighting
=
2
/
(
math
.
pi
*
bot
)
else
:
weighting
=
torch
.
ones_like
(
sigmas
)
return
weighting
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class
EMAModel
:
class
EMAModel
:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment