Unverified Commit 074a7cc3 authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

SD3: update default training timestep / loss weighting distribution to logit_normal (#8592)


Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarKashif Rasul <kashif.rasul@gmail.com>
parent 6bfd13f0
......@@ -473,7 +473,7 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"]
"--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"]
)
parser.add_argument("--logit_mean", type=float, default=0.0)
parser.add_argument("--logit_std", type=float, default=1.0)
......
......@@ -471,7 +471,7 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
"--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"]
"--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"]
)
parser.add_argument("--logit_mean", type=float, default=0.0)
parser.add_argument("--logit_std", type=float, default=1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment