Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
24563ca6
Unverified
Commit
24563ca6
authored
Sep 20, 2023
by
Bagheera
Committed by
GitHub
Sep 20, 2023
Browse files
SNR gamma fixes for v_prediction training (#5106)
Co-authored-by:
bghira
<
bghira@users.github.com
>
parent
914586f5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
15 additions
and
0 deletions
+15
-0
examples/controlnet/train_controlnet_flax.py
examples/controlnet/train_controlnet_flax.py
+3
-0
examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
...projects/onnxruntime/text_to_image/train_text_to_image.py
+3
-0
examples/text_to_image/train_text_to_image.py
examples/text_to_image/train_text_to_image.py
+3
-0
examples/text_to_image/train_text_to_image_lora.py
examples/text_to_image/train_text_to_image_lora.py
+3
-0
examples/text_to_image/train_text_to_image_lora_sdxl.py
examples/text_to_image/train_text_to_image_lora_sdxl.py
+3
-0
No files found.
examples/controlnet/train_controlnet_flax.py
View file @
24563ca6
...
@@ -908,6 +908,9 @@ def main():
...
@@ -908,6 +908,9 @@ def main():
if
args
.
snr_gamma
is
not
None
:
if
args
.
snr_gamma
is
not
None
:
snr
=
jnp
.
array
(
compute_snr
(
timesteps
))
snr
=
jnp
.
array
(
compute_snr
(
timesteps
))
snr_loss_weights
=
jnp
.
where
(
snr
<
args
.
snr_gamma
,
snr
,
jnp
.
ones_like
(
snr
)
*
args
.
snr_gamma
)
/
snr
snr_loss_weights
=
jnp
.
where
(
snr
<
args
.
snr_gamma
,
snr
,
jnp
.
ones_like
(
snr
)
*
args
.
snr_gamma
)
/
snr
if
noise_scheduler
.
config
.
prediction_type
==
"v_prediction"
:
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
snr_loss_weights
=
snr_loss_weights
+
1
loss
=
loss
*
snr_loss_weights
loss
=
loss
*
snr_loss_weights
loss
=
loss
.
mean
()
loss
=
loss
.
mean
()
...
...
examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
View file @
24563ca6
...
@@ -875,6 +875,9 @@ def main():
...
@@ -875,6 +875,9 @@ def main():
mse_loss_weights
=
(
mse_loss_weights
=
(
torch
.
stack
([
snr
,
args
.
snr_gamma
*
torch
.
ones_like
(
timesteps
)],
dim
=
1
).
min
(
dim
=
1
)[
0
]
/
snr
torch
.
stack
([
snr
,
args
.
snr_gamma
*
torch
.
ones_like
(
timesteps
)],
dim
=
1
).
min
(
dim
=
1
)[
0
]
/
snr
)
)
if
noise_scheduler
.
config
.
prediction_type
==
"v_prediction"
:
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights
=
mse_loss_weights
+
1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
# Finally, we take the mean of the rebalanced loss.
...
...
examples/text_to_image/train_text_to_image.py
View file @
24563ca6
...
@@ -955,6 +955,9 @@ def main():
...
@@ -955,6 +955,9 @@ def main():
mse_loss_weights
=
(
mse_loss_weights
=
(
torch
.
stack
([
snr
,
args
.
snr_gamma
*
torch
.
ones_like
(
timesteps
)],
dim
=
1
).
min
(
dim
=
1
)[
0
]
/
snr
torch
.
stack
([
snr
,
args
.
snr_gamma
*
torch
.
ones_like
(
timesteps
)],
dim
=
1
).
min
(
dim
=
1
)[
0
]
/
snr
)
)
if
noise_scheduler
.
config
.
prediction_type
==
"v_prediction"
:
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights
=
mse_loss_weights
+
1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
# Finally, we take the mean of the rebalanced loss.
...
...
examples/text_to_image/train_text_to_image_lora.py
View file @
24563ca6
...
@@ -786,6 +786,9 @@ def main():
...
@@ -786,6 +786,9 @@ def main():
mse_loss_weights
=
(
mse_loss_weights
=
(
torch
.
stack
([
snr
,
args
.
snr_gamma
*
torch
.
ones_like
(
timesteps
)],
dim
=
1
).
min
(
dim
=
1
)[
0
]
/
snr
torch
.
stack
([
snr
,
args
.
snr_gamma
*
torch
.
ones_like
(
timesteps
)],
dim
=
1
).
min
(
dim
=
1
)[
0
]
/
snr
)
)
if
noise_scheduler
.
config
.
prediction_type
==
"v_prediction"
:
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights
=
mse_loss_weights
+
1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
# Finally, we take the mean of the rebalanced loss.
...
...
examples/text_to_image/train_text_to_image_lora_sdxl.py
View file @
24563ca6
...
@@ -1075,6 +1075,9 @@ def main(args):
...
@@ -1075,6 +1075,9 @@ def main(args):
mse_loss_weights
=
(
mse_loss_weights
=
(
torch
.
stack
([
snr
,
args
.
snr_gamma
*
torch
.
ones_like
(
timesteps
)],
dim
=
1
).
min
(
dim
=
1
)[
0
]
/
snr
torch
.
stack
([
snr
,
args
.
snr_gamma
*
torch
.
ones_like
(
timesteps
)],
dim
=
1
).
min
(
dim
=
1
)[
0
]
/
snr
)
)
if
noise_scheduler
.
config
.
prediction_type
==
"v_prediction"
:
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights
=
mse_loss_weights
+
1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
# Finally, we take the mean of the rebalanced loss.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment