Unverified Commit 856331c6 authored by Yasyf Mohamedali's avatar Yasyf Mohamedali Committed by GitHub
Browse files

Support training SD V2 with Flax (#1783)

* Support training SD V2 with Flax

Mostly involves supporting a v_prediction scheduler.

The implementation in #1777 doesn't take into account a recent refactor of `scheduling_utils_flax`, so this should be used instead.

* Add to other top-level files.
parent f7154f85
......@@ -525,28 +525,35 @@ def main():
)[0]
# Predict the noise residual
unet_outputs = unet.apply(
model_pred = unet.apply(
{"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True
)
noise_pred = unet_outputs.sample
).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
if args.with_prior_preservation:
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
noise_pred, noise_pred_prior = jnp.split(noise_pred, 2, axis=0)
noise, noise_prior = jnp.split(noise, 2, axis=0)
model_pred, model_pred_prior = jnp.split(model_pred, 2, axis=0)
target, target_prior = jnp.split(target, 2, axis=0)
# Compute instance loss
loss = (noise - noise_pred) ** 2
loss = (target - model_pred) ** 2
loss = loss.mean()
# Compute prior loss
prior_loss = (noise_prior - noise_pred_prior) ** 2
prior_loss = (target_prior - model_pred_prior) ** 2
prior_loss = prior_loss.mean()
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = (noise - noise_pred) ** 2
loss = (target - model_pred) ** 2
loss = loss.mean()
return loss
......
......@@ -459,9 +459,19 @@ def main():
)[0]
# Predict the noise residual and compute loss
unet_outputs = unet.apply({"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True)
noise_pred = unet_outputs.sample
loss = (noise - noise_pred) ** 2
model_pred = unet.apply(
{"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True
).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = (target - model_pred) ** 2
loss = loss.mean()
return loss
......
......@@ -536,11 +536,20 @@ def main():
encoder_hidden_states = state.apply_fn(
batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
)[0]
unet_outputs = unet.apply(
# Predict the noise residual and compute loss
model_pred = unet.apply(
{"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
)
noise_pred = unet_outputs.sample
loss = (noise - noise_pred) ** 2
).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = (target - model_pred) ** 2
loss = loss.mean()
return loss
......
......@@ -29,6 +29,7 @@ from .scheduling_utils_flax import (
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
get_velocity_common,
)
......@@ -301,5 +302,14 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray:
return add_noise_common(state.common, original_samples, noise, timesteps)
def get_velocity(
self,
state: DDIMSchedulerState,
sample: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
return get_velocity_common(state.common, sample, noise, timesteps)
def __len__(self):
return self.config.num_train_timesteps
......@@ -29,6 +29,7 @@ from .scheduling_utils_flax import (
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
get_velocity_common,
)
......@@ -293,5 +294,14 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray:
return add_noise_common(state.common, original_samples, noise, timesteps)
def get_velocity(
self,
state: DDPMSchedulerState,
sample: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
return get_velocity_common(state.common, sample, noise, timesteps)
def __len__(self):
return self.config.num_train_timesteps
......@@ -242,7 +242,7 @@ class CommonSchedulerState:
)
def add_noise_common(
def get_sqrt_alpha_prod(
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
):
alphas_cumprod = state.alphas_cumprod
......@@ -255,5 +255,18 @@ def add_noise_common(
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
return sqrt_alpha_prod, sqrt_one_minus_alpha_prod
def add_noise_common(
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
):
sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray):
sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment