Unverified Commit 89d8f848 authored by Bagheera's avatar Bagheera Committed by GitHub
Browse files

Timestep bias for fine-tuning SDXL (#5094)



* Timestep bias for fine-tuning SDXL

* Adjust parameter choices to include "range" and reword the help statements

* Condition our use of weighted timesteps on the value of timestep_bias_strategy

* style

---------
Co-authored-by: default avatarbghira <bghira@users.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent bdd25446
...@@ -325,6 +325,55 @@ def parse_args(input_args=None): ...@@ -325,6 +325,55 @@ def parse_args(input_args=None):
parser.add_argument( parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
) )
parser.add_argument(
"--timestep_bias_strategy",
type=str,
default="none",
choices=["earlier", "later", "range", "none"],
help=(
"The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
" Choices: ['earlier', 'later', 'range', 'none']."
" The default is 'none', which means no bias is applied, and training proceeds normally."
" The value of 'later' will increase the frequency of the model's final training timesteps."
),
)
parser.add_argument(
"--timestep_bias_multiplier",
type=float,
default=1.0,
help=(
"The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
" A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
),
)
parser.add_argument(
"--timestep_bias_begin",
type=int,
default=0,
help=(
"When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
" Defaults to zero, which equates to having no specific bias."
),
)
parser.add_argument(
"--timestep_bias_end",
type=int,
default=1000,
help=(
"When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
" Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
),
)
parser.add_argument(
"--timestep_bias_portion",
type=float,
default=0.25,
help=(
"The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
" A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
" whether the biased portions are in the earlier or later timesteps."
),
)
parser.add_argument( parser.add_argument(
"--snr_gamma", "--snr_gamma",
type=float, type=float,
...@@ -479,6 +528,47 @@ def compute_vae_encodings(batch, vae): ...@@ -479,6 +528,47 @@ def compute_vae_encodings(batch, vae):
return {"model_input": model_input.cpu()} return {"model_input": model_input.cpu()}
def generate_timestep_weights(args, num_timesteps):
weights = torch.ones(num_timesteps)
# Determine the indices to bias
num_to_bias = int(args.timestep_bias_portion * num_timesteps)
if args.timestep_bias_strategy == "later":
bias_indices = slice(-num_to_bias, None)
elif args.timestep_bias_strategy == "earlier":
bias_indices = slice(0, num_to_bias)
elif args.timestep_bias_strategy == "range":
# Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
range_begin = args.timestep_bias_begin
range_end = args.timestep_bias_end
if range_begin < 0:
raise ValueError(
"When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
)
if range_end > num_timesteps:
raise ValueError(
"When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
)
bias_indices = slice(range_begin, range_end)
else: # 'none' or any other string
return weights
if args.timestep_bias_multiplier <= 0:
return ValueError(
"The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
" If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
" A timestep bias multiplier less than or equal to 0 is not allowed."
)
# Apply the bias
weights[bias_indices] *= args.timestep_bias_multiplier
# Normalize
weights /= weights.sum()
return weights
def main(args): def main(args):
logging_dir = Path(args.output_dir, args.logging_dir) logging_dir = Path(args.output_dir, args.logging_dir)
...@@ -935,11 +1025,18 @@ def main(args): ...@@ -935,11 +1025,18 @@ def main(args):
) )
bsz = model_input.shape[0] bsz = model_input.shape[0]
# Sample a random timestep for each image if args.timestep_bias_strategy == "none":
timesteps = torch.randint( # Sample a random timestep for each image without bias.
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device timesteps = torch.randint(
) 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
timesteps = timesteps.long() )
else:
# Sample a random timestep for each image, potentially biased by the timestep weights.
# Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
model_input.device
)
timesteps = torch.multinomial(weights, bsz, replacement=True).long()
# Add noise to the model input according to the noise magnitude at each timestep # Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process) # (this is the forward diffusion process)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment