Unverified Commit 97ee6169 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

add ipo, hinge and cpo loss to dpo trainer (#6788)

add ipo and hinge loss to dpo trainer
parent 0fc62d17
...@@ -299,9 +299,15 @@ def parse_args(input_args=None): ...@@ -299,9 +299,15 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--beta_dpo", "--beta_dpo",
type=int, type=int,
default=5000, default=2500,
help="DPO KL Divergence penalty.", help="DPO KL Divergence penalty.",
) )
parser.add_argument(
"--loss_type",
type=str,
default="sigmoid",
help="DPO loss type. Can be one of 'sigmoid' (default), 'ipo', or 'cpo'",
)
parser.add_argument( parser.add_argument(
"--learning_rate", "--learning_rate",
type=float, type=float,
...@@ -858,12 +864,19 @@ def main(args): ...@@ -858,12 +864,19 @@ def main(args):
accelerator.unwrap_model(unet).enable_adapters() accelerator.unwrap_model(unet).enable_adapters()
# Final loss. # Final loss.
scale_term = -0.5 * args.beta_dpo logits = ref_diff - model_diff
inside_term = scale_term * (model_diff - ref_diff) if args.loss_type == "sigmoid":
loss = -1 * F.logsigmoid(inside_term).mean() loss = -1 * F.logsigmoid(args.beta_dpo * logits).mean()
elif args.loss_type == "hinge":
loss = torch.relu(1 - args.beta_dpo * logits).mean()
elif args.loss_type == "ipo":
losses = (logits - 1 / (2 * args.beta)) ** 2
loss = losses.mean()
else:
raise ValueError(f"Unknown loss type {args.loss_type}")
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) implicit_acc = (logits > 0).sum().float() / logits.size(0)
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0) implicit_acc += 0.5 * (logits == 0).sum().float() / logits.size(0)
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients: if accelerator.sync_gradients:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment