Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
97ee6169
Unverified
Commit
97ee6169
authored
Jan 31, 2024
by
Kashif Rasul
Committed by
GitHub
Jan 31, 2024
Browse files
add ipo, hinge and cpo loss to dpo trainer (#6788)
add ipo and hinge loss to dpo trainer
parent
0fc62d17
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
6 deletions
+19
-6
examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
...es/research_projects/diffusion_dpo/train_diffusion_dpo.py
+19
-6
No files found.
examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
View file @
97ee6169
...
@@ -299,9 +299,15 @@ def parse_args(input_args=None):
...
@@ -299,9 +299,15 @@ def parse_args(input_args=None):
parser
.
add_argument
(
parser
.
add_argument
(
"--beta_dpo"
,
"--beta_dpo"
,
type
=
int
,
type
=
int
,
default
=
5
0
00
,
default
=
2
500
,
help
=
"DPO KL Divergence penalty."
,
help
=
"DPO KL Divergence penalty."
,
)
)
parser
.
add_argument
(
"--loss_type"
,
type
=
str
,
default
=
"sigmoid"
,
help
=
"DPO loss type. Can be one of 'sigmoid' (default), 'ipo', or 'cpo'"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--learning_rate"
,
"--learning_rate"
,
type
=
float
,
type
=
float
,
...
@@ -858,12 +864,19 @@ def main(args):
...
@@ -858,12 +864,19 @@ def main(args):
accelerator
.
unwrap_model
(
unet
).
enable_adapters
()
accelerator
.
unwrap_model
(
unet
).
enable_adapters
()
# Final loss.
# Final loss.
scale_term
=
-
0.5
*
args
.
beta_dpo
logits
=
ref_diff
-
model_diff
inside_term
=
scale_term
*
(
model_diff
-
ref_diff
)
if
args
.
loss_type
==
"sigmoid"
:
loss
=
-
1
*
F
.
logsigmoid
(
inside_term
).
mean
()
loss
=
-
1
*
F
.
logsigmoid
(
args
.
beta_dpo
*
logits
).
mean
()
elif
args
.
loss_type
==
"hinge"
:
loss
=
torch
.
relu
(
1
-
args
.
beta_dpo
*
logits
).
mean
()
elif
args
.
loss_type
==
"ipo"
:
losses
=
(
logits
-
1
/
(
2
*
args
.
beta
))
**
2
loss
=
losses
.
mean
()
else
:
raise
ValueError
(
f
"Unknown loss type
{
args
.
loss_type
}
"
)
implicit_acc
=
(
inside_term
>
0
).
sum
().
float
()
/
inside_term
.
size
(
0
)
implicit_acc
=
(
logits
>
0
).
sum
().
float
()
/
logits
.
size
(
0
)
implicit_acc
+=
0.5
*
(
inside_term
==
0
).
sum
().
float
()
/
inside_term
.
size
(
0
)
implicit_acc
+=
0.5
*
(
logits
==
0
).
sum
().
float
()
/
logits
.
size
(
0
)
accelerator
.
backward
(
loss
)
accelerator
.
backward
(
loss
)
if
accelerator
.
sync_gradients
:
if
accelerator
.
sync_gradients
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment