"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "deed945625418b1f2625048e22350e528f796cbc"
Unverified Commit 856331c6 authored by Yasyf Mohamedali's avatar Yasyf Mohamedali Committed by GitHub
Browse files

Support training SD V2 with Flax (#1783)

* Support training SD V2 with Flax

Mostly involves supporting a v_prediction scheduler.

The implementation in #1777 doesn't take into account a recent refactor of `scheduling_utils_flax`, so this should be used instead.

* Add to other top-level files.
parent f7154f85
...@@ -525,28 +525,35 @@ def main(): ...@@ -525,28 +525,35 @@ def main():
)[0] )[0]
# Predict the noise residual # Predict the noise residual
unet_outputs = unet.apply( model_pred = unet.apply(
{"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True {"params": params["unet"]}, noisy_latents, timesteps, encoder_hidden_states, train=True
) ).sample
noise_pred = unet_outputs.sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
if args.with_prior_preservation: if args.with_prior_preservation:
# Chunk the noise and noise_pred into two parts and compute the loss on each part separately. # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
noise_pred, noise_pred_prior = jnp.split(noise_pred, 2, axis=0) model_pred, model_pred_prior = jnp.split(model_pred, 2, axis=0)
noise, noise_prior = jnp.split(noise, 2, axis=0) target, target_prior = jnp.split(target, 2, axis=0)
# Compute instance loss # Compute instance loss
loss = (noise - noise_pred) ** 2 loss = (target - model_pred) ** 2
loss = loss.mean() loss = loss.mean()
# Compute prior loss # Compute prior loss
prior_loss = (noise_prior - noise_pred_prior) ** 2 prior_loss = (target_prior - model_pred_prior) ** 2
prior_loss = prior_loss.mean() prior_loss = prior_loss.mean()
# Add the prior loss to the instance loss. # Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss loss = loss + args.prior_loss_weight * prior_loss
else: else:
loss = (noise - noise_pred) ** 2 loss = (target - model_pred) ** 2
loss = loss.mean() loss = loss.mean()
return loss return loss
......
...@@ -459,9 +459,19 @@ def main(): ...@@ -459,9 +459,19 @@ def main():
)[0] )[0]
# Predict the noise residual and compute loss # Predict the noise residual and compute loss
unet_outputs = unet.apply({"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True) model_pred = unet.apply(
noise_pred = unet_outputs.sample {"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True
loss = (noise - noise_pred) ** 2 ).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = (target - model_pred) ** 2
loss = loss.mean() loss = loss.mean()
return loss return loss
......
...@@ -536,11 +536,20 @@ def main(): ...@@ -536,11 +536,20 @@ def main():
encoder_hidden_states = state.apply_fn( encoder_hidden_states = state.apply_fn(
batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True batch["input_ids"], params=params, dropout_rng=dropout_rng, train=True
)[0] )[0]
unet_outputs = unet.apply( # Predict the noise residual and compute loss
model_pred = unet.apply(
{"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False {"params": unet_params}, noisy_latents, timesteps, encoder_hidden_states, train=False
) ).sample
noise_pred = unet_outputs.sample
loss = (noise - noise_pred) ** 2 # Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(noise_scheduler_state, latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = (target - model_pred) ** 2
loss = loss.mean() loss = loss.mean()
return loss return loss
......
...@@ -29,6 +29,7 @@ from .scheduling_utils_flax import ( ...@@ -29,6 +29,7 @@ from .scheduling_utils_flax import (
FlaxSchedulerMixin, FlaxSchedulerMixin,
FlaxSchedulerOutput, FlaxSchedulerOutput,
add_noise_common, add_noise_common,
get_velocity_common,
) )
...@@ -301,5 +302,14 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -301,5 +302,14 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray: ) -> jnp.ndarray:
return add_noise_common(state.common, original_samples, noise, timesteps) return add_noise_common(state.common, original_samples, noise, timesteps)
def get_velocity(
self,
state: DDIMSchedulerState,
sample: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
return get_velocity_common(state.common, sample, noise, timesteps)
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
...@@ -29,6 +29,7 @@ from .scheduling_utils_flax import ( ...@@ -29,6 +29,7 @@ from .scheduling_utils_flax import (
FlaxSchedulerMixin, FlaxSchedulerMixin,
FlaxSchedulerOutput, FlaxSchedulerOutput,
add_noise_common, add_noise_common,
get_velocity_common,
) )
...@@ -293,5 +294,14 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -293,5 +294,14 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
) -> jnp.ndarray: ) -> jnp.ndarray:
return add_noise_common(state.common, original_samples, noise, timesteps) return add_noise_common(state.common, original_samples, noise, timesteps)
def get_velocity(
self,
state: DDPMSchedulerState,
sample: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
return get_velocity_common(state.common, sample, noise, timesteps)
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
...@@ -242,7 +242,7 @@ class CommonSchedulerState: ...@@ -242,7 +242,7 @@ class CommonSchedulerState:
) )
def add_noise_common( def get_sqrt_alpha_prod(
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
): ):
alphas_cumprod = state.alphas_cumprod alphas_cumprod = state.alphas_cumprod
...@@ -255,5 +255,18 @@ def add_noise_common( ...@@ -255,5 +255,18 @@ def add_noise_common(
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
return sqrt_alpha_prod, sqrt_one_minus_alpha_prod
def add_noise_common(
state: CommonSchedulerState, original_samples: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray
):
sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, original_samples, noise, timesteps)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples return noisy_samples
def get_velocity_common(state: CommonSchedulerState, sample: jnp.ndarray, noise: jnp.ndarray, timesteps: jnp.ndarray):
sqrt_alpha_prod, sqrt_one_minus_alpha_prod = get_sqrt_alpha_prod(state, sample, noise, timesteps)
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment