Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
4edde134
Unverified
Commit
4edde134
authored
Jun 18, 2024
by
Sayak Paul
Committed by
GitHub
Jun 18, 2024
Browse files
[SD3 training] refactor the density and weighting utilities. (#8591)
refactor the density and weighting utilities.
parent
074a7cc3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
89 additions
and
48 deletions
+89
-48
examples/dreambooth/train_dreambooth_lora_sd3.py
examples/dreambooth/train_dreambooth_lora_sd3.py
+27
-25
examples/dreambooth/train_dreambooth_sd3.py
examples/dreambooth/train_dreambooth_sd3.py
+23
-23
src/diffusers/training_utils.py
src/diffusers/training_utils.py
+39
-0
No files found.
examples/dreambooth/train_dreambooth_lora_sd3.py
View file @
4edde134
...
...
@@ -53,7 +53,11 @@ from diffusers import (
StableDiffusion3Pipeline
,
)
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
cast_training_params
from
diffusers.training_utils
import
(
cast_training_params
,
compute_density_for_timestep_sampling
,
compute_loss_weighting_for_sd3
,
)
from
diffusers.utils
import
(
check_min_version
,
convert_unet_state_dict_to_peft
,
...
...
@@ -473,11 +477,20 @@ def parse_args(input_args=None):
),
)
parser
.
add_argument
(
"--weighting_scheme"
,
type
=
str
,
default
=
"logit_normal"
,
choices
=
[
"sigma_sqrt"
,
"logit_normal"
,
"mode"
]
"--weighting_scheme"
,
type
=
str
,
default
=
"sigma_sqrt"
,
choices
=
[
"sigma_sqrt"
,
"logit_normal"
,
"mode"
,
"cosmap"
]
)
parser
.
add_argument
(
"--logit_mean"
,
type
=
float
,
default
=
0.0
,
help
=
"mean to use when using the `'logit_normal'` weighting scheme."
)
parser
.
add_argument
(
"--logit_std"
,
type
=
float
,
default
=
1.0
,
help
=
"std to use when using the `'logit_normal'` weighting scheme."
)
parser
.
add_argument
(
"--mode_scale"
,
type
=
float
,
default
=
1.29
,
help
=
"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`."
,
)
parser
.
add_argument
(
"--logit_mean"
,
type
=
float
,
default
=
0.0
)
parser
.
add_argument
(
"--logit_std"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--mode_scale"
,
type
=
float
,
default
=
1.29
)
parser
.
add_argument
(
"--optimizer"
,
type
=
str
,
...
...
@@ -1477,16 +1490,13 @@ def main(args):
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
if
args
.
weighting_scheme
==
"logit_normal"
:
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u
=
torch
.
normal
(
mean
=
args
.
logit_mean
,
std
=
args
.
logit_std
,
size
=
(
bsz
,),
device
=
"cpu"
)
u
=
torch
.
nn
.
functional
.
sigmoid
(
u
)
elif
args
.
weighting_scheme
==
"mode"
:
u
=
torch
.
rand
(
size
=
(
bsz
,),
device
=
"cpu"
)
u
=
1
-
u
-
args
.
mode_scale
*
(
torch
.
cos
(
math
.
pi
*
u
/
2
)
**
2
-
1
+
u
)
else
:
u
=
torch
.
rand
(
size
=
(
bsz
,),
device
=
"cpu"
)
u
=
compute_density_for_timestep_sampling
(
weighting_scheme
=
args
.
weighting_scheme
,
batch_size
=
bsz
,
logit_mean
=
args
.
logit_mean
,
logit_std
=
args
.
logit_std
,
mode_scale
=
args
.
mode_scale
,
)
indices
=
(
u
*
noise_scheduler_copy
.
config
.
num_train_timesteps
).
long
()
timesteps
=
noise_scheduler_copy
.
timesteps
[
indices
].
to
(
device
=
model_input
.
device
)
...
...
@@ -1507,19 +1517,11 @@ def main(args):
# Preconditioning of the model outputs.
model_pred
=
model_pred
*
(
-
sigmas
)
+
noisy_model_input
# TODO (kashif, sayakpaul): weighting sceme needs to be experimented with :)
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
if
args
.
weighting_scheme
==
"sigma_sqrt"
:
weighting
=
(
sigmas
**-
2.0
).
float
()
elif
args
.
weighting_scheme
==
"cosmap"
:
bot
=
1
-
2
*
sigmas
+
2
*
sigmas
**
2
weighting
=
2
/
(
math
.
pi
*
bot
)
else
:
weighting
=
torch
.
ones_like
(
sigmas
)
weighting
=
compute_loss_weighting_for_sd3
(
weighting_scheme
=
args
.
weighting_scheme
,
sigmas
=
sigmas
)
# simplified flow matching aka 0-rectified flow matching loss
# target = model_input - noise
# flow matching loss
target
=
model_input
if
args
.
with_prior_preservation
:
...
...
examples/dreambooth/train_dreambooth_sd3.py
View file @
4edde134
...
...
@@ -51,6 +51,7 @@ from diffusers import (
StableDiffusion3Pipeline
,
)
from
diffusers.optimization
import
get_scheduler
from
diffusers.training_utils
import
compute_density_for_timestep_sampling
,
compute_loss_weighting_for_sd3
from
diffusers.utils
import
(
check_min_version
,
is_wandb_available
,
...
...
@@ -471,11 +472,20 @@ def parse_args(input_args=None):
),
)
parser
.
add_argument
(
"--weighting_scheme"
,
type
=
str
,
default
=
"logit_normal"
,
choices
=
[
"sigma_sqrt"
,
"logit_normal"
,
"mode"
]
"--weighting_scheme"
,
type
=
str
,
default
=
"sigma_sqrt"
,
choices
=
[
"sigma_sqrt"
,
"logit_normal"
,
"mode"
,
"cosmap"
]
)
parser
.
add_argument
(
"--logit_mean"
,
type
=
float
,
default
=
0.0
,
help
=
"mean to use when using the `'logit_normal'` weighting scheme."
)
parser
.
add_argument
(
"--logit_std"
,
type
=
float
,
default
=
1.0
,
help
=
"std to use when using the `'logit_normal'` weighting scheme."
)
parser
.
add_argument
(
"--mode_scale"
,
type
=
float
,
default
=
1.29
,
help
=
"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`."
,
)
parser
.
add_argument
(
"--logit_mean"
,
type
=
float
,
default
=
0.0
)
parser
.
add_argument
(
"--logit_std"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--mode_scale"
,
type
=
float
,
default
=
1.29
)
parser
.
add_argument
(
"--optimizer"
,
type
=
str
,
...
...
@@ -1541,16 +1551,13 @@ def main(args):
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
if
args
.
weighting_scheme
==
"logit_normal"
:
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u
=
torch
.
normal
(
mean
=
args
.
logit_mean
,
std
=
args
.
logit_std
,
size
=
(
bsz
,),
device
=
"cpu"
)
u
=
torch
.
nn
.
functional
.
sigmoid
(
u
)
elif
args
.
weighting_scheme
==
"mode"
:
u
=
torch
.
rand
(
size
=
(
bsz
,),
device
=
"cpu"
)
u
=
1
-
u
-
args
.
mode_scale
*
(
torch
.
cos
(
math
.
pi
*
u
/
2
)
**
2
-
1
+
u
)
else
:
u
=
torch
.
rand
(
size
=
(
bsz
,),
device
=
"cpu"
)
u
=
compute_density_for_timestep_sampling
(
weighting_scheme
=
args
.
weighting_scheme
,
batch_size
=
bsz
,
logit_mean
=
args
.
logit_mean
,
logit_std
=
args
.
logit_std
,
mode_scale
=
args
.
mode_scale
,
)
indices
=
(
u
*
noise_scheduler_copy
.
config
.
num_train_timesteps
).
long
()
timesteps
=
noise_scheduler_copy
.
timesteps
[
indices
].
to
(
device
=
model_input
.
device
)
...
...
@@ -1587,16 +1594,9 @@ def main(args):
model_pred
=
model_pred
*
(
-
sigmas
)
+
noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
if
args
.
weighting_scheme
==
"sigma_sqrt"
:
weighting
=
(
sigmas
**-
2.0
).
float
()
elif
args
.
weighting_scheme
==
"cosmap"
:
bot
=
1
-
2
*
sigmas
+
2
*
sigmas
**
2
weighting
=
2
/
(
math
.
pi
*
bot
)
else
:
weighting
=
torch
.
ones_like
(
sigmas
)
weighting
=
compute_loss_weighting_for_sd3
(
weighting_scheme
=
args
.
weighting_scheme
,
sigmas
=
sigmas
)
# simplified flow matching aka 0-rectified flow matching loss
# target = model_input - noise
# flow matching loss
target
=
model_input
if
args
.
with_prior_preservation
:
...
...
src/diffusers/training_utils.py
View file @
4edde134
import
contextlib
import
copy
import
math
import
random
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
...
...
@@ -220,6 +221,44 @@ def _set_state_dict_into_text_encoder(
set_peft_model_state_dict
(
text_encoder
,
text_encoder_state_dict
,
adapter_name
=
"default"
)
def
compute_density_for_timestep_sampling
(
weighting_scheme
:
str
,
batch_size
:
int
,
logit_mean
:
float
=
None
,
logit_std
:
float
=
None
,
mode_scale
:
float
=
None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if
weighting_scheme
==
"logit_normal"
:
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u
=
torch
.
normal
(
mean
=
logit_mean
,
std
=
logit_std
,
size
=
(
batch_size
,),
device
=
"cpu"
)
u
=
torch
.
nn
.
functional
.
sigmoid
(
u
)
elif
weighting_scheme
==
"mode"
:
u
=
torch
.
rand
(
size
=
(
batch_size
,),
device
=
"cpu"
)
u
=
1
-
u
-
mode_scale
*
(
torch
.
cos
(
math
.
pi
*
u
/
2
)
**
2
-
1
+
u
)
else
:
u
=
torch
.
rand
(
size
=
(
batch_size
,),
device
=
"cpu"
)
return
u
def
compute_loss_weighting_for_sd3
(
weighting_scheme
:
str
,
sigmas
=
None
):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if
weighting_scheme
==
"sigma_sqrt"
:
weighting
=
(
sigmas
**-
2.0
).
float
()
elif
weighting_scheme
==
"cosmap"
:
bot
=
1
-
2
*
sigmas
+
2
*
sigmas
**
2
weighting
=
2
/
(
math
.
pi
*
bot
)
else
:
weighting
=
torch
.
ones_like
(
sigmas
)
return
weighting
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class
EMAModel
:
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment